Files
OBLITERATUS/obliteratus/analysis/visualization.py
T
2026-03-04 12:38:18 -08:00

410 lines
15 KiB
Python

"""Rich visualization module for abliteration analysis outputs.
Generates publication-quality figures and interactive terminal displays
for all analysis components. Designed for both Jupyter notebook and
CLI consumption.
Visualizations:
1. Refusal Topology Map — layer-wise refusal strength heatmap
2. Cross-Layer Direction Flow — cosine similarity matrix + angular drift
3. Logit Lens Token Spectrum — promoted/suppressed token waterfall
4. Defense Profile Radar — spider chart of defense properties
5. Capability-Safety Pareto Frontier — benchmark vs. refusal rate tradeoff
6. Activation Probe Dashboard — per-layer elimination status
"""
from __future__ import annotations
import re
from pathlib import Path
from typing import Any
import matplotlib
matplotlib.use("Agg") # Set once at import time; safe for server & notebook
import matplotlib.pyplot as plt
import torch
def _sanitize_label(text: str, max_len: int = 80) -> str:
"""Strip filesystem paths, tokens, and overly-long strings from labels.
Prevents accidental leakage of sensitive information (HF org names,
local paths, API tokens) into saved chart images.
"""
# Remove anything that looks like an absolute path
text = re.sub(r"(/[a-zA-Z0-9_./-]{3,})", lambda m: m.group(0).rsplit("/", 1)[-1], text)
# Remove HF-style token prefixes (hf_...)
text = re.sub(r"\bhf_[A-Za-z0-9]{6,}\b", "<TOKEN>", text)
# Remove generic secret-like hex strings (32+ chars)
text = re.sub(r"\b[0-9a-fA-F]{32,}\b", "<REDACTED>", text)
# Truncate
if len(text) > max_len:
text = text[: max_len - 3] + "..."
return text
def plot_refusal_topology(
refusal_directions: dict[int, torch.Tensor],
harmful_means: dict[int, torch.Tensor],
harmless_means: dict[int, torch.Tensor],
strong_layers: list[int],
output_path: str | Path | None = None,
title: str = "Refusal Topology Map",
):
"""Visualize refusal signal strength across all layers.
Creates a bar chart showing per-layer refusal strength (norm of the
harmful-harmless mean difference projected onto the refusal direction),
with strong layers highlighted.
"""
title = _sanitize_label(title)
layers = sorted(refusal_directions.keys())
strengths = []
for idx in layers:
d = refusal_directions[idx].float()
if d.dim() > 1:
d = d.squeeze()
d = d / d.norm().clamp(min=1e-8)
if idx in harmful_means and idx in harmless_means:
diff = (harmful_means[idx] - harmless_means[idx]).squeeze().float()
strengths.append((diff @ d).abs().item())
else:
strengths.append(0.0)
colors = ["#e74c3c" if idx in strong_layers else "#3498db" for idx in layers]
fig, ax = plt.subplots(figsize=(14, 5))
ax.bar(range(len(layers)), strengths, color=colors, alpha=0.85, edgecolor="white", linewidth=0.5)
ax.set_xlabel("Layer Index", fontsize=12)
ax.set_ylabel("Refusal Signal Strength", fontsize=12)
ax.set_title(title, fontsize=14, fontweight="bold")
ax.set_xticks(range(0, len(layers), max(1, len(layers) // 20)))
ax.set_xticklabels([str(layers[i]) for i in range(0, len(layers), max(1, len(layers) // 20))])
# Legend
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor="#e74c3c", label="Strong (selected for abliteration)"),
Patch(facecolor="#3498db", label="Weak (not targeted)"),
]
ax.legend(handles=legend_elements, loc="upper right")
plt.tight_layout()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return fig
def plot_cross_layer_heatmap(
cross_layer_result,
output_path: str | Path | None = None,
title: str = "Cross-Layer Refusal Direction Alignment",
):
"""Visualize the pairwise cosine similarity matrix between layer refusal directions."""
title = _sanitize_label(title)
matrix = cross_layer_result.cosine_matrix.numpy()
indices = cross_layer_result.layer_indices
n = len(indices)
fig, ax = plt.subplots(figsize=(max(8, n * 0.5), max(6, n * 0.4)))
im = ax.imshow(matrix, cmap="RdYlBu_r", vmin=0, vmax=1, aspect="auto")
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels([str(i) for i in indices], fontsize=max(6, 10 - n // 5))
ax.set_yticklabels([str(i) for i in indices], fontsize=max(6, 10 - n // 5))
ax.set_xlabel("Layer", fontsize=12)
ax.set_ylabel("Layer", fontsize=12)
ax.set_title(title, fontsize=14, fontweight="bold")
cbar = plt.colorbar(im, ax=ax, shrink=0.8)
cbar.set_label("Cosine Similarity (|cos θ|)", fontsize=10)
# Annotate if small enough
if n <= 15:
for i in range(n):
for j in range(n):
val = matrix[i, j]
color = "white" if val > 0.7 or val < 0.3 else "black"
ax.text(j, i, f"{val:.2f}", ha="center", va="center",
color=color, fontsize=max(6, 9 - n // 3))
plt.tight_layout()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return fig
def plot_angular_drift(
cross_layer_result,
output_path: str | Path | None = None,
title: str = "Refusal Direction Angular Drift Through Network",
):
"""Visualize cumulative angular drift of the refusal direction."""
title = _sanitize_label(title)
indices = cross_layer_result.layer_indices
drift = cross_layer_result.angular_drift
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(indices, drift, "o-", color="#e74c3c", linewidth=2, markersize=6)
ax.fill_between(indices, drift, alpha=0.15, color="#e74c3c")
ax.set_xlabel("Layer Index", fontsize=12)
ax.set_ylabel("Cumulative Angular Drift (radians)", fontsize=12)
ax.set_title(title, fontsize=14, fontweight="bold")
ax.grid(True, alpha=0.3)
# Add persistence score annotation
ps = cross_layer_result.direction_persistence_score
ax.annotate(
f"Direction Persistence: {ps:.3f}",
xy=(0.02, 0.95), xycoords="axes fraction",
fontsize=11, fontweight="bold",
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow", alpha=0.9),
)
plt.tight_layout()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return fig
def plot_logit_lens_spectrum(
logit_lens_result,
layer_idx: int | None = None,
output_path: str | Path | None = None,
title: str | None = None,
):
"""Visualize the logit lens token promotion/suppression spectrum."""
# Select which layer to display
if layer_idx is not None:
result = logit_lens_result.per_layer.get(layer_idx)
else:
result = logit_lens_result.per_layer.get(logit_lens_result.strongest_refusal_layer)
if result is None:
return None
if title is None:
title = f"Logit Lens — Layer {result.layer_idx}"
title = _sanitize_label(title)
# Combine top promoted and suppressed
promoted = result.top_promoted[:15]
suppressed = result.top_suppressed[:15]
tokens = [t for t, _ in reversed(suppressed)] + [t for t, _ in promoted]
values = [v for _, v in reversed(suppressed)] + [v for _, v in promoted]
colors = ["#2ecc71" if v > 0 else "#e74c3c" for v in values]
fig, ax = plt.subplots(figsize=(10, max(6, len(tokens) * 0.3)))
y_pos = range(len(tokens))
ax.barh(y_pos, values, color=colors, alpha=0.85, edgecolor="white", linewidth=0.5)
ax.set_yticks(y_pos)
ax.set_yticklabels([repr(t)[:20] for t in tokens], fontsize=9)
ax.set_xlabel("Logit Boost from Refusal Direction", fontsize=12)
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axvline(x=0, color="black", linewidth=0.8)
ax.grid(True, axis="x", alpha=0.3)
# Annotation
gap = result.refusal_compliance_gap
spec = result.refusal_specificity
ax.annotate(
f"Refusal-Compliance Gap: {gap:.4f}\nRefusal Specificity: {spec:.3f}",
xy=(0.98, 0.02), xycoords="axes fraction",
fontsize=9, ha="right",
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow", alpha=0.9),
)
plt.tight_layout()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return fig
def plot_defense_radar(
defense_profile,
output_path: str | Path | None = None,
title: str = "Defense Robustness Profile",
):
"""Spider/radar chart of defense properties."""
title = _sanitize_label(title)
import numpy as np
categories = [
"Distribution\n(1-Gini)",
"Layer\nSpread",
"Refusal\nStrength",
"Self-\nRepair",
"Entangle-\nment",
]
p = defense_profile
# Normalize to 0-1 range
values = [
1.0 - p.refusal_concentration,
min(p.refusal_layer_spread / 15.0, 1.0),
min(p.mean_refusal_strength / 5.0, 1.0),
p.self_repair_estimate,
p.entanglement_score,
]
n_cats = len(categories)
angles = np.linspace(0, 2 * np.pi, n_cats, endpoint=False).tolist()
values_plot = values + [values[0]]
angles += [angles[0]]
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
ax.plot(angles, values_plot, "o-", linewidth=2, color="#e74c3c")
ax.fill(angles, values_plot, alpha=0.2, color="#e74c3c")
ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories, fontsize=10)
ax.set_ylim(0, 1)
ax.set_yticks([0.25, 0.5, 0.75, 1.0])
ax.set_yticklabels(["0.25", "0.50", "0.75", "1.00"], fontsize=8)
ax.set_title(f"{title}\n{_sanitize_label(p.model_name)}", fontsize=14, fontweight="bold", pad=20)
# Robustness badge
robustness_colors = {
"low": "#e74c3c", "medium": "#f39c12",
"high": "#27ae60", "very_high": "#2ecc71",
}
badge_color = robustness_colors.get(p.estimated_robustness, "#95a5a6")
ax.annotate(
f"Robustness: {p.estimated_robustness.upper()}",
xy=(0.5, -0.08), xycoords="axes fraction",
fontsize=14, fontweight="bold", ha="center",
color=badge_color,
bbox=dict(boxstyle="round,pad=0.4", facecolor="white", edgecolor=badge_color),
)
plt.tight_layout()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return fig
def plot_capability_safety_pareto(
benchmark_results: dict[str, Any],
refusal_rate: float,
other_points: list[tuple[float, float, str]] | None = None,
output_path: str | Path | None = None,
title: str = "Capability-Safety Pareto Frontier",
):
"""Plot the capability vs safety tradeoff."""
title = _sanitize_label(title)
# Current point
scores = [r.score for r in benchmark_results.values()]
capability = sum(scores) / max(len(scores), 1)
fig, ax = plt.subplots(figsize=(10, 7))
# Plot current model
ax.scatter([refusal_rate], [capability], s=200, c="#e74c3c", zorder=5,
edgecolors="black", linewidth=1.5)
ax.annotate("Current Model", (refusal_rate, capability),
textcoords="offset points", xytext=(10, 10), fontsize=11)
# Plot reference points if provided
if other_points:
for rr, cap, label in other_points:
ax.scatter([rr], [cap], s=100, c="#3498db", zorder=4, alpha=0.7)
ax.annotate(label, (rr, cap), textcoords="offset points",
xytext=(8, 5), fontsize=9)
# Reference quadrants
ax.axhline(y=0.5, color="gray", linestyle="--", alpha=0.3)
ax.axvline(x=0.5, color="gray", linestyle="--", alpha=0.3)
ax.text(0.25, 0.25, "BROKEN\n(unsafe & dumb)", ha="center", va="center",
fontsize=10, color="gray", alpha=0.5)
ax.text(0.75, 0.25, "CENSORED\n(safe but dumb)", ha="center", va="center",
fontsize=10, color="gray", alpha=0.5)
ax.text(0.25, 0.75, "ABLITERATED\n(capable but unsafe)", ha="center", va="center",
fontsize=10, color="gray", alpha=0.5)
ax.text(0.75, 0.75, "IDEAL\n(safe & capable)", ha="center", va="center",
fontsize=10, color="gray", alpha=0.5)
ax.set_xlabel("Refusal Rate (higher = safer)", fontsize=12)
ax.set_ylabel("Capability Score (higher = more capable)", fontsize=12)
ax.set_title(title, fontsize=14, fontweight="bold")
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
ax.grid(True, alpha=0.2)
plt.tight_layout()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return fig
def plot_probe_dashboard(
probe_result,
output_path: str | Path | None = None,
title: str = "Activation Probe Dashboard",
):
"""Dashboard showing per-layer refusal elimination status."""
title = _sanitize_label(title)
layers = sorted(probe_result.per_layer.keys())
gaps = [probe_result.per_layer[idx].projection_gap for idx in layers]
d_primes = [probe_result.per_layer[idx].separation_d_prime for idx in layers]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Left: projection gaps
colors = ["#e74c3c" if abs(g) > 0.1 else "#2ecc71" for g in gaps]
ax1.bar(range(len(layers)), gaps, color=colors, alpha=0.85)
ax1.axhline(y=0, color="black", linewidth=0.8)
ax1.axhline(y=0.1, color="red", linewidth=0.5, linestyle="--", alpha=0.5)
ax1.axhline(y=-0.1, color="red", linewidth=0.5, linestyle="--", alpha=0.5)
ax1.set_xlabel("Layer", fontsize=11)
ax1.set_ylabel("Projection Gap (harmful - harmless)", fontsize=11)
ax1.set_title("Residual Refusal Signal", fontsize=12, fontweight="bold")
ax1.set_xticks(range(0, len(layers), max(1, len(layers) // 10)))
ax1.set_xticklabels([str(layers[i]) for i in range(0, len(layers), max(1, len(layers) // 10))])
# Right: d-prime
colors2 = ["#e74c3c" if d > 1.0 else "#f39c12" if d > 0.5 else "#2ecc71" for d in d_primes]
ax2.bar(range(len(layers)), d_primes, color=colors2, alpha=0.85)
ax2.axhline(y=1.0, color="red", linewidth=0.5, linestyle="--", alpha=0.5, label="d'=1 (detectable)")
ax2.set_xlabel("Layer", fontsize=11)
ax2.set_ylabel("d' (sensitivity)", fontsize=11)
ax2.set_title("Signal Detection Sensitivity", fontsize=12, fontweight="bold")
ax2.set_xticks(range(0, len(layers), max(1, len(layers) // 10)))
ax2.set_xticklabels([str(layers[i]) for i in range(0, len(layers), max(1, len(layers) // 10))])
ax2.legend()
# Overall RES badge
res = probe_result.refusal_elimination_score
fig.suptitle(
f"{title} | RES = {res:.3f}",
fontsize=14, fontweight="bold", y=1.02,
)
plt.tight_layout()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return fig