mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-06-07 14:53:53 +02:00
Add files via upload
This commit is contained in:
@@ -57,6 +57,7 @@ if "HF_HOME" not in os.environ:
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from obliteratus import device as dev
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
|
||||
# ── ZeroGPU support ─────────────────────────────────────────────────
|
||||
@@ -83,6 +84,20 @@ except (ImportError, AttributeError):
|
||||
return decorator
|
||||
spaces = _FakeSpaces() # type: ignore[assignment]
|
||||
|
||||
def _is_quota_error(exc: BaseException) -> bool:
|
||||
"""Return True if *exc* is a ZeroGPU quota or session error.
|
||||
|
||||
Matches quota-exceeded errors ("exceeded your GPU quota") and expired
|
||||
proxy tokens ("Expired ZeroGPU proxy token") — both mean the GPU is
|
||||
unavailable and the user should retry later.
|
||||
"""
|
||||
msg = str(exc).lower()
|
||||
if "exceeded" in msg and "gpu quota" in msg:
|
||||
return True
|
||||
if "expired" in msg and "zerogpu" in msg:
|
||||
return True
|
||||
return False
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global state
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -266,6 +281,7 @@ def _build_model_choices() -> dict[str, str]:
|
||||
MODELS = _build_model_choices()
|
||||
|
||||
METHODS = {
|
||||
"adaptive (telemetry-recommended)": "adaptive",
|
||||
"advanced (recommended)": "advanced",
|
||||
"basic (fast, single direction)": "basic",
|
||||
"aggressive (maximum removal)": "aggressive",
|
||||
@@ -277,6 +293,12 @@ METHODS = {
|
||||
"nuclear (maximum force combo)": "nuclear",
|
||||
}
|
||||
|
||||
# ── Community Hub push ────────────────────────────────────────────────
|
||||
# Shared org + token so users can auto-push without their own HF_TOKEN.
|
||||
# Set OBLITERATUS_HUB_TOKEN as a Space secret with write access to the org.
|
||||
_HUB_COMMUNITY_ORG = os.environ.get("OBLITERATUS_HUB_ORG", "OBLITERATUS-community")
|
||||
_HUB_COMMUNITY_TOKEN = os.environ.get("OBLITERATUS_HUB_TOKEN")
|
||||
|
||||
# Import preset configs for Advanced Settings defaults
|
||||
from obliteratus.abliterate import METHODS as _PRESET_CONFIGS # noqa: E402
|
||||
from obliteratus.prompts import ( # noqa: E402
|
||||
@@ -382,16 +404,223 @@ def _validate_hub_repo(hub_repo: str) -> str:
|
||||
"Invalid repo format — use `username/model-name` "
|
||||
"(letters, numbers, hyphens, dots only)"
|
||||
)
|
||||
if not os.environ.get("HF_TOKEN"):
|
||||
if not os.environ.get("HF_TOKEN") and not _HUB_COMMUNITY_TOKEN:
|
||||
warnings.append(
|
||||
"HF_TOKEN not set — push to Hub will fail. "
|
||||
"Set it via: `export HF_TOKEN=hf_...`"
|
||||
"No Hub token available — push will fail. "
|
||||
"Set HF_TOKEN or OBLITERATUS_HUB_TOKEN."
|
||||
)
|
||||
if warnings:
|
||||
return "**Warning:** " + " | ".join(warnings)
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Push to Hub — dedicated tab backend
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _generate_model_card(meta: dict) -> str:
|
||||
"""Generate a HuggingFace model card README for a session model."""
|
||||
model_id = meta.get("model_id", "unknown")
|
||||
method = meta.get("method", "unknown")
|
||||
source = meta.get("source", "obliterate")
|
||||
short_model = model_id.split("/")[-1] if "/" in model_id else model_id
|
||||
|
||||
metrics_table = ""
|
||||
tourney_metrics = meta.get("tourney_metrics")
|
||||
if tourney_metrics:
|
||||
rows = "\n".join(
|
||||
f"| {k.replace('_', ' ').title()} | {v:.4f} |"
|
||||
for k, v in tourney_metrics.items() if isinstance(v, (int, float))
|
||||
)
|
||||
metrics_table = f"\n## Metrics\n\n| Metric | Value |\n|--------|-------|\n{rows}\n"
|
||||
|
||||
return f"""---
|
||||
language: en
|
||||
tags:
|
||||
- obliteratus
|
||||
- abliteration
|
||||
- uncensored
|
||||
- {source}
|
||||
base_model: {model_id}
|
||||
---
|
||||
|
||||
# {short_model}-OBLITERATED
|
||||
|
||||
This model was abliterated using the **`{method}`** method via
|
||||
[OBLITERATUS](https://github.com/elder-plinius/OBLITERATUS).
|
||||
|
||||
| Detail | Value |
|
||||
|--------|-------|
|
||||
| Base model | `{model_id}` |
|
||||
| Method | `{method}` |
|
||||
| Source | {source} |
|
||||
{metrics_table}
|
||||
## How to Use
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("{short_model}-OBLITERATED")
|
||||
tokenizer = AutoTokenizer.from_pretrained("{short_model}-OBLITERATED")
|
||||
|
||||
prompt = "Hello, how are you?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
outputs = model.generate(**inputs, max_new_tokens=256)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## About OBLITERATUS
|
||||
|
||||
OBLITERATUS is an open-source tool for removing refusal behavior from language
|
||||
models via activation engineering (abliteration). Learn more at
|
||||
[github.com/elder-plinius/OBLITERATUS](https://github.com/elder-plinius/OBLITERATUS).
|
||||
"""
|
||||
|
||||
|
||||
def _get_hub_session_info(label: str) -> str:
|
||||
"""Return a markdown summary of the selected session model."""
|
||||
if not label or label.startswith("("):
|
||||
return ""
|
||||
meta = _session_models.get(label)
|
||||
if not meta:
|
||||
return "*Session model not found — try refreshing the list.*"
|
||||
lines = [
|
||||
f"**Model:** `{meta.get('model_id', 'unknown')}`",
|
||||
f"**Method:** `{meta.get('method', 'unknown')}`",
|
||||
f"**Source:** {meta.get('source', 'unknown')}",
|
||||
f"**Path:** `{meta.get('output_dir', 'N/A')}`",
|
||||
]
|
||||
score = meta.get("tourney_score")
|
||||
if score is not None:
|
||||
lines.append(f"**Tourney score:** {score:.4f}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _auto_hub_repo_id(label: str) -> str:
|
||||
"""Generate an auto-filled Hub repo ID for the selected session model."""
|
||||
meta = _session_models.get(label)
|
||||
if not meta:
|
||||
return ""
|
||||
model_id = meta.get("model_id", "")
|
||||
import re
|
||||
short = model_id.split("/")[-1] if "/" in model_id else model_id
|
||||
short = re.sub(r"[^a-zA-Z0-9\-.]", "-", short)
|
||||
return f"{_HUB_COMMUNITY_ORG}/{short}-OBLITERATED"
|
||||
|
||||
|
||||
def push_session_to_hub(
|
||||
session_label: str,
|
||||
hub_repo_id: str,
|
||||
hub_token_input: str,
|
||||
refine_enabled: bool,
|
||||
refine_regularization: float,
|
||||
refine_passes: int,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
"""Push a session model to HuggingFace Hub, with optional refinement."""
|
||||
import os
|
||||
import re
|
||||
|
||||
if not session_label or session_label.startswith("("):
|
||||
yield "**Error:** Select a session model first.", ""
|
||||
return
|
||||
|
||||
meta = _session_models.get(session_label)
|
||||
if not meta:
|
||||
yield "**Error:** Session model not found. Try refreshing the list.", ""
|
||||
return
|
||||
|
||||
output_dir = meta.get("output_dir", "")
|
||||
if not output_dir or not Path(output_dir).exists():
|
||||
yield f"**Error:** Model directory not found: `{output_dir}`", ""
|
||||
return
|
||||
|
||||
# Resolve repo ID
|
||||
repo_id = hub_repo_id.strip() if hub_repo_id else ""
|
||||
if not repo_id:
|
||||
repo_id = _auto_hub_repo_id(session_label)
|
||||
if not repo_id:
|
||||
yield "**Error:** Could not determine Hub repo ID.", ""
|
||||
return
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+/[a-zA-Z0-9_.-]+$', repo_id):
|
||||
yield "**Error:** Invalid repo format. Use `username/model-name`.", ""
|
||||
return
|
||||
|
||||
# Resolve token
|
||||
token = hub_token_input.strip() if hub_token_input else None
|
||||
if not token:
|
||||
token = os.environ.get("HF_TOKEN") or _HUB_COMMUNITY_TOKEN
|
||||
if not token:
|
||||
yield (
|
||||
"**Error:** No Hub token available. Enter a token above, "
|
||||
"or set `HF_TOKEN` / `OBLITERATUS_HUB_TOKEN` as an environment variable.",
|
||||
"",
|
||||
)
|
||||
return
|
||||
|
||||
# Optional refinement pass
|
||||
if refine_enabled and refine_passes > 0:
|
||||
progress(0.1, desc="Refining model...")
|
||||
yield "Applying refinement passes...", ""
|
||||
try:
|
||||
from obliteratus.abliterate import AbliterationPipeline
|
||||
from obliteratus.prompts import load_dataset_source
|
||||
|
||||
dataset_key = meta.get("dataset_key", "builtin")
|
||||
if dataset_key == "custom":
|
||||
dataset_key = "builtin"
|
||||
harmful, harmless = load_dataset_source(dataset_key)
|
||||
n = min(33, len(harmful), len(harmless))
|
||||
|
||||
pipeline = AbliterationPipeline(
|
||||
model_name=output_dir, # load from saved checkpoint
|
||||
output_dir=output_dir,
|
||||
device="auto",
|
||||
dtype="float16",
|
||||
method=meta.get("method", "advanced"),
|
||||
regularization=refine_regularization,
|
||||
refinement_passes=refine_passes,
|
||||
harmful_prompts=harmful[:n],
|
||||
harmless_prompts=harmless[:n],
|
||||
)
|
||||
pipeline.run()
|
||||
except Exception as e:
|
||||
yield f"**Refinement failed:** {e}", ""
|
||||
return
|
||||
|
||||
# Generate model card
|
||||
progress(0.5, desc="Generating model card...")
|
||||
yield f"Generating model card and uploading to `{repo_id}`...", ""
|
||||
card_content = _generate_model_card(meta)
|
||||
card_path = Path(output_dir) / "README.md"
|
||||
card_path.write_text(card_content)
|
||||
|
||||
# Upload to Hub
|
||||
progress(0.6, desc="Uploading to Hub...")
|
||||
try:
|
||||
from huggingface_hub import HfApi
|
||||
api = HfApi(token=token)
|
||||
api.create_repo(repo_id, exist_ok=True)
|
||||
|
||||
method = meta.get("method", "unknown")
|
||||
model_id = meta.get("model_id", "unknown")
|
||||
api.upload_folder(
|
||||
folder_path=output_dir,
|
||||
repo_id=repo_id,
|
||||
commit_message=f"OBLITERATUS: {method} on {model_id}",
|
||||
)
|
||||
except Exception as e:
|
||||
yield f"**Upload failed:** {e}", ""
|
||||
return
|
||||
|
||||
progress(1.0, desc="Done!")
|
||||
hub_url = f"https://huggingface.co/{repo_id}"
|
||||
yield (
|
||||
f"**Pushed successfully to [{repo_id}]({hub_url})**",
|
||||
f"[Open on HuggingFace Hub]({hub_url})",
|
||||
)
|
||||
|
||||
|
||||
PROMPT_VOLUMES = {
|
||||
"33 (fast)": 33,
|
||||
"66 (better signal)": 66,
|
||||
@@ -440,25 +669,11 @@ def _should_quantize(model_id: str, is_preset: bool = False) -> str | None:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _clear_gpu():
|
||||
"""Free GPU memory. Resilient to CUDA errors (e.g. after illegal memory access)."""
|
||||
"""Free GPU/accelerator memory. Resilient to device errors."""
|
||||
with _lock:
|
||||
_state["model"] = None
|
||||
_state["tokenizer"] = None
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
# CUDA context may be poisoned after an illegal-address error;
|
||||
# attempt a device reset so subsequent loads can succeed.
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
except Exception:
|
||||
pass
|
||||
dev.free_gpu_memory()
|
||||
|
||||
|
||||
def _install_steering_hooks(model, steering_meta: dict) -> int:
|
||||
@@ -582,16 +797,16 @@ def _cleanup_disk():
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_vram_html() -> str:
|
||||
"""Return an HTML snippet showing GPU VRAM usage as a styled bar."""
|
||||
if not torch.cuda.is_available():
|
||||
"""Return an HTML snippet showing GPU/accelerator memory usage as a styled bar."""
|
||||
if not dev.is_gpu_available():
|
||||
return (
|
||||
'<div style="text-align:center;color:#4a5568;font-size:0.72rem;'
|
||||
'letter-spacing:1px;margin-top:6px;">CPU ONLY — NO GPU DETECTED</div>'
|
||||
)
|
||||
try:
|
||||
used = torch.cuda.memory_allocated() / 1024**3
|
||||
reserved = torch.cuda.memory_reserved() / 1024**3
|
||||
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
mem = dev.get_memory_info()
|
||||
used = mem.used_gb
|
||||
total = mem.total_gb
|
||||
pct = (used / total * 100) if total > 0 else 0
|
||||
# Color shifts from green → yellow → red
|
||||
if pct < 50:
|
||||
@@ -600,12 +815,17 @@ def _get_vram_html() -> str:
|
||||
bar_color = "#ffcc00"
|
||||
else:
|
||||
bar_color = "#ff003c"
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
device_name = mem.device_name
|
||||
reserved_html = (
|
||||
f'<span style="color:#4a5568;">reserved: {mem.reserved_gb:.1f} GB</span>'
|
||||
if mem.reserved_gb > 0
|
||||
else f'<span style="color:#4a5568;">unified memory</span>'
|
||||
)
|
||||
return (
|
||||
f'<div style="margin:6px auto 0;max-width:480px;">'
|
||||
f'<div style="display:flex;justify-content:space-between;font-size:0.68rem;'
|
||||
f'color:#4a5568;letter-spacing:1px;margin-bottom:2px;">'
|
||||
f'<span>GPU: {device_name}</span>'
|
||||
f'<span>{device_name}</span>'
|
||||
f'<span>{used:.1f} / {total:.1f} GB ({pct:.0f}%)</span></div>'
|
||||
f'<div style="background:#0a0a0f;border:1px solid #1a1f2e;border-radius:3px;'
|
||||
f'height:10px;overflow:hidden;">'
|
||||
@@ -613,11 +833,11 @@ def _get_vram_html() -> str:
|
||||
f'box-shadow:0 0 6px {bar_color};transition:width 0.5s ease;"></div></div>'
|
||||
f'<div style="display:flex;justify-content:space-between;font-size:0.6rem;'
|
||||
f'color:#333;margin-top:1px;">'
|
||||
f'<span style="color:#4a5568;">reserved: {reserved:.1f} GB</span></div>'
|
||||
f'{reserved_html}</div>'
|
||||
f'</div>'
|
||||
)
|
||||
except Exception:
|
||||
return '<div style="text-align:center;color:#4a5568;font-size:0.72rem;">VRAM: unavailable</div>'
|
||||
return '<div style="text-align:center;color:#4a5568;font-size:0.72rem;">Memory: unavailable</div>'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1060,8 +1280,7 @@ def benchmark(
|
||||
pass
|
||||
pipeline_ref[0] = None
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
dev.empty_cache()
|
||||
|
||||
yield (
|
||||
f"**{method_key} complete** ({mi + 1}/{len(methods_to_test)}) \u2014 {_bench_elapsed()}",
|
||||
@@ -1411,8 +1630,7 @@ def benchmark_multi_model(
|
||||
pass
|
||||
pipeline_ref[0] = None
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
dev.empty_cache()
|
||||
|
||||
yield (
|
||||
f"**{model_id} complete** ({mi + 1}/{len(model_choices)}) \u2014 {_mm_elapsed()}",
|
||||
@@ -1510,7 +1728,7 @@ def _format_multi_model_results(results: list[dict], context: dict | None = None
|
||||
|
||||
|
||||
@spaces.GPU(duration=300)
|
||||
def obliterate(model_choice: str, method_choice: str, hub_repo: str,
|
||||
def obliterate(model_choice: str, method_choice: str,
|
||||
prompt_volume_choice: str, dataset_source_choice: str,
|
||||
custom_harmful: str, custom_harmless: str,
|
||||
# Advanced params (sliders)
|
||||
@@ -1543,9 +1761,38 @@ def obliterate(model_choice: str, method_choice: str, hub_repo: str,
|
||||
model_id = MODELS.get(model_choice, model_choice)
|
||||
is_preset = model_choice in MODELS
|
||||
method = METHODS.get(method_choice, "advanced")
|
||||
push_to_hub = hub_repo.strip() if hub_repo and hub_repo.strip() else None
|
||||
prompt_volume = PROMPT_VOLUMES.get(prompt_volume_choice, 33)
|
||||
|
||||
# Resolve "adaptive" → telemetry-recommended method for this model
|
||||
_adaptive_info = ""
|
||||
if method == "adaptive":
|
||||
try:
|
||||
from obliteratus.architecture_profiles import detect_architecture, enhance_profile_with_telemetry
|
||||
from transformers import AutoConfig
|
||||
try:
|
||||
_cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
||||
_nl = getattr(_cfg, "num_hidden_layers", 0)
|
||||
_hs = getattr(_cfg, "hidden_size", 0)
|
||||
except Exception:
|
||||
_cfg, _nl, _hs = None, 0, 0
|
||||
_profile = detect_architecture(model_id, _cfg, _nl, _hs)
|
||||
_profile, _rec = enhance_profile_with_telemetry(_profile)
|
||||
if _rec and _rec.recommended_method and _rec.confidence != "none":
|
||||
method = _rec.recommended_method
|
||||
_adaptive_info = (
|
||||
f"Adaptive: telemetry recommends `{method}` "
|
||||
f"({_rec.confidence} confidence, {_rec.n_records} runs)"
|
||||
)
|
||||
else:
|
||||
method = _profile.recommended_method or "advanced"
|
||||
_adaptive_info = (
|
||||
f"Adaptive: using architecture default `{method}` "
|
||||
f"(no telemetry data yet)"
|
||||
)
|
||||
except Exception:
|
||||
method = "advanced"
|
||||
_adaptive_info = "Adaptive: fallback to `advanced` (could not detect architecture)"
|
||||
|
||||
# Early validation: gated model access
|
||||
from obliteratus.presets import is_gated
|
||||
if is_gated(model_id) and not os.environ.get("HF_TOKEN"):
|
||||
@@ -1561,22 +1808,6 @@ def obliterate(model_choice: str, method_choice: str, hub_repo: str,
|
||||
)
|
||||
return
|
||||
|
||||
# Early validation: Hub repo format + HF_TOKEN
|
||||
if push_to_hub:
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+/[a-zA-Z0-9_.-]+$', push_to_hub):
|
||||
yield (
|
||||
"**Error:** Invalid Hub repo format. Use `username/model-name`.",
|
||||
"", gr.update(), gr.update(), gr.update(), gr.update(),
|
||||
)
|
||||
return
|
||||
if not os.environ.get("HF_TOKEN"):
|
||||
yield (
|
||||
"**Error:** HF_TOKEN not set. Push to Hub requires a write token. "
|
||||
"Set it via `export HF_TOKEN=hf_...` or in your Space secrets.",
|
||||
"", gr.update(), gr.update(), gr.update(), gr.update(),
|
||||
)
|
||||
return
|
||||
|
||||
# Resolve dataset source — custom prompts override the dropdown
|
||||
use_custom = custom_harmful and custom_harmful.strip()
|
||||
dataset_key = get_source_key_from_label(dataset_source_choice) if dataset_source_choice else "builtin"
|
||||
@@ -1650,7 +1881,6 @@ def obliterate(model_choice: str, method_choice: str, hub_repo: str,
|
||||
output_dir=save_dir,
|
||||
device="auto",
|
||||
dtype="float16",
|
||||
push_to_hub=push_to_hub,
|
||||
quantization=quantization,
|
||||
trust_remote_code=is_preset,
|
||||
harmful_prompts=harmful_all[:n],
|
||||
@@ -1668,7 +1898,6 @@ def obliterate(model_choice: str, method_choice: str, hub_repo: str,
|
||||
device="auto",
|
||||
dtype="float16",
|
||||
method=method,
|
||||
push_to_hub=push_to_hub,
|
||||
quantization=quantization,
|
||||
trust_remote_code=is_preset,
|
||||
harmful_prompts=harmful_all[:n],
|
||||
@@ -1716,11 +1945,11 @@ def obliterate(model_choice: str, method_choice: str, hub_repo: str,
|
||||
source_label = source_info.label if source_info else dataset_key
|
||||
log_lines.append(f"Target: {model_id}")
|
||||
log_lines.append(f"Method: {method}")
|
||||
if _adaptive_info:
|
||||
log_lines.append(_adaptive_info)
|
||||
log_lines.append(f"Dataset: {source_label}")
|
||||
vol_label = "all" if prompt_volume == -1 else str(prompt_volume)
|
||||
log_lines.append(f"Prompt volume: {vol_label} pairs")
|
||||
if push_to_hub:
|
||||
log_lines.append(f"Push to Hub: {push_to_hub}")
|
||||
if quantization:
|
||||
log_lines.append(f"Quantization: {quantization} (auto-detected for GPU fit)")
|
||||
log_lines.append("")
|
||||
@@ -2059,11 +2288,11 @@ def chat_respond(message: str, history: list[dict], system_prompt: str,
|
||||
_needs_reload = model is None or tokenizer is None
|
||||
if not _needs_reload:
|
||||
try:
|
||||
dev = next(model.parameters()).device
|
||||
if dev.type == "meta":
|
||||
model_dev = next(model.parameters()).device
|
||||
if model_dev.type == "meta":
|
||||
_needs_reload = True
|
||||
elif torch.cuda.is_available() and dev.type != "cuda":
|
||||
model.to("cuda")
|
||||
elif dev.is_gpu_available() and model_dev.type not in ("cuda", "mps"):
|
||||
model.to(dev.get_device())
|
||||
except Exception:
|
||||
_needs_reload = True
|
||||
|
||||
@@ -2493,11 +2722,11 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
|
||||
_needs_reload = abliterated_model is None or tokenizer is None
|
||||
if not _needs_reload:
|
||||
try:
|
||||
dev = next(abliterated_model.parameters()).device
|
||||
if dev.type == "meta":
|
||||
model_dev = next(abliterated_model.parameters()).device
|
||||
if model_dev.type == "meta":
|
||||
_needs_reload = True
|
||||
elif torch.cuda.is_available() and dev.type != "cuda":
|
||||
abliterated_model.to("cuda")
|
||||
elif dev.is_gpu_available() and model_dev.type not in ("cuda", "mps"):
|
||||
abliterated_model.to(dev.get_device())
|
||||
except Exception:
|
||||
_needs_reload = True
|
||||
|
||||
@@ -2630,8 +2859,7 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
|
||||
abl_device = next(abliterated_model.parameters()).device
|
||||
abliterated_model.to("cpu")
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
dev.empty_cache()
|
||||
|
||||
model_id = MODELS.get(model_name, model_name)
|
||||
# Only trust remote code for known preset models, not arbitrary user-supplied IDs
|
||||
@@ -2683,8 +2911,7 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
|
||||
# Free the original model
|
||||
del original_model
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
dev.empty_cache()
|
||||
|
||||
except Exception as e:
|
||||
original_response = f"*Could not load original model for comparison: {e}*"
|
||||
@@ -2693,7 +2920,7 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
|
||||
# Use torch.device("cuda") rather than the captured abl_device, since
|
||||
# on ZeroGPU the original device reference may point to a stale context.
|
||||
try:
|
||||
restore_device = torch.device("cuda") if torch.cuda.is_available() else abl_device
|
||||
restore_device = torch.device(dev.get_device()) if dev.is_gpu_available() else abl_device
|
||||
abliterated_model.to(restore_device)
|
||||
except Exception:
|
||||
pass # If GPU restore fails, model stays on CPU (still usable)
|
||||
@@ -2811,8 +3038,7 @@ def strength_sweep(model_choice: str, method_choice: str,
|
||||
|
||||
# Cleanup between runs
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
dev.empty_cache()
|
||||
|
||||
# Generate dose-response curve
|
||||
gallery = None
|
||||
@@ -2904,6 +3130,233 @@ def _format_sweep_results(results: list[dict]) -> str:
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tournament
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@spaces.GPU(duration=300)
|
||||
def _tourney_gpu_run(fn, *args, **kwargs):
|
||||
"""Execute *fn* inside a ZeroGPU GPU allocation.
|
||||
|
||||
Used by ``run_tourney`` to give each tournament method its own 5-minute
|
||||
GPU allocation instead of sharing a single allocation for the whole
|
||||
tournament. On non-ZeroGPU machines the ``@spaces.GPU`` decorator is a
|
||||
no-op and this simply calls *fn* directly.
|
||||
"""
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
class _TourneyLogger:
|
||||
"""Picklable log collector for tournament progress.
|
||||
|
||||
Gradio's queue system pickles generator frames, so closures like
|
||||
``lambda msg: log_lines.append(msg)`` cause PicklingError. This
|
||||
simple class is picklable and serves the same purpose.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.lines: list[str] = []
|
||||
|
||||
def __call__(self, msg: str):
|
||||
self.lines.append(msg)
|
||||
|
||||
def tail(self, n: int = 100) -> str:
|
||||
"""Return the last *n* log lines joined by newlines. ``n=0`` returns all."""
|
||||
if n <= 0:
|
||||
return "\n".join(self.lines)
|
||||
return "\n".join(self.lines[-n:])
|
||||
|
||||
|
||||
def _tourney_gpu_wrapper(fn, *args, **kwargs):
|
||||
"""Indirection so the @spaces.GPU-wrapped function is resolved at call
|
||||
time rather than captured in the generator frame (which Gradio pickles)."""
|
||||
return _tourney_gpu_run(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def run_tourney(model_choice, dataset, quantization):
|
||||
"""Run an elimination tournament across all abliteration methods.
|
||||
|
||||
Each individual method is run inside its own ``@spaces.GPU`` allocation
|
||||
(up to 5 minutes per method) so the full tournament is not constrained
|
||||
by a single 300 s ZeroGPU limit. Between methods the GPU is released,
|
||||
allowing the generator to yield progress updates to the Gradio UI.
|
||||
"""
|
||||
import traceback
|
||||
|
||||
if not model_choice or not model_choice.strip():
|
||||
yield "**Error:** Select a model first.", "", ""
|
||||
return
|
||||
|
||||
from obliteratus.tourney import (
|
||||
TourneyRunner, render_bracket_html,
|
||||
_load_checkpoint, _checkpoint_matches,
|
||||
)
|
||||
|
||||
# Resolve display label → HuggingFace model ID
|
||||
model_id = model_choice.strip()
|
||||
if model_id in MODELS:
|
||||
model_id = MODELS[model_id]
|
||||
|
||||
quant = quantization if quantization != "none" else None
|
||||
|
||||
logger = _TourneyLogger()
|
||||
|
||||
dataset_key = get_source_key_from_label(dataset) if dataset else "builtin"
|
||||
|
||||
# Check for a resumable checkpoint from a previous quota-interrupted run
|
||||
tourney_dir = Path("/tmp/obliteratus_tourney")
|
||||
checkpoint = _load_checkpoint(tourney_dir)
|
||||
resume = (
|
||||
checkpoint is not None
|
||||
and _checkpoint_matches(checkpoint, model_id, dataset_key, quant)
|
||||
)
|
||||
|
||||
try:
|
||||
runner = TourneyRunner(
|
||||
model_name=model_id,
|
||||
hub_org=None,
|
||||
hub_repo=None,
|
||||
dataset_key=dataset_key,
|
||||
quantization=quant,
|
||||
on_log=logger,
|
||||
resume=resume,
|
||||
)
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
yield (f"**Error creating runner:** {e}", "", tb)
|
||||
return
|
||||
|
||||
n_methods = len(runner.methods)
|
||||
if resume:
|
||||
n_done = len(checkpoint.get("completed_rounds", []))
|
||||
n_partial = len(checkpoint.get("interrupted_round", {}).get("completed_methods", []))
|
||||
yield (
|
||||
f"**Resuming tournament** — {n_done} round(s) + {n_partial} method(s) "
|
||||
f"completed previously. Continuing on `{model_id}`...",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
else:
|
||||
yield (
|
||||
f"**Tournament starting** — {n_methods} methods will compete on `{model_id}`...",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
result = None
|
||||
try:
|
||||
for status_msg, partial_result in runner.run_iter(gpu_wrapper=_tourney_gpu_wrapper):
|
||||
result = partial_result
|
||||
yield (
|
||||
status_msg,
|
||||
"",
|
||||
logger.tail(),
|
||||
)
|
||||
except Exception as e:
|
||||
if _is_quota_error(e):
|
||||
# Known-resumable error — don't dump a scary traceback
|
||||
bracket_md = ""
|
||||
if result and result.rounds:
|
||||
bracket_md = render_bracket_html(result)
|
||||
is_expired = "expired" in str(e).lower()
|
||||
if is_expired:
|
||||
reason = (
|
||||
"**GPU session expired** — the ZeroGPU proxy token "
|
||||
"timed out during the tournament.\n\n"
|
||||
)
|
||||
else:
|
||||
reason = f"**GPU quota exceeded** — {e}\n\n"
|
||||
yield (
|
||||
reason +
|
||||
"Your progress has been **saved automatically**. "
|
||||
"Click **Run Tournament** again and the tournament will "
|
||||
"resume from where it left off.\n\n"
|
||||
"Quota recharges over time (half-life ~2 hours). "
|
||||
"HuggingFace Pro subscribers get 7x more daily quota.\n\n"
|
||||
"**Tip:** use quantization to reduce per-method GPU time.",
|
||||
bracket_md,
|
||||
logger.tail(0),
|
||||
)
|
||||
else:
|
||||
yield (
|
||||
f"**Error:** {type(e).__name__}: {e}",
|
||||
"",
|
||||
logger.tail(0),
|
||||
)
|
||||
return
|
||||
|
||||
if not result:
|
||||
yield ("**Error:** Tournament produced no result.", "", logger.tail(0))
|
||||
return
|
||||
|
||||
winner = result.winner
|
||||
if winner and winner.error:
|
||||
winner = None
|
||||
result.winner = None
|
||||
|
||||
# ── Telemetry: log tournament winner to community leaderboard ──
|
||||
if winner and not winner.error:
|
||||
try:
|
||||
from obliteratus.telemetry import log_benchmark_from_dict
|
||||
log_benchmark_from_dict(
|
||||
model_id=model_id,
|
||||
method=winner.method,
|
||||
entry={
|
||||
"perplexity": winner.metrics.get("perplexity"),
|
||||
"coherence": winner.metrics.get("coherence"),
|
||||
"refusal_rate": winner.metrics.get("refusal_rate"),
|
||||
"kl_divergence": winner.metrics.get("kl_divergence"),
|
||||
"time_s": winner.time_s,
|
||||
"error": None,
|
||||
},
|
||||
dataset=dataset_key,
|
||||
quantization=quant,
|
||||
)
|
||||
except Exception:
|
||||
pass # Telemetry is best-effort
|
||||
|
||||
if winner:
|
||||
bracket_md = render_bracket_html(result)
|
||||
# Register winner in session models for Push to Hub tab
|
||||
if winner.output_dir:
|
||||
_ts = datetime.now().strftime("%H:%M")
|
||||
_short = model_id.split("/")[-1] if "/" in model_id else model_id
|
||||
_label = f"tourney winner ({winner.method}) on {_short} ({_ts})"
|
||||
with _lock:
|
||||
_session_models[_label] = {
|
||||
"model_id": model_id,
|
||||
"model_choice": model_choice,
|
||||
"method": winner.method,
|
||||
"dataset_key": dataset_key,
|
||||
"prompt_volume": 0,
|
||||
"output_dir": winner.output_dir,
|
||||
"source": "tourney",
|
||||
"tourney_score": winner.score,
|
||||
"tourney_metrics": winner.metrics,
|
||||
}
|
||||
yield (
|
||||
f"**Champion: `{winner.method}`** "
|
||||
f"(score: {winner.score:.4f})\n"
|
||||
f"Push it to HuggingFace Hub from the **Push to Hub** tab.",
|
||||
bracket_md,
|
||||
logger.tail(0),
|
||||
)
|
||||
else:
|
||||
n_errors = sum(
|
||||
1 for rnd in result.rounds
|
||||
for c in rnd.contenders if c.error
|
||||
)
|
||||
bracket_md = render_bracket_html(result) if result.rounds else ""
|
||||
msg = "**Tournament complete** — no winner determined."
|
||||
if n_errors:
|
||||
msg += f" ({n_errors} method(s) errored — check the log for details.)"
|
||||
yield (
|
||||
msg,
|
||||
bracket_md,
|
||||
logger.tail(0),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export Research Artifacts
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -3464,14 +3917,10 @@ with gr.Blocks(theme=THEME, css=CSS, js=_JS, title="OBLITERATUS", fill_height=Tr
|
||||
lines=5,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
hub_repo = gr.Textbox(
|
||||
label="Push to Hub (optional)",
|
||||
placeholder="your-username/model-name-abliterated",
|
||||
info="HF Hub repo ID — saves locally then uploads. "
|
||||
"Requires HF_TOKEN env var with write access.",
|
||||
)
|
||||
hub_warning_md = gr.Markdown("")
|
||||
gr.Markdown(
|
||||
"*After obliterating, push your model to HuggingFace Hub from the **Push to Hub** tab.*",
|
||||
elem_classes=["hub-hint"],
|
||||
)
|
||||
|
||||
# ── Advanced Settings (auto-populated from method preset) ────
|
||||
_defaults = _get_preset_defaults("advanced (recommended)")
|
||||
@@ -4099,7 +4548,59 @@ tradeoff point where refusal is minimized with minimal capability damage.
|
||||
gr.State()], # 5th output is unused File placeholder
|
||||
)
|
||||
|
||||
# ── Tab 6: Export ─────────────────────────────────────────────────
|
||||
# ── Tab 6: Tourney ────────────────────────────────────────────────
|
||||
with gr.Tab("Tourney", id="tourney"):
|
||||
gr.Markdown("""### March Madness Tournament
|
||||
Pit **all abliteration methods** against each other in elimination rounds.
|
||||
The winner is saved locally — push it to HuggingFace Hub from the **Push to Hub** tab.
|
||||
|
||||
**Round 1 — Qualifiers:** All methods, reduced prompts. Bottom half eliminated.
|
||||
**Round 2 — Semifinals:** Survivors, full prompts. Bottom half eliminated.
|
||||
**Round 3 — Finals:** Top contenders, maximum prompts. Champion crowned.
|
||||
""")
|
||||
tourney_model_dd = gr.Dropdown(
|
||||
choices=list(MODELS.keys()),
|
||||
value="Alibaba (Qwen) / Qwen3-4B",
|
||||
label="Target Model",
|
||||
info="Select a model to tournament-abliterate",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
with gr.Row():
|
||||
tourney_dataset_dd = gr.Dropdown(
|
||||
choices=get_source_choices(),
|
||||
value=get_source_choices()[0],
|
||||
label="Dataset Source",
|
||||
)
|
||||
tourney_quant_dd = gr.Dropdown(
|
||||
choices=["none", "4bit", "8bit"],
|
||||
value="none",
|
||||
label="Quantization",
|
||||
)
|
||||
|
||||
tourney_btn = gr.Button(
|
||||
"Start Tournament",
|
||||
variant="primary",
|
||||
size="lg",
|
||||
)
|
||||
tourney_status = gr.Markdown("")
|
||||
tourney_bracket = gr.HTML("")
|
||||
tourney_log = gr.Textbox(
|
||||
label="Tournament Log",
|
||||
lines=20,
|
||||
max_lines=40,
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
tourney_btn.click(
|
||||
fn=run_tourney,
|
||||
inputs=[tourney_model_dd,
|
||||
tourney_dataset_dd, tourney_quant_dd],
|
||||
outputs=[tourney_status, tourney_bracket, tourney_log],
|
||||
)
|
||||
|
||||
# ── Tab 7: Export ─────────────────────────────────────────────────
|
||||
with gr.Tab("Export", id="export"):
|
||||
gr.Markdown("""### Export Research Artifacts
|
||||
Download all intermediate data from your last obliteration run as a ZIP archive.
|
||||
@@ -4120,7 +4621,94 @@ Download all intermediate data from your last obliteration run as a ZIP archive.
|
||||
outputs=[export_file, export_status],
|
||||
)
|
||||
|
||||
# ── Tab 7: Leaderboard ────────────────────────────────────────────
|
||||
# ── Tab: Push to Hub ──────────────────────────────────────────────
|
||||
with gr.Tab("Push to Hub", id="push_hub"):
|
||||
gr.Markdown("""### Push to HuggingFace Hub
|
||||
Select any session model from your Obliterate, Benchmark, or Tourney runs,
|
||||
optionally apply a quick refinement pass, then push to HuggingFace Hub
|
||||
with the **-OBLITERATED** tag.
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
push_session_dd = gr.Dropdown(
|
||||
choices=_get_session_model_choices(),
|
||||
label="Session Model",
|
||||
info="Pick a model from any tab's output",
|
||||
)
|
||||
push_refresh_btn = gr.Button("Refresh List", variant="secondary", size="sm")
|
||||
push_model_info = gr.Markdown("")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
push_repo_id = gr.Textbox(
|
||||
label="Hub Repo ID",
|
||||
placeholder="auto-filled, or type your own",
|
||||
info="e.g. my-org/my-model-OBLITERATED",
|
||||
)
|
||||
push_token = gr.Textbox(
|
||||
label="HF Token (optional)",
|
||||
placeholder="hf_...",
|
||||
type="password",
|
||||
info="Leave blank to use HF_TOKEN env var or community token",
|
||||
)
|
||||
push_repo_warning = gr.Markdown("")
|
||||
|
||||
with gr.Accordion("Quick Refiner (optional)", open=False):
|
||||
gr.Markdown(
|
||||
"*Optionally apply extra refinement passes to your model before pushing. "
|
||||
"This re-runs the abliteration pipeline with adjusted regularization.*"
|
||||
)
|
||||
with gr.Row():
|
||||
push_refine_reg = gr.Slider(
|
||||
0.0, 1.0, value=0.1, step=0.05,
|
||||
label="Regularization",
|
||||
info="Weight preservation (0 = full removal, 1 = no change)",
|
||||
)
|
||||
push_refine_passes = gr.Slider(
|
||||
0, 3, value=0, step=1,
|
||||
label="Extra Refinement Passes",
|
||||
info="0 = skip refinement, 1-3 = apply additional passes",
|
||||
)
|
||||
push_refine_enabled = gr.Checkbox(
|
||||
label="Apply refinement before pushing",
|
||||
value=False,
|
||||
)
|
||||
|
||||
push_btn = gr.Button(
|
||||
"Push to Hub",
|
||||
variant="primary",
|
||||
size="lg",
|
||||
)
|
||||
push_status = gr.Markdown("")
|
||||
push_link = gr.Markdown("")
|
||||
|
||||
# -- Event wiring (inline since components are scoped to this tab) --
|
||||
|
||||
push_refresh_btn.click(
|
||||
fn=lambda: gr.update(choices=_get_session_model_choices()),
|
||||
outputs=[push_session_dd],
|
||||
)
|
||||
|
||||
push_session_dd.change(
|
||||
fn=lambda label: (_get_hub_session_info(label), _auto_hub_repo_id(label)),
|
||||
inputs=[push_session_dd],
|
||||
outputs=[push_model_info, push_repo_id],
|
||||
)
|
||||
|
||||
push_repo_id.change(
|
||||
fn=_validate_hub_repo,
|
||||
inputs=[push_repo_id],
|
||||
outputs=[push_repo_warning],
|
||||
)
|
||||
|
||||
push_btn.click(
|
||||
fn=push_session_to_hub,
|
||||
inputs=[push_session_dd, push_repo_id, push_token,
|
||||
push_refine_enabled, push_refine_reg, push_refine_passes],
|
||||
outputs=[push_status, push_link],
|
||||
)
|
||||
|
||||
# ── Tab: Leaderboard ────────────────────────────────────────────
|
||||
with gr.Tab("Leaderboard", id="leaderboard"):
|
||||
gr.Markdown("""### Community Leaderboard
|
||||
All benchmark results from **every OBLITERATUS Space** (including duplicated copies) are
|
||||
@@ -4346,12 +4934,6 @@ Built on the shoulders of:
|
||||
outputs=[prompt_vol_dd, dataset_info_md],
|
||||
)
|
||||
|
||||
# Wire hub repo → live validation
|
||||
hub_repo.change(
|
||||
fn=_validate_hub_repo,
|
||||
inputs=[hub_repo],
|
||||
outputs=[hub_warning_md],
|
||||
)
|
||||
|
||||
# Wire benchmark → Chat/A/B cross-tab dropdown updates
|
||||
bench_btn.click(
|
||||
@@ -4400,7 +4982,7 @@ Built on the shoulders of:
|
||||
# may not fire after generator teardown.
|
||||
obliterate_btn.click(
|
||||
fn=obliterate,
|
||||
inputs=[model_dd, method_dd, hub_repo, prompt_vol_dd, dataset_dd,
|
||||
inputs=[model_dd, method_dd, prompt_vol_dd, dataset_dd,
|
||||
custom_harmful_tb, custom_harmless_tb] + _adv_controls,
|
||||
outputs=[status_md, log_box, chat_status, session_model_dd, metrics_md, ab_session_model_dd],
|
||||
).then(
|
||||
|
||||
@@ -0,0 +1,710 @@
|
||||
"""Telemetry-driven adaptive defaults for OBLITERATUS.
|
||||
|
||||
Fetches community telemetry from the HuggingFace Hub dataset and analyzes
|
||||
historical runs to recommend the best abliteration method and hyperparameters
|
||||
for a given model architecture.
|
||||
|
||||
Architecture bucketing:
|
||||
Records are grouped by (arch_class, reasoning_class, param_bucket) where
|
||||
param_bucket is a coarse size tier (tiny/small/medium/large/frontier).
|
||||
Within each bucket, methods are ranked by composite score and the
|
||||
best-performing hyperparameter ranges are extracted.
|
||||
|
||||
The ``get_adaptive_recommendation()`` function returns an
|
||||
``AdaptiveRecommendation`` that the pipeline/UI can apply on top of
|
||||
(or instead of) the static research-grounded defaults in
|
||||
``architecture_profiles.py``.
|
||||
|
||||
Data flow:
|
||||
HF Hub (OBLITERATUS-TELEMETRY) ──► fetch_hub_records()
|
||||
│ │
|
||||
▼ ▼
|
||||
Local JSONL cache ──────────► build_knowledge_base()
|
||||
│
|
||||
▼
|
||||
get_adaptive_recommendation()
|
||||
│
|
||||
▼
|
||||
AdaptiveRecommendation
|
||||
(best method, overrides, confidence)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Cache config ──────────────────────────────────────────────────────────
|
||||
|
||||
_CACHE_TTL_S = 600 # 10 minutes — telemetry doesn't change that fast
|
||||
_cache: dict[str, Any] = {}
|
||||
_cache_ts: float = 0.0
|
||||
|
||||
# Minimum records per bucket to trust the recommendation
|
||||
_MIN_RECORDS_FOR_CONFIDENCE = 5
|
||||
_HIGH_CONFIDENCE_RECORDS = 20
|
||||
|
||||
|
||||
# ── Size bucketing ────────────────────────────────────────────────────────
|
||||
|
||||
def _param_bucket(total_params_b: float) -> str:
|
||||
"""Coarse size tier matching presets.py tiers."""
|
||||
if total_params_b <= 0.5:
|
||||
return "tiny"
|
||||
if total_params_b <= 4:
|
||||
return "small"
|
||||
if total_params_b <= 16:
|
||||
return "medium"
|
||||
if total_params_b <= 80:
|
||||
return "large"
|
||||
return "frontier"
|
||||
|
||||
|
||||
def _extract_arch_key(record: dict) -> tuple[str, str, str] | None:
|
||||
"""Extract (arch_class, reasoning_class, param_bucket) from a telemetry record.
|
||||
|
||||
Returns None if the record lacks enough information to classify.
|
||||
"""
|
||||
model = record.get("model", {})
|
||||
if isinstance(model, str):
|
||||
# Schema v1 — just model name, can't reliably bucket
|
||||
return None
|
||||
|
||||
arch_str = model.get("architecture", "")
|
||||
num_layers = model.get("num_layers", 0)
|
||||
hidden_size = model.get("hidden_size", 0)
|
||||
total_params = model.get("total_params", 0)
|
||||
|
||||
# Estimate params in billions
|
||||
if total_params > 0:
|
||||
params_b = total_params / 1e9
|
||||
elif num_layers > 0 and hidden_size > 0:
|
||||
# Rough estimate: 12 * hidden² * num_layers (transformer scaling)
|
||||
params_b = (12 * hidden_size**2 * num_layers) / 1e9
|
||||
else:
|
||||
return None
|
||||
|
||||
# Detect architecture class from the architecture string or model config
|
||||
arch_lower = arch_str.lower()
|
||||
moe_keywords = {"moe", "mixtral", "qwen2_moe", "qwen3_moe", "deepseek_v2",
|
||||
"deepseek_v3", "dbrx", "grok", "jamba", "arctic", "olmoe",
|
||||
"switch", "llama4"}
|
||||
is_moe = any(kw in arch_lower for kw in moe_keywords)
|
||||
|
||||
# Check method_config for per_expert_directions as MoE signal
|
||||
mc = record.get("method_config", {})
|
||||
if mc.get("per_expert_directions"):
|
||||
is_moe = True
|
||||
|
||||
if is_moe:
|
||||
arch_class = "large_moe" if params_b > 100 else "small_moe"
|
||||
else:
|
||||
arch_class = "dense"
|
||||
|
||||
# Detect reasoning from analysis insights or architecture name
|
||||
analysis = record.get("analysis_insights", {})
|
||||
reasoning_class = "standard"
|
||||
reasoning_keywords = {"reason", "think", "cot", "r1", "qwq", "o1", "o3"}
|
||||
if any(kw in arch_lower for kw in reasoning_keywords):
|
||||
reasoning_class = "reasoning"
|
||||
if analysis.get("cot_aware") or mc.get("cot_aware"):
|
||||
reasoning_class = "reasoning"
|
||||
|
||||
return (arch_class, reasoning_class, _param_bucket(params_b))
|
||||
|
||||
|
||||
# ── Composite scoring (same as tourney.py) ────────────────────────────────
|
||||
|
||||
def _composite_score(qm: dict[str, Any]) -> float:
|
||||
"""Score a run on [0, 1]. Higher is better."""
|
||||
rr = qm.get("refusal_rate")
|
||||
co = qm.get("coherence")
|
||||
kl = qm.get("kl_divergence")
|
||||
pp = qm.get("perplexity")
|
||||
|
||||
refusal_score = (1.0 - rr) if rr is not None else 0.0
|
||||
coherence_score = co if co is not None else 0.0
|
||||
kl_score = 1.0 / (1.0 + kl) if kl is not None else 0.5
|
||||
ppl_score = 1.0 / (1.0 + pp / 100.0) if pp is not None else 0.5
|
||||
|
||||
return (
|
||||
refusal_score * 0.4
|
||||
+ coherence_score * 0.3
|
||||
+ kl_score * 0.2
|
||||
+ ppl_score * 0.1
|
||||
)
|
||||
|
||||
|
||||
# ── Data structures ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodStats:
|
||||
"""Aggregated statistics for one method within an architecture bucket."""
|
||||
|
||||
method: str
|
||||
n_runs: int = 0
|
||||
scores: list[float] = field(default_factory=list)
|
||||
refusal_rates: list[float] = field(default_factory=list)
|
||||
coherences: list[float] = field(default_factory=list)
|
||||
kl_divergences: list[float] = field(default_factory=list)
|
||||
perplexities: list[float] = field(default_factory=list)
|
||||
configs: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def mean_score(self) -> float:
|
||||
return statistics.mean(self.scores) if self.scores else 0.0
|
||||
|
||||
@property
|
||||
def best_score(self) -> float:
|
||||
return max(self.scores) if self.scores else 0.0
|
||||
|
||||
@property
|
||||
def median_score(self) -> float:
|
||||
return statistics.median(self.scores) if self.scores else 0.0
|
||||
|
||||
def best_config_ranges(self) -> dict[str, Any]:
|
||||
"""Extract the hyperparameter ranges from top-performing runs.
|
||||
|
||||
Takes the top 25% of runs by composite score and returns the median
|
||||
value for each numeric config key, or the mode for booleans.
|
||||
"""
|
||||
if not self.configs or not self.scores:
|
||||
return {}
|
||||
|
||||
# Pair scores with configs and take top 25%
|
||||
paired = sorted(zip(self.scores, self.configs), key=lambda x: x[0], reverse=True)
|
||||
top_n = max(1, len(paired) // 4)
|
||||
top_configs = [c for _, c in paired[:top_n]]
|
||||
|
||||
ranges: dict[str, Any] = {}
|
||||
all_keys = set()
|
||||
for c in top_configs:
|
||||
all_keys.update(c.keys())
|
||||
|
||||
for key in all_keys:
|
||||
values = [c[key] for c in top_configs if key in c and c[key] is not None]
|
||||
if not values:
|
||||
continue
|
||||
|
||||
if all(isinstance(v, bool) for v in values):
|
||||
# Mode for booleans
|
||||
true_count = sum(1 for v in values if v)
|
||||
ranges[key] = true_count > len(values) / 2
|
||||
elif all(isinstance(v, (int, float)) for v in values):
|
||||
# Median for numerics
|
||||
ranges[key] = statistics.median(values)
|
||||
# Round ints back to ints
|
||||
if all(isinstance(v, int) for v in values):
|
||||
ranges[key] = int(round(ranges[key]))
|
||||
# Skip strings and other types
|
||||
|
||||
return ranges
|
||||
|
||||
|
||||
@dataclass
|
||||
class BucketKnowledge:
|
||||
"""Everything we know about one architecture bucket from telemetry."""
|
||||
|
||||
arch_key: tuple[str, str, str] # (arch_class, reasoning_class, param_bucket)
|
||||
methods: dict[str, MethodStats] = field(default_factory=dict)
|
||||
total_runs: int = 0
|
||||
|
||||
@property
|
||||
def best_method(self) -> str | None:
|
||||
"""Method with highest mean composite score (min 3 runs)."""
|
||||
candidates = [
|
||||
(name, ms) for name, ms in self.methods.items()
|
||||
if ms.n_runs >= 3
|
||||
]
|
||||
if not candidates:
|
||||
# Fall back to any method with runs
|
||||
candidates = [(name, ms) for name, ms in self.methods.items() if ms.n_runs > 0]
|
||||
if not candidates:
|
||||
return None
|
||||
return max(candidates, key=lambda x: x[1].mean_score)[0]
|
||||
|
||||
@property
|
||||
def ranked_methods(self) -> list[tuple[str, MethodStats]]:
|
||||
"""All methods ranked by mean score, descending."""
|
||||
return sorted(
|
||||
self.methods.items(),
|
||||
key=lambda x: x[1].mean_score,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptiveRecommendation:
|
||||
"""A telemetry-driven recommendation for a specific model."""
|
||||
|
||||
# What we recommend
|
||||
recommended_method: str
|
||||
method_overrides: dict[str, Any]
|
||||
|
||||
# How confident we are
|
||||
confidence: str # "high", "medium", "low", "none"
|
||||
n_records: int # total records in bucket
|
||||
n_method_records: int # records for this specific method
|
||||
|
||||
# Context
|
||||
arch_key: tuple[str, str, str]
|
||||
bucket_label: str # human-readable e.g. "Dense Standard Medium"
|
||||
method_ranking: list[tuple[str, float]] # [(method, mean_score), ...]
|
||||
|
||||
# Best metrics seen in this bucket
|
||||
best_refusal_rate: float | None = None
|
||||
best_coherence: float | None = None
|
||||
|
||||
# Explanation
|
||||
reason: str = ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"recommended_method": self.recommended_method,
|
||||
"method_overrides": self.method_overrides,
|
||||
"confidence": self.confidence,
|
||||
"n_records": self.n_records,
|
||||
"n_method_records": self.n_method_records,
|
||||
"arch_key": list(self.arch_key),
|
||||
"bucket_label": self.bucket_label,
|
||||
"method_ranking": self.method_ranking,
|
||||
"best_refusal_rate": self.best_refusal_rate,
|
||||
"best_coherence": self.best_coherence,
|
||||
"reason": self.reason,
|
||||
}
|
||||
|
||||
|
||||
# ── Knowledge base construction ──────────────────────────────────────────
|
||||
|
||||
|
||||
def build_knowledge_base(
|
||||
records: list[dict[str, Any]] | None = None,
|
||||
) -> dict[tuple[str, str, str], BucketKnowledge]:
|
||||
"""Build per-bucket knowledge from telemetry records.
|
||||
|
||||
If *records* is None, fetches from local + Hub automatically.
|
||||
"""
|
||||
if records is None:
|
||||
records = _fetch_all_records()
|
||||
|
||||
buckets: dict[tuple[str, str, str], BucketKnowledge] = {}
|
||||
|
||||
for record in records:
|
||||
# Skip errored runs
|
||||
if record.get("error"):
|
||||
continue
|
||||
|
||||
arch_key = _extract_arch_key(record)
|
||||
if arch_key is None:
|
||||
continue
|
||||
|
||||
method = record.get("method", "")
|
||||
if not method:
|
||||
continue
|
||||
|
||||
qm = record.get("quality_metrics", {})
|
||||
if not qm:
|
||||
continue
|
||||
|
||||
score = _composite_score(qm)
|
||||
|
||||
if arch_key not in buckets:
|
||||
buckets[arch_key] = BucketKnowledge(arch_key=arch_key)
|
||||
|
||||
bucket = buckets[arch_key]
|
||||
bucket.total_runs += 1
|
||||
|
||||
if method not in bucket.methods:
|
||||
bucket.methods[method] = MethodStats(method=method)
|
||||
|
||||
ms = bucket.methods[method]
|
||||
ms.n_runs += 1
|
||||
ms.scores.append(score)
|
||||
|
||||
rr = qm.get("refusal_rate")
|
||||
if rr is not None:
|
||||
ms.refusal_rates.append(rr)
|
||||
co = qm.get("coherence")
|
||||
if co is not None:
|
||||
ms.coherences.append(co)
|
||||
kl = qm.get("kl_divergence")
|
||||
if kl is not None:
|
||||
ms.kl_divergences.append(kl)
|
||||
pp = qm.get("perplexity")
|
||||
if pp is not None:
|
||||
ms.perplexities.append(pp)
|
||||
|
||||
mc = record.get("method_config", {})
|
||||
if mc:
|
||||
ms.configs.append(mc)
|
||||
|
||||
return buckets
|
||||
|
||||
|
||||
def _fetch_all_records() -> list[dict[str, Any]]:
|
||||
"""Fetch telemetry from local file + Hub, with caching."""
|
||||
global _cache, _cache_ts
|
||||
|
||||
now = time.time()
|
||||
if _cache.get("records") is not None and (now - _cache_ts) < _CACHE_TTL_S:
|
||||
return _cache["records"]
|
||||
|
||||
records: list[dict[str, Any]] = []
|
||||
|
||||
# Local records
|
||||
try:
|
||||
from obliteratus.telemetry import read_telemetry
|
||||
records.extend(read_telemetry())
|
||||
except Exception as e:
|
||||
logger.debug("Failed to read local telemetry: %s", e)
|
||||
|
||||
# Hub records
|
||||
try:
|
||||
from obliteratus.telemetry import fetch_hub_records
|
||||
hub = fetch_hub_records()
|
||||
records.extend(hub)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to fetch Hub telemetry: %s", e)
|
||||
|
||||
# Deduplicate by (session_id, timestamp)
|
||||
seen: set[tuple[str, str]] = set()
|
||||
deduped = []
|
||||
for r in records:
|
||||
key = (r.get("session_id", ""), r.get("timestamp", ""))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(r)
|
||||
|
||||
_cache["records"] = deduped
|
||||
_cache_ts = now
|
||||
return deduped
|
||||
|
||||
|
||||
# ── Recommendation engine ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_adaptive_recommendation(
|
||||
arch_class: str,
|
||||
reasoning_class: str,
|
||||
total_params_b: float,
|
||||
model_name: str = "",
|
||||
knowledge: dict[tuple[str, str, str], BucketKnowledge] | None = None,
|
||||
) -> AdaptiveRecommendation:
|
||||
"""Get a telemetry-based recommendation for the given architecture.
|
||||
|
||||
Looks up the closest bucket in the knowledge base and returns the
|
||||
best-performing method + hyperparameter overrides.
|
||||
|
||||
Falls through to broader buckets if the exact match has too few records:
|
||||
1. Exact match: (arch_class, reasoning_class, param_bucket)
|
||||
2. Size-agnostic: (arch_class, reasoning_class, "*")
|
||||
3. Arch-only: (arch_class, "*", "*")
|
||||
|
||||
Args:
|
||||
arch_class: "dense", "small_moe", or "large_moe"
|
||||
reasoning_class: "standard" or "reasoning"
|
||||
total_params_b: Total params in billions
|
||||
model_name: Optional, for model-specific matching
|
||||
knowledge: Pre-built knowledge base (fetches if None)
|
||||
"""
|
||||
if knowledge is None:
|
||||
knowledge = build_knowledge_base()
|
||||
|
||||
param_bucket = _param_bucket(total_params_b)
|
||||
bucket_label = f"{arch_class.replace('_', ' ').title()} {reasoning_class.title()} {param_bucket.title()}"
|
||||
|
||||
# Try exact match first, then broaden
|
||||
candidates = [
|
||||
(arch_class, reasoning_class, param_bucket),
|
||||
]
|
||||
|
||||
# Also check model-specific records (exact model name match)
|
||||
# This is for the future when we have enough data per-model
|
||||
model_short = model_name.split("/")[-1].lower() if model_name else ""
|
||||
|
||||
bucket = None
|
||||
used_key = None
|
||||
for key in candidates:
|
||||
if key in knowledge and knowledge[key].total_runs >= _MIN_RECORDS_FOR_CONFIDENCE:
|
||||
bucket = knowledge[key]
|
||||
used_key = key
|
||||
break
|
||||
|
||||
# Fall back: merge all buckets that share (arch_class, reasoning_class)
|
||||
if bucket is None:
|
||||
merged = BucketKnowledge(arch_key=(arch_class, reasoning_class, "*"))
|
||||
for key, bkt in knowledge.items():
|
||||
if key[0] == arch_class and key[1] == reasoning_class:
|
||||
for method_name, ms in bkt.methods.items():
|
||||
if method_name not in merged.methods:
|
||||
merged.methods[method_name] = MethodStats(method=method_name)
|
||||
target = merged.methods[method_name]
|
||||
target.n_runs += ms.n_runs
|
||||
target.scores.extend(ms.scores)
|
||||
target.refusal_rates.extend(ms.refusal_rates)
|
||||
target.coherences.extend(ms.coherences)
|
||||
target.kl_divergences.extend(ms.kl_divergences)
|
||||
target.perplexities.extend(ms.perplexities)
|
||||
target.configs.extend(ms.configs)
|
||||
merged.total_runs += bkt.total_runs
|
||||
if merged.total_runs >= _MIN_RECORDS_FOR_CONFIDENCE:
|
||||
bucket = merged
|
||||
used_key = merged.arch_key
|
||||
bucket_label = f"{arch_class.replace('_', ' ').title()} {reasoning_class.title()} (all sizes)"
|
||||
|
||||
# Last resort: merge all buckets that share arch_class
|
||||
if bucket is None:
|
||||
merged = BucketKnowledge(arch_key=(arch_class, "*", "*"))
|
||||
for key, bkt in knowledge.items():
|
||||
if key[0] == arch_class:
|
||||
for method_name, ms in bkt.methods.items():
|
||||
if method_name not in merged.methods:
|
||||
merged.methods[method_name] = MethodStats(method=method_name)
|
||||
target = merged.methods[method_name]
|
||||
target.n_runs += ms.n_runs
|
||||
target.scores.extend(ms.scores)
|
||||
target.refusal_rates.extend(ms.refusal_rates)
|
||||
target.coherences.extend(ms.coherences)
|
||||
target.kl_divergences.extend(ms.kl_divergences)
|
||||
target.perplexities.extend(ms.perplexities)
|
||||
target.configs.extend(ms.configs)
|
||||
merged.total_runs += bkt.total_runs
|
||||
if merged.total_runs > 0:
|
||||
bucket = merged
|
||||
used_key = merged.arch_key
|
||||
bucket_label = f"{arch_class.replace('_', ' ').title()} (all)"
|
||||
|
||||
# No data at all
|
||||
if bucket is None or not bucket.methods:
|
||||
return AdaptiveRecommendation(
|
||||
recommended_method="",
|
||||
method_overrides={},
|
||||
confidence="none",
|
||||
n_records=0,
|
||||
n_method_records=0,
|
||||
arch_key=(arch_class, reasoning_class, param_bucket),
|
||||
bucket_label=bucket_label,
|
||||
method_ranking=[],
|
||||
reason="No telemetry data available for this architecture.",
|
||||
)
|
||||
|
||||
# Get best method
|
||||
best_method = bucket.best_method
|
||||
if not best_method:
|
||||
return AdaptiveRecommendation(
|
||||
recommended_method="",
|
||||
method_overrides={},
|
||||
confidence="none",
|
||||
n_records=bucket.total_runs,
|
||||
n_method_records=0,
|
||||
arch_key=used_key or (arch_class, reasoning_class, param_bucket),
|
||||
bucket_label=bucket_label,
|
||||
method_ranking=[],
|
||||
reason="Telemetry records found but no method has enough runs.",
|
||||
)
|
||||
|
||||
ms = bucket.methods[best_method]
|
||||
|
||||
# Extract best hyperparams from top runs
|
||||
overrides = ms.best_config_ranges()
|
||||
|
||||
# Confidence level
|
||||
if ms.n_runs >= _HIGH_CONFIDENCE_RECORDS:
|
||||
confidence = "high"
|
||||
elif ms.n_runs >= _MIN_RECORDS_FOR_CONFIDENCE:
|
||||
confidence = "medium"
|
||||
else:
|
||||
confidence = "low"
|
||||
|
||||
# Method ranking
|
||||
ranking = [
|
||||
(name, stats.mean_score)
|
||||
for name, stats in bucket.ranked_methods
|
||||
]
|
||||
|
||||
# Best metrics seen
|
||||
best_rr = min(ms.refusal_rates) if ms.refusal_rates else None
|
||||
best_co = max(ms.coherences) if ms.coherences else None
|
||||
|
||||
# Build explanation
|
||||
runner_up = ranking[1] if len(ranking) > 1 else None
|
||||
reason_parts = [
|
||||
f"Based on {bucket.total_runs} community runs for {bucket_label}.",
|
||||
f"`{best_method}` achieves a mean composite score of {ms.mean_score:.4f} "
|
||||
f"across {ms.n_runs} runs.",
|
||||
]
|
||||
if runner_up:
|
||||
reason_parts.append(
|
||||
f"Runner-up: `{runner_up[0]}` ({runner_up[1]:.4f})."
|
||||
)
|
||||
if best_rr is not None:
|
||||
reason_parts.append(f"Best refusal rate seen: {best_rr:.1%}.")
|
||||
if overrides:
|
||||
override_strs = [f"{k}={v}" for k, v in sorted(overrides.items())]
|
||||
reason_parts.append(f"Optimal hyperparams from top runs: {', '.join(override_strs[:6])}")
|
||||
|
||||
return AdaptiveRecommendation(
|
||||
recommended_method=best_method,
|
||||
method_overrides=overrides,
|
||||
confidence=confidence,
|
||||
n_records=bucket.total_runs,
|
||||
n_method_records=ms.n_runs,
|
||||
arch_key=used_key or (arch_class, reasoning_class, param_bucket),
|
||||
bucket_label=bucket_label,
|
||||
method_ranking=ranking,
|
||||
best_refusal_rate=best_rr,
|
||||
best_coherence=best_co,
|
||||
reason=" ".join(reason_parts),
|
||||
)
|
||||
|
||||
|
||||
# ── Cross-architecture insights ──────────────────────────────────────────
|
||||
|
||||
|
||||
def get_global_insights(
|
||||
knowledge: dict[tuple[str, str, str], BucketKnowledge] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute cross-architecture insights from all telemetry.
|
||||
|
||||
Returns a summary dict with:
|
||||
- overall_best_methods: top methods across all architectures
|
||||
- architecture_breakdown: per-bucket summaries
|
||||
- total_records: total telemetry records analyzed
|
||||
- hyperparameter_trends: keys that consistently appear in top configs
|
||||
"""
|
||||
if knowledge is None:
|
||||
knowledge = build_knowledge_base()
|
||||
|
||||
total_records = sum(b.total_runs for b in knowledge.values())
|
||||
|
||||
# Global method scores (weighted by bucket size)
|
||||
global_method_scores: dict[str, list[float]] = {}
|
||||
for bucket in knowledge.values():
|
||||
for name, ms in bucket.methods.items():
|
||||
if name not in global_method_scores:
|
||||
global_method_scores[name] = []
|
||||
global_method_scores[name].extend(ms.scores)
|
||||
|
||||
overall_ranking = sorted(
|
||||
[
|
||||
(name, statistics.mean(scores), len(scores))
|
||||
for name, scores in global_method_scores.items()
|
||||
if scores
|
||||
],
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Per-bucket summaries
|
||||
arch_breakdown = {}
|
||||
for key, bucket in sorted(knowledge.items()):
|
||||
label = f"{key[0]} / {key[1]} / {key[2]}"
|
||||
best = bucket.best_method
|
||||
arch_breakdown[label] = {
|
||||
"total_runs": bucket.total_runs,
|
||||
"best_method": best,
|
||||
"best_score": bucket.methods[best].mean_score if best and best in bucket.methods else 0,
|
||||
"n_methods_tested": len(bucket.methods),
|
||||
}
|
||||
|
||||
# Hyperparameter trends across top runs
|
||||
all_top_configs: list[dict] = []
|
||||
for bucket in knowledge.values():
|
||||
for ms in bucket.methods.values():
|
||||
if ms.configs and ms.scores:
|
||||
paired = sorted(zip(ms.scores, ms.configs), key=lambda x: x[0], reverse=True)
|
||||
top_n = max(1, len(paired) // 4)
|
||||
all_top_configs.extend(c for _, c in paired[:top_n])
|
||||
|
||||
hp_trends: dict[str, Any] = {}
|
||||
if all_top_configs:
|
||||
all_keys = set()
|
||||
for c in all_top_configs:
|
||||
all_keys.update(c.keys())
|
||||
for key in sorted(all_keys):
|
||||
values = [c[key] for c in all_top_configs if key in c and c[key] is not None]
|
||||
if not values:
|
||||
continue
|
||||
if all(isinstance(v, bool) for v in values):
|
||||
true_pct = sum(1 for v in values if v) / len(values)
|
||||
hp_trends[key] = {"type": "bool", "true_pct": round(true_pct, 2), "n": len(values)}
|
||||
elif all(isinstance(v, (int, float)) for v in values):
|
||||
hp_trends[key] = {
|
||||
"type": "numeric",
|
||||
"median": round(statistics.median(values), 4),
|
||||
"mean": round(statistics.mean(values), 4),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"n": len(values),
|
||||
}
|
||||
|
||||
return {
|
||||
"total_records": total_records,
|
||||
"overall_best_methods": [
|
||||
{"method": name, "mean_score": round(score, 4), "n_runs": n}
|
||||
for name, score, n in overall_ranking
|
||||
],
|
||||
"architecture_breakdown": arch_breakdown,
|
||||
"hyperparameter_trends": hp_trends,
|
||||
}
|
||||
|
||||
|
||||
# ── Format helpers ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def format_recommendation(rec: AdaptiveRecommendation) -> str:
|
||||
"""Format a recommendation as a human-readable markdown string."""
|
||||
if rec.confidence == "none":
|
||||
return (
|
||||
f"**No telemetry data** for {rec.bucket_label}.\n\n"
|
||||
"Using research-grounded defaults from `architecture_profiles.py`.\n"
|
||||
"Run some abliterations and the adaptive system will learn!"
|
||||
)
|
||||
|
||||
confidence_emoji = {"high": "HIGH", "medium": "MEDIUM", "low": "LOW"}
|
||||
conf = confidence_emoji.get(rec.confidence, rec.confidence.upper())
|
||||
|
||||
lines = [
|
||||
f"### Adaptive Recommendation [{conf} confidence]",
|
||||
f"**Architecture bucket:** {rec.bucket_label}",
|
||||
f"**Based on:** {rec.n_records} community runs",
|
||||
"",
|
||||
f"**Recommended method:** `{rec.recommended_method}` "
|
||||
f"(score: {rec.method_ranking[0][1]:.4f}, {rec.n_method_records} runs)",
|
||||
"",
|
||||
]
|
||||
|
||||
if len(rec.method_ranking) > 1:
|
||||
lines.append("**Method ranking:**")
|
||||
lines.append("| Rank | Method | Mean Score | Runs |")
|
||||
lines.append("|------|--------|------------|------|")
|
||||
for i, (name, score) in enumerate(rec.method_ranking[:8], 1):
|
||||
ms_runs = 0
|
||||
# Get run count from the knowledge (not stored directly, but we have n_method_records for winner)
|
||||
lines.append(f"| {i} | `{name}` | {score:.4f} | — |")
|
||||
lines.append("")
|
||||
|
||||
if rec.method_overrides:
|
||||
lines.append("**Optimal hyperparameters** (from top 25% of runs):")
|
||||
for k, v in sorted(rec.method_overrides.items()):
|
||||
lines.append(f" - `{k}`: {v}")
|
||||
lines.append("")
|
||||
|
||||
if rec.best_refusal_rate is not None:
|
||||
lines.append(f"**Best refusal rate achieved:** {rec.best_refusal_rate:.1%}")
|
||||
if rec.best_coherence is not None:
|
||||
lines.append(f"**Best coherence achieved:** {rec.best_coherence:.3f}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"*{rec.reason}*")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -582,3 +582,53 @@ def apply_profile_to_method_config(
|
||||
# are valid pipeline parameters needed by the UI auto-detect path.
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def enhance_profile_with_telemetry(
|
||||
profile: ArchitectureProfile,
|
||||
) -> tuple[ArchitectureProfile, "AdaptiveRecommendation | None"]:
|
||||
"""Optionally enhance a profile with telemetry-driven adaptive defaults.
|
||||
|
||||
Queries the community telemetry dataset and, if sufficient data exists for
|
||||
this architecture bucket, overlays the empirically-best method and
|
||||
hyperparameters onto the profile's research-grounded defaults.
|
||||
|
||||
Research defaults remain the fallback when telemetry data is sparse.
|
||||
|
||||
Returns:
|
||||
(profile, recommendation) — recommendation is None if no telemetry data.
|
||||
"""
|
||||
try:
|
||||
from obliteratus.adaptive_defaults import get_adaptive_recommendation
|
||||
except ImportError:
|
||||
return profile, None
|
||||
|
||||
try:
|
||||
rec = get_adaptive_recommendation(
|
||||
arch_class=profile.arch_class.value,
|
||||
reasoning_class=profile.reasoning_class.value,
|
||||
total_params_b=profile.total_params_b,
|
||||
model_name=profile.model_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Adaptive defaults failed: %s", e)
|
||||
return profile, None
|
||||
|
||||
if rec.confidence == "none":
|
||||
return profile, rec
|
||||
|
||||
# Only override research defaults if we have medium+ confidence
|
||||
if rec.confidence in ("medium", "high"):
|
||||
if rec.recommended_method:
|
||||
profile.recommended_method = rec.recommended_method
|
||||
profile.profile_description += (
|
||||
f"\n\n**Telemetry override ({rec.confidence} confidence):** "
|
||||
f"Community data ({rec.n_records} runs) shows `{rec.recommended_method}` "
|
||||
f"performs best for this architecture."
|
||||
)
|
||||
if rec.method_overrides:
|
||||
# Merge telemetry overrides on top of research defaults
|
||||
profile.method_overrides.update(rec.method_overrides)
|
||||
|
||||
return profile, rec
|
||||
|
||||
|
||||
@@ -155,6 +155,40 @@ def main(argv: list[str] | None = None):
|
||||
help="Directory containing contribution JSON files",
|
||||
)
|
||||
|
||||
# --- tourney ---
|
||||
tourney_parser = subparsers.add_parser(
|
||||
"tourney",
|
||||
help="March Madness tournament — pit all methods against each other, push winner to Hub",
|
||||
)
|
||||
tourney_parser.add_argument("model", type=str, help="HuggingFace model name/path")
|
||||
tourney_parser.add_argument("--hub-org", type=str, default=None, help="HF org to push winner (e.g. my-org)")
|
||||
tourney_parser.add_argument("--hub-repo", type=str, default=None, help="Full HF repo ID (overrides --hub-org)")
|
||||
tourney_parser.add_argument("--device", type=str, default="auto")
|
||||
tourney_parser.add_argument("--dtype", type=str, default="float16")
|
||||
tourney_parser.add_argument("--dataset", type=str, default="builtin", help="Dataset source (default: builtin)")
|
||||
tourney_parser.add_argument(
|
||||
"--quantization", type=str, default=None, choices=["4bit", "8bit"],
|
||||
help="Load model with quantization",
|
||||
)
|
||||
tourney_parser.add_argument("--output-dir", type=str, default="/tmp/obliteratus_tourney")
|
||||
tourney_parser.add_argument(
|
||||
"--methods", type=str, nargs="+", default=None,
|
||||
help="Override: only run these methods (space-separated)",
|
||||
)
|
||||
|
||||
# --- recommend ---
|
||||
recommend_parser = subparsers.add_parser(
|
||||
"recommend",
|
||||
help="Show telemetry-driven best method + hyperparams for a model",
|
||||
)
|
||||
recommend_parser.add_argument("model", type=str, help="HuggingFace model name/path")
|
||||
recommend_parser.add_argument("--device", type=str, default="cpu")
|
||||
recommend_parser.add_argument("--dtype", type=str, default="float32")
|
||||
recommend_parser.add_argument(
|
||||
"--insights", action="store_true", default=False,
|
||||
help="Also show global cross-architecture insights",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if args.command == "run":
|
||||
@@ -175,6 +209,10 @@ def main(argv: list[str] | None = None):
|
||||
_cmd_aggregate(args)
|
||||
elif args.command == "ui":
|
||||
_cmd_ui(args)
|
||||
elif args.command == "recommend":
|
||||
_cmd_recommend(args)
|
||||
elif args.command == "tourney":
|
||||
_cmd_tourney(args)
|
||||
elif args.command in ("obliterate", "abliterate"):
|
||||
_cmd_abliterate(args)
|
||||
|
||||
@@ -371,6 +409,112 @@ def _cmd_aggregate(args):
|
||||
console.print(table)
|
||||
|
||||
|
||||
def _cmd_recommend(args):
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from obliteratus.architecture_profiles import detect_architecture, enhance_profile_with_telemetry
|
||||
from obliteratus.adaptive_defaults import format_recommendation, get_global_insights
|
||||
|
||||
model_name = args.model
|
||||
console.print(f"\nAnalyzing [bold]{model_name}[/]...")
|
||||
|
||||
# Detect architecture
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
num_layers = getattr(config, "num_hidden_layers", 0)
|
||||
hidden_size = getattr(config, "hidden_size", 0)
|
||||
except Exception:
|
||||
config = None
|
||||
num_layers = 0
|
||||
hidden_size = 0
|
||||
|
||||
profile = detect_architecture(model_name, config, num_layers, hidden_size)
|
||||
profile, rec = enhance_profile_with_telemetry(profile)
|
||||
|
||||
console.print(Panel(
|
||||
f"[bold]{profile.profile_label}[/]\n"
|
||||
f"Architecture: {profile.arch_class.value} | Reasoning: {profile.reasoning_class.value}\n"
|
||||
f"Params: ~{profile.total_params_b:.1f}B | Layers: {profile.num_layers} | "
|
||||
f"Hidden: {profile.hidden_size}",
|
||||
title="Architecture Profile",
|
||||
border_style="cyan",
|
||||
))
|
||||
|
||||
if rec:
|
||||
md = format_recommendation(rec)
|
||||
console.print(Markdown(md))
|
||||
else:
|
||||
console.print("\n[yellow]Could not fetch telemetry — using research-grounded defaults.[/]")
|
||||
|
||||
console.print(f"\n[bold green]Research default method:[/] {profile.recommended_method}")
|
||||
if profile.method_overrides:
|
||||
console.print("[bold green]Overrides:[/]")
|
||||
for k, v in sorted(profile.method_overrides.items()):
|
||||
console.print(f" {k}: {v}")
|
||||
|
||||
if args.insights:
|
||||
console.print("\n")
|
||||
console.rule("[bold magenta]Global Telemetry Insights")
|
||||
insights = get_global_insights()
|
||||
console.print(f"Total records analyzed: {insights['total_records']}")
|
||||
if insights["overall_best_methods"]:
|
||||
console.print("\n[bold]Overall method ranking (all architectures):[/]")
|
||||
for entry in insights["overall_best_methods"][:10]:
|
||||
console.print(
|
||||
f" {entry['method']}: {entry['mean_score']:.4f} "
|
||||
f"({entry['n_runs']} runs)"
|
||||
)
|
||||
if insights["architecture_breakdown"]:
|
||||
console.print("\n[bold]Per-architecture breakdown:[/]")
|
||||
for label, info in insights["architecture_breakdown"].items():
|
||||
console.print(
|
||||
f" {label}: best={info['best_method']} "
|
||||
f"({info['best_score']:.4f}), "
|
||||
f"{info['n_methods_tested']} methods tested, "
|
||||
f"{info['total_runs']} runs"
|
||||
)
|
||||
|
||||
|
||||
def _cmd_tourney(args):
|
||||
from obliteratus.tourney import TourneyRunner, render_bracket
|
||||
|
||||
def on_log(msg):
|
||||
console.print(msg)
|
||||
|
||||
def on_round(rnd):
|
||||
console.print()
|
||||
console.rule(f"[bold green]Round {rnd.round_num} complete — "
|
||||
f"{len(rnd.advanced_to)} advance, {len(rnd.eliminated)} eliminated")
|
||||
|
||||
runner = TourneyRunner(
|
||||
model_name=args.model,
|
||||
hub_org=args.hub_org,
|
||||
hub_repo=args.hub_repo,
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
dataset_key=args.dataset,
|
||||
quantization=args.quantization,
|
||||
methods=args.methods,
|
||||
output_dir=args.output_dir,
|
||||
on_log=on_log,
|
||||
on_round=on_round,
|
||||
)
|
||||
|
||||
result = runner.run()
|
||||
|
||||
if result.winner:
|
||||
console.print()
|
||||
console.rule("[bold magenta]TOURNAMENT CHAMPION", style="magenta")
|
||||
console.print(f" [bold]{result.winner.method}[/] — score {result.winner.score:.4f}")
|
||||
console.print(f" Refusal rate: {result.winner.metrics.get('refusal_rate', '?')}")
|
||||
console.print(f" Coherence: {result.winner.metrics.get('coherence', '?')}")
|
||||
if result.hub_repo:
|
||||
console.print(f" Pushed to: [link=https://huggingface.co/{result.hub_repo}]{result.hub_repo}[/link]")
|
||||
console.print(f"\n Full bracket: {args.output_dir}/tourney_bracket.md")
|
||||
|
||||
|
||||
def _cmd_abliterate(args):
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
|
||||
@@ -26,6 +26,7 @@ import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from obliteratus import device as dev
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -261,8 +262,7 @@ class BenchmarkRunner:
|
||||
("math_reasoning", self.run_math_reasoning_probe)]:
|
||||
results[name] = fn()
|
||||
# Free KV caches between probes to prevent OOM on tight GPUs
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
dev.empty_cache()
|
||||
return results
|
||||
|
||||
def _answer_mcq(self, question: str, choices: list[str]) -> int:
|
||||
|
||||
@@ -32,6 +32,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from obliteratus import device as dev
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
@@ -363,8 +364,7 @@ def unload_harmbench_classifier() -> None:
|
||||
model, tokenizer = _HARMBENCH_CLASSIFIER
|
||||
del model, tokenizer
|
||||
_HARMBENCH_CLASSIFIER = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
dev.empty_cache()
|
||||
logger.info("HarmBench classifier unloaded")
|
||||
|
||||
|
||||
@@ -432,8 +432,7 @@ def harmbench_asr(
|
||||
|
||||
# Free memory between batches
|
||||
del inputs, outputs
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
dev.empty_cache()
|
||||
|
||||
n_successful = sum(per_item)
|
||||
return {
|
||||
@@ -536,8 +535,7 @@ def first_token_kl_on_prompts(
|
||||
kl_values.extend(kl.cpu().tolist())
|
||||
|
||||
del inputs_orig, inputs_mod, logits_orig, logits_mod, first_orig, first_mod
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
dev.empty_cache()
|
||||
|
||||
mean_kl = statistics.mean(kl_values) if kl_values else 0.0
|
||||
std_kl = statistics.stdev(kl_values) if len(kl_values) > 1 else 0.0
|
||||
@@ -1098,8 +1096,8 @@ def run_full_heretic_eval(
|
||||
completions.append("")
|
||||
|
||||
del inputs
|
||||
if i % 25 == 0 and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if i % 25 == 0:
|
||||
dev.empty_cache()
|
||||
|
||||
log(f"Generated {len(completions)} completions")
|
||||
|
||||
|
||||
@@ -181,6 +181,8 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
||||
on_log: Callable[[str], None] | None = None,
|
||||
# Base pipeline kwargs forwarded to AbliterationPipeline
|
||||
push_to_hub: str | None = None,
|
||||
hub_token: str | None = None,
|
||||
hub_community_org: str | None = None,
|
||||
quantization: str | None = None,
|
||||
# Analysis configuration
|
||||
run_cone_analysis: bool = True,
|
||||
@@ -212,6 +214,8 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
||||
on_stage=on_stage,
|
||||
on_log=on_log,
|
||||
push_to_hub=push_to_hub,
|
||||
hub_token=hub_token,
|
||||
hub_community_org=hub_community_org,
|
||||
quantization=quantization,
|
||||
# Set informed defaults
|
||||
norm_preserve=True,
|
||||
|
||||
@@ -21,9 +21,10 @@ console = Console()
|
||||
def _detect_compute_tier() -> str:
|
||||
"""Auto-detect the best compute tier based on available hardware."""
|
||||
try:
|
||||
import torch
|
||||
from obliteratus import device as dev
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if dev.is_cuda():
|
||||
import torch
|
||||
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||
if vram_gb >= 20:
|
||||
return "large"
|
||||
@@ -31,8 +32,13 @@ def _detect_compute_tier() -> str:
|
||||
return "medium"
|
||||
else:
|
||||
return "small"
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return "small" # Apple Silicon — conservative estimate
|
||||
elif dev.is_mps():
|
||||
# Apple Silicon with unified memory — estimate from system RAM
|
||||
mem = dev.get_memory_info()
|
||||
if mem.total_gb >= 24:
|
||||
return "medium" # M1 Pro/Max/Ultra, M2 Pro/Max/Ultra, M3 Pro/Max
|
||||
else:
|
||||
return "small" # M1/M2/M3 base (8-16 GB)
|
||||
except ImportError:
|
||||
pass
|
||||
return "tiny" # CPU only
|
||||
@@ -237,12 +243,11 @@ def run_interactive():
|
||||
dtype = model_preset.recommended_dtype
|
||||
quantization = None
|
||||
try:
|
||||
import torch
|
||||
from obliteratus import device as _dev
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "auto"
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
resolved = _dev.get_device()
|
||||
if resolved != "cpu":
|
||||
device = resolved if resolved == "mps" else "auto"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Optional
|
||||
import sys as _sys
|
||||
|
||||
import torch
|
||||
from obliteratus import device as dev
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -381,24 +382,8 @@ def _estimate_model_memory_gb(config: AutoConfig, dtype: torch.dtype) -> float:
|
||||
|
||||
|
||||
def _available_gpu_memory_gb() -> float:
|
||||
"""Return free GPU memory across all CUDA devices, in GB.
|
||||
|
||||
Uses torch.cuda.mem_get_info which reports actual free memory,
|
||||
not total capacity. Falls back to total_memory if mem_get_info
|
||||
is unavailable (PyTorch < 1.10).
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return 0.0
|
||||
total_free = 0.0
|
||||
for i in range(torch.cuda.device_count()):
|
||||
try:
|
||||
free, _ = torch.cuda.mem_get_info(i)
|
||||
total_free += free / (1024 ** 3)
|
||||
except AttributeError:
|
||||
# Fallback for old PyTorch without mem_get_info
|
||||
props = torch.cuda.get_device_properties(i)
|
||||
total_free += props.total_memory / (1024 ** 3)
|
||||
return total_free
|
||||
"""Return free accelerator memory in GB (CUDA, MPS, or 0 for CPU)."""
|
||||
return dev.get_total_free_gb()
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
@@ -515,34 +500,54 @@ def load_model(
|
||||
load_kwargs.pop("torch_dtype", None)
|
||||
load_kwargs["device_map"] = "auto"
|
||||
elif quantization in ("4bit", "8bit"):
|
||||
try:
|
||||
import bitsandbytes # noqa: F401
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
f"Quantization '{quantization}' requires bitsandbytes: "
|
||||
f"pip install -U bitsandbytes>=0.46.1"
|
||||
)
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
# Enable fp32 CPU offload so that models too large to fit entirely on
|
||||
# GPU (even quantized) can spill to CPU without crashing bitsandbytes.
|
||||
# This is critical for frontier MoE models (GLM-5 744B, DeepSeek-V3 685B,
|
||||
# Mistral Large 3 675B, etc.) on single-GPU setups.
|
||||
if quantization == "4bit":
|
||||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch_dtype,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
# BitsAndBytes only works on NVIDIA CUDA GPUs.
|
||||
resolved_device = dev.get_device(device)
|
||||
if not dev.supports_bitsandbytes(resolved_device):
|
||||
logger.warning(
|
||||
"BitsAndBytes quantization is not supported on %s. "
|
||||
"Loading in %s instead.",
|
||||
resolved_device, dtype,
|
||||
)
|
||||
# On MPS, load normally to the device; on CPU, fall through.
|
||||
if resolved_device == "mps":
|
||||
device = "mps"
|
||||
# Don't set quantization_config — fall through to normal loading.
|
||||
else:
|
||||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
)
|
||||
load_kwargs["device_map"] = "auto"
|
||||
elif device == "auto":
|
||||
load_kwargs["device_map"] = "auto"
|
||||
try:
|
||||
import bitsandbytes # noqa: F401
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
f"Quantization '{quantization}' requires bitsandbytes: "
|
||||
f"pip install -U bitsandbytes>=0.46.1"
|
||||
)
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
# Enable fp32 CPU offload so that models too large to fit entirely on
|
||||
# GPU (even quantized) can spill to CPU without crashing bitsandbytes.
|
||||
# This is critical for frontier MoE models (GLM-5 744B, DeepSeek-V3 685B,
|
||||
# Mistral Large 3 675B, etc.) on single-GPU setups.
|
||||
if quantization == "4bit":
|
||||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch_dtype,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
)
|
||||
else:
|
||||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
)
|
||||
load_kwargs["device_map"] = "auto"
|
||||
|
||||
# device_map="auto" is only reliable on CUDA (accelerate doesn't support MPS).
|
||||
if "device_map" not in load_kwargs and device == "auto":
|
||||
resolved_device = dev.get_device(device)
|
||||
if dev.supports_device_map_auto(resolved_device):
|
||||
load_kwargs["device_map"] = "auto"
|
||||
else:
|
||||
# MPS / CPU: load to CPU first, then .to(device) after loading.
|
||||
pass
|
||||
|
||||
# Offload support: provide a folder for disk offloading when GPU memory is insufficient
|
||||
_offload_dir = None
|
||||
@@ -560,9 +565,9 @@ def load_model(
|
||||
# Reserve GPU headroom for inference (KV cache, activations, generate()).
|
||||
# Without this, device_map="auto" packs 100% of layers onto GPU, leaving
|
||||
# no room for forward passes or generation on tight-memory setups.
|
||||
if torch.cuda.is_available():
|
||||
if dev.is_cuda():
|
||||
max_memory = {}
|
||||
for i in range(torch.cuda.device_count()):
|
||||
for i in range(dev.device_count()):
|
||||
total = torch.cuda.get_device_properties(i).total_memory
|
||||
# Reserve 15% or 2 GiB (whichever is larger) for inference headroom
|
||||
reserve = max(int(total * 0.15), 2 * 1024 ** 3)
|
||||
@@ -570,16 +575,8 @@ def load_model(
|
||||
max_memory[i] = f"{usable // (1024 ** 2)}MiB"
|
||||
# Allow overflow to CPU RAM, capped at 85% of physical memory
|
||||
# to leave room for the OS, Python runtime, and serialization buffers.
|
||||
try:
|
||||
import psutil
|
||||
cpu_ram_gb = psutil.virtual_memory().total / (1024 ** 3)
|
||||
except ImportError:
|
||||
try:
|
||||
cpu_ram_gb = os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE") / (1024 ** 3)
|
||||
except (AttributeError, ValueError):
|
||||
# os.sysconf is unavailable on non-POSIX platforms (Windows)
|
||||
cpu_ram_gb = 16.0 # conservative fallback
|
||||
cpu_budget_gb = int(cpu_ram_gb * 0.85)
|
||||
total_ram, _ = dev._system_memory_gb()
|
||||
cpu_budget_gb = int(total_ram * 0.85)
|
||||
max_memory["cpu"] = f"{max(cpu_budget_gb, 4)}GiB"
|
||||
load_kwargs["max_memory"] = max_memory
|
||||
logger.info(
|
||||
@@ -625,12 +622,15 @@ def load_model(
|
||||
|
||||
if device not in ("auto",) and quantization is None and native_quant is None:
|
||||
model = model.to(device)
|
||||
elif device == "auto" and not dev.supports_device_map_auto():
|
||||
# MPS / CPU: device_map wasn't used, move model to best device.
|
||||
resolved = dev.get_device()
|
||||
model = model.to(resolved)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Free CUDA cache after loading
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
# Free accelerator cache after loading
|
||||
dev.empty_cache()
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
@@ -665,9 +665,7 @@ def load_model(
|
||||
if gpu_gb > 0 and native_quant is not None:
|
||||
# Model is pre-quantized but we can't estimate its true size.
|
||||
# Check actual free memory after loading — if less than 40% free, skip snapshot.
|
||||
free_gb = 0.0
|
||||
for i in range(torch.cuda.device_count()):
|
||||
free_gb += torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
|
||||
free_gb = dev.get_total_free_gb()
|
||||
if free_gb < gpu_gb * 0.4:
|
||||
logger.warning(
|
||||
f"Auto-skipping state dict snapshot for natively quantized model "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user