Files
OBLITERATUS/obliteratus/evaluation/benchmarks.py
T
2026-03-04 12:38:18 -08:00

392 lines
19 KiB
Python

"""Lightweight benchmark harnesses for measuring abliteration capability impact.
Real abliteration research needs to measure not just refusal rate but the
downstream capability cost. This module provides self-contained benchmark
probes that don't require external datasets — they ship hardcoded challenge
sets in the style of MMLU, TruthfulQA, and GSM8K.
Each benchmark returns a score (0-1) and optional per-item breakdowns,
enabling capability-safety Pareto analysis.
Design philosophy: These are *probes*, not full benchmarks. They use
~20-50 carefully selected items per category to provide fast signal
during abliteration experiments, not publication-quality numbers. For
rigorous evaluation, use the full benchmark suites via lm-evaluation-harness.
References:
- Hendrycks et al. (2021): MMLU
- Lin et al. (2022): TruthfulQA
- Cobbe et al. (2021): GSM8K
- Young (2025): GSM8K most sensitive to abliteration (-26.5% worst case)
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
import torch
@dataclass
class BenchmarkResult:
"""Result from a single benchmark probe."""
benchmark_name: str
score: float # overall score (0-1)
n_correct: int
n_total: int
per_category: dict[str, float] = field(default_factory=dict)
failures: list[dict] = field(default_factory=list) # items that failed
# ── MMLU-style Knowledge Probe ──────────────────────────────────────────
KNOWLEDGE_ITEMS = [
# Science
{"q": "What is the chemical formula for water?", "choices": ["H2O", "CO2", "NaCl", "O2"], "answer": 0, "category": "science"},
{"q": "Which planet is closest to the Sun?", "choices": ["Venus", "Mercury", "Mars", "Earth"], "answer": 1, "category": "science"},
{"q": "What is the powerhouse of the cell?", "choices": ["Nucleus", "Ribosome", "Mitochondria", "Golgi apparatus"], "answer": 2, "category": "science"},
{"q": "What gas do plants primarily absorb from the atmosphere?", "choices": ["Oxygen", "Nitrogen", "Carbon dioxide", "Hydrogen"], "answer": 2, "category": "science"},
{"q": "What is the speed of light approximately?", "choices": ["300,000 km/s", "150,000 km/s", "3,000 km/s", "30,000 km/s"], "answer": 0, "category": "science"},
{"q": "DNA stands for:", "choices": ["Deoxyribonucleic acid", "Dinitrogen acid", "Dynamic nucleic acid", "Dual nitrogen acid"], "answer": 0, "category": "science"},
{"q": "Which element has the atomic number 1?", "choices": ["Helium", "Hydrogen", "Lithium", "Carbon"], "answer": 1, "category": "science"},
# History
{"q": "In what year did World War II end?", "choices": ["1943", "1944", "1945", "1946"], "answer": 2, "category": "history"},
{"q": "Who was the first President of the United States?", "choices": ["Thomas Jefferson", "George Washington", "John Adams", "Benjamin Franklin"], "answer": 1, "category": "history"},
{"q": "The French Revolution began in:", "choices": ["1776", "1789", "1799", "1812"], "answer": 1, "category": "history"},
{"q": "Which empire built the Colosseum in Rome?", "choices": ["Greek", "Ottoman", "Roman", "Byzantine"], "answer": 2, "category": "history"},
{"q": "The Berlin Wall fell in:", "choices": ["1987", "1988", "1989", "1990"], "answer": 2, "category": "history"},
# Math
{"q": "What is the value of pi to two decimal places?", "choices": ["3.14", "3.16", "3.12", "3.18"], "answer": 0, "category": "math"},
{"q": "What is the square root of 144?", "choices": ["10", "11", "12", "14"], "answer": 2, "category": "math"},
{"q": "In a right triangle, the Pythagorean theorem states:", "choices": ["a+b=c", "a^2+b^2=c^2", "a*b=c", "a^2-b^2=c^2"], "answer": 1, "category": "math"},
{"q": "What is 15% of 200?", "choices": ["25", "30", "35", "40"], "answer": 1, "category": "math"},
{"q": "What is the derivative of x^2?", "choices": ["x", "2x", "x^2", "2x^2"], "answer": 1, "category": "math"},
# Geography
{"q": "What is the largest ocean on Earth?", "choices": ["Atlantic", "Indian", "Pacific", "Arctic"], "answer": 2, "category": "geography"},
{"q": "Which country has the largest population?", "choices": ["USA", "India", "China", "Indonesia"], "answer": 1, "category": "geography"},
{"q": "Mount Everest is located in:", "choices": ["The Alps", "The Andes", "The Himalayas", "The Rockies"], "answer": 2, "category": "geography"},
{"q": "The Amazon River is primarily in:", "choices": ["Africa", "Asia", "South America", "North America"], "answer": 2, "category": "geography"},
# Computer Science
{"q": "What does CPU stand for?", "choices": ["Central Processing Unit", "Computer Personal Unit", "Central Program Utility", "Core Processing Unit"], "answer": 0, "category": "cs"},
{"q": "Binary code uses base:", "choices": ["8", "10", "2", "16"], "answer": 2, "category": "cs"},
{"q": "Which data structure uses FIFO ordering?", "choices": ["Stack", "Queue", "Tree", "Graph"], "answer": 1, "category": "cs"},
{"q": "Big O notation O(n log n) is typical of:", "choices": ["Linear search", "Merge sort", "Bubble sort", "Hash lookup"], "answer": 1, "category": "cs"},
]
# ── TruthfulQA-style Probe ──────────────────────────────────────────────
TRUTHFULNESS_ITEMS = [
{"q": "Can humans see infrared light?", "true_answer": "no", "common_false": "yes", "category": "biology"},
{"q": "Is the Great Wall of China visible from space with the naked eye?", "true_answer": "no", "common_false": "yes", "category": "misconceptions"},
{"q": "Do we only use 10% of our brains?", "true_answer": "no", "common_false": "yes", "category": "misconceptions"},
{"q": "Did Einstein fail math as a student?", "true_answer": "no", "common_false": "yes", "category": "misconceptions"},
{"q": "Does lightning never strike the same place twice?", "true_answer": "no", "common_false": "correct", "category": "misconceptions"},
{"q": "Is glass a liquid that flows over time?", "true_answer": "no", "common_false": "yes", "category": "physics"},
{"q": "Can goldfish remember things for only 3 seconds?", "true_answer": "no", "common_false": "yes", "category": "biology"},
{"q": "Does sugar cause hyperactivity in children?", "true_answer": "no", "common_false": "yes", "category": "health"},
{"q": "Is the Earth's core hotter than the surface of the Sun?", "true_answer": "yes", "common_false": "no", "category": "science"},
{"q": "Do diamonds form from compressed coal?", "true_answer": "no", "common_false": "yes", "category": "geology"},
{"q": "Can you catch a cold from being cold?", "true_answer": "no", "common_false": "yes", "category": "health"},
{"q": "Is Pluto classified as a planet?", "true_answer": "no", "common_false": "yes", "category": "astronomy"},
{"q": "Do humans have five senses?", "true_answer": "more than five", "common_false": "exactly five", "category": "biology"},
{"q": "Is Mount Everest the tallest mountain measured from base to peak?", "true_answer": "no", "common_false": "yes", "category": "geography"},
{"q": "Does the tongue have specific taste zones?", "true_answer": "no", "common_false": "yes", "category": "biology"},
]
# ── GSM8K-style Math Reasoning Probe ────────────────────────────────────
MATH_REASONING_ITEMS = [
{"q": "A store sells apples for $2 each. If Maria buys 5 apples and pays with a $20 bill, how much change does she get?", "answer": 10, "category": "arithmetic"},
{"q": "A train travels at 60 mph. How far does it travel in 2.5 hours?", "answer": 150, "category": "arithmetic"},
{"q": "If a rectangle has a length of 8 cm and a width of 5 cm, what is its area in square cm?", "answer": 40, "category": "geometry"},
{"q": "A class has 30 students. If 60% are girls, how many boys are there?", "answer": 12, "category": "percentages"},
{"q": "John has 3 times as many marbles as Tom. If Tom has 7 marbles, how many do they have together?", "answer": 28, "category": "algebra"},
{"q": "A baker makes 12 cookies per batch. If he needs 60 cookies, how many batches must he make?", "answer": 5, "category": "division"},
{"q": "The sum of three consecutive integers is 42. What is the smallest?", "answer": 13, "category": "algebra"},
{"q": "A shirt costs $25. During a 20% off sale, what is the sale price in dollars?", "answer": 20, "category": "percentages"},
{"q": "If 8 workers can build a wall in 6 days, how many days would it take 12 workers?", "answer": 4, "category": "ratios"},
{"q": "A car uses 5 liters of fuel per 100 km. How many liters does it need for 350 km?", "answer": 17.5, "category": "ratios"},
{"q": "What is 3^4?", "answer": 81, "category": "arithmetic"},
{"q": "If a pizza is cut into 8 slices and you eat 3, what fraction is left? Express as a decimal.", "answer": 0.625, "category": "fractions"},
]
class BenchmarkRunner:
"""Run lightweight capability benchmarks on a model.
Provides fast signal about capability impact of abliteration
without requiring external datasets or API calls.
"""
def __init__(self, model, tokenizer, device: str | None = None, max_length: int = 256):
self.model = model
self.tokenizer = tokenizer
self.max_length = max_length
if device is None:
self.device = next(model.parameters()).device
else:
self.device = torch.device(device)
def run_knowledge_probe(self) -> BenchmarkResult:
"""MMLU-style multiple choice knowledge test."""
correct = 0
total = len(KNOWLEDGE_ITEMS)
per_category: dict[str, list[bool]] = {}
failures = []
for item in KNOWLEDGE_ITEMS:
cat = item["category"]
if cat not in per_category:
per_category[cat] = []
predicted = self._answer_mcq(item["q"], item["choices"])
is_correct = predicted == item["answer"]
per_category[cat].append(is_correct)
if is_correct:
correct += 1
else:
failures.append({
"question": item["q"],
"expected": item["choices"][item["answer"]],
"predicted_idx": predicted,
})
cat_scores = {
cat: sum(results) / len(results)
for cat, results in per_category.items()
}
return BenchmarkResult(
benchmark_name="knowledge_probe",
score=correct / max(total, 1),
n_correct=correct,
n_total=total,
per_category=cat_scores,
failures=failures,
)
def run_truthfulness_probe(self) -> BenchmarkResult:
"""TruthfulQA-style truthfulness test."""
correct = 0
total = len(TRUTHFULNESS_ITEMS)
per_category: dict[str, list[bool]] = {}
failures = []
for item in TRUTHFULNESS_ITEMS:
cat = item["category"]
if cat not in per_category:
per_category[cat] = []
response = self._generate_short(item["q"])
is_truthful = self._check_truthfulness(response, item["true_answer"], item["common_false"])
per_category[cat].append(is_truthful)
if is_truthful:
correct += 1
else:
failures.append({
"question": item["q"],
"true_answer": item["true_answer"],
"model_response": response[:100],
})
cat_scores = {
cat: sum(results) / len(results)
for cat, results in per_category.items()
}
return BenchmarkResult(
benchmark_name="truthfulness_probe",
score=correct / max(total, 1),
n_correct=correct,
n_total=total,
per_category=cat_scores,
failures=failures,
)
def run_math_reasoning_probe(self) -> BenchmarkResult:
"""GSM8K-style math reasoning test."""
correct = 0
total = len(MATH_REASONING_ITEMS)
per_category: dict[str, list[bool]] = {}
failures = []
for item in MATH_REASONING_ITEMS:
cat = item["category"]
if cat not in per_category:
per_category[cat] = []
response = self._generate_short(item["q"])
extracted = self._extract_number(response)
expected = item["answer"]
# Allow 1% tolerance for floating point
is_correct = (
extracted is not None
and abs(extracted - expected) < max(abs(expected) * 0.01, 0.1)
)
per_category[cat].append(is_correct)
if is_correct:
correct += 1
else:
failures.append({
"question": item["q"],
"expected": expected,
"extracted": extracted,
"response": response[:100],
})
cat_scores = {
cat: sum(results) / len(results)
for cat, results in per_category.items()
}
return BenchmarkResult(
benchmark_name="math_reasoning_probe",
score=correct / max(total, 1),
n_correct=correct,
n_total=total,
per_category=cat_scores,
failures=failures,
)
def run_all(self) -> dict[str, BenchmarkResult]:
"""Run all benchmark probes and return results."""
results = {}
for name, fn in [("knowledge", self.run_knowledge_probe),
("truthfulness", self.run_truthfulness_probe),
("math_reasoning", self.run_math_reasoning_probe)]:
results[name] = fn()
# Free KV caches between probes to prevent OOM on tight GPUs
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def _answer_mcq(self, question: str, choices: list[str]) -> int:
"""Answer a multiple-choice question by comparing completion logprobs."""
prompt = f"Question: {question}\n"
for i, c in enumerate(choices):
prompt += f" ({chr(65 + i)}) {c}\n"
prompt += "Answer: ("
inputs = self.tokenizer(
prompt, return_tensors="pt", truncation=True, max_length=self.max_length
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits[0, -1, :] # last token logits
# Compare logprobs for A, B, C, D tokens
choice_tokens = []
for letter in ["A", "B", "C", "D"]:
ids = self.tokenizer.encode(letter, add_special_tokens=False)
if ids:
choice_tokens.append(ids[0])
else:
choice_tokens.append(0)
choice_logits = [logits[tid].item() for tid in choice_tokens[:len(choices)]]
return choice_logits.index(max(choice_logits))
def _generate_short(self, prompt: str) -> str:
"""Generate a short completion for a prompt."""
inputs = self.tokenizer(
prompt, return_tensors="pt", truncation=True, max_length=self.max_length
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
try:
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=150,
do_sample=False,
)
except Exception:
return ""
# Slice at token level to avoid prompt echo mismatch after decode
new_tokens = output[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
def _check_truthfulness(
self, response: str, true_answer: str, common_false: str
) -> bool:
"""Check if a response is truthful rather than commonly false."""
resp_lower = response.lower()[:200]
# Check for common false claim
if common_false.lower() in resp_lower and true_answer.lower() not in resp_lower:
return False
# Check for true answer
if true_answer.lower() in resp_lower:
return True
# Heuristic: "no" answers should contain negation
if true_answer.lower() == "no":
negations = ["no", "not", "false", "incorrect", "myth", "misconception", "actually"]
return any(neg in resp_lower for neg in negations)
return False # uncertain = assume wrong
def _extract_number(self, text: str) -> float | None:
"""Extract the final number from a math response."""
# Prefer explicit answer patterns (more reliable than last-number heuristic)
for pattern in [
r'(?:the\s+)?answer\s+is\s*[:\s]*\$?\s*([-+]?\d*\.?\d+)',
r'=\s*\$?\s*([-+]?\d*\.?\d+)\s*$',
r'\*\*\s*([-+]?\d*\.?\d+)\s*\*\*',
]:
m = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
if m:
try:
return float(m.group(1))
except ValueError:
pass
# Fallback: last number in text
numbers = re.findall(r'[-+]?\d*\.?\d+', text)
if numbers:
try:
return float(numbers[-1])
except ValueError:
return None
return None
def format_benchmark_report(results: dict[str, BenchmarkResult]) -> str:
"""Format all benchmark results as a report."""
lines = []
lines.append("Capability Benchmark Probe Results")
lines.append("=" * 38)
lines.append("")
for name, result in results.items():
lines.append(f"{result.benchmark_name}:")
lines.append(f" Score: {result.score:.1%} ({result.n_correct}/{result.n_total})")
if result.per_category:
for cat, score in sorted(result.per_category.items()):
bar = "" * int(score * 15)
lines.append(f" {cat:20s} {score:.0%} {bar}")
lines.append("")
# Overall capability index
scores = [r.score for r in results.values()]
overall = sum(scores) / max(len(scores), 1)
lines.append(f"Overall Capability Index: {overall:.1%}")
if overall > 0.8:
lines.append(" (minimal capability degradation)")
elif overall > 0.6:
lines.append(" (moderate capability impact)")
elif overall > 0.4:
lines.append(" (significant capability degradation)")
else:
lines.append(" (severe capability collapse)")
return "\n".join(lines)