Files
OBLITERATUS/obliteratus/evaluation/evaluator.py
T
2026-03-04 12:38:18 -08:00

128 lines
4.2 KiB
Python

"""Evaluator: runs a model on a dataset and computes metrics."""
from __future__ import annotations
import torch
from tqdm import tqdm
from obliteratus.models.loader import ModelHandle
class Evaluator:
"""Evaluate a model handle on a dataset, returning metric results.
Supports two modes:
- **perplexity** (default for causal_lm): feeds tokenized text and computes PPL.
- **classification**: runs forward pass, takes argmax, computes accuracy/F1.
"""
def __init__(
self,
handle: ModelHandle,
dataset,
metrics: list[str] | None = None,
batch_size: int = 8,
max_length: int = 512,
max_samples: int | None = None,
text_column: str = "text",
label_column: str = "label",
):
self.handle = handle
self.dataset = dataset
self.metrics = metrics or (
["perplexity"] if handle.task == "causal_lm" else ["accuracy", "f1"]
)
self.batch_size = batch_size
self.max_length = max_length
self.max_samples = max_samples
self.text_column = text_column
self.label_column = label_column
@torch.no_grad()
def evaluate(self) -> dict[str, float]:
"""Run evaluation and return a dict of metric_name -> score."""
if self.handle.task == "causal_lm":
return self._evaluate_causal_lm()
elif self.handle.task == "classification":
return self._evaluate_classification()
else:
raise ValueError(f"Unsupported task: {self.handle.task}")
def _evaluate_causal_lm(self) -> dict[str, float]:
model = self.handle.model
tokenizer = self.handle.tokenizer
device = next(model.parameters()).device
ds = self.dataset
if self.max_samples is not None:
ds = ds.select(range(min(self.max_samples, len(ds))))
total_loss = 0.0
total_tokens = 0
for i in tqdm(range(0, len(ds), self.batch_size), desc="Evaluating PPL"):
batch_texts = ds[i : i + self.batch_size][self.text_column]
encodings = tokenizer(
batch_texts,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=True,
).to(device)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
# Mask out padding tokens for loss computation
num_tokens = attention_mask[:, 1:].sum().item()
total_loss += outputs.loss.item() * num_tokens
total_tokens += num_tokens
import math
avg_loss = total_loss / max(total_tokens, 1)
return {"perplexity": math.exp(avg_loss)}
def _evaluate_classification(self) -> dict[str, float]:
from obliteratus.evaluation.metrics import accuracy as acc_fn
from obliteratus.evaluation.metrics import f1_score_metric as f1_fn
model = self.handle.model
tokenizer = self.handle.tokenizer
device = next(model.parameters()).device
ds = self.dataset
if self.max_samples is not None:
ds = ds.select(range(min(self.max_samples, len(ds))))
all_preds = []
all_labels = []
for i in tqdm(range(0, len(ds), self.batch_size), desc="Evaluating"):
batch = ds[i : i + self.batch_size]
texts = batch[self.text_column]
labels = batch[self.label_column]
encodings = tokenizer(
texts,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=True,
).to(device)
outputs = model(**encodings)
preds = outputs.logits.argmax(dim=-1).cpu().tolist()
all_preds.extend(preds)
all_labels.extend(labels)
results = {}
if "accuracy" in self.metrics:
results["accuracy"] = acc_fn(all_preds, all_labels)
if "f1" in self.metrics:
results["f1"] = f1_fn(all_preds, all_labels)
return results