mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-05-31 11:29:35 +02:00
128 lines
4.2 KiB
Python
128 lines
4.2 KiB
Python
"""Evaluator: runs a model on a dataset and computes metrics."""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from obliteratus.models.loader import ModelHandle
|
|
|
|
|
|
class Evaluator:
|
|
"""Evaluate a model handle on a dataset, returning metric results.
|
|
|
|
Supports two modes:
|
|
- **perplexity** (default for causal_lm): feeds tokenized text and computes PPL.
|
|
- **classification**: runs forward pass, takes argmax, computes accuracy/F1.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
handle: ModelHandle,
|
|
dataset,
|
|
metrics: list[str] | None = None,
|
|
batch_size: int = 8,
|
|
max_length: int = 512,
|
|
max_samples: int | None = None,
|
|
text_column: str = "text",
|
|
label_column: str = "label",
|
|
):
|
|
self.handle = handle
|
|
self.dataset = dataset
|
|
self.metrics = metrics or (
|
|
["perplexity"] if handle.task == "causal_lm" else ["accuracy", "f1"]
|
|
)
|
|
self.batch_size = batch_size
|
|
self.max_length = max_length
|
|
self.max_samples = max_samples
|
|
self.text_column = text_column
|
|
self.label_column = label_column
|
|
|
|
@torch.no_grad()
|
|
def evaluate(self) -> dict[str, float]:
|
|
"""Run evaluation and return a dict of metric_name -> score."""
|
|
if self.handle.task == "causal_lm":
|
|
return self._evaluate_causal_lm()
|
|
elif self.handle.task == "classification":
|
|
return self._evaluate_classification()
|
|
else:
|
|
raise ValueError(f"Unsupported task: {self.handle.task}")
|
|
|
|
def _evaluate_causal_lm(self) -> dict[str, float]:
|
|
|
|
model = self.handle.model
|
|
tokenizer = self.handle.tokenizer
|
|
device = next(model.parameters()).device
|
|
|
|
ds = self.dataset
|
|
if self.max_samples is not None:
|
|
ds = ds.select(range(min(self.max_samples, len(ds))))
|
|
|
|
total_loss = 0.0
|
|
total_tokens = 0
|
|
|
|
for i in tqdm(range(0, len(ds), self.batch_size), desc="Evaluating PPL"):
|
|
batch_texts = ds[i : i + self.batch_size][self.text_column]
|
|
encodings = tokenizer(
|
|
batch_texts,
|
|
return_tensors="pt",
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
padding=True,
|
|
).to(device)
|
|
|
|
input_ids = encodings["input_ids"]
|
|
attention_mask = encodings["attention_mask"]
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
|
# Mask out padding tokens for loss computation
|
|
num_tokens = attention_mask[:, 1:].sum().item()
|
|
total_loss += outputs.loss.item() * num_tokens
|
|
total_tokens += num_tokens
|
|
|
|
import math
|
|
|
|
avg_loss = total_loss / max(total_tokens, 1)
|
|
return {"perplexity": math.exp(avg_loss)}
|
|
|
|
def _evaluate_classification(self) -> dict[str, float]:
|
|
from obliteratus.evaluation.metrics import accuracy as acc_fn
|
|
from obliteratus.evaluation.metrics import f1_score_metric as f1_fn
|
|
|
|
model = self.handle.model
|
|
tokenizer = self.handle.tokenizer
|
|
device = next(model.parameters()).device
|
|
|
|
ds = self.dataset
|
|
if self.max_samples is not None:
|
|
ds = ds.select(range(min(self.max_samples, len(ds))))
|
|
|
|
all_preds = []
|
|
all_labels = []
|
|
|
|
for i in tqdm(range(0, len(ds), self.batch_size), desc="Evaluating"):
|
|
batch = ds[i : i + self.batch_size]
|
|
texts = batch[self.text_column]
|
|
labels = batch[self.label_column]
|
|
|
|
encodings = tokenizer(
|
|
texts,
|
|
return_tensors="pt",
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
padding=True,
|
|
).to(device)
|
|
|
|
outputs = model(**encodings)
|
|
preds = outputs.logits.argmax(dim=-1).cpu().tolist()
|
|
all_preds.extend(preds)
|
|
all_labels.extend(labels)
|
|
|
|
results = {}
|
|
if "accuracy" in self.metrics:
|
|
results["accuracy"] = acc_fn(all_preds, all_labels)
|
|
if "f1" in self.metrics:
|
|
results["f1"] = f1_fn(all_preds, all_labels)
|
|
return results
|