mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-30 23:17:57 +02:00
410 lines
15 KiB
Python
410 lines
15 KiB
Python
"""Rich visualization module for abliteration analysis outputs.
|
|
|
|
Generates publication-quality figures and interactive terminal displays
|
|
for all analysis components. Designed for both Jupyter notebook and
|
|
CLI consumption.
|
|
|
|
Visualizations:
|
|
1. Refusal Topology Map — layer-wise refusal strength heatmap
|
|
2. Cross-Layer Direction Flow — cosine similarity matrix + angular drift
|
|
3. Logit Lens Token Spectrum — promoted/suppressed token waterfall
|
|
4. Defense Profile Radar — spider chart of defense properties
|
|
5. Capability-Safety Pareto Frontier — benchmark vs. refusal rate tradeoff
|
|
6. Activation Probe Dashboard — per-layer elimination status
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import matplotlib
|
|
matplotlib.use("Agg") # Set once at import time; safe for server & notebook
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
|
|
|
|
def _sanitize_label(text: str, max_len: int = 80) -> str:
|
|
"""Strip filesystem paths, tokens, and overly-long strings from labels.
|
|
|
|
Prevents accidental leakage of sensitive information (HF org names,
|
|
local paths, API tokens) into saved chart images.
|
|
"""
|
|
# Remove anything that looks like an absolute path
|
|
text = re.sub(r"(/[a-zA-Z0-9_./-]{3,})", lambda m: m.group(0).rsplit("/", 1)[-1], text)
|
|
# Remove HF-style token prefixes (hf_...)
|
|
text = re.sub(r"\bhf_[A-Za-z0-9]{6,}\b", "<TOKEN>", text)
|
|
# Remove generic secret-like hex strings (32+ chars)
|
|
text = re.sub(r"\b[0-9a-fA-F]{32,}\b", "<REDACTED>", text)
|
|
# Truncate
|
|
if len(text) > max_len:
|
|
text = text[: max_len - 3] + "..."
|
|
return text
|
|
|
|
|
|
def plot_refusal_topology(
|
|
refusal_directions: dict[int, torch.Tensor],
|
|
harmful_means: dict[int, torch.Tensor],
|
|
harmless_means: dict[int, torch.Tensor],
|
|
strong_layers: list[int],
|
|
output_path: str | Path | None = None,
|
|
title: str = "Refusal Topology Map",
|
|
):
|
|
"""Visualize refusal signal strength across all layers.
|
|
|
|
Creates a bar chart showing per-layer refusal strength (norm of the
|
|
harmful-harmless mean difference projected onto the refusal direction),
|
|
with strong layers highlighted.
|
|
"""
|
|
title = _sanitize_label(title)
|
|
layers = sorted(refusal_directions.keys())
|
|
strengths = []
|
|
for idx in layers:
|
|
d = refusal_directions[idx].float()
|
|
if d.dim() > 1:
|
|
d = d.squeeze()
|
|
d = d / d.norm().clamp(min=1e-8)
|
|
if idx in harmful_means and idx in harmless_means:
|
|
diff = (harmful_means[idx] - harmless_means[idx]).squeeze().float()
|
|
strengths.append((diff @ d).abs().item())
|
|
else:
|
|
strengths.append(0.0)
|
|
|
|
colors = ["#e74c3c" if idx in strong_layers else "#3498db" for idx in layers]
|
|
|
|
fig, ax = plt.subplots(figsize=(14, 5))
|
|
ax.bar(range(len(layers)), strengths, color=colors, alpha=0.85, edgecolor="white", linewidth=0.5)
|
|
ax.set_xlabel("Layer Index", fontsize=12)
|
|
ax.set_ylabel("Refusal Signal Strength", fontsize=12)
|
|
ax.set_title(title, fontsize=14, fontweight="bold")
|
|
ax.set_xticks(range(0, len(layers), max(1, len(layers) // 20)))
|
|
ax.set_xticklabels([str(layers[i]) for i in range(0, len(layers), max(1, len(layers) // 20))])
|
|
|
|
# Legend
|
|
from matplotlib.patches import Patch
|
|
legend_elements = [
|
|
Patch(facecolor="#e74c3c", label="Strong (selected for abliteration)"),
|
|
Patch(facecolor="#3498db", label="Weak (not targeted)"),
|
|
]
|
|
ax.legend(handles=legend_elements, loc="upper right")
|
|
|
|
plt.tight_layout()
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close(fig)
|
|
else:
|
|
plt.show()
|
|
return fig
|
|
|
|
|
|
def plot_cross_layer_heatmap(
|
|
cross_layer_result,
|
|
output_path: str | Path | None = None,
|
|
title: str = "Cross-Layer Refusal Direction Alignment",
|
|
):
|
|
"""Visualize the pairwise cosine similarity matrix between layer refusal directions."""
|
|
title = _sanitize_label(title)
|
|
matrix = cross_layer_result.cosine_matrix.numpy()
|
|
indices = cross_layer_result.layer_indices
|
|
n = len(indices)
|
|
|
|
fig, ax = plt.subplots(figsize=(max(8, n * 0.5), max(6, n * 0.4)))
|
|
im = ax.imshow(matrix, cmap="RdYlBu_r", vmin=0, vmax=1, aspect="auto")
|
|
ax.set_xticks(range(n))
|
|
ax.set_yticks(range(n))
|
|
ax.set_xticklabels([str(i) for i in indices], fontsize=max(6, 10 - n // 5))
|
|
ax.set_yticklabels([str(i) for i in indices], fontsize=max(6, 10 - n // 5))
|
|
ax.set_xlabel("Layer", fontsize=12)
|
|
ax.set_ylabel("Layer", fontsize=12)
|
|
ax.set_title(title, fontsize=14, fontweight="bold")
|
|
|
|
cbar = plt.colorbar(im, ax=ax, shrink=0.8)
|
|
cbar.set_label("Cosine Similarity (|cos θ|)", fontsize=10)
|
|
|
|
# Annotate if small enough
|
|
if n <= 15:
|
|
for i in range(n):
|
|
for j in range(n):
|
|
val = matrix[i, j]
|
|
color = "white" if val > 0.7 or val < 0.3 else "black"
|
|
ax.text(j, i, f"{val:.2f}", ha="center", va="center",
|
|
color=color, fontsize=max(6, 9 - n // 3))
|
|
|
|
plt.tight_layout()
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close(fig)
|
|
else:
|
|
plt.show()
|
|
return fig
|
|
|
|
|
|
def plot_angular_drift(
|
|
cross_layer_result,
|
|
output_path: str | Path | None = None,
|
|
title: str = "Refusal Direction Angular Drift Through Network",
|
|
):
|
|
"""Visualize cumulative angular drift of the refusal direction."""
|
|
title = _sanitize_label(title)
|
|
indices = cross_layer_result.layer_indices
|
|
drift = cross_layer_result.angular_drift
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 5))
|
|
ax.plot(indices, drift, "o-", color="#e74c3c", linewidth=2, markersize=6)
|
|
ax.fill_between(indices, drift, alpha=0.15, color="#e74c3c")
|
|
ax.set_xlabel("Layer Index", fontsize=12)
|
|
ax.set_ylabel("Cumulative Angular Drift (radians)", fontsize=12)
|
|
ax.set_title(title, fontsize=14, fontweight="bold")
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
# Add persistence score annotation
|
|
ps = cross_layer_result.direction_persistence_score
|
|
ax.annotate(
|
|
f"Direction Persistence: {ps:.3f}",
|
|
xy=(0.02, 0.95), xycoords="axes fraction",
|
|
fontsize=11, fontweight="bold",
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow", alpha=0.9),
|
|
)
|
|
|
|
plt.tight_layout()
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close(fig)
|
|
else:
|
|
plt.show()
|
|
return fig
|
|
|
|
|
|
def plot_logit_lens_spectrum(
|
|
logit_lens_result,
|
|
layer_idx: int | None = None,
|
|
output_path: str | Path | None = None,
|
|
title: str | None = None,
|
|
):
|
|
"""Visualize the logit lens token promotion/suppression spectrum."""
|
|
# Select which layer to display
|
|
if layer_idx is not None:
|
|
result = logit_lens_result.per_layer.get(layer_idx)
|
|
else:
|
|
result = logit_lens_result.per_layer.get(logit_lens_result.strongest_refusal_layer)
|
|
|
|
if result is None:
|
|
return None
|
|
|
|
if title is None:
|
|
title = f"Logit Lens — Layer {result.layer_idx}"
|
|
title = _sanitize_label(title)
|
|
|
|
# Combine top promoted and suppressed
|
|
promoted = result.top_promoted[:15]
|
|
suppressed = result.top_suppressed[:15]
|
|
|
|
tokens = [t for t, _ in reversed(suppressed)] + [t for t, _ in promoted]
|
|
values = [v for _, v in reversed(suppressed)] + [v for _, v in promoted]
|
|
colors = ["#2ecc71" if v > 0 else "#e74c3c" for v in values]
|
|
|
|
fig, ax = plt.subplots(figsize=(10, max(6, len(tokens) * 0.3)))
|
|
y_pos = range(len(tokens))
|
|
ax.barh(y_pos, values, color=colors, alpha=0.85, edgecolor="white", linewidth=0.5)
|
|
ax.set_yticks(y_pos)
|
|
ax.set_yticklabels([repr(t)[:20] for t in tokens], fontsize=9)
|
|
ax.set_xlabel("Logit Boost from Refusal Direction", fontsize=12)
|
|
ax.set_title(title, fontsize=14, fontweight="bold")
|
|
ax.axvline(x=0, color="black", linewidth=0.8)
|
|
ax.grid(True, axis="x", alpha=0.3)
|
|
|
|
# Annotation
|
|
gap = result.refusal_compliance_gap
|
|
spec = result.refusal_specificity
|
|
ax.annotate(
|
|
f"Refusal-Compliance Gap: {gap:.4f}\nRefusal Specificity: {spec:.3f}",
|
|
xy=(0.98, 0.02), xycoords="axes fraction",
|
|
fontsize=9, ha="right",
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow", alpha=0.9),
|
|
)
|
|
|
|
plt.tight_layout()
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close(fig)
|
|
else:
|
|
plt.show()
|
|
return fig
|
|
|
|
|
|
def plot_defense_radar(
|
|
defense_profile,
|
|
output_path: str | Path | None = None,
|
|
title: str = "Defense Robustness Profile",
|
|
):
|
|
"""Spider/radar chart of defense properties."""
|
|
title = _sanitize_label(title)
|
|
import numpy as np
|
|
|
|
categories = [
|
|
"Distribution\n(1-Gini)",
|
|
"Layer\nSpread",
|
|
"Refusal\nStrength",
|
|
"Self-\nRepair",
|
|
"Entangle-\nment",
|
|
]
|
|
|
|
p = defense_profile
|
|
# Normalize to 0-1 range
|
|
values = [
|
|
1.0 - p.refusal_concentration,
|
|
min(p.refusal_layer_spread / 15.0, 1.0),
|
|
min(p.mean_refusal_strength / 5.0, 1.0),
|
|
p.self_repair_estimate,
|
|
p.entanglement_score,
|
|
]
|
|
|
|
n_cats = len(categories)
|
|
angles = np.linspace(0, 2 * np.pi, n_cats, endpoint=False).tolist()
|
|
values_plot = values + [values[0]]
|
|
angles += [angles[0]]
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
|
|
ax.plot(angles, values_plot, "o-", linewidth=2, color="#e74c3c")
|
|
ax.fill(angles, values_plot, alpha=0.2, color="#e74c3c")
|
|
|
|
ax.set_xticks(angles[:-1])
|
|
ax.set_xticklabels(categories, fontsize=10)
|
|
ax.set_ylim(0, 1)
|
|
ax.set_yticks([0.25, 0.5, 0.75, 1.0])
|
|
ax.set_yticklabels(["0.25", "0.50", "0.75", "1.00"], fontsize=8)
|
|
ax.set_title(f"{title}\n{_sanitize_label(p.model_name)}", fontsize=14, fontweight="bold", pad=20)
|
|
|
|
# Robustness badge
|
|
robustness_colors = {
|
|
"low": "#e74c3c", "medium": "#f39c12",
|
|
"high": "#27ae60", "very_high": "#2ecc71",
|
|
}
|
|
badge_color = robustness_colors.get(p.estimated_robustness, "#95a5a6")
|
|
ax.annotate(
|
|
f"Robustness: {p.estimated_robustness.upper()}",
|
|
xy=(0.5, -0.08), xycoords="axes fraction",
|
|
fontsize=14, fontweight="bold", ha="center",
|
|
color=badge_color,
|
|
bbox=dict(boxstyle="round,pad=0.4", facecolor="white", edgecolor=badge_color),
|
|
)
|
|
|
|
plt.tight_layout()
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close(fig)
|
|
else:
|
|
plt.show()
|
|
return fig
|
|
|
|
|
|
def plot_capability_safety_pareto(
|
|
benchmark_results: dict[str, Any],
|
|
refusal_rate: float,
|
|
other_points: list[tuple[float, float, str]] | None = None,
|
|
output_path: str | Path | None = None,
|
|
title: str = "Capability-Safety Pareto Frontier",
|
|
):
|
|
"""Plot the capability vs safety tradeoff."""
|
|
title = _sanitize_label(title)
|
|
# Current point
|
|
scores = [r.score for r in benchmark_results.values()]
|
|
capability = sum(scores) / max(len(scores), 1)
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 7))
|
|
|
|
# Plot current model
|
|
ax.scatter([refusal_rate], [capability], s=200, c="#e74c3c", zorder=5,
|
|
edgecolors="black", linewidth=1.5)
|
|
ax.annotate("Current Model", (refusal_rate, capability),
|
|
textcoords="offset points", xytext=(10, 10), fontsize=11)
|
|
|
|
# Plot reference points if provided
|
|
if other_points:
|
|
for rr, cap, label in other_points:
|
|
ax.scatter([rr], [cap], s=100, c="#3498db", zorder=4, alpha=0.7)
|
|
ax.annotate(label, (rr, cap), textcoords="offset points",
|
|
xytext=(8, 5), fontsize=9)
|
|
|
|
# Reference quadrants
|
|
ax.axhline(y=0.5, color="gray", linestyle="--", alpha=0.3)
|
|
ax.axvline(x=0.5, color="gray", linestyle="--", alpha=0.3)
|
|
|
|
ax.text(0.25, 0.25, "BROKEN\n(unsafe & dumb)", ha="center", va="center",
|
|
fontsize=10, color="gray", alpha=0.5)
|
|
ax.text(0.75, 0.25, "CENSORED\n(safe but dumb)", ha="center", va="center",
|
|
fontsize=10, color="gray", alpha=0.5)
|
|
ax.text(0.25, 0.75, "ABLITERATED\n(capable but unsafe)", ha="center", va="center",
|
|
fontsize=10, color="gray", alpha=0.5)
|
|
ax.text(0.75, 0.75, "IDEAL\n(safe & capable)", ha="center", va="center",
|
|
fontsize=10, color="gray", alpha=0.5)
|
|
|
|
ax.set_xlabel("Refusal Rate (higher = safer)", fontsize=12)
|
|
ax.set_ylabel("Capability Score (higher = more capable)", fontsize=12)
|
|
ax.set_title(title, fontsize=14, fontweight="bold")
|
|
ax.set_xlim(-0.05, 1.05)
|
|
ax.set_ylim(-0.05, 1.05)
|
|
ax.grid(True, alpha=0.2)
|
|
|
|
plt.tight_layout()
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close(fig)
|
|
else:
|
|
plt.show()
|
|
return fig
|
|
|
|
|
|
def plot_probe_dashboard(
|
|
probe_result,
|
|
output_path: str | Path | None = None,
|
|
title: str = "Activation Probe Dashboard",
|
|
):
|
|
"""Dashboard showing per-layer refusal elimination status."""
|
|
title = _sanitize_label(title)
|
|
layers = sorted(probe_result.per_layer.keys())
|
|
gaps = [probe_result.per_layer[idx].projection_gap for idx in layers]
|
|
d_primes = [probe_result.per_layer[idx].separation_d_prime for idx in layers]
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
|
|
|
# Left: projection gaps
|
|
colors = ["#e74c3c" if abs(g) > 0.1 else "#2ecc71" for g in gaps]
|
|
ax1.bar(range(len(layers)), gaps, color=colors, alpha=0.85)
|
|
ax1.axhline(y=0, color="black", linewidth=0.8)
|
|
ax1.axhline(y=0.1, color="red", linewidth=0.5, linestyle="--", alpha=0.5)
|
|
ax1.axhline(y=-0.1, color="red", linewidth=0.5, linestyle="--", alpha=0.5)
|
|
ax1.set_xlabel("Layer", fontsize=11)
|
|
ax1.set_ylabel("Projection Gap (harmful - harmless)", fontsize=11)
|
|
ax1.set_title("Residual Refusal Signal", fontsize=12, fontweight="bold")
|
|
ax1.set_xticks(range(0, len(layers), max(1, len(layers) // 10)))
|
|
ax1.set_xticklabels([str(layers[i]) for i in range(0, len(layers), max(1, len(layers) // 10))])
|
|
|
|
# Right: d-prime
|
|
colors2 = ["#e74c3c" if d > 1.0 else "#f39c12" if d > 0.5 else "#2ecc71" for d in d_primes]
|
|
ax2.bar(range(len(layers)), d_primes, color=colors2, alpha=0.85)
|
|
ax2.axhline(y=1.0, color="red", linewidth=0.5, linestyle="--", alpha=0.5, label="d'=1 (detectable)")
|
|
ax2.set_xlabel("Layer", fontsize=11)
|
|
ax2.set_ylabel("d' (sensitivity)", fontsize=11)
|
|
ax2.set_title("Signal Detection Sensitivity", fontsize=12, fontweight="bold")
|
|
ax2.set_xticks(range(0, len(layers), max(1, len(layers) // 10)))
|
|
ax2.set_xticklabels([str(layers[i]) for i in range(0, len(layers), max(1, len(layers) // 10))])
|
|
ax2.legend()
|
|
|
|
# Overall RES badge
|
|
res = probe_result.refusal_elimination_score
|
|
fig.suptitle(
|
|
f"{title} | RES = {res:.3f}",
|
|
fontsize=14, fontweight="bold", y=1.02,
|
|
)
|
|
|
|
plt.tight_layout()
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close(fig)
|
|
else:
|
|
plt.show()
|
|
return fig
|