352 lines
11 KiB
Python
352 lines
11 KiB
Python
|
|
"""bench/lme500/aggregate.py — post-process LongMemEval-S blind-run output.
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
python bench/lme500/aggregate.py \
|
|||
|
|
--in bench/lme500/output/lme500-v1.json \
|
|||
|
|
--report bench/lme500/output/lme500-v1-report.md \
|
|||
|
|
--summary bench/lme500/output/lme500-v1-summary.json
|
|||
|
|
|
|||
|
|
The --in path may be:
|
|||
|
|
- the final summary JSON ({"per_row": [...], ...} schema), or
|
|||
|
|
- the per-row JSONL checkpoint (one JSON dict per line — works on
|
|||
|
|
partial runs while the bench is still in progress).
|
|||
|
|
|
|||
|
|
Computes:
|
|||
|
|
- Overall R@5 / R@10 per prong (X = retrieve_recall, Y = recall_for_benchmark)
|
|||
|
|
- Architecture lift Y - X
|
|||
|
|
- Per-question-type stratification with n per bin (low-power flag if n<30)
|
|||
|
|
- Bootstrap 95% CI via percentile method (10000 resamples, seed=42)
|
|||
|
|
- Errors counted as miss for both prongs
|
|||
|
|
|
|||
|
|
Output:
|
|||
|
|
- Markdown report (--report)
|
|||
|
|
- Aggregated JSON summary (--summary)
|
|||
|
|
- One-line stderr summary at end
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import json
|
|||
|
|
import random
|
|||
|
|
import statistics
|
|||
|
|
import sys
|
|||
|
|
from collections import defaultdict
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_rows(input_path: Path) -> list[dict[str, Any]]:
|
|||
|
|
"""Load per-row dicts from JSON, JSONL, or list-JSON.
|
|||
|
|
|
|||
|
|
Order of detection:
|
|||
|
|
1. JSONL: every non-empty line parses as a dict.
|
|||
|
|
2. JSON object with "per_row" key → return per_row.
|
|||
|
|
3. JSON list → return as-is.
|
|||
|
|
"""
|
|||
|
|
text = input_path.read_text(encoding="utf-8")
|
|||
|
|
stripped = text.strip()
|
|||
|
|
# Try JSON first
|
|||
|
|
if stripped.startswith("{"):
|
|||
|
|
try:
|
|||
|
|
data = json.loads(text)
|
|||
|
|
if isinstance(data, dict) and "per_row" in data:
|
|||
|
|
return list(data["per_row"])
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
pass
|
|||
|
|
if stripped.startswith("["):
|
|||
|
|
try:
|
|||
|
|
return list(json.loads(text))
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
pass
|
|||
|
|
# Fall back to JSONL
|
|||
|
|
rows: list[dict[str, Any]] = []
|
|||
|
|
for lineno, line in enumerate(text.splitlines(), 1):
|
|||
|
|
line = line.strip()
|
|||
|
|
if not line:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
rows.append(json.loads(line))
|
|||
|
|
except json.JSONDecodeError as exc:
|
|||
|
|
print(
|
|||
|
|
f"[aggregate] WARN: skipping corrupt line {lineno}: {exc}",
|
|||
|
|
file=sys.stderr,
|
|||
|
|
)
|
|||
|
|
return rows
|
|||
|
|
|
|||
|
|
|
|||
|
|
def bootstrap_ci(
|
|||
|
|
values: list[float],
|
|||
|
|
n_resamples: int = 10000,
|
|||
|
|
seed: int = 42,
|
|||
|
|
) -> tuple[float, float, float]:
|
|||
|
|
"""Bootstrap mean + 95% percentile CI.
|
|||
|
|
|
|||
|
|
Returns (mean, ci_lo, ci_hi). Empty input → (0, 0, 0).
|
|||
|
|
"""
|
|||
|
|
if not values:
|
|||
|
|
return 0.0, 0.0, 0.0
|
|||
|
|
rng = random.Random(seed)
|
|||
|
|
n = len(values)
|
|||
|
|
means: list[float] = []
|
|||
|
|
for _ in range(n_resamples):
|
|||
|
|
s = 0.0
|
|||
|
|
for _ in range(n):
|
|||
|
|
s += values[rng.randrange(n)]
|
|||
|
|
means.append(s / n)
|
|||
|
|
means.sort()
|
|||
|
|
lo_idx = max(0, int(0.025 * n_resamples))
|
|||
|
|
hi_idx = min(n_resamples - 1, int(0.975 * n_resamples))
|
|||
|
|
return statistics.fmean(values), means[lo_idx], means[hi_idx]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _get_prong_value(row: dict[str, Any], prong: str, k: int) -> float:
|
|||
|
|
"""Extract r_at_<k>_<prong> from a row, treating error rows as 0."""
|
|||
|
|
if "error" in row and isinstance(row.get("error"), dict):
|
|||
|
|
return 0.0
|
|||
|
|
return float(row.get(f"r_at_{k}_{prong}", 0.0))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def aggregate(rows: list[dict[str, Any]]) -> dict[str, Any]:
|
|||
|
|
"""Aggregate overall + per-type bootstrap CIs."""
|
|||
|
|
if not rows:
|
|||
|
|
return {"overall": {"n": 0, "n_errors": 0}, "per_type": {}}
|
|||
|
|
|
|||
|
|
by_type: dict[str, dict[str, list[float]]] = defaultdict(
|
|||
|
|
lambda: {"x5": [], "x10": [], "y5": [], "y10": []}
|
|||
|
|
)
|
|||
|
|
overall: dict[str, list[float]] = {"x5": [], "x10": [], "y5": [], "y10": []}
|
|||
|
|
n_errors = 0
|
|||
|
|
|
|||
|
|
for row in rows:
|
|||
|
|
is_error = "error" in row and isinstance(row.get("error"), dict)
|
|||
|
|
if is_error:
|
|||
|
|
n_errors += 1
|
|||
|
|
qtype = str(row.get("question_type", "unknown"))
|
|||
|
|
x5 = _get_prong_value(row, "retrieve", 5)
|
|||
|
|
x10 = _get_prong_value(row, "retrieve", 10)
|
|||
|
|
y5 = _get_prong_value(row, "pipeline", 5)
|
|||
|
|
y10 = _get_prong_value(row, "pipeline", 10)
|
|||
|
|
overall["x5"].append(x5)
|
|||
|
|
overall["x10"].append(x10)
|
|||
|
|
overall["y5"].append(y5)
|
|||
|
|
overall["y10"].append(y10)
|
|||
|
|
by_type[qtype]["x5"].append(x5)
|
|||
|
|
by_type[qtype]["x10"].append(x10)
|
|||
|
|
by_type[qtype]["y5"].append(y5)
|
|||
|
|
by_type[qtype]["y10"].append(y10)
|
|||
|
|
|
|||
|
|
def _prong_block(vals_5: list[float], vals_10: list[float]) -> dict:
|
|||
|
|
m5, lo5, hi5 = bootstrap_ci(vals_5)
|
|||
|
|
m10, lo10, hi10 = bootstrap_ci(vals_10)
|
|||
|
|
return {
|
|||
|
|
"r_at_5": {"mean": m5, "ci_lo": lo5, "ci_hi": hi5},
|
|||
|
|
"r_at_10": {"mean": m10, "ci_lo": lo10, "ci_hi": hi10},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
overall_block = {
|
|||
|
|
"n": len(rows),
|
|||
|
|
"n_errors": n_errors,
|
|||
|
|
"X_retrieve": _prong_block(overall["x5"], overall["x10"]),
|
|||
|
|
"Y_pipeline": _prong_block(overall["y5"], overall["y10"]),
|
|||
|
|
}
|
|||
|
|
overall_block["lift_Y_minus_X"] = {
|
|||
|
|
"r_at_5": (
|
|||
|
|
overall_block["Y_pipeline"]["r_at_5"]["mean"]
|
|||
|
|
- overall_block["X_retrieve"]["r_at_5"]["mean"]
|
|||
|
|
),
|
|||
|
|
"r_at_10": (
|
|||
|
|
overall_block["Y_pipeline"]["r_at_10"]["mean"]
|
|||
|
|
- overall_block["X_retrieve"]["r_at_10"]["mean"]
|
|||
|
|
),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
per_type_out: dict[str, dict[str, Any]] = {}
|
|||
|
|
for qt in sorted(by_type.keys()):
|
|||
|
|
data = by_type[qt]
|
|||
|
|
block = {
|
|||
|
|
"n": len(data["x5"]),
|
|||
|
|
"X_retrieve": _prong_block(data["x5"], data["x10"]),
|
|||
|
|
"Y_pipeline": _prong_block(data["y5"], data["y10"]),
|
|||
|
|
}
|
|||
|
|
block["lift_Y_minus_X"] = {
|
|||
|
|
"r_at_5": (
|
|||
|
|
block["Y_pipeline"]["r_at_5"]["mean"]
|
|||
|
|
- block["X_retrieve"]["r_at_5"]["mean"]
|
|||
|
|
),
|
|||
|
|
"r_at_10": (
|
|||
|
|
block["Y_pipeline"]["r_at_10"]["mean"]
|
|||
|
|
- block["X_retrieve"]["r_at_10"]["mean"]
|
|||
|
|
),
|
|||
|
|
}
|
|||
|
|
per_type_out[qt] = block
|
|||
|
|
|
|||
|
|
return {"overall": overall_block, "per_type": per_type_out}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def format_markdown_report(agg: dict[str, Any], source_path: Path) -> str:
|
|||
|
|
overall = agg["overall"]
|
|||
|
|
lines: list[str] = []
|
|||
|
|
lines.append("# LongMemEval-S Aggregate Report")
|
|||
|
|
lines.append("")
|
|||
|
|
lines.append(f"- Source: `{source_path}`")
|
|||
|
|
lines.append(f"- n = {overall['n']}, errors = {overall['n_errors']}")
|
|||
|
|
lines.append(
|
|||
|
|
"- 95% CI via bootstrap percentile method (10000 resamples, seed=42)"
|
|||
|
|
)
|
|||
|
|
lines.append("")
|
|||
|
|
|
|||
|
|
if overall["n"] == 0:
|
|||
|
|
lines.append("**No rows loaded.**")
|
|||
|
|
return "\n".join(lines) + "\n"
|
|||
|
|
|
|||
|
|
lines.append("## Overall")
|
|||
|
|
lines.append("")
|
|||
|
|
lines.append("| Prong | R@5 | R@5 95% CI | R@10 | R@10 95% CI |")
|
|||
|
|
lines.append("|---|---|---|---|---|")
|
|||
|
|
x = overall["X_retrieve"]
|
|||
|
|
y = overall["Y_pipeline"]
|
|||
|
|
lift = overall["lift_Y_minus_X"]
|
|||
|
|
lines.append(
|
|||
|
|
f"| X (retrieve_recall — flat-cosine baseline) "
|
|||
|
|
f"| {x['r_at_5']['mean']:.3f} "
|
|||
|
|
f"| [{x['r_at_5']['ci_lo']:.3f}, {x['r_at_5']['ci_hi']:.3f}] "
|
|||
|
|
f"| {x['r_at_10']['mean']:.3f} "
|
|||
|
|
f"| [{x['r_at_10']['ci_lo']:.3f}, {x['r_at_10']['ci_hi']:.3f}] |"
|
|||
|
|
)
|
|||
|
|
lines.append(
|
|||
|
|
f"| Y (recall_for_benchmark — full graph-native pipeline) "
|
|||
|
|
f"| {y['r_at_5']['mean']:.3f} "
|
|||
|
|
f"| [{y['r_at_5']['ci_lo']:.3f}, {y['r_at_5']['ci_hi']:.3f}] "
|
|||
|
|
f"| {y['r_at_10']['mean']:.3f} "
|
|||
|
|
f"| [{y['r_at_10']['ci_lo']:.3f}, {y['r_at_10']['ci_hi']:.3f}] |"
|
|||
|
|
)
|
|||
|
|
lines.append(
|
|||
|
|
f"| **Architecture lift Y − X** "
|
|||
|
|
f"| **{lift['r_at_5']:+.3f}** "
|
|||
|
|
f"| — "
|
|||
|
|
f"| **{lift['r_at_10']:+.3f}** "
|
|||
|
|
f"| — |"
|
|||
|
|
)
|
|||
|
|
lines.append("")
|
|||
|
|
|
|||
|
|
lines.append("## Per question type")
|
|||
|
|
lines.append("")
|
|||
|
|
lines.append(
|
|||
|
|
"| Type | n | X R@5 | Y R@5 | Lift R@5 "
|
|||
|
|
"| X R@10 | Y R@10 | Lift R@10 |"
|
|||
|
|
)
|
|||
|
|
lines.append("|---|---|---|---|---|---|---|---|")
|
|||
|
|
for qt, block in agg["per_type"].items():
|
|||
|
|
n = block["n"]
|
|||
|
|
flag = " ⚠️" if n < 30 else ""
|
|||
|
|
x = block["X_retrieve"]
|
|||
|
|
y = block["Y_pipeline"]
|
|||
|
|
lift = block["lift_Y_minus_X"]
|
|||
|
|
lines.append(
|
|||
|
|
f"| `{qt}`{flag} | {n} "
|
|||
|
|
f"| {x['r_at_5']['mean']:.3f} | {y['r_at_5']['mean']:.3f} "
|
|||
|
|
f"| {lift['r_at_5']:+.3f} "
|
|||
|
|
f"| {x['r_at_10']['mean']:.3f} | {y['r_at_10']['mean']:.3f} "
|
|||
|
|
f"| {lift['r_at_10']:+.3f} |"
|
|||
|
|
)
|
|||
|
|
lines.append("")
|
|||
|
|
lines.append("⚠️ = n < 30, low statistical power for that bin.")
|
|||
|
|
lines.append("")
|
|||
|
|
lines.append("## Notes")
|
|||
|
|
lines.append("")
|
|||
|
|
lines.append(
|
|||
|
|
"- Errors (graph-build failures, malformed rows, etc.) are counted "
|
|||
|
|
"as miss for **both** prongs (R@k = 0)."
|
|||
|
|
)
|
|||
|
|
lines.append(
|
|||
|
|
"- Mean is the unweighted row average; CI is bootstrap percentile."
|
|||
|
|
)
|
|||
|
|
lines.append(
|
|||
|
|
"- Architecture lift = mean(Y) − mean(X). The CI of the lift "
|
|||
|
|
"itself is not computed here (would require paired bootstrap on "
|
|||
|
|
"the (Y_i, X_i) tuples — TODO if needed)."
|
|||
|
|
)
|
|||
|
|
return "\n".join(lines) + "\n"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main() -> int:
|
|||
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--in",
|
|||
|
|
dest="input",
|
|||
|
|
required=True,
|
|||
|
|
help="Path to per-row JSON / JSONL file",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--report",
|
|||
|
|
default=None,
|
|||
|
|
help="Output path for markdown report; default: <input>-report.md",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--summary",
|
|||
|
|
default=None,
|
|||
|
|
help="Output path for aggregated JSON; default: <input>-summary.json",
|
|||
|
|
)
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
input_path = Path(args.input)
|
|||
|
|
if not input_path.exists():
|
|||
|
|
print(f"[aggregate] ERROR: {input_path} does not exist", file=sys.stderr)
|
|||
|
|
return 1
|
|||
|
|
rows = load_rows(input_path)
|
|||
|
|
if not rows:
|
|||
|
|
print(f"[aggregate] WARN: 0 rows loaded from {input_path}", file=sys.stderr)
|
|||
|
|
return 1
|
|||
|
|
|
|||
|
|
agg = aggregate(rows)
|
|||
|
|
|
|||
|
|
summary_path = (
|
|||
|
|
Path(args.summary)
|
|||
|
|
if args.summary
|
|||
|
|
else input_path.with_name(input_path.stem + "-summary.json")
|
|||
|
|
)
|
|||
|
|
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
with open(summary_path, "w", encoding="utf-8") as f:
|
|||
|
|
json.dump(agg, f, indent=2)
|
|||
|
|
|
|||
|
|
report_path = (
|
|||
|
|
Path(args.report)
|
|||
|
|
if args.report
|
|||
|
|
else input_path.with_name(input_path.stem + "-report.md")
|
|||
|
|
)
|
|||
|
|
report_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
report_path.write_text(format_markdown_report(agg, input_path), encoding="utf-8")
|
|||
|
|
|
|||
|
|
overall = agg["overall"]
|
|||
|
|
x = overall["X_retrieve"]
|
|||
|
|
y = overall["Y_pipeline"]
|
|||
|
|
lift = overall["lift_Y_minus_X"]
|
|||
|
|
print(
|
|||
|
|
f"[aggregate] n={overall['n']} errors={overall['n_errors']}",
|
|||
|
|
file=sys.stderr,
|
|||
|
|
)
|
|||
|
|
print(
|
|||
|
|
f"[aggregate] X (retrieve) R@5={x['r_at_5']['mean']:.3f} "
|
|||
|
|
f"[{x['r_at_5']['ci_lo']:.3f},{x['r_at_5']['ci_hi']:.3f}] "
|
|||
|
|
f"R@10={x['r_at_10']['mean']:.3f}",
|
|||
|
|
file=sys.stderr,
|
|||
|
|
)
|
|||
|
|
print(
|
|||
|
|
f"[aggregate] Y (pipeline) R@5={y['r_at_5']['mean']:.3f} "
|
|||
|
|
f"[{y['r_at_5']['ci_lo']:.3f},{y['r_at_5']['ci_hi']:.3f}] "
|
|||
|
|
f"R@10={y['r_at_10']['mean']:.3f}",
|
|||
|
|
file=sys.stderr,
|
|||
|
|
)
|
|||
|
|
print(
|
|||
|
|
f"[aggregate] Lift Y − X R@5={lift['r_at_5']:+.3f} "
|
|||
|
|
f"R@10={lift['r_at_10']:+.3f}",
|
|||
|
|
file=sys.stderr,
|
|||
|
|
)
|
|||
|
|
print(f"[aggregate] -> {summary_path}", file=sys.stderr)
|
|||
|
|
print(f"[aggregate] -> {report_path}", file=sys.stderr)
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
raise SystemExit(main())
|