mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-05-12 20:02:52 +02:00
392 lines
19 KiB
Python
392 lines
19 KiB
Python
"""Lightweight benchmark harnesses for measuring abliteration capability impact.
|
|
|
|
Real abliteration research needs to measure not just refusal rate but the
|
|
downstream capability cost. This module provides self-contained benchmark
|
|
probes that don't require external datasets — they ship hardcoded challenge
|
|
sets in the style of MMLU, TruthfulQA, and GSM8K.
|
|
|
|
Each benchmark returns a score (0-1) and optional per-item breakdowns,
|
|
enabling capability-safety Pareto analysis.
|
|
|
|
Design philosophy: These are *probes*, not full benchmarks. They use
|
|
~20-50 carefully selected items per category to provide fast signal
|
|
during abliteration experiments, not publication-quality numbers. For
|
|
rigorous evaluation, use the full benchmark suites via lm-evaluation-harness.
|
|
|
|
References:
|
|
- Hendrycks et al. (2021): MMLU
|
|
- Lin et al. (2022): TruthfulQA
|
|
- Cobbe et al. (2021): GSM8K
|
|
- Young (2025): GSM8K most sensitive to abliteration (-26.5% worst case)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkResult:
|
|
"""Result from a single benchmark probe."""
|
|
|
|
benchmark_name: str
|
|
score: float # overall score (0-1)
|
|
n_correct: int
|
|
n_total: int
|
|
per_category: dict[str, float] = field(default_factory=dict)
|
|
failures: list[dict] = field(default_factory=list) # items that failed
|
|
|
|
|
|
# ── MMLU-style Knowledge Probe ──────────────────────────────────────────
|
|
|
|
KNOWLEDGE_ITEMS = [
|
|
# Science
|
|
{"q": "What is the chemical formula for water?", "choices": ["H2O", "CO2", "NaCl", "O2"], "answer": 0, "category": "science"},
|
|
{"q": "Which planet is closest to the Sun?", "choices": ["Venus", "Mercury", "Mars", "Earth"], "answer": 1, "category": "science"},
|
|
{"q": "What is the powerhouse of the cell?", "choices": ["Nucleus", "Ribosome", "Mitochondria", "Golgi apparatus"], "answer": 2, "category": "science"},
|
|
{"q": "What gas do plants primarily absorb from the atmosphere?", "choices": ["Oxygen", "Nitrogen", "Carbon dioxide", "Hydrogen"], "answer": 2, "category": "science"},
|
|
{"q": "What is the speed of light approximately?", "choices": ["300,000 km/s", "150,000 km/s", "3,000 km/s", "30,000 km/s"], "answer": 0, "category": "science"},
|
|
{"q": "DNA stands for:", "choices": ["Deoxyribonucleic acid", "Dinitrogen acid", "Dynamic nucleic acid", "Dual nitrogen acid"], "answer": 0, "category": "science"},
|
|
{"q": "Which element has the atomic number 1?", "choices": ["Helium", "Hydrogen", "Lithium", "Carbon"], "answer": 1, "category": "science"},
|
|
# History
|
|
{"q": "In what year did World War II end?", "choices": ["1943", "1944", "1945", "1946"], "answer": 2, "category": "history"},
|
|
{"q": "Who was the first President of the United States?", "choices": ["Thomas Jefferson", "George Washington", "John Adams", "Benjamin Franklin"], "answer": 1, "category": "history"},
|
|
{"q": "The French Revolution began in:", "choices": ["1776", "1789", "1799", "1812"], "answer": 1, "category": "history"},
|
|
{"q": "Which empire built the Colosseum in Rome?", "choices": ["Greek", "Ottoman", "Roman", "Byzantine"], "answer": 2, "category": "history"},
|
|
{"q": "The Berlin Wall fell in:", "choices": ["1987", "1988", "1989", "1990"], "answer": 2, "category": "history"},
|
|
# Math
|
|
{"q": "What is the value of pi to two decimal places?", "choices": ["3.14", "3.16", "3.12", "3.18"], "answer": 0, "category": "math"},
|
|
{"q": "What is the square root of 144?", "choices": ["10", "11", "12", "14"], "answer": 2, "category": "math"},
|
|
{"q": "In a right triangle, the Pythagorean theorem states:", "choices": ["a+b=c", "a^2+b^2=c^2", "a*b=c", "a^2-b^2=c^2"], "answer": 1, "category": "math"},
|
|
{"q": "What is 15% of 200?", "choices": ["25", "30", "35", "40"], "answer": 1, "category": "math"},
|
|
{"q": "What is the derivative of x^2?", "choices": ["x", "2x", "x^2", "2x^2"], "answer": 1, "category": "math"},
|
|
# Geography
|
|
{"q": "What is the largest ocean on Earth?", "choices": ["Atlantic", "Indian", "Pacific", "Arctic"], "answer": 2, "category": "geography"},
|
|
{"q": "Which country has the largest population?", "choices": ["USA", "India", "China", "Indonesia"], "answer": 1, "category": "geography"},
|
|
{"q": "Mount Everest is located in:", "choices": ["The Alps", "The Andes", "The Himalayas", "The Rockies"], "answer": 2, "category": "geography"},
|
|
{"q": "The Amazon River is primarily in:", "choices": ["Africa", "Asia", "South America", "North America"], "answer": 2, "category": "geography"},
|
|
# Computer Science
|
|
{"q": "What does CPU stand for?", "choices": ["Central Processing Unit", "Computer Personal Unit", "Central Program Utility", "Core Processing Unit"], "answer": 0, "category": "cs"},
|
|
{"q": "Binary code uses base:", "choices": ["8", "10", "2", "16"], "answer": 2, "category": "cs"},
|
|
{"q": "Which data structure uses FIFO ordering?", "choices": ["Stack", "Queue", "Tree", "Graph"], "answer": 1, "category": "cs"},
|
|
{"q": "Big O notation O(n log n) is typical of:", "choices": ["Linear search", "Merge sort", "Bubble sort", "Hash lookup"], "answer": 1, "category": "cs"},
|
|
]
|
|
|
|
|
|
# ── TruthfulQA-style Probe ──────────────────────────────────────────────
|
|
|
|
TRUTHFULNESS_ITEMS = [
|
|
{"q": "Can humans see infrared light?", "true_answer": "no", "common_false": "yes", "category": "biology"},
|
|
{"q": "Is the Great Wall of China visible from space with the naked eye?", "true_answer": "no", "common_false": "yes", "category": "misconceptions"},
|
|
{"q": "Do we only use 10% of our brains?", "true_answer": "no", "common_false": "yes", "category": "misconceptions"},
|
|
{"q": "Did Einstein fail math as a student?", "true_answer": "no", "common_false": "yes", "category": "misconceptions"},
|
|
{"q": "Does lightning never strike the same place twice?", "true_answer": "no", "common_false": "correct", "category": "misconceptions"},
|
|
{"q": "Is glass a liquid that flows over time?", "true_answer": "no", "common_false": "yes", "category": "physics"},
|
|
{"q": "Can goldfish remember things for only 3 seconds?", "true_answer": "no", "common_false": "yes", "category": "biology"},
|
|
{"q": "Does sugar cause hyperactivity in children?", "true_answer": "no", "common_false": "yes", "category": "health"},
|
|
{"q": "Is the Earth's core hotter than the surface of the Sun?", "true_answer": "yes", "common_false": "no", "category": "science"},
|
|
{"q": "Do diamonds form from compressed coal?", "true_answer": "no", "common_false": "yes", "category": "geology"},
|
|
{"q": "Can you catch a cold from being cold?", "true_answer": "no", "common_false": "yes", "category": "health"},
|
|
{"q": "Is Pluto classified as a planet?", "true_answer": "no", "common_false": "yes", "category": "astronomy"},
|
|
{"q": "Do humans have five senses?", "true_answer": "more than five", "common_false": "exactly five", "category": "biology"},
|
|
{"q": "Is Mount Everest the tallest mountain measured from base to peak?", "true_answer": "no", "common_false": "yes", "category": "geography"},
|
|
{"q": "Does the tongue have specific taste zones?", "true_answer": "no", "common_false": "yes", "category": "biology"},
|
|
]
|
|
|
|
|
|
# ── GSM8K-style Math Reasoning Probe ────────────────────────────────────
|
|
|
|
MATH_REASONING_ITEMS = [
|
|
{"q": "A store sells apples for $2 each. If Maria buys 5 apples and pays with a $20 bill, how much change does she get?", "answer": 10, "category": "arithmetic"},
|
|
{"q": "A train travels at 60 mph. How far does it travel in 2.5 hours?", "answer": 150, "category": "arithmetic"},
|
|
{"q": "If a rectangle has a length of 8 cm and a width of 5 cm, what is its area in square cm?", "answer": 40, "category": "geometry"},
|
|
{"q": "A class has 30 students. If 60% are girls, how many boys are there?", "answer": 12, "category": "percentages"},
|
|
{"q": "John has 3 times as many marbles as Tom. If Tom has 7 marbles, how many do they have together?", "answer": 28, "category": "algebra"},
|
|
{"q": "A baker makes 12 cookies per batch. If he needs 60 cookies, how many batches must he make?", "answer": 5, "category": "division"},
|
|
{"q": "The sum of three consecutive integers is 42. What is the smallest?", "answer": 13, "category": "algebra"},
|
|
{"q": "A shirt costs $25. During a 20% off sale, what is the sale price in dollars?", "answer": 20, "category": "percentages"},
|
|
{"q": "If 8 workers can build a wall in 6 days, how many days would it take 12 workers?", "answer": 4, "category": "ratios"},
|
|
{"q": "A car uses 5 liters of fuel per 100 km. How many liters does it need for 350 km?", "answer": 17.5, "category": "ratios"},
|
|
{"q": "What is 3^4?", "answer": 81, "category": "arithmetic"},
|
|
{"q": "If a pizza is cut into 8 slices and you eat 3, what fraction is left? Express as a decimal.", "answer": 0.625, "category": "fractions"},
|
|
]
|
|
|
|
|
|
class BenchmarkRunner:
|
|
"""Run lightweight capability benchmarks on a model.
|
|
|
|
Provides fast signal about capability impact of abliteration
|
|
without requiring external datasets or API calls.
|
|
"""
|
|
|
|
def __init__(self, model, tokenizer, device: str | None = None, max_length: int = 256):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_length
|
|
if device is None:
|
|
self.device = next(model.parameters()).device
|
|
else:
|
|
self.device = torch.device(device)
|
|
|
|
def run_knowledge_probe(self) -> BenchmarkResult:
|
|
"""MMLU-style multiple choice knowledge test."""
|
|
correct = 0
|
|
total = len(KNOWLEDGE_ITEMS)
|
|
per_category: dict[str, list[bool]] = {}
|
|
failures = []
|
|
|
|
for item in KNOWLEDGE_ITEMS:
|
|
cat = item["category"]
|
|
if cat not in per_category:
|
|
per_category[cat] = []
|
|
|
|
predicted = self._answer_mcq(item["q"], item["choices"])
|
|
is_correct = predicted == item["answer"]
|
|
per_category[cat].append(is_correct)
|
|
if is_correct:
|
|
correct += 1
|
|
else:
|
|
failures.append({
|
|
"question": item["q"],
|
|
"expected": item["choices"][item["answer"]],
|
|
"predicted_idx": predicted,
|
|
})
|
|
|
|
cat_scores = {
|
|
cat: sum(results) / len(results)
|
|
for cat, results in per_category.items()
|
|
}
|
|
|
|
return BenchmarkResult(
|
|
benchmark_name="knowledge_probe",
|
|
score=correct / max(total, 1),
|
|
n_correct=correct,
|
|
n_total=total,
|
|
per_category=cat_scores,
|
|
failures=failures,
|
|
)
|
|
|
|
def run_truthfulness_probe(self) -> BenchmarkResult:
|
|
"""TruthfulQA-style truthfulness test."""
|
|
correct = 0
|
|
total = len(TRUTHFULNESS_ITEMS)
|
|
per_category: dict[str, list[bool]] = {}
|
|
failures = []
|
|
|
|
for item in TRUTHFULNESS_ITEMS:
|
|
cat = item["category"]
|
|
if cat not in per_category:
|
|
per_category[cat] = []
|
|
|
|
response = self._generate_short(item["q"])
|
|
is_truthful = self._check_truthfulness(response, item["true_answer"], item["common_false"])
|
|
per_category[cat].append(is_truthful)
|
|
if is_truthful:
|
|
correct += 1
|
|
else:
|
|
failures.append({
|
|
"question": item["q"],
|
|
"true_answer": item["true_answer"],
|
|
"model_response": response[:100],
|
|
})
|
|
|
|
cat_scores = {
|
|
cat: sum(results) / len(results)
|
|
for cat, results in per_category.items()
|
|
}
|
|
|
|
return BenchmarkResult(
|
|
benchmark_name="truthfulness_probe",
|
|
score=correct / max(total, 1),
|
|
n_correct=correct,
|
|
n_total=total,
|
|
per_category=cat_scores,
|
|
failures=failures,
|
|
)
|
|
|
|
def run_math_reasoning_probe(self) -> BenchmarkResult:
|
|
"""GSM8K-style math reasoning test."""
|
|
correct = 0
|
|
total = len(MATH_REASONING_ITEMS)
|
|
per_category: dict[str, list[bool]] = {}
|
|
failures = []
|
|
|
|
for item in MATH_REASONING_ITEMS:
|
|
cat = item["category"]
|
|
if cat not in per_category:
|
|
per_category[cat] = []
|
|
|
|
response = self._generate_short(item["q"])
|
|
extracted = self._extract_number(response)
|
|
expected = item["answer"]
|
|
|
|
# Allow 1% tolerance for floating point
|
|
is_correct = (
|
|
extracted is not None
|
|
and abs(extracted - expected) < max(abs(expected) * 0.01, 0.1)
|
|
)
|
|
per_category[cat].append(is_correct)
|
|
if is_correct:
|
|
correct += 1
|
|
else:
|
|
failures.append({
|
|
"question": item["q"],
|
|
"expected": expected,
|
|
"extracted": extracted,
|
|
"response": response[:100],
|
|
})
|
|
|
|
cat_scores = {
|
|
cat: sum(results) / len(results)
|
|
for cat, results in per_category.items()
|
|
}
|
|
|
|
return BenchmarkResult(
|
|
benchmark_name="math_reasoning_probe",
|
|
score=correct / max(total, 1),
|
|
n_correct=correct,
|
|
n_total=total,
|
|
per_category=cat_scores,
|
|
failures=failures,
|
|
)
|
|
|
|
def run_all(self) -> dict[str, BenchmarkResult]:
|
|
"""Run all benchmark probes and return results."""
|
|
results = {}
|
|
for name, fn in [("knowledge", self.run_knowledge_probe),
|
|
("truthfulness", self.run_truthfulness_probe),
|
|
("math_reasoning", self.run_math_reasoning_probe)]:
|
|
results[name] = fn()
|
|
# Free KV caches between probes to prevent OOM on tight GPUs
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
return results
|
|
|
|
def _answer_mcq(self, question: str, choices: list[str]) -> int:
|
|
"""Answer a multiple-choice question by comparing completion logprobs."""
|
|
prompt = f"Question: {question}\n"
|
|
for i, c in enumerate(choices):
|
|
prompt += f" ({chr(65 + i)}) {c}\n"
|
|
prompt += "Answer: ("
|
|
|
|
inputs = self.tokenizer(
|
|
prompt, return_tensors="pt", truncation=True, max_length=self.max_length
|
|
)
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
logits = outputs.logits[0, -1, :] # last token logits
|
|
|
|
# Compare logprobs for A, B, C, D tokens
|
|
choice_tokens = []
|
|
for letter in ["A", "B", "C", "D"]:
|
|
ids = self.tokenizer.encode(letter, add_special_tokens=False)
|
|
if ids:
|
|
choice_tokens.append(ids[0])
|
|
else:
|
|
choice_tokens.append(0)
|
|
|
|
choice_logits = [logits[tid].item() for tid in choice_tokens[:len(choices)]]
|
|
return choice_logits.index(max(choice_logits))
|
|
|
|
def _generate_short(self, prompt: str) -> str:
|
|
"""Generate a short completion for a prompt."""
|
|
inputs = self.tokenizer(
|
|
prompt, return_tensors="pt", truncation=True, max_length=self.max_length
|
|
)
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
output = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=150,
|
|
do_sample=False,
|
|
)
|
|
except Exception:
|
|
return ""
|
|
|
|
# Slice at token level to avoid prompt echo mismatch after decode
|
|
new_tokens = output[0][inputs["input_ids"].shape[1]:]
|
|
return self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
|
|
|
def _check_truthfulness(
|
|
self, response: str, true_answer: str, common_false: str
|
|
) -> bool:
|
|
"""Check if a response is truthful rather than commonly false."""
|
|
resp_lower = response.lower()[:200]
|
|
|
|
# Check for common false claim
|
|
if common_false.lower() in resp_lower and true_answer.lower() not in resp_lower:
|
|
return False
|
|
|
|
# Check for true answer
|
|
if true_answer.lower() in resp_lower:
|
|
return True
|
|
|
|
# Heuristic: "no" answers should contain negation
|
|
if true_answer.lower() == "no":
|
|
negations = ["no", "not", "false", "incorrect", "myth", "misconception", "actually"]
|
|
return any(neg in resp_lower for neg in negations)
|
|
|
|
return False # uncertain = assume wrong
|
|
|
|
def _extract_number(self, text: str) -> float | None:
|
|
"""Extract the final number from a math response."""
|
|
# Prefer explicit answer patterns (more reliable than last-number heuristic)
|
|
for pattern in [
|
|
r'(?:the\s+)?answer\s+is\s*[:\s]*\$?\s*([-+]?\d*\.?\d+)',
|
|
r'=\s*\$?\s*([-+]?\d*\.?\d+)\s*$',
|
|
r'\*\*\s*([-+]?\d*\.?\d+)\s*\*\*',
|
|
]:
|
|
m = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
|
|
if m:
|
|
try:
|
|
return float(m.group(1))
|
|
except ValueError:
|
|
pass
|
|
# Fallback: last number in text
|
|
numbers = re.findall(r'[-+]?\d*\.?\d+', text)
|
|
if numbers:
|
|
try:
|
|
return float(numbers[-1])
|
|
except ValueError:
|
|
return None
|
|
return None
|
|
|
|
|
|
def format_benchmark_report(results: dict[str, BenchmarkResult]) -> str:
|
|
"""Format all benchmark results as a report."""
|
|
lines = []
|
|
lines.append("Capability Benchmark Probe Results")
|
|
lines.append("=" * 38)
|
|
lines.append("")
|
|
|
|
for name, result in results.items():
|
|
lines.append(f"{result.benchmark_name}:")
|
|
lines.append(f" Score: {result.score:.1%} ({result.n_correct}/{result.n_total})")
|
|
if result.per_category:
|
|
for cat, score in sorted(result.per_category.items()):
|
|
bar = "█" * int(score * 15)
|
|
lines.append(f" {cat:20s} {score:.0%} {bar}")
|
|
lines.append("")
|
|
|
|
# Overall capability index
|
|
scores = [r.score for r in results.values()]
|
|
overall = sum(scores) / max(len(scores), 1)
|
|
lines.append(f"Overall Capability Index: {overall:.1%}")
|
|
if overall > 0.8:
|
|
lines.append(" (minimal capability degradation)")
|
|
elif overall > 0.6:
|
|
lines.append(" (moderate capability impact)")
|
|
elif overall > 0.4:
|
|
lines.append(" (significant capability degradation)")
|
|
else:
|
|
lines.append(" (severe capability collapse)")
|
|
|
|
return "\n".join(lines)
|