mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-23 11:46:28 +02:00
568 lines
20 KiB
Python
568 lines
20 KiB
Python
"""Tests for the community contribution system."""
|
|
|
|
import json
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from obliteratus.community import (
|
|
CONTRIBUTION_SCHEMA_VERSION,
|
|
_config_fingerprint,
|
|
_model_short_name,
|
|
aggregate_results,
|
|
generate_latex_table,
|
|
load_contributions,
|
|
save_contribution,
|
|
)
|
|
|
|
|
|
# ── Helper: mock pipeline ──────────────────────────────────────────────
|
|
|
|
|
|
def _make_mock_pipeline():
|
|
"""Build a mock pipeline with all fields the community module reads."""
|
|
p = MagicMock()
|
|
p.handle.summary.return_value = {
|
|
"architecture": "LlamaForCausalLM",
|
|
"num_layers": 32,
|
|
"num_heads": 32,
|
|
"hidden_size": 4096,
|
|
"total_params": 8_000_000_000,
|
|
}
|
|
p.method = "advanced"
|
|
p.n_directions = 4
|
|
p.norm_preserve = True
|
|
p.regularization = 0.3
|
|
p.refinement_passes = 2
|
|
p.project_biases = True
|
|
p.use_chat_template = True
|
|
p.use_whitened_svd = True
|
|
p.true_iterative_refinement = False
|
|
p.use_jailbreak_contrast = False
|
|
p.layer_adaptive_strength = False
|
|
p.attention_head_surgery = True
|
|
p.safety_neuron_masking = False
|
|
p.per_expert_directions = False
|
|
p.use_sae_features = False
|
|
p.invert_refusal = False
|
|
p.project_embeddings = False
|
|
p.embed_regularization = 0.5
|
|
p.activation_steering = False
|
|
p.steering_strength = 0.3
|
|
p.expert_transplant = False
|
|
p.transplant_blend = 0.3
|
|
p.reflection_strength = 2.0
|
|
p.quantization = None
|
|
|
|
p._quality_metrics = {"perplexity": 5.2, "coherence": 0.8, "refusal_rate": 0.05}
|
|
p._strong_layers = [10, 11, 12, 13]
|
|
p._stage_durations = {
|
|
"summon": 3.0, "probe": 12.5, "distill": 4.1,
|
|
"excise": 2.0, "verify": 8.3, "rebirth": 5.0,
|
|
}
|
|
p._excise_modified_count = 128
|
|
|
|
# Direction data
|
|
d = torch.randn(4096)
|
|
d = d / d.norm()
|
|
p.refusal_directions = {10: d, 11: d + 0.01 * torch.randn(4096)}
|
|
p.refusal_subspaces = {10: torch.randn(4, 4096)}
|
|
|
|
# Excise details
|
|
p._refusal_heads = {10: [(0, 0.9), (3, 0.8)]}
|
|
p._sae_directions = {}
|
|
p._expert_safety_scores = {}
|
|
p._layer_excise_weights = {}
|
|
p._expert_directions = {}
|
|
p._steering_hooks = []
|
|
|
|
# Prompts
|
|
p.harmful_prompts = ["x"] * 33
|
|
p.harmless_prompts = ["y"] * 33
|
|
p.jailbreak_prompts = None
|
|
|
|
return p
|
|
|
|
|
|
# ── Model short name ───────────────────────────────────────────────────
|
|
|
|
|
|
class TestModelShortName:
|
|
def test_strips_org_prefix(self):
|
|
assert _model_short_name("meta-llama/Llama-2-7b-chat-hf") == "llama-2-7b-chat-hf"
|
|
|
|
def test_no_org_prefix(self):
|
|
assert _model_short_name("gpt2") == "gpt2"
|
|
|
|
def test_sanitizes_special_chars(self):
|
|
assert _model_short_name("org/Model_V2.1") == "model-v2-1"
|
|
|
|
def test_caps_length(self):
|
|
long_name = "a" * 100
|
|
assert len(_model_short_name(long_name)) <= 60
|
|
|
|
def test_collapses_dashes(self):
|
|
assert _model_short_name("org/Model---Name") == "model-name"
|
|
|
|
def test_strips_trailing_dashes(self):
|
|
assert _model_short_name("org/Model-") == "model"
|
|
|
|
|
|
# ── Config fingerprint ─────────────────────────────────────────────────
|
|
|
|
|
|
class TestConfigFingerprint:
|
|
def test_deterministic(self):
|
|
config = {"n_directions": 4, "norm_preserve": True}
|
|
fp1 = _config_fingerprint(config)
|
|
fp2 = _config_fingerprint(config)
|
|
assert fp1 == fp2
|
|
|
|
def test_different_configs_different_hashes(self):
|
|
fp1 = _config_fingerprint({"n_directions": 4})
|
|
fp2 = _config_fingerprint({"n_directions": 8})
|
|
assert fp1 != fp2
|
|
|
|
def test_key_order_invariant(self):
|
|
fp1 = _config_fingerprint({"a": 1, "b": 2})
|
|
fp2 = _config_fingerprint({"b": 2, "a": 1})
|
|
assert fp1 == fp2
|
|
|
|
def test_returns_8_char_hex(self):
|
|
fp = _config_fingerprint({"test": True})
|
|
assert len(fp) == 8
|
|
assert all(c in "0123456789abcdef" for c in fp)
|
|
|
|
|
|
# ── Save contribution ──────────────────────────────────────────────────
|
|
|
|
|
|
class TestSaveContribution:
|
|
def test_saves_json_file(self, tmp_path):
|
|
pipeline = _make_mock_pipeline()
|
|
path = save_contribution(
|
|
pipeline,
|
|
model_name="meta-llama/Llama-2-7b-chat-hf",
|
|
output_dir=tmp_path,
|
|
)
|
|
assert path.exists()
|
|
assert path.suffix == ".json"
|
|
data = json.loads(path.read_text())
|
|
assert data["contribution_schema_version"] == CONTRIBUTION_SCHEMA_VERSION
|
|
assert data["model_name"] == "meta-llama/Llama-2-7b-chat-hf"
|
|
|
|
def test_filename_format(self, tmp_path):
|
|
pipeline = _make_mock_pipeline()
|
|
path = save_contribution(
|
|
pipeline,
|
|
model_name="meta-llama/Llama-2-7b-chat-hf",
|
|
output_dir=tmp_path,
|
|
)
|
|
name = path.stem
|
|
assert name.startswith("llama-2-7b-chat-hf_advanced_")
|
|
|
|
def test_includes_telemetry_report(self, tmp_path):
|
|
pipeline = _make_mock_pipeline()
|
|
path = save_contribution(
|
|
pipeline,
|
|
model_name="meta-llama/Llama-2-7b-chat-hf",
|
|
output_dir=tmp_path,
|
|
)
|
|
data = json.loads(path.read_text())
|
|
telemetry = data["telemetry"]
|
|
assert telemetry["schema_version"] == 2
|
|
assert telemetry["model"]["architecture"] == "LlamaForCausalLM"
|
|
assert telemetry["method"] == "advanced"
|
|
assert telemetry["quality_metrics"]["refusal_rate"] == 0.05
|
|
|
|
def test_includes_config_fingerprint(self, tmp_path):
|
|
pipeline = _make_mock_pipeline()
|
|
path = save_contribution(
|
|
pipeline,
|
|
model_name="meta-llama/Llama-2-7b-chat-hf",
|
|
output_dir=tmp_path,
|
|
)
|
|
data = json.loads(path.read_text())
|
|
assert "config_fingerprint" in data
|
|
assert len(data["config_fingerprint"]) == 8
|
|
|
|
def test_includes_notes(self, tmp_path):
|
|
pipeline = _make_mock_pipeline()
|
|
path = save_contribution(
|
|
pipeline,
|
|
model_name="test/model",
|
|
notes="Ran on A100 with default prompts",
|
|
output_dir=tmp_path,
|
|
)
|
|
data = json.loads(path.read_text())
|
|
assert data["notes"] == "Ran on A100 with default prompts"
|
|
|
|
def test_creates_output_dir(self, tmp_path):
|
|
subdir = tmp_path / "nested" / "dir"
|
|
assert not subdir.exists()
|
|
pipeline = _make_mock_pipeline()
|
|
path = save_contribution(
|
|
pipeline, model_name="test/model", output_dir=subdir,
|
|
)
|
|
assert subdir.exists()
|
|
assert path.exists()
|
|
|
|
def test_timestamp_format(self, tmp_path):
|
|
pipeline = _make_mock_pipeline()
|
|
path = save_contribution(
|
|
pipeline, model_name="test/model", output_dir=tmp_path,
|
|
)
|
|
data = json.loads(path.read_text())
|
|
ts = data["timestamp"]
|
|
# Should be UTC ISO-ish: YYYYMMDDTHHMMSSZ
|
|
assert ts.endswith("Z")
|
|
assert "T" in ts
|
|
assert len(ts) == 16
|
|
|
|
def test_method_config_extracted(self, tmp_path):
|
|
pipeline = _make_mock_pipeline()
|
|
path = save_contribution(
|
|
pipeline, model_name="test/model", output_dir=tmp_path,
|
|
)
|
|
data = json.loads(path.read_text())
|
|
cfg = data["telemetry"]["method_config"]
|
|
assert cfg["n_directions"] == 4
|
|
assert cfg["norm_preserve"] is True
|
|
assert cfg["attention_head_surgery"] is True
|
|
|
|
|
|
# ── Load contributions ─────────────────────────────────────────────────
|
|
|
|
|
|
class TestLoadContributions:
|
|
def _write_contrib(self, directory, model, method, refusal_rate, idx=0):
|
|
"""Write a minimal valid contribution file."""
|
|
record = {
|
|
"contribution_schema_version": CONTRIBUTION_SCHEMA_VERSION,
|
|
"timestamp": f"20260227T{idx:06d}Z",
|
|
"model_name": model,
|
|
"config_fingerprint": "abcd1234",
|
|
"notes": "",
|
|
"telemetry": {
|
|
"schema_version": 2,
|
|
"method": method,
|
|
"quality_metrics": {"refusal_rate": refusal_rate},
|
|
},
|
|
}
|
|
path = directory / f"contrib_{idx}.json"
|
|
path.write_text(json.dumps(record))
|
|
return path
|
|
|
|
def test_loads_valid_files(self, tmp_path):
|
|
self._write_contrib(tmp_path, "test/model", "advanced", 0.05, 0)
|
|
self._write_contrib(tmp_path, "test/model", "basic", 0.10, 1)
|
|
records = load_contributions(tmp_path)
|
|
assert len(records) == 2
|
|
|
|
def test_sorts_by_timestamp(self, tmp_path):
|
|
self._write_contrib(tmp_path, "model-b", "advanced", 0.05, 2)
|
|
self._write_contrib(tmp_path, "model-a", "advanced", 0.10, 1)
|
|
records = load_contributions(tmp_path)
|
|
assert records[0]["model_name"] == "model-a"
|
|
assert records[1]["model_name"] == "model-b"
|
|
|
|
def test_skips_non_contribution_json(self, tmp_path):
|
|
# Write a JSON file without contribution_schema_version
|
|
(tmp_path / "random.json").write_text('{"foo": "bar"}')
|
|
self._write_contrib(tmp_path, "test/model", "advanced", 0.05, 0)
|
|
records = load_contributions(tmp_path)
|
|
assert len(records) == 1
|
|
|
|
def test_skips_invalid_json(self, tmp_path):
|
|
(tmp_path / "bad.json").write_text("not valid json {{{")
|
|
self._write_contrib(tmp_path, "test/model", "advanced", 0.05, 0)
|
|
records = load_contributions(tmp_path)
|
|
assert len(records) == 1
|
|
|
|
def test_returns_empty_for_missing_dir(self, tmp_path):
|
|
records = load_contributions(tmp_path / "nonexistent")
|
|
assert records == []
|
|
|
|
def test_tracks_source_file(self, tmp_path):
|
|
self._write_contrib(tmp_path, "test/model", "advanced", 0.05, 0)
|
|
records = load_contributions(tmp_path)
|
|
assert "_source_file" in records[0]
|
|
assert "contrib_0.json" in records[0]["_source_file"]
|
|
|
|
def test_ignores_non_json_files(self, tmp_path):
|
|
(tmp_path / "readme.txt").write_text("some text")
|
|
self._write_contrib(tmp_path, "test/model", "advanced", 0.05, 0)
|
|
records = load_contributions(tmp_path)
|
|
assert len(records) == 1
|
|
|
|
|
|
# ── Aggregate results ──────────────────────────────────────────────────
|
|
|
|
|
|
class TestAggregateResults:
|
|
def _make_record(self, model, method, refusal_rate, perplexity=None, coherence=None):
|
|
metrics = {"refusal_rate": refusal_rate}
|
|
if perplexity is not None:
|
|
metrics["perplexity"] = perplexity
|
|
if coherence is not None:
|
|
metrics["coherence"] = coherence
|
|
return {
|
|
"model_name": model,
|
|
"telemetry": {
|
|
"method": method,
|
|
"quality_metrics": metrics,
|
|
},
|
|
}
|
|
|
|
def test_single_record(self):
|
|
records = [self._make_record("model-a", "advanced", 0.05)]
|
|
result = aggregate_results(records)
|
|
assert "model-a" in result
|
|
assert "advanced" in result["model-a"]
|
|
assert result["model-a"]["advanced"]["n_runs"] == 1
|
|
assert result["model-a"]["advanced"]["refusal_rate"]["mean"] == 0.05
|
|
|
|
def test_multiple_runs_same_model_method(self):
|
|
records = [
|
|
self._make_record("model-a", "advanced", 0.04),
|
|
self._make_record("model-a", "advanced", 0.06),
|
|
]
|
|
result = aggregate_results(records)
|
|
stats = result["model-a"]["advanced"]
|
|
assert stats["n_runs"] == 2
|
|
assert stats["refusal_rate"]["mean"] == 0.05
|
|
assert stats["refusal_rate"]["min"] == 0.04
|
|
assert stats["refusal_rate"]["max"] == 0.06
|
|
assert stats["refusal_rate"]["n"] == 2
|
|
|
|
def test_multiple_models(self):
|
|
records = [
|
|
self._make_record("model-a", "advanced", 0.05),
|
|
self._make_record("model-b", "basic", 0.10),
|
|
]
|
|
result = aggregate_results(records)
|
|
assert len(result) == 2
|
|
assert "model-a" in result
|
|
assert "model-b" in result
|
|
|
|
def test_multiple_methods(self):
|
|
records = [
|
|
self._make_record("model-a", "advanced", 0.05),
|
|
self._make_record("model-a", "basic", 0.10),
|
|
]
|
|
result = aggregate_results(records)
|
|
assert len(result["model-a"]) == 2
|
|
assert "advanced" in result["model-a"]
|
|
assert "basic" in result["model-a"]
|
|
|
|
def test_std_zero_for_single_run(self):
|
|
records = [self._make_record("model-a", "advanced", 0.05)]
|
|
result = aggregate_results(records)
|
|
assert result["model-a"]["advanced"]["refusal_rate"]["std"] == 0.0
|
|
|
|
def test_multiple_metrics(self):
|
|
records = [
|
|
self._make_record("model-a", "advanced", 0.05, perplexity=5.2, coherence=0.8),
|
|
]
|
|
result = aggregate_results(records)
|
|
stats = result["model-a"]["advanced"]
|
|
assert "refusal_rate" in stats
|
|
assert "perplexity" in stats
|
|
assert "coherence" in stats
|
|
assert stats["perplexity"]["mean"] == 5.2
|
|
|
|
def test_missing_metric_skipped(self):
|
|
records = [self._make_record("model-a", "advanced", 0.05)]
|
|
result = aggregate_results(records)
|
|
# coherence not provided, should not appear
|
|
assert "coherence" not in result["model-a"]["advanced"]
|
|
|
|
def test_unknown_model_and_method(self):
|
|
records = [{
|
|
"telemetry": {"quality_metrics": {"refusal_rate": 0.1}},
|
|
}]
|
|
result = aggregate_results(records)
|
|
assert "unknown" in result
|
|
assert "unknown" in result["unknown"]
|
|
|
|
|
|
# ── LaTeX table generation ─────────────────────────────────────────────
|
|
|
|
|
|
class TestGenerateLatexTable:
|
|
def _sample_aggregated(self):
|
|
return {
|
|
"meta-llama/Llama-2-7b-chat-hf": {
|
|
"advanced": {
|
|
"n_runs": 3,
|
|
"refusal_rate": {"mean": 0.04, "std": 0.01, "n": 3, "min": 0.03, "max": 0.05},
|
|
},
|
|
"basic": {
|
|
"n_runs": 2,
|
|
"refusal_rate": {"mean": 0.08, "std": 0.02, "n": 2, "min": 0.06, "max": 0.10},
|
|
},
|
|
},
|
|
"mistralai/Mistral-7B-Instruct-v0.2": {
|
|
"advanced": {
|
|
"n_runs": 1,
|
|
"refusal_rate": {"mean": 0.03, "std": 0.0, "n": 1, "min": 0.03, "max": 0.03},
|
|
},
|
|
},
|
|
}
|
|
|
|
def test_produces_valid_latex(self):
|
|
agg = self._sample_aggregated()
|
|
latex = generate_latex_table(agg)
|
|
assert "\\begin{tabular}" in latex
|
|
assert "\\end{tabular}" in latex
|
|
assert "\\toprule" in latex
|
|
assert "\\bottomrule" in latex
|
|
|
|
def test_includes_model_names(self):
|
|
agg = self._sample_aggregated()
|
|
latex = generate_latex_table(agg)
|
|
assert "Llama-2-7b-chat-hf" in latex
|
|
assert "Mistral-7B-Instruct-v0.2" in latex
|
|
|
|
def test_includes_method_headers(self):
|
|
agg = self._sample_aggregated()
|
|
latex = generate_latex_table(agg)
|
|
assert "advanced" in latex
|
|
assert "basic" in latex
|
|
|
|
def test_missing_method_shows_dash(self):
|
|
agg = self._sample_aggregated()
|
|
latex = generate_latex_table(agg)
|
|
# Mistral doesn't have "basic" method
|
|
assert "---" in latex
|
|
|
|
def test_shows_std_when_multiple_runs(self):
|
|
agg = self._sample_aggregated()
|
|
latex = generate_latex_table(agg)
|
|
assert "$\\pm$" in latex
|
|
|
|
def test_no_std_for_single_run(self):
|
|
agg = {
|
|
"model": {
|
|
"method": {
|
|
"n_runs": 1,
|
|
"refusal_rate": {"mean": 0.03, "std": 0.0, "n": 1, "min": 0.03, "max": 0.03},
|
|
},
|
|
},
|
|
}
|
|
latex = generate_latex_table(agg)
|
|
assert "$\\pm$" not in latex
|
|
|
|
def test_methods_filter(self):
|
|
agg = self._sample_aggregated()
|
|
latex = generate_latex_table(agg, methods=["advanced"])
|
|
assert "\\textbf{advanced}" in latex
|
|
assert "\\textbf{basic}" not in latex
|
|
|
|
def test_custom_metric(self):
|
|
agg = {
|
|
"model": {
|
|
"method": {
|
|
"n_runs": 2,
|
|
"perplexity": {"mean": 5.2, "std": 0.3, "n": 2, "min": 4.9, "max": 5.5},
|
|
},
|
|
},
|
|
}
|
|
latex = generate_latex_table(agg, metric="perplexity")
|
|
assert "5.2" in latex
|
|
|
|
def test_column_count_matches_methods(self):
|
|
agg = self._sample_aggregated()
|
|
latex = generate_latex_table(agg)
|
|
# 2 methods → "lcc" (1 model col + 2 method cols)
|
|
assert "{@{}lcc@{}}" in latex
|
|
|
|
|
|
# ── CLI integration ────────────────────────────────────────────────────
|
|
|
|
|
|
class TestCLIContributeFlag:
|
|
def test_contribute_flag_accepted(self):
|
|
"""Verify the --contribute flag parses without error."""
|
|
from obliteratus.cli import main
|
|
|
|
# We can't run the full command (no GPU), but verify parsing works
|
|
with pytest.raises(SystemExit):
|
|
# "obliterate" requires a model, so parse will fail,
|
|
# but if --contribute is not recognized it fails differently
|
|
main(["obliterate", "--help"])
|
|
|
|
def test_aggregate_command_accepted(self):
|
|
"""Verify the aggregate command parses without error."""
|
|
from obliteratus.cli import main
|
|
|
|
with pytest.raises(SystemExit):
|
|
main(["aggregate", "--help"])
|
|
|
|
|
|
# ── Package exports ────────────────────────────────────────────────────
|
|
|
|
|
|
class TestPackageExports:
|
|
def test_save_contribution_importable(self):
|
|
from obliteratus import save_contribution
|
|
assert callable(save_contribution)
|
|
|
|
def test_load_contributions_importable(self):
|
|
from obliteratus import load_contributions
|
|
assert callable(load_contributions)
|
|
|
|
def test_aggregate_results_importable(self):
|
|
from obliteratus import aggregate_results
|
|
assert callable(aggregate_results)
|
|
|
|
|
|
# ── End-to-end: save → load → aggregate ───────────────────────────────
|
|
|
|
|
|
class TestEndToEnd:
|
|
def test_save_load_aggregate_roundtrip(self, tmp_path):
|
|
"""Full roundtrip: save contributions, load them, aggregate."""
|
|
pipeline = _make_mock_pipeline()
|
|
|
|
# Save two contributions (different models to avoid filename collision)
|
|
save_contribution(
|
|
pipeline, model_name="test/model-a", output_dir=tmp_path,
|
|
)
|
|
# Tweak metrics for second run with a different model name
|
|
pipeline._quality_metrics = {"perplexity": 5.5, "coherence": 0.75, "refusal_rate": 0.07}
|
|
save_contribution(
|
|
pipeline, model_name="test/model-b", output_dir=tmp_path,
|
|
)
|
|
|
|
# Load
|
|
records = load_contributions(tmp_path)
|
|
assert len(records) == 2
|
|
|
|
# Aggregate
|
|
aggregated = aggregate_results(records)
|
|
assert "test/model-a" in aggregated
|
|
assert "test/model-b" in aggregated
|
|
stats_a = aggregated["test/model-a"]["advanced"]
|
|
stats_b = aggregated["test/model-b"]["advanced"]
|
|
assert stats_a["n_runs"] == 1
|
|
assert stats_b["n_runs"] == 1
|
|
assert abs(stats_a["refusal_rate"]["mean"] - 0.05) < 0.001
|
|
assert abs(stats_b["refusal_rate"]["mean"] - 0.07) < 0.001
|
|
|
|
def test_save_load_aggregate_to_latex(self, tmp_path):
|
|
"""Full roundtrip ending in LaTeX output."""
|
|
pipeline = _make_mock_pipeline()
|
|
save_contribution(
|
|
pipeline, model_name="meta-llama/Llama-2-7b-chat-hf", output_dir=tmp_path,
|
|
)
|
|
|
|
records = load_contributions(tmp_path)
|
|
aggregated = aggregate_results(records)
|
|
latex = generate_latex_table(aggregated)
|
|
|
|
assert "\\begin{tabular}" in latex
|
|
assert "Llama-2-7b-chat-hf" in latex
|
|
assert "advanced" in latex
|