mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-06-08 07:13:56 +02:00
303 lines
11 KiB
Python
303 lines
11 KiB
Python
"""Extended tests for novel abliteration pipeline features.
|
|
|
|
Tests the new capabilities added to the OBLITERATUS abliteration pipeline:
|
|
- Bias projection
|
|
- Chat template wrapping
|
|
- Method presets with new parameters
|
|
- True iterative refinement
|
|
- Whitened SVD integration
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import torch
|
|
from transformers import GPT2Config, GPT2LMHeadModel
|
|
|
|
from obliteratus.abliterate import (
|
|
METHODS,
|
|
AbliterationPipeline,
|
|
)
|
|
from obliteratus.models.loader import ModelHandle
|
|
|
|
|
|
def _make_tiny_handle():
|
|
"""Create a minimal ModelHandle with a tiny GPT-2 for testing."""
|
|
config = GPT2Config(
|
|
vocab_size=1000,
|
|
n_positions=128,
|
|
n_embd=64,
|
|
n_layer=4,
|
|
n_head=2,
|
|
n_inner=256,
|
|
)
|
|
model = GPT2LMHeadModel(config)
|
|
model.eval()
|
|
|
|
tokenizer = MagicMock()
|
|
tokenizer.pad_token = "<pad>"
|
|
tokenizer.eos_token = "<eos>"
|
|
tokenizer.return_value = {
|
|
"input_ids": torch.randint(0, 1000, (1, 10)),
|
|
"attention_mask": torch.ones(1, 10, dtype=torch.long),
|
|
}
|
|
tokenizer.decode.return_value = "The capital of France is Paris, a beautiful city"
|
|
|
|
handle = ModelHandle(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
config=config,
|
|
model_name="gpt2-test",
|
|
task="causal_lm",
|
|
)
|
|
handle.snapshot()
|
|
return handle
|
|
|
|
|
|
def _make_varied_tokenizer(handle):
|
|
"""Set up a tokenizer mock that returns different tokens per call."""
|
|
call_count = [0]
|
|
def mock_tokenizer(prompt, **kwargs):
|
|
call_count[0] += 1
|
|
torch.manual_seed(call_count[0])
|
|
return {
|
|
"input_ids": torch.randint(0, 1000, (1, 5)),
|
|
"attention_mask": torch.ones(1, 5, dtype=torch.long),
|
|
}
|
|
handle.tokenizer.side_effect = mock_tokenizer
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# New method preset parameters
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestNewMethodPresets:
|
|
def test_basic_has_new_params(self):
|
|
cfg = METHODS["basic"]
|
|
assert "project_biases" in cfg
|
|
assert "use_chat_template" in cfg
|
|
assert "use_whitened_svd" in cfg
|
|
assert "true_iterative_refinement" in cfg
|
|
assert cfg["project_biases"] is False
|
|
assert cfg["use_chat_template"] is False
|
|
|
|
def test_advanced_has_new_params(self):
|
|
cfg = METHODS["advanced"]
|
|
assert cfg["project_biases"] is True
|
|
assert cfg["use_chat_template"] is True
|
|
assert cfg["use_whitened_svd"] is False
|
|
assert cfg["true_iterative_refinement"] is False
|
|
|
|
def test_aggressive_has_new_params(self):
|
|
cfg = METHODS["aggressive"]
|
|
assert cfg["project_biases"] is True
|
|
assert cfg["use_chat_template"] is True
|
|
assert cfg["use_whitened_svd"] is True
|
|
assert cfg["true_iterative_refinement"] is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pipeline initialization with new parameters
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestNewPipelineInit:
|
|
def test_default_new_params(self):
|
|
pipeline = AbliterationPipeline(model_name="test-model")
|
|
# advanced method defaults
|
|
assert pipeline.project_biases is True
|
|
assert pipeline.use_chat_template is True
|
|
assert pipeline.use_whitened_svd is False
|
|
assert pipeline.true_iterative_refinement is False
|
|
|
|
def test_basic_method_new_params(self):
|
|
pipeline = AbliterationPipeline(model_name="test-model", method="basic")
|
|
assert pipeline.project_biases is False
|
|
assert pipeline.use_chat_template is False
|
|
assert pipeline.use_whitened_svd is False
|
|
assert pipeline.true_iterative_refinement is False
|
|
|
|
def test_aggressive_method_new_params(self):
|
|
pipeline = AbliterationPipeline(model_name="test-model", method="aggressive")
|
|
assert pipeline.project_biases is True
|
|
assert pipeline.use_chat_template is True
|
|
assert pipeline.use_whitened_svd is True
|
|
assert pipeline.true_iterative_refinement is True
|
|
|
|
def test_explicit_overrides_new_params(self):
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
method="basic",
|
|
project_biases=True,
|
|
use_chat_template=True,
|
|
use_whitened_svd=True,
|
|
true_iterative_refinement=True,
|
|
)
|
|
assert pipeline.project_biases is True
|
|
assert pipeline.use_chat_template is True
|
|
assert pipeline.use_whitened_svd is True
|
|
assert pipeline.true_iterative_refinement is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Bias projection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestBiasProjection:
|
|
def test_project_bias_removes_component(self):
|
|
"""Bias projection should remove refusal direction component from bias."""
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(4, 4, bias=True)
|
|
|
|
module = Wrapper()
|
|
torch.manual_seed(42)
|
|
module.o_proj.bias.data = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
|
|
|
direction = torch.tensor([1.0, 0.0, 0.0, 0.0]).unsqueeze(-1) # unit vector along dim 0
|
|
|
|
count = AbliterationPipeline._project_bias(module, direction, ["o_proj"])
|
|
assert count == 1
|
|
|
|
# The component along direction [1,0,0,0] was 1.0, should now be ~0
|
|
new_bias = module.o_proj.bias.data
|
|
projection_onto_dir = (new_bias @ direction.squeeze()).item()
|
|
assert abs(projection_onto_dir) < 1e-5
|
|
|
|
# Other components should be unchanged
|
|
assert abs(new_bias[1].item() - 2.0) < 1e-5
|
|
assert abs(new_bias[2].item() - 3.0) < 1e-5
|
|
assert abs(new_bias[3].item() - 4.0) < 1e-5
|
|
|
|
def test_project_bias_no_bias(self):
|
|
"""Should handle modules without bias gracefully."""
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(4, 4, bias=False)
|
|
|
|
module = Wrapper()
|
|
direction = torch.randn(4, 1)
|
|
count = AbliterationPipeline._project_bias(module, direction, ["o_proj"])
|
|
assert count == 0
|
|
|
|
def test_project_bias_no_matching_module(self):
|
|
"""Should return 0 when no candidate names match."""
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.something = torch.nn.Linear(4, 4, bias=True)
|
|
|
|
module = Wrapper()
|
|
direction = torch.randn(4, 1)
|
|
count = AbliterationPipeline._project_bias(module, direction, ["o_proj"])
|
|
assert count == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Chat template wrapping
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestChatTemplate:
|
|
def test_no_wrap_when_disabled(self):
|
|
"""Should not wrap prompts when use_chat_template is False."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
method="basic",
|
|
use_chat_template=False,
|
|
)
|
|
prompts = ["Hello", "World"]
|
|
result = pipeline._maybe_apply_chat_template(prompts)
|
|
assert result == prompts
|
|
|
|
def test_no_wrap_without_handle(self):
|
|
"""Should return raw prompts when handle is not set."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
use_chat_template=True,
|
|
)
|
|
prompts = ["Hello"]
|
|
result = pipeline._maybe_apply_chat_template(prompts)
|
|
assert result == prompts
|
|
|
|
def test_wraps_with_template(self):
|
|
"""Should wrap prompts when tokenizer has apply_chat_template."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
use_chat_template=True,
|
|
)
|
|
handle = MagicMock()
|
|
tokenizer = MagicMock()
|
|
|
|
def mock_apply(messages, tokenize=False, add_generation_prompt=True):
|
|
return f"<user>{messages[0]['content']}</user><assistant>"
|
|
|
|
tokenizer.apply_chat_template = mock_apply
|
|
handle.tokenizer = tokenizer
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
|
|
result = pipeline._maybe_apply_chat_template(["Hello"])
|
|
assert "<user>Hello</user>" in result[0]
|
|
|
|
def test_fallback_when_no_template(self):
|
|
"""Should fall back to raw prompts when template is not configured."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
use_chat_template=True,
|
|
)
|
|
handle = MagicMock()
|
|
tokenizer = MagicMock()
|
|
tokenizer.apply_chat_template.side_effect = Exception("No template")
|
|
handle.tokenizer = tokenizer
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
|
|
result = pipeline._maybe_apply_chat_template(["Hello"])
|
|
assert result == ["Hello"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Metadata includes new fields
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestMetadata:
|
|
def test_rebirth_includes_new_config(self):
|
|
"""Metadata should include all new configuration parameters."""
|
|
import json
|
|
handle = _make_tiny_handle()
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
method="aggressive",
|
|
)
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
pipeline._strong_layers = [0]
|
|
pipeline._quality_metrics = {"perplexity": 8.5, "coherence": 1.0}
|
|
|
|
handle.model.save_pretrained = MagicMock()
|
|
handle.tokenizer.save_pretrained = MagicMock()
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
pipeline.output_dir = Path(tmp) / "output"
|
|
pipeline._rebirth()
|
|
|
|
metadata = json.loads(
|
|
(pipeline.output_dir / "abliteration_metadata.json").read_text()
|
|
)
|
|
cfg = metadata["method_config"]
|
|
assert "project_biases" in cfg
|
|
assert "use_chat_template" in cfg
|
|
assert "use_whitened_svd" in cfg
|
|
assert "true_iterative_refinement" in cfg
|
|
assert cfg["project_biases"] is True
|
|
assert cfg["use_whitened_svd"] is True
|
|
|
|
# Should have more references now
|
|
assert len(metadata["references"]) >= 5
|
|
assert any("OBLITERATUS" in r for r in metadata["references"])
|