mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-23 03:36:25 +02:00
161 lines
4.9 KiB
Python
161 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
"""Aggregate community contributions into paper-ready tables.
|
|
|
|
Usage:
|
|
python scripts/aggregate_contributions.py [--dir community_results] [--format latex|csv|json]
|
|
|
|
Reads all contribution JSON files from the specified directory, aggregates
|
|
them by model and method, and outputs summary tables suitable for inclusion
|
|
in the paper.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from obliteratus.community import (
|
|
aggregate_results,
|
|
generate_latex_table,
|
|
load_contributions,
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Aggregate community contributions into paper tables."
|
|
)
|
|
parser.add_argument(
|
|
"--dir",
|
|
default="community_results",
|
|
help="Directory containing contribution JSON files (default: community_results)",
|
|
)
|
|
parser.add_argument(
|
|
"--format",
|
|
choices=["latex", "csv", "json", "summary"],
|
|
default="summary",
|
|
help="Output format (default: summary)",
|
|
)
|
|
parser.add_argument(
|
|
"--metric",
|
|
default="refusal_rate",
|
|
help="Metric to display in tables (default: refusal_rate)",
|
|
)
|
|
parser.add_argument(
|
|
"--methods",
|
|
nargs="*",
|
|
help="Methods to include (default: all)",
|
|
)
|
|
parser.add_argument(
|
|
"--min-runs",
|
|
type=int,
|
|
default=1,
|
|
help="Minimum runs per (model, method) to include (default: 1)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Load all contributions
|
|
records = load_contributions(args.dir)
|
|
if not records:
|
|
print(f"No contributions found in {args.dir}/", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
print(f"Loaded {len(records)} contribution(s) from {args.dir}/", file=sys.stderr)
|
|
|
|
# Aggregate
|
|
aggregated = aggregate_results(records)
|
|
|
|
# Filter by minimum runs
|
|
if args.min_runs > 1:
|
|
for model in list(aggregated.keys()):
|
|
for method in list(aggregated[model].keys()):
|
|
if aggregated[model][method]["n_runs"] < args.min_runs:
|
|
del aggregated[model][method]
|
|
if not aggregated[model]:
|
|
del aggregated[model]
|
|
|
|
if not aggregated:
|
|
print("No results meet the minimum run threshold.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Output
|
|
if args.format == "summary":
|
|
_print_summary(aggregated, args.metric)
|
|
elif args.format == "latex":
|
|
print(generate_latex_table(aggregated, methods=args.methods, metric=args.metric))
|
|
elif args.format == "json":
|
|
print(json.dumps(aggregated, indent=2))
|
|
elif args.format == "csv":
|
|
_print_csv(aggregated, args.metric)
|
|
|
|
|
|
def _print_summary(aggregated: dict, metric: str):
|
|
"""Print a human-readable summary of aggregated results."""
|
|
total_runs = sum(
|
|
data["n_runs"]
|
|
for model_data in aggregated.values()
|
|
for data in model_data.values()
|
|
)
|
|
n_models = len(aggregated)
|
|
n_methods = len(set(
|
|
method
|
|
for model_data in aggregated.values()
|
|
for method in model_data
|
|
))
|
|
|
|
print(f"\n{'=' * 70}")
|
|
print("Community Contribution Summary")
|
|
print(f"{'=' * 70}")
|
|
print(f" Total runs: {total_runs}")
|
|
print(f" Models: {n_models}")
|
|
print(f" Methods: {n_methods}")
|
|
print()
|
|
|
|
for model in sorted(aggregated.keys()):
|
|
model_data = aggregated[model]
|
|
short = model.split("/")[-1] if "/" in model else model
|
|
print(f" {short}:")
|
|
for method in sorted(model_data.keys()):
|
|
data = model_data[method]
|
|
n = data["n_runs"]
|
|
if metric in data:
|
|
stats = data[metric]
|
|
mean = stats["mean"]
|
|
std = stats["std"]
|
|
if std > 0 and n > 1:
|
|
print(f" {method:20s} {metric}={mean:.2f} ± {std:.2f} (n={n})")
|
|
else:
|
|
print(f" {method:20s} {metric}={mean:.2f} (n={n})")
|
|
else:
|
|
print(f" {method:20s} (no {metric} data, n={n})")
|
|
print()
|
|
|
|
print(f"{'=' * 70}")
|
|
print(f"To generate LaTeX: python {sys.argv[0]} --format latex")
|
|
print(f"To generate CSV: python {sys.argv[0]} --format csv")
|
|
|
|
|
|
def _print_csv(aggregated: dict, metric: str):
|
|
"""Print results as CSV."""
|
|
print("model,method,n_runs,mean,std,min,max")
|
|
for model in sorted(aggregated.keys()):
|
|
for method in sorted(aggregated[model].keys()):
|
|
data = aggregated[model][method]
|
|
n = data["n_runs"]
|
|
if metric in data:
|
|
stats = data[metric]
|
|
print(
|
|
f"{model},{method},{n},"
|
|
f"{stats['mean']:.4f},{stats['std']:.4f},"
|
|
f"{stats['min']:.4f},{stats['max']:.4f}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|