mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-24 04:06:06 +02:00
61 lines
1.7 KiB
Python
61 lines
1.7 KiB
Python
"""Tests for evaluation metrics."""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import torch
|
|
|
|
from obliteratus.evaluation.metrics import accuracy, f1_score_metric, perplexity
|
|
|
|
|
|
class TestPerplexity:
|
|
def test_perfect_prediction(self):
|
|
# Create logits that strongly predict the correct next token
|
|
vocab_size = 10
|
|
seq_len = 5
|
|
batch_size = 1
|
|
|
|
labels = torch.tensor([[0, 1, 2, 3, 4]])
|
|
logits = torch.full((batch_size, seq_len, vocab_size), -100.0)
|
|
# Set high logit for the correct next token
|
|
for t in range(seq_len - 1):
|
|
logits[0, t, labels[0, t + 1]] = 100.0
|
|
|
|
ppl = perplexity(logits, labels)
|
|
assert ppl < 2.0, f"Expected near-1 perplexity, got {ppl}"
|
|
|
|
def test_random_prediction_higher(self):
|
|
vocab_size = 100
|
|
seq_len = 20
|
|
batch_size = 2
|
|
|
|
torch.manual_seed(42)
|
|
logits = torch.randn(batch_size, seq_len, vocab_size)
|
|
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
|
|
|
|
ppl = perplexity(logits, labels)
|
|
assert ppl > 10, f"Random logits should yield high perplexity, got {ppl}"
|
|
|
|
|
|
class TestAccuracy:
|
|
def test_perfect(self):
|
|
assert accuracy([1, 2, 3], [1, 2, 3]) == 1.0
|
|
|
|
def test_zero(self):
|
|
assert accuracy([1, 2, 3], [4, 5, 6]) == 0.0
|
|
|
|
def test_partial(self):
|
|
assert accuracy([1, 2, 3, 4], [1, 2, 0, 0]) == 0.5
|
|
|
|
def test_empty(self):
|
|
assert accuracy([], []) == 0.0
|
|
|
|
|
|
class TestF1:
|
|
def test_perfect(self):
|
|
assert f1_score_metric([0, 1, 0, 1], [0, 1, 0, 1]) == 1.0
|
|
|
|
def test_zero(self):
|
|
score = f1_score_metric([0, 0, 0, 0], [1, 1, 1, 1])
|
|
assert score == 0.0
|