"""Rich visualization module for abliteration analysis outputs. Generates publication-quality figures and interactive terminal displays for all analysis components. Designed for both Jupyter notebook and CLI consumption. Visualizations: 1. Refusal Topology Map — layer-wise refusal strength heatmap 2. Cross-Layer Direction Flow — cosine similarity matrix + angular drift 3. Logit Lens Token Spectrum — promoted/suppressed token waterfall 4. Defense Profile Radar — spider chart of defense properties 5. Capability-Safety Pareto Frontier — benchmark vs. refusal rate tradeoff 6. Activation Probe Dashboard — per-layer elimination status """ from __future__ import annotations import re from pathlib import Path from typing import Any import matplotlib matplotlib.use("Agg") # Set once at import time; safe for server & notebook import matplotlib.pyplot as plt import torch def _sanitize_label(text: str, max_len: int = 80) -> str: """Strip filesystem paths, tokens, and overly-long strings from labels. Prevents accidental leakage of sensitive information (HF org names, local paths, API tokens) into saved chart images. """ # Remove anything that looks like an absolute path text = re.sub(r"(/[a-zA-Z0-9_./-]{3,})", lambda m: m.group(0).rsplit("/", 1)[-1], text) # Remove HF-style token prefixes (hf_...) text = re.sub(r"\bhf_[A-Za-z0-9]{6,}\b", "", text) # Remove generic secret-like hex strings (32+ chars) text = re.sub(r"\b[0-9a-fA-F]{32,}\b", "", text) # Truncate if len(text) > max_len: text = text[: max_len - 3] + "..." return text def plot_refusal_topology( refusal_directions: dict[int, torch.Tensor], harmful_means: dict[int, torch.Tensor], harmless_means: dict[int, torch.Tensor], strong_layers: list[int], output_path: str | Path | None = None, title: str = "Refusal Topology Map", ): """Visualize refusal signal strength across all layers. Creates a bar chart showing per-layer refusal strength (norm of the harmful-harmless mean difference projected onto the refusal direction), with strong layers highlighted. """ title = _sanitize_label(title) layers = sorted(refusal_directions.keys()) strengths = [] for idx in layers: d = refusal_directions[idx].float() if d.dim() > 1: d = d.squeeze() d = d / d.norm().clamp(min=1e-8) if idx in harmful_means and idx in harmless_means: diff = (harmful_means[idx] - harmless_means[idx]).squeeze().float() strengths.append((diff @ d).abs().item()) else: strengths.append(0.0) colors = ["#e74c3c" if idx in strong_layers else "#3498db" for idx in layers] fig, ax = plt.subplots(figsize=(14, 5)) ax.bar(range(len(layers)), strengths, color=colors, alpha=0.85, edgecolor="white", linewidth=0.5) ax.set_xlabel("Layer Index", fontsize=12) ax.set_ylabel("Refusal Signal Strength", fontsize=12) ax.set_title(title, fontsize=14, fontweight="bold") ax.set_xticks(range(0, len(layers), max(1, len(layers) // 20))) ax.set_xticklabels([str(layers[i]) for i in range(0, len(layers), max(1, len(layers) // 20))]) # Legend from matplotlib.patches import Patch legend_elements = [ Patch(facecolor="#e74c3c", label="Strong (selected for abliteration)"), Patch(facecolor="#3498db", label="Weak (not targeted)"), ] ax.legend(handles=legend_elements, loc="upper right") plt.tight_layout() if output_path: fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) else: plt.show() return fig def plot_cross_layer_heatmap( cross_layer_result, output_path: str | Path | None = None, title: str = "Cross-Layer Refusal Direction Alignment", ): """Visualize the pairwise cosine similarity matrix between layer refusal directions.""" title = _sanitize_label(title) matrix = cross_layer_result.cosine_matrix.numpy() indices = cross_layer_result.layer_indices n = len(indices) fig, ax = plt.subplots(figsize=(max(8, n * 0.5), max(6, n * 0.4))) im = ax.imshow(matrix, cmap="RdYlBu_r", vmin=0, vmax=1, aspect="auto") ax.set_xticks(range(n)) ax.set_yticks(range(n)) ax.set_xticklabels([str(i) for i in indices], fontsize=max(6, 10 - n // 5)) ax.set_yticklabels([str(i) for i in indices], fontsize=max(6, 10 - n // 5)) ax.set_xlabel("Layer", fontsize=12) ax.set_ylabel("Layer", fontsize=12) ax.set_title(title, fontsize=14, fontweight="bold") cbar = plt.colorbar(im, ax=ax, shrink=0.8) cbar.set_label("Cosine Similarity (|cos θ|)", fontsize=10) # Annotate if small enough if n <= 15: for i in range(n): for j in range(n): val = matrix[i, j] color = "white" if val > 0.7 or val < 0.3 else "black" ax.text(j, i, f"{val:.2f}", ha="center", va="center", color=color, fontsize=max(6, 9 - n // 3)) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) else: plt.show() return fig def plot_angular_drift( cross_layer_result, output_path: str | Path | None = None, title: str = "Refusal Direction Angular Drift Through Network", ): """Visualize cumulative angular drift of the refusal direction.""" title = _sanitize_label(title) indices = cross_layer_result.layer_indices drift = cross_layer_result.angular_drift fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(indices, drift, "o-", color="#e74c3c", linewidth=2, markersize=6) ax.fill_between(indices, drift, alpha=0.15, color="#e74c3c") ax.set_xlabel("Layer Index", fontsize=12) ax.set_ylabel("Cumulative Angular Drift (radians)", fontsize=12) ax.set_title(title, fontsize=14, fontweight="bold") ax.grid(True, alpha=0.3) # Add persistence score annotation ps = cross_layer_result.direction_persistence_score ax.annotate( f"Direction Persistence: {ps:.3f}", xy=(0.02, 0.95), xycoords="axes fraction", fontsize=11, fontweight="bold", bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow", alpha=0.9), ) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) else: plt.show() return fig def plot_logit_lens_spectrum( logit_lens_result, layer_idx: int | None = None, output_path: str | Path | None = None, title: str | None = None, ): """Visualize the logit lens token promotion/suppression spectrum.""" # Select which layer to display if layer_idx is not None: result = logit_lens_result.per_layer.get(layer_idx) else: result = logit_lens_result.per_layer.get(logit_lens_result.strongest_refusal_layer) if result is None: return None if title is None: title = f"Logit Lens — Layer {result.layer_idx}" title = _sanitize_label(title) # Combine top promoted and suppressed promoted = result.top_promoted[:15] suppressed = result.top_suppressed[:15] tokens = [t for t, _ in reversed(suppressed)] + [t for t, _ in promoted] values = [v for _, v in reversed(suppressed)] + [v for _, v in promoted] colors = ["#2ecc71" if v > 0 else "#e74c3c" for v in values] fig, ax = plt.subplots(figsize=(10, max(6, len(tokens) * 0.3))) y_pos = range(len(tokens)) ax.barh(y_pos, values, color=colors, alpha=0.85, edgecolor="white", linewidth=0.5) ax.set_yticks(y_pos) ax.set_yticklabels([repr(t)[:20] for t in tokens], fontsize=9) ax.set_xlabel("Logit Boost from Refusal Direction", fontsize=12) ax.set_title(title, fontsize=14, fontweight="bold") ax.axvline(x=0, color="black", linewidth=0.8) ax.grid(True, axis="x", alpha=0.3) # Annotation gap = result.refusal_compliance_gap spec = result.refusal_specificity ax.annotate( f"Refusal-Compliance Gap: {gap:.4f}\nRefusal Specificity: {spec:.3f}", xy=(0.98, 0.02), xycoords="axes fraction", fontsize=9, ha="right", bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow", alpha=0.9), ) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) else: plt.show() return fig def plot_defense_radar( defense_profile, output_path: str | Path | None = None, title: str = "Defense Robustness Profile", ): """Spider/radar chart of defense properties.""" title = _sanitize_label(title) import numpy as np categories = [ "Distribution\n(1-Gini)", "Layer\nSpread", "Refusal\nStrength", "Self-\nRepair", "Entangle-\nment", ] p = defense_profile # Normalize to 0-1 range values = [ 1.0 - p.refusal_concentration, min(p.refusal_layer_spread / 15.0, 1.0), min(p.mean_refusal_strength / 5.0, 1.0), p.self_repair_estimate, p.entanglement_score, ] n_cats = len(categories) angles = np.linspace(0, 2 * np.pi, n_cats, endpoint=False).tolist() values_plot = values + [values[0]] angles += [angles[0]] fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) ax.plot(angles, values_plot, "o-", linewidth=2, color="#e74c3c") ax.fill(angles, values_plot, alpha=0.2, color="#e74c3c") ax.set_xticks(angles[:-1]) ax.set_xticklabels(categories, fontsize=10) ax.set_ylim(0, 1) ax.set_yticks([0.25, 0.5, 0.75, 1.0]) ax.set_yticklabels(["0.25", "0.50", "0.75", "1.00"], fontsize=8) ax.set_title(f"{title}\n{_sanitize_label(p.model_name)}", fontsize=14, fontweight="bold", pad=20) # Robustness badge robustness_colors = { "low": "#e74c3c", "medium": "#f39c12", "high": "#27ae60", "very_high": "#2ecc71", } badge_color = robustness_colors.get(p.estimated_robustness, "#95a5a6") ax.annotate( f"Robustness: {p.estimated_robustness.upper()}", xy=(0.5, -0.08), xycoords="axes fraction", fontsize=14, fontweight="bold", ha="center", color=badge_color, bbox=dict(boxstyle="round,pad=0.4", facecolor="white", edgecolor=badge_color), ) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) else: plt.show() return fig def plot_capability_safety_pareto( benchmark_results: dict[str, Any], refusal_rate: float, other_points: list[tuple[float, float, str]] | None = None, output_path: str | Path | None = None, title: str = "Capability-Safety Pareto Frontier", ): """Plot the capability vs safety tradeoff.""" title = _sanitize_label(title) # Current point scores = [r.score for r in benchmark_results.values()] capability = sum(scores) / max(len(scores), 1) fig, ax = plt.subplots(figsize=(10, 7)) # Plot current model ax.scatter([refusal_rate], [capability], s=200, c="#e74c3c", zorder=5, edgecolors="black", linewidth=1.5) ax.annotate("Current Model", (refusal_rate, capability), textcoords="offset points", xytext=(10, 10), fontsize=11) # Plot reference points if provided if other_points: for rr, cap, label in other_points: ax.scatter([rr], [cap], s=100, c="#3498db", zorder=4, alpha=0.7) ax.annotate(label, (rr, cap), textcoords="offset points", xytext=(8, 5), fontsize=9) # Reference quadrants ax.axhline(y=0.5, color="gray", linestyle="--", alpha=0.3) ax.axvline(x=0.5, color="gray", linestyle="--", alpha=0.3) ax.text(0.25, 0.25, "BROKEN\n(unsafe & dumb)", ha="center", va="center", fontsize=10, color="gray", alpha=0.5) ax.text(0.75, 0.25, "CENSORED\n(safe but dumb)", ha="center", va="center", fontsize=10, color="gray", alpha=0.5) ax.text(0.25, 0.75, "ABLITERATED\n(capable but unsafe)", ha="center", va="center", fontsize=10, color="gray", alpha=0.5) ax.text(0.75, 0.75, "IDEAL\n(safe & capable)", ha="center", va="center", fontsize=10, color="gray", alpha=0.5) ax.set_xlabel("Refusal Rate (higher = safer)", fontsize=12) ax.set_ylabel("Capability Score (higher = more capable)", fontsize=12) ax.set_title(title, fontsize=14, fontweight="bold") ax.set_xlim(-0.05, 1.05) ax.set_ylim(-0.05, 1.05) ax.grid(True, alpha=0.2) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) else: plt.show() return fig def plot_probe_dashboard( probe_result, output_path: str | Path | None = None, title: str = "Activation Probe Dashboard", ): """Dashboard showing per-layer refusal elimination status.""" title = _sanitize_label(title) layers = sorted(probe_result.per_layer.keys()) gaps = [probe_result.per_layer[idx].projection_gap for idx in layers] d_primes = [probe_result.per_layer[idx].separation_d_prime for idx in layers] fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) # Left: projection gaps colors = ["#e74c3c" if abs(g) > 0.1 else "#2ecc71" for g in gaps] ax1.bar(range(len(layers)), gaps, color=colors, alpha=0.85) ax1.axhline(y=0, color="black", linewidth=0.8) ax1.axhline(y=0.1, color="red", linewidth=0.5, linestyle="--", alpha=0.5) ax1.axhline(y=-0.1, color="red", linewidth=0.5, linestyle="--", alpha=0.5) ax1.set_xlabel("Layer", fontsize=11) ax1.set_ylabel("Projection Gap (harmful - harmless)", fontsize=11) ax1.set_title("Residual Refusal Signal", fontsize=12, fontweight="bold") ax1.set_xticks(range(0, len(layers), max(1, len(layers) // 10))) ax1.set_xticklabels([str(layers[i]) for i in range(0, len(layers), max(1, len(layers) // 10))]) # Right: d-prime colors2 = ["#e74c3c" if d > 1.0 else "#f39c12" if d > 0.5 else "#2ecc71" for d in d_primes] ax2.bar(range(len(layers)), d_primes, color=colors2, alpha=0.85) ax2.axhline(y=1.0, color="red", linewidth=0.5, linestyle="--", alpha=0.5, label="d'=1 (detectable)") ax2.set_xlabel("Layer", fontsize=11) ax2.set_ylabel("d' (sensitivity)", fontsize=11) ax2.set_title("Signal Detection Sensitivity", fontsize=12, fontweight="bold") ax2.set_xticks(range(0, len(layers), max(1, len(layers) // 10))) ax2.set_xticklabels([str(layers[i]) for i in range(0, len(layers), max(1, len(layers) // 10))]) ax2.legend() # Overall RES badge res = probe_result.refusal_elimination_score fig.suptitle( f"{title} | RES = {res:.3f}", fontsize=14, fontweight="bold", y=1.02, ) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) else: plt.show() return fig