SurfSense/surfsense_evals/scripts/peek_disagreements.py
DESKTOP-RTLN3BA\$punk 3737118050 chore: evals
2026-05-13 14:02:26 -07:00

64 lines
2.7 KiB
Python

"""Show questions where SurfSense was wrong but long-context was right (and vice versa)."""
from __future__ import annotations
import glob
import json
from collections import defaultdict
def main() -> None:
raw_path = sorted(glob.glob("data/research/runs/*/crag/raw.jsonl"))[-1]
print(f"Reading: {raw_path}")
rows = [json.loads(line) for line in open(raw_path, encoding="utf-8") if line.strip()]
by_q: dict[str, dict[str, dict]] = defaultdict(dict)
for r in rows:
by_q[r["qid"]][r["arm"]] = r
surf_wrong_lc_right = []
lc_wrong_surf_right = []
surf_wrong_bare_right = []
for qid, arms in by_q.items():
b = arms.get("bare_llm", {}).get("graded", {}).get("grade")
lc = arms.get("long_context", {}).get("graded", {}).get("grade")
s = arms.get("surfsense", {}).get("graded", {}).get("grade")
if s == "incorrect" and lc == "correct":
surf_wrong_lc_right.append(qid)
if lc == "incorrect" and s == "correct":
lc_wrong_surf_right.append(qid)
if s == "incorrect" and b == "correct":
surf_wrong_bare_right.append(qid)
print(f"\nSurfSense INCORRECT but Long-Context CORRECT: {len(surf_wrong_lc_right)}")
print(f"Long-Context INCORRECT but SurfSense CORRECT: {len(lc_wrong_surf_right)}")
print(f"SurfSense INCORRECT but Bare CORRECT: {len(surf_wrong_bare_right)}")
print("\n=== Where SurfSense is wrong but long-context is right (top 5) ===")
for qid in surf_wrong_lc_right[:5]:
arms = by_q[qid]
b = arms.get("bare_llm", {})
print(f"\n[{qid}] domain={b.get('domain')} qtype={b.get('question_type')}")
print(f" GOLD: {b.get('gold')!r}")
for arm_name in ("bare_llm", "long_context", "surfsense"):
a = arms.get(arm_name, {})
t = (a.get("raw_text") or "").strip()
tail = t[-180:] if t else ""
grade = a.get("graded", {})
print(f" [{arm_name}] {grade.get('grade')} ({grade.get('method')}): {tail!r}")
print("\n=== Where Long-Context is wrong but SurfSense is right (top 5) ===")
for qid in lc_wrong_surf_right[:5]:
arms = by_q[qid]
b = arms.get("bare_llm", {})
print(f"\n[{qid}] domain={b.get('domain')} qtype={b.get('question_type')}")
print(f" GOLD: {b.get('gold')!r}")
for arm_name in ("bare_llm", "long_context", "surfsense"):
a = arms.get(arm_name, {})
t = (a.get("raw_text") or "").strip()
tail = t[-180:] if t else ""
grade = a.get("graded", {})
print(f" [{arm_name}] {grade.get('grade')} ({grade.get('method')}): {tail!r}")
if __name__ == "__main__":
main()