chore: evals

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-13 14:02:26 -07:00
parent 2402b730fa
commit 3737118050
122 changed files with 22598 additions and 13 deletions

View file

@ -0,0 +1,65 @@
# surfsense_evals — environment template.
#
# Copy this file to `.env` (in the surfsense_evals/ project root or your
# CWD) and fill in the values. `python-dotenv` loads it automatically
# the first time `core.config` is imported, so every CLI subcommand
# (`setup`, `ingest`, `run`, `report`, `teardown`, `models list`, …)
# will pick the values up.
#
# cp .env.example .env
# # then edit .env with your values
#
# `.env` is gitignored — never commit real secrets.
# ---------------------------------------------------------------------------
# 1. Backend target — REQUIRED (default works for a local dev backend)
# ---------------------------------------------------------------------------
SURFSENSE_API_BASE=http://localhost:8000
# ---------------------------------------------------------------------------
# 2. OpenRouter — REQUIRED for any `run` invocation
# ---------------------------------------------------------------------------
# The `native_pdf` arm calls OpenRouter directly; the `surfsense` arm
# routes through SurfSense which uses the same key under the hood.
OPENROUTER_API_KEY=sk-or-...
# Override only if you proxy OpenRouter through a private gateway:
# OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
# Multimodal benchmarks (medxpertqa, mmlongbench) require a vision-capable
# slug. Recommended (verify in your catalog with `models list --grep ...`):
# anthropic/claude-sonnet-4.5 (default recommendation)
# anthropic/claude-opus-4.7 (strongest)
# openai/gpt-5 (top-tier vision)
# google/gemini-2.5-pro (1M-token context, best for long PDFs)
# DO NOT use openai/gpt-5.4-mini for image-bearing benchmarks — it's
# text-only on PDF content and the runner emits a warning if pinned.
# ---------------------------------------------------------------------------
# 3. Auth — pick EXACTLY ONE of the two modes below
# ---------------------------------------------------------------------------
# --- Mode A: LOCAL (backend started with AUTH_TYPE=LOCAL)
# The harness POSTs these to /auth/jwt/login automatically.
# SURFSENSE_USER_EMAIL=you@example.com
# SURFSENSE_USER_PASSWORD=...
# --- Mode B: GOOGLE OAuth (or any pre-issued JWT)
# Open the SurfSense web UI in your browser, log in via Google, then in
# DevTools → Application → Local Storage copy:
# surfsense_bearer_token → SURFSENSE_JWT
# surfsense_refresh_token → SURFSENSE_REFRESH_TOKEN (optional, enables
# auto-refresh on 401)
# SURFSENSE_JWT=eyJhbGciOi...
# SURFSENSE_REFRESH_TOKEN=eyJhbGciOi...
# ---------------------------------------------------------------------------
# 4. Filesystem paths — OPTIONAL (defaults below)
# ---------------------------------------------------------------------------
# Where datasets, rendered PDFs, ingestion id maps, run outputs, and
# state.json live. Default: <surfsense_evals>/data/
# EVAL_DATA_DIR=./data
# Where generated reports (summary.md / summary.json) get written.
# Default: <surfsense_evals>/reports/
# EVAL_REPORTS_DIR=./reports

29
surfsense_evals/.gitignore vendored Normal file
View file

@ -0,0 +1,29 @@
# Python bytecode + caches
__pycache__/
*.py[cod]
*.pyo
# Editable-install / build artifacts
*.egg-info/
build/
dist/
.eggs/
# Virtual envs (uv venv default + common alternates)
.venv/
venv/
env/
# Tooling caches
.pytest_cache/
.ruff_cache/
.mypy_cache/
.coverage
.coverage.*
htmlcov/
# Local secrets — keep `.env.example` tracked, never the real `.env`.
.env
.env.local
.env.*.local
!.env.example

228
surfsense_evals/README.md Normal file
View file

@ -0,0 +1,228 @@
# SurfSense Evals
Domain-agnostic eval harness for SurfSense. Each benchmark is a Python subpackage under `suites/<domain>/<benchmark>/` that self-registers with the CLI; `core/` is the shared infrastructure (HTTP clients, arms, parsers, metrics, report writer, registry). The harness talks to SurfSense over HTTP only — it does **not** import any backend Python module — so it ships in its own venv and never bloats the FastAPI runtime image.
## Benchmarks
| Benchmark | Shape | Vision required? | Default ingest |
|---------------------------------|--------------------------------------------------|------------------|----------------------------|
| `medical/medxpertqa` (headline) | Native PDF vs SurfSense head-to-head, MCQ | yes | `vision=on, mode=basic` |
| `medical/mirage` | SurfSense single-arm, MCQ | no | `vision=off, mode=basic` |
| `medical/cure` | SurfSense single-arm retrieval (Recall/MRR/nDCG) | no | `vision=off, mode=basic` |
| `multimodal_doc/mmlongbench` | Native PDF vs SurfSense head-to-head, open-ended | yes | `vision=on, mode=basic` |
Future domains (`legal/`, `finance/`, `code/`, `scientific/`) drop into `suites/` without touching `core/` or the CLI.
## Install + auth
```bash
uv pip install -e ./surfsense_evals
cp surfsense_evals/.env.example surfsense_evals/.env
# Edit .env: SURFSENSE_API_BASE, OPENROUTER_API_KEY, and ONE of:
# LOCAL → SURFSENSE_USER_EMAIL + SURFSENSE_USER_PASSWORD
# GOOGLE → SURFSENSE_JWT (+ optional SURFSENSE_REFRESH_TOKEN)
# (lift both from browser localStorage after a normal Google login)
```
## Step-by-step: run all four benchmarks
The medical and multimodal_doc suites each get their own SearchSpace and pinned model, so they're independent — run them in any order. Both head-to-head benchmarks (`medxpertqa`, `mmlongbench`) require a **vision-capable** OpenRouter slug; pinning a text-only one (e.g. `openai/gpt-5.4-mini`) silently drops images and the runner emits a warning.
Recommended vision slugs (use `models list --grep <name>` to confirm one): `anthropic/claude-sonnet-4.5` (balanced cost), `anthropic/claude-opus-4.7` (strongest reasoning), `openai/gpt-5` (top-tier vision), `google/gemini-2.5-pro` (best for long PDFs, 1M-token context).
```bash
# 0. (optional) discover what's registered
python -m surfsense_evals suites list
python -m surfsense_evals benchmarks list
# 1. MEDICAL SUITE — one SearchSpace, three benchmarks
python -m surfsense_evals setup --suite medical --provider-model anthropic/claude-sonnet-4.5
# 1a. headline head-to-head: Native PDF (vision) vs SurfSense (vision RAG)
# Downloads dev+test JSONL + images.zip, renders one PDF per question
# (case + table + images + 5 options), uploads with use_vision_llm=True.
python -m surfsense_evals ingest medical medxpertqa --split test
python -m surfsense_evals run medical medxpertqa --concurrency 4
# 1b. MIRAGE — single-arm SurfSense MCQ accuracy
# (MMLU-Med / MedQA-US / MedMCQA / PubMedQA / BioASQ)
python -m surfsense_evals ingest medical mirage
python -m surfsense_evals run medical mirage
# 1c. CUREv1 — single-arm SurfSense retrieval (Recall@k / MRR / nDCG@10)
python -m surfsense_evals ingest medical cure --lang en
python -m surfsense_evals run medical cure --lang en
# 1d. write reports/medical/<UTC-ts>/summary.{md,json}
python -m surfsense_evals report --suite medical
# 2. MULTIMODAL_DOC SUITE — long PDFs with embedded images, charts, tables
python -m surfsense_evals setup --suite multimodal_doc --provider-model google/gemini-2.5-pro
python -m surfsense_evals ingest multimodal_doc mmlongbench # ~660MB, resumable
python -m surfsense_evals run multimodal_doc mmlongbench --concurrency 4
python -m surfsense_evals report --suite multimodal_doc
# 3. CLEANUP — soft-deletes the SearchSpaces; rendered PDFs stay cached
python -m surfsense_evals teardown --suite medical
python -m surfsense_evals teardown --suite multimodal_doc
```
## Asymmetric scenarios — the "vision-extract once, answer cheap" play
The walkthrough above is `--scenario head-to-head` (default): both arms answer with the same vision-capable slug. SurfSense's actual architectural value-prop is that the **ingestion-time vision LLM and the runtime LLM are completely independent** — you can pay a vision LLM *once*, at ingest, to convert every embedded image into text (per-image OCR **and** semantic description, inlined where the image actually appears in the document — see [What `--use-vision-llm` produces](#what---use-vision-llm-produces) below). Then every query is served by a cheap text-only model that sees that extracted text natively. Two extra scenarios make this explicit:
| `--scenario` | Native arm answers with | SurfSense arm answers with | Question being measured |
|--------------------|----------------------------------------|--------------------------------|------------------------------------------------------------------------------------------|
| `head-to-head` | `--provider-model` (vision) | `--provider-model` (vision) | Pure RAG quality at parity. (Default.) |
| `symmetric-cheap` | `--provider-model` (cheap, text-only) | `--provider-model` (same) | Does pre-extracted image context let a non-vision LLM reason over image-heavy docs? |
| `cost-arbitrage` | `--native-arm-model` (vision) | `--provider-model` (cheap) | How close does SurfSense get to a vision-native baseline at a fraction of per-query cost?|
In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_llm_config_id` (auto-picked from the strongest registered global OpenRouter vision config — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm <slug>`). What changes is which slug the *answering* models hit per arm.
### Ingest with vision, evaluate with a non-vision LLM (`symmetric-cheap`)
This is the answer to *"does SurfSense give a non-vision LLM enough context to reason over image-heavy docs?"*. Both arms hit the same cheap text-only slug. The native arm is structurally blind to images (text-only LLM + raw PDFs). The SurfSense arm reads chunks that already contain the per-image OCR and visual descriptions, written there by the vision LLM at ingest time.
```bash
python -m surfsense_evals setup --suite medical \
--scenario symmetric-cheap \
--provider-model openai/gpt-5.4-mini
# vision LLM at ingest = auto-picked (claude-sonnet-4.5 by default)
# answer LLM for BOTH arms = openai/gpt-5.4-mini (text-only)
python -m surfsense_evals ingest medical medxpertqa --split test # vision=on by default
python -m surfsense_evals run medical medxpertqa --concurrency 4
python -m surfsense_evals report --suite medical
# Δ accuracy on image-required MCQs is the headline number; native arm
# baseline is "what a text-only LLM gets without seeing the images".
```
### Cheap SurfSense vs vision-native baseline (`cost-arbitrage`)
```bash
python -m surfsense_evals setup --suite medical \
--scenario cost-arbitrage \
--provider-model openai/gpt-5.4-mini \
--native-arm-model anthropic/claude-sonnet-4.5
# vision LLM at ingest = auto-picked claude-sonnet-4.5
# native arm = sonnet (vision); SurfSense arm = gpt-5.4-mini (text-only)
python -m surfsense_evals ingest medical medxpertqa --split test
python -m surfsense_evals run medical medxpertqa --concurrency 4
python -m surfsense_evals report --suite medical
# Report header reads:
# Scenario: cost-arbitrage — native arm answers with `anthropic/claude-sonnet-4.5`
# (vision); SurfSense answers with `openai/gpt-5.4-mini` over chunks vision-extracted
# at ingest by `anthropic/claude-sonnet-4.5`.
```
Notes:
- `cost-arbitrage` requires both `--provider-model` (the cheap SurfSense slug) AND `--native-arm-model <vision slug>`.
- `--vision-llm <slug>` is optional; if omitted the harness queries `GET /api/v1/global-vision-llm-configs` and auto-picks the strongest registered one. Pass `--no-vision-llm-setup` if you want to keep whatever vision config is already attached to the SearchSpace.
- The runner's "looks text-only" warning is suppressed (or relabelled as informational) for `symmetric-cheap` so intentional asymmetry doesn't read as a misconfiguration.
- All three scenario fields (`scenario`, `provider_model`, `native_arm_model`, `vision_provider_model`) are persisted to `state.json` and recorded in `run_artifact.extra` + the report header — no need to retrace what was set.
## Per-benchmark useful flags
`medical/medxpertqa` (`run`):
- `--split {test,dev,all}` — pick a subset (default `test`)
- `--task "Diagnosis"` / `--body-system "Cardiovascular"` — slice the report
- `--require-images` — drop rare rows where every image filename failed to resolve
- `--n 100` — quick smoke run
- `--no-mentions` — let SurfSense retrieve unscoped ("did the @-mention matter?")
`multimodal_doc/mmlongbench`:
- `--max-docs N` (ingest) — cap downloads at the first N unique PDFs
- `--format {str,int,float,list,none}` (run) — slice by answer format; `none` = the ~22% intentionally unanswerable hallucination probes
- `--skip-unanswerable` (run) — drop unanswerable questions
- `--docs <a.pdf>,<b.pdf>` (run) — scope to specific docs
## Ingestion knobs (vision LLM, processing mode, summarize)
The harness exposes `POST /api/v1/documents/fileupload`'s three knobs on every `ingest` subcommand:
| Flag pair | Effect |
|--------------------------------------------|-----------------------------------------------------------------------------------------|
| `--use-vision-llm` / `--no-vision-llm` | Walk every embedded image in the PDF and inline image-derived text at the image's position (see below). |
| `--processing-mode {basic,premium}` | `premium` carries a 10× page multiplier and routes to a stronger ETL (e.g. LlamaCloud). |
| `--should-summarize` / `--no-summarize` | Generate a per-document summary at ingest. |
The "Default ingest" column in the benchmarks table is what runs if you don't pass any flag. Whatever was actually used is recorded as a `__settings__` header in the doc map (`data/<suite>/maps/<benchmark>_*_map.jsonl`) and as `extra.ingest_settings` in `run_artifact.json`, then surfaced in the report — no need to hunt through CLI history.
> The backend's `ETL_SERVICE` env var (`DOCLING` | `UNSTRUCTURED` | `LLAMACLOUD`) is **not** per-upload. Restart the backend with a different `ETL_SERVICE` and re-ingest to compare ETLs (route through `--processing-mode premium` if your backend uses that mode for the stronger ETL).
### What `--use-vision-llm` produces
When vision is on, the backend's ETL pipeline (`app/etl_pipeline/picture_describer.py`) does, **per embedded image** in the PDF:
1. Extract the raw image bytes via `pypdf` (deduped by sha256, size-capped to match the vision LLM's per-image limit).
2. **Per-image OCR** — re-feed the image as a standalone upload through the configured ETL service (Docling / Azure DI / LlamaCloud) with `vision_llm=None`, so the ETL's OCR engine extracts the literal text-in-image.
3. **Visual description** — call the vision LLM on the image with a description-only prompt (it's explicitly told *not* to transcribe text — that's OCR's job). Steps 2 and 3 run in parallel per image.
4. Splice a horizontal-rule-delimited section **at the image's original position** in the parser markdown (replacing Docling's `<!-- image -->` placeholder + caption, or the bare `Image: <name>` caption a stripped-image parser leaves behind):
```markdown
---
**Embedded image:** `MM-130-a.jpeg`
**OCR text:**
Slice 24 / 60
L R
**Visual description:**
- Axial contrast-enhanced CT showing a large cystic mass in the left upper quadrant.
- Mass effect on the adjacent stomach; left kidney displaced inferiorly.
---
```
This is what makes `--scenario symmetric-cheap` and `--scenario cost-arbitrage` work: a non-vision LLM reading SurfSense's chunks sees the image's text and semantic content as plain markdown, alongside the surrounding case text, in the same retrieved chunk. Without it the cheap LLM would have nothing extra to read.
### A/B testing the same corpus with different settings
SurfSense dedupes uploads by `(filename, search_space_id)`**not** by content hash and **not** by ingestion settings. Re-uploading the same filename to the same SearchSpace with a different `--use-vision-llm` flag silently skips re-processing. Give each variant its own SearchSpace:
```bash
# Baseline arm (vision off)
python -m surfsense_evals setup --suite medical --provider-model anthropic/claude-sonnet-4.5
python -m surfsense_evals ingest medical medxpertqa --no-vision-llm
python -m surfsense_evals run medical medxpertqa --n 100
python -m surfsense_evals teardown --suite medical
# Vision arm (the benchmark default)
python -m surfsense_evals setup --suite medical --provider-model anthropic/claude-sonnet-4.5
python -m surfsense_evals ingest medical medxpertqa
python -m surfsense_evals run medical medxpertqa --n 100
python -m surfsense_evals report --suite medical
```
Both runs land in `data/medical/runs/<ts>/medxpertqa/` with their settings recorded; rendered PDFs stay cached under `data/medical/medxpertqa/pdfs/` so the second `ingest` is upload-only.
## Environment variables
- `SURFSENSE_API_BASE` (default `http://localhost:8000`)
- `OPENROUTER_API_KEY` — required for the `native_pdf` arm and for `models list`
- One of `SURFSENSE_USER_EMAIL` + `SURFSENSE_USER_PASSWORD` (LOCAL), **or** `SURFSENSE_JWT` (+ optional `SURFSENSE_REFRESH_TOKEN`) for GOOGLE/pre-issued JWT
- `EVAL_DATA_DIR` (default `<project>/data`) — datasets, rendered PDFs, ingestion id maps, run outputs, `state.json`
- `EVAL_REPORTS_DIR` (default `<project>/reports`)
- `OPENROUTER_BASE_URL` (default `https://openrouter.ai/api/v1`) — only if you proxy OpenRouter
## Adding a new domain suite
1. Create `surfsense_evals/src/surfsense_evals/suites/<domain>/<benchmark>/` with `__init__.py`, `ingest.py`, `runner.py`, optional `prompt.py`.
2. Implement a `Benchmark` subclass (see `core/registry.py`); compose with `core.clients.*`, `core.arms.*`, `core.parse.*`, `core.metrics.*`.
3. Call `register(MyBenchmark())` at the bottom of `<benchmark>/__init__.py`. Auto-discovery picks it up; `setup --suite <domain>` and `ingest/run <domain> <benchmark>` work immediately.
Each suite gets its own SearchSpace (`eval-<suite>-<UTC-ts>`), `state.json` slot, data dir, reports dir, and pinned LLM. Suites never share a SearchSpace.
## Out of scope (follow-up PRs)
- Docker service for `docker compose run evals run medical medxpertqa`.
- Multi-model sweeps (one slug per `setup` for now; aggregate reports come later).
- A long-context-stuffing arm (give the model the same retrieved chunks SurfSense saw).
- LLM-judge grader for MMLongBench-Doc (paper uses GPT-4 as judge; we ship a deterministic rule-based grader).
- MedXpertQA-MM accuracy by image modality — dataset doesn't tag modality directly; we slice by `medical_task` and `body_system`.
- A `--slot <name>` flag that decouples the state-slot key from the benchmark registry's `suite` attribute, so parallel SearchSpaces with different ingestion settings can coexist on the same benchmark without `teardown` between A/B arms.
See `c:/Users/91882/.cursor/plans/medical_rag_evals_(mirage_+_curev1)_e797a324.plan.md` for the full design rationale.

2
surfsense_evals/data/.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
*
!.gitignore

View file

@ -0,0 +1,60 @@
[project]
name = "surfsense-evals"
version = "0.1.0"
description = "Domain-agnostic evaluation harness for SurfSense (medical RAG suite ships first; legal/finance/code suites slot in under suites/)."
readme = "README.md"
requires-python = ">=3.12"
license = { text = "Apache-2.0" }
authors = [{ name = "SurfSense" }]
dependencies = [
"httpx>=0.27.0",
"httpx-sse>=0.4.0",
"datasets>=2.21.0",
"huggingface_hub>=0.24.0",
"reportlab>=4.0.0",
"Pillow>=10.0.0",
"pyarrow>=15.0.0",
"pydantic>=2.6.0",
"tqdm>=4.66.0",
"numpy>=1.26.0",
"scikit-learn>=1.4.0",
"scipy>=1.12.0",
"python-dotenv>=1.0.0",
"rich>=13.7.0",
"trafilatura>=1.12.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"respx>=0.21.0",
"ruff>=0.5.0",
]
[project.scripts]
surfsense-evals = "surfsense_evals.core.cli:main"
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
include = ["surfsense_evals*"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
markers = [
"integration: opt-in tests that hit a live SurfSense instance (run with `-m integration`)",
]
[tool.ruff]
line-length = 100
target-version = "py312"
[tool.ruff.lint]
select = ["E", "F", "I", "B", "UP", "SIM", "ASYNC"]
ignore = ["E501"]

4
surfsense_evals/reports/.gitignore vendored Normal file
View file

@ -0,0 +1,4 @@
*
!.gitignore
!medical/
!medical/sample_summary.md

View file

@ -0,0 +1,97 @@
"""Download CRAG Task 3's 4 .tar.bz2 parts in parallel.
Run once before ``ingest research crag_t3`` to avoid the ingest
synchronously blocking on a 7 GB download. Skips parts already
present and complete on disk.
"""
from __future__ import annotations
import logging
import sys
import time
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
log = logging.getLogger("download_task3")
_BASE = (
"https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/"
"crag_task_3_dev_v4.tar.bz2.part"
)
_USER_AGENT = "SurfSense-Evals/0.1 (CRAG Task 3 fetch)"
def _expected_size(url: str) -> int:
req = urllib.request.Request(url, method="HEAD", headers={"User-Agent": _USER_AGENT})
with urllib.request.urlopen(req, timeout=30) as resp:
return int(resp.headers.get("content-length", 0))
def download_one(part: int, dest_dir: Path) -> Path:
url = f"{_BASE}{part}"
dest = dest_dir / f"crag_task_3_dev_v4.tar.bz2.part{part}"
expected = _expected_size(url)
if dest.exists() and dest.stat().st_size == expected:
log.info("part%d: cached (%d bytes)", part, expected)
return dest
log.info("part%d: downloading %d bytes ...", part, expected)
tmp = dest.with_suffix(dest.suffix + ".part_dl")
started = time.monotonic()
last_log = started
with urllib.request.urlopen(
urllib.request.Request(url, headers={"User-Agent": _USER_AGENT}),
timeout=900,
) as resp, tmp.open("wb") as fh:
downloaded = 0
chunk = resp.read(1 << 20)
while chunk:
fh.write(chunk)
downloaded += len(chunk)
now = time.monotonic()
if now - last_log > 5.0:
pct = 100 * downloaded / expected if expected else 0
rate_mb = (downloaded / (now - started)) / (1 << 20)
log.info(
"part%d: %5.1f%% (%.1f / %.1f MiB at %.1f MiB/s)",
part, pct, downloaded / (1 << 20), expected / (1 << 20), rate_mb,
)
last_log = now
chunk = resp.read(1 << 20)
tmp.replace(dest)
elapsed = time.monotonic() - started
log.info(
"part%d: done in %.1fs (%.1f MiB/s avg)",
part, elapsed, (expected / (1 << 20)) / max(elapsed, 0.001),
)
return dest
def main() -> int:
dest_dir = Path("data/research/crag_t3/.raw_cache")
dest_dir.mkdir(parents=True, exist_ok=True)
# 4 parts in parallel — typical residential connection saturates around
# 2 streams; GitHub raw serves these fine in parallel.
started = time.monotonic()
with ThreadPoolExecutor(max_workers=4) as ex:
futures = {ex.submit(download_one, i, dest_dir): i for i in range(1, 5)}
for fut in as_completed(futures):
part = futures[fut]
try:
fut.result()
except Exception as exc: # noqa: BLE001
log.error("part%d failed: %s", part, exc)
return 1
log.info("All 4 parts downloaded in %.1fs", time.monotonic() - started)
return 0
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,37 @@
"""Tiny helper to inspect the latest CRAG run's per-question outputs."""
from __future__ import annotations
import glob
import json
from collections import defaultdict
def main() -> None:
raw_path = sorted(glob.glob("data/research/runs/*/crag/raw.jsonl"))[-1]
print(f"Reading: {raw_path}")
rows = [json.loads(line) for line in open(raw_path, encoding="utf-8") if line.strip()]
by_q: dict[str, dict[str, dict]] = defaultdict(dict)
for r in rows:
by_q[r["qid"]][r["arm"]] = r
for qid, arms in list(by_q.items()):
b = arms.get("bare_llm", {})
l = arms.get("long_context", {})
s = arms.get("surfsense", {})
print(f"\n=== {qid} ({b.get('domain')}/{b.get('question_type')}) ===")
print(f" question: {b.get('extra', {}).get('question', '?')!r}")
print(f" gold: {b.get('gold')!r}")
for arm_name, a in (("bare_llm", b), ("long_context", l), ("surfsense", s)):
grade = a.get("graded", {})
text = (a.get("raw_text") or "").strip()
tail = text[-200:] if text else ""
print(
f" [{arm_name}] grade={grade.get('grade')} "
f"method={grade.get('method')}"
)
print(f" -> {tail!r}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,64 @@
"""Show questions where SurfSense was wrong but long-context was right (and vice versa)."""
from __future__ import annotations
import glob
import json
from collections import defaultdict
def main() -> None:
raw_path = sorted(glob.glob("data/research/runs/*/crag/raw.jsonl"))[-1]
print(f"Reading: {raw_path}")
rows = [json.loads(line) for line in open(raw_path, encoding="utf-8") if line.strip()]
by_q: dict[str, dict[str, dict]] = defaultdict(dict)
for r in rows:
by_q[r["qid"]][r["arm"]] = r
surf_wrong_lc_right = []
lc_wrong_surf_right = []
surf_wrong_bare_right = []
for qid, arms in by_q.items():
b = arms.get("bare_llm", {}).get("graded", {}).get("grade")
lc = arms.get("long_context", {}).get("graded", {}).get("grade")
s = arms.get("surfsense", {}).get("graded", {}).get("grade")
if s == "incorrect" and lc == "correct":
surf_wrong_lc_right.append(qid)
if lc == "incorrect" and s == "correct":
lc_wrong_surf_right.append(qid)
if s == "incorrect" and b == "correct":
surf_wrong_bare_right.append(qid)
print(f"\nSurfSense INCORRECT but Long-Context CORRECT: {len(surf_wrong_lc_right)}")
print(f"Long-Context INCORRECT but SurfSense CORRECT: {len(lc_wrong_surf_right)}")
print(f"SurfSense INCORRECT but Bare CORRECT: {len(surf_wrong_bare_right)}")
print("\n=== Where SurfSense is wrong but long-context is right (top 5) ===")
for qid in surf_wrong_lc_right[:5]:
arms = by_q[qid]
b = arms.get("bare_llm", {})
print(f"\n[{qid}] domain={b.get('domain')} qtype={b.get('question_type')}")
print(f" GOLD: {b.get('gold')!r}")
for arm_name in ("bare_llm", "long_context", "surfsense"):
a = arms.get(arm_name, {})
t = (a.get("raw_text") or "").strip()
tail = t[-180:] if t else ""
grade = a.get("graded", {})
print(f" [{arm_name}] {grade.get('grade')} ({grade.get('method')}): {tail!r}")
print("\n=== Where Long-Context is wrong but SurfSense is right (top 5) ===")
for qid in lc_wrong_surf_right[:5]:
arms = by_q[qid]
b = arms.get("bare_llm", {})
print(f"\n[{qid}] domain={b.get('domain')} qtype={b.get('question_type')}")
print(f" GOLD: {b.get('gold')!r}")
for arm_name in ("bare_llm", "long_context", "surfsense"):
a = arms.get(arm_name, {})
t = (a.get("raw_text") or "").strip()
tail = t[-180:] if t else ""
grade = a.get("graded", {})
print(f" [{arm_name}] {grade.get('grade')} ({grade.get('method')}): {tail!r}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,40 @@
"""Quick sanity-check for the CRAG Task 3 doc map after ingest."""
from __future__ import annotations
import json
import sys
from pathlib import Path
def main() -> int:
p = Path("data/research/maps/crag_t3_doc_map.jsonl")
if not p.exists():
print(f"Doc map missing: {p}")
return 1
rows = []
settings = {}
for line in p.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
row = json.loads(line)
if "__settings__" in row:
settings = row
continue
rows.append(row)
print(f"Settings header: {settings}")
print(f"Doc map rows: {len(rows)}")
for r in rows:
print(f" qid={r['qid']:<10} domain={r['domain']:<8} qtype={r['question_type']}")
print(f" question: {r['question'][:90]}")
print(f" gold: {r['gold_answer'][:90]}")
print(
f" pages: {len(r['page_filenames'])} extracted, "
f"{len(r['document_ids'])} doc_ids, "
f"{len(r['missing_pages'])} missing"
)
return 0
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,65 @@
"""Render a quick textual summary of the latest CRAG run."""
from __future__ import annotations
import glob
import json
def main() -> None:
runs = sorted(glob.glob("data/research/runs/*/crag/run_artifact.json"))
if not runs:
print("(no CRAG runs found)")
return
m = json.load(open(runs[-1], encoding="utf-8"))
metrics = m["metrics"]
print(f"Reading: {runs[-1]}")
print(f"n_questions: {m['extra']['n_questions']}")
print()
print("=== ARMS ===")
for arm in ("bare_llm", "long_context", "surfsense"):
d = metrics[arm]
print(
f"{arm:14s}: "
f"acc={d['accuracy']*100:5.1f}% (Wilson 95% CI "
f"{d['ci_low']*100:.1f}-{d['ci_high']*100:.1f}) | "
f"correct={d['correct_rate']*100:5.1f}% "
f"missing={d['missing_rate']*100:5.1f}% "
f"incorrect={d['incorrect_rate']*100:5.1f}% | "
f"truth={d['truthfulness_score']*100:+5.1f}%"
)
print()
print("=== DELTAS ===")
for key, d in metrics["deltas"].items():
print(
f"{key:30s}: acc={d['accuracy_pp']:+5.1f}pp "
f"truth={d['truthfulness_score_pp']:+5.1f}pp "
f"McNemar p={d['mcnemar_p_value']:.4f} ({d['mcnemar_method']}) "
f"bootstrap CI [{d['bootstrap_ci_low']:+.1f}, {d['bootstrap_ci_high']:+.1f}]"
)
print()
print("=== PER-QUESTION-TYPE TRUTHFULNESS ===")
for qt, row in sorted(metrics["per_question_type"].items()):
n = row["n"]
pieces = [f"{qt:20s} (n={n:3d}):"]
for arm in ("bare_llm", "long_context", "surfsense"):
if arm in row:
pieces.append(f"{arm}={row[arm]['truthfulness_score']*100:+7.1f}%")
print(" ".join(pieces))
print()
print("=== PER-DOMAIN TRUTHFULNESS ===")
for dom, row in sorted(metrics["per_domain"].items()):
n = row["n"]
pieces = [f"{dom:10s} (n={n:3d}):"]
for arm in ("bare_llm", "long_context", "surfsense"):
if arm in row:
pieces.append(f"{arm}={row[arm]['truthfulness_score']*100:+7.1f}%")
print(" ".join(pieces))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,10 @@
"""SurfSense Evals — domain-agnostic eval harness.
Public entry-point is the ``surfsense_evals`` CLI (``python -m surfsense_evals``).
Programmatic embedding is a non-goal for now; everything goes through the CLI
+ filesystem outputs (state.json, raw run JSONL, summary.md/json reports).
"""
from __future__ import annotations
__version__ = "0.1.0"

View file

@ -0,0 +1,13 @@
"""Module entry point: ``python -m surfsense_evals ...``.
Delegates to ``core.cli.main``. ``core.cli`` lazily imports
``surfsense_evals.suites`` so every benchmark gets a chance to register
before argparse builds its subcommand groups.
"""
from __future__ import annotations
from surfsense_evals.core.cli import main
if __name__ == "__main__": # pragma: no cover
raise SystemExit(main())

View file

@ -0,0 +1,8 @@
"""Domain-agnostic infrastructure shared by every suite.
Nothing under ``core/`` knows or cares about a specific evaluation domain.
Suites live under ``surfsense_evals.suites.<domain>.<benchmark>`` and
register themselves with ``core.registry`` on import.
"""
from __future__ import annotations

View file

@ -0,0 +1,44 @@
"""Arm protocol + concrete arms shared across suites.
Concrete arms (``NativePdfArm``, ``SurfSenseArm``, ``BareLlmArm``) are
imported lazily via ``__getattr__`` so consumers that only need the
protocol e.g. the registry's ``Arm`` re-export — don't transitively
pull in ``httpx`` providers or the SurfSense client unless they
actually use those arms.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from .base import Arm, ArmRequest, ArmResult
if TYPE_CHECKING: # pragma: no cover
from .bare_llm import BareLlmArm
from .native_pdf import NativePdfArm
from .surfsense import SurfSenseArm
__all__ = [
"Arm",
"ArmRequest",
"ArmResult",
"BareLlmArm",
"NativePdfArm",
"SurfSenseArm",
]
def __getattr__(name: str): # PEP 562
if name == "NativePdfArm":
from .native_pdf import NativePdfArm
return NativePdfArm
if name == "SurfSenseArm":
from .surfsense import SurfSenseArm
return SurfSenseArm
if name == "BareLlmArm":
from .bare_llm import BareLlmArm
return BareLlmArm
raise AttributeError(f"module 'surfsense_evals.core.arms' has no attribute {name!r}")

View file

@ -0,0 +1,100 @@
"""Bare-LLM arm: chat completion with prompt-only input, no retrieval.
Pairs with ``SurfSenseArm`` for any benchmark that wants to measure
"how much does the model already know without RAG?". For factuality /
multi-hop benchmarks (FRAMES, MuSiQue, ) this produces the published
"naive prompting" baseline e.g. FRAMES's 40.8% on Gemini-Pro-1.5.
Symmetric with ``NativePdfArm`` in shape, but the request carries no
``pdf_paths``: the prompt itself is the only input the model gets.
"""
from __future__ import annotations
import logging
from ..providers.openrouter_chat import OpenRouterChatProvider
from .base import Arm, ArmRequest, ArmResult
logger = logging.getLogger(__name__)
class BareLlmArm(Arm):
"""``Arm`` implementation backed by ``OpenRouterChatProvider``.
``name`` defaults to ``"bare_llm"`` but is overridable per-instance.
Suites that want two distinct OpenRouter chat arms (e.g. CRAG's
``bare_llm`` vs ``long_context`` both backed by chat-completions
but exercising different prompt strategies) instantiate twice with
different names so the metrics aggregator can keep them separate.
"""
name: str = "bare_llm"
def __init__(
self,
*,
provider: OpenRouterChatProvider,
max_output_tokens: int | None = 1024,
system_prompt: str | None = None,
name: str | None = None,
) -> None:
self._provider = provider
self._max_output = max_output_tokens
self._system_prompt = system_prompt
if name:
self.name = name
@classmethod
def from_env(
cls,
*,
api_key: str,
model: str,
base_url: str = "https://openrouter.ai/api/v1",
max_output_tokens: int | None = 1024,
system_prompt: str | None = None,
name: str | None = None,
) -> BareLlmArm:
provider = OpenRouterChatProvider(
api_key=api_key,
base_url=base_url,
model=model,
)
return cls(
provider=provider,
max_output_tokens=max_output_tokens,
system_prompt=system_prompt,
name=name,
)
async def answer(self, request: ArmRequest) -> ArmResult:
try:
response = await self._provider.complete(
prompt=request.prompt,
system_prompt=self._system_prompt,
max_tokens=self._max_output,
)
except Exception as exc: # noqa: BLE001
return ArmResult(
arm=self.name,
question_id=request.question_id,
raw_text="",
error=f"{type(exc).__name__}: {exc}",
)
return ArmResult(
arm=self.name,
question_id=request.question_id,
raw_text=response.text,
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
cost_micros=response.cost_micros,
latency_ms=response.latency_ms,
extra={
"model": self._provider.model,
"finish_reason": response.finish_reason,
},
)
__all__ = ["BareLlmArm"]

View file

@ -0,0 +1,93 @@
"""Arm protocol + the value types every arm exchanges with a runner.
An ``Arm`` is "one way to answer one question". Two ship in this PR:
* ``NativePdfArm`` drop the PDF straight into an OpenRouter
chat-completions request with ``plugins=[{file-parser, engine:
native}]``. Used for the head-to-head "is the model good enough on
its own?" measurement.
* ``SurfSenseArm`` POST ``/api/v1/new_chat`` with the question
scoped to the relevant ``mentioned_document_ids``; consume the SSE
stream and parse citations.
Both implement the same protocol so a benchmark runner only sees
``Arm.answer(request) -> ArmResult``.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol
@dataclass
class ArmRequest:
"""One arm-call worth of input.
* ``question_id`` is opaque used for logging and joining results.
* ``prompt`` is the fully-formatted text the arm should send. The
runner is responsible for prompt construction so head-to-head
comparisons use byte-identical text.
* ``pdf_paths`` is the per-question source PDFs (used by
``NativePdfArm``). Empty for retrieval-only / corpus-wide
benchmarks.
* ``mentioned_document_ids`` is the SurfSense document scoping list
(used by ``SurfSenseArm``). When ``None`` SurfSense retrieves
across the whole search space.
* ``options`` is a free-form bag of arm-specific overrides
(e.g. SurfSense's ``disabled_tools``).
"""
question_id: str
prompt: str
pdf_paths: list[Path] = field(default_factory=list)
mentioned_document_ids: list[int] | None = None
options: dict[str, Any] = field(default_factory=dict)
@dataclass
class ArmResult:
"""Outcome of one ``Arm.answer`` invocation."""
arm: str
question_id: str
raw_text: str
answer_letter: str | None = None
citations: list[dict[str, Any]] = field(default_factory=list)
input_tokens: int = 0
output_tokens: int = 0
cost_micros: int = 0
latency_ms: int = 0
error: str | None = None
extra: dict[str, Any] = field(default_factory=dict)
@property
def ok(self) -> bool:
return self.error is None
def to_jsonl(self) -> dict[str, Any]:
"""Stable dict shape for ``data/<suite>/runs/<ts>/<bench>_raw.jsonl``."""
return {
"arm": self.arm,
"question_id": self.question_id,
"answer_letter": self.answer_letter,
"raw_text": self.raw_text,
"citations": self.citations,
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cost_micros": self.cost_micros,
"latency_ms": self.latency_ms,
"error": self.error,
"extra": self.extra,
}
class Arm(Protocol):
"""One concrete way to answer questions for a given run."""
name: str
async def answer(self, request: ArmRequest) -> ArmResult: # pragma: no cover - protocol
...

View file

@ -0,0 +1,104 @@
"""Native-PDF arm: drop the PDF straight into OpenRouter chat-completions.
Generic across suites a benchmark just supplies the prompt and the
single PDF path. Multi-PDF questions concatenate in the runner before
calling this arm so each ``answer`` invocation feeds the model exactly
one ``data:application/pdf;base64,...`` block (matches the human
"drag-and-drop one PDF into Claude" intent).
"""
from __future__ import annotations
import logging
from ..parse.answer_letter import extract_answer_letter
from ..providers.openrouter_pdf import OpenRouterPdfProvider, PdfEngine
from .base import Arm, ArmRequest, ArmResult
logger = logging.getLogger(__name__)
class NativePdfArm(Arm):
"""``Arm`` implementation backed by ``OpenRouterPdfProvider``."""
name: str = "native_pdf"
def __init__(
self,
*,
provider: OpenRouterPdfProvider,
max_output_tokens: int | None = 1024,
) -> None:
self._provider = provider
self._max_output = max_output_tokens
@classmethod
def from_env(
cls,
*,
api_key: str,
model: str,
engine: PdfEngine = PdfEngine.NATIVE,
base_url: str = "https://openrouter.ai/api/v1",
max_output_tokens: int | None = 1024,
) -> NativePdfArm:
provider = OpenRouterPdfProvider(
api_key=api_key,
base_url=base_url,
model=model,
engine=engine,
)
return cls(provider=provider, max_output_tokens=max_output_tokens)
async def answer(self, request: ArmRequest) -> ArmResult:
if not request.pdf_paths:
return ArmResult(
arm=self.name,
question_id=request.question_id,
raw_text="",
error="native_pdf arm requires at least one pdf_path",
)
if len(request.pdf_paths) > 1:
# The plan calls out one-PDF-per-question so the head-to-head
# is fair; runners are responsible for upstream concatenation.
logger.debug(
"qid=%s native_pdf got %d pdfs; using first only",
request.question_id,
len(request.pdf_paths),
)
pdf = request.pdf_paths[0]
try:
response = await self._provider.complete(
prompt=request.prompt,
pdf_path=pdf,
max_tokens=self._max_output,
)
except Exception as exc: # noqa: BLE001
return ArmResult(
arm=self.name,
question_id=request.question_id,
raw_text="",
error=f"{type(exc).__name__}: {exc}",
)
letter = extract_answer_letter(response.text)
return ArmResult(
arm=self.name,
question_id=request.question_id,
raw_text=response.text,
answer_letter=letter.letter,
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
cost_micros=response.cost_micros,
latency_ms=response.latency_ms,
extra={
"model": self._provider.model,
"engine": self._provider.engine.value,
"answer_letter_strategy": letter.strategy,
"finish_reason": response.finish_reason,
"pdf_filename": pdf.name,
},
)
__all__ = ["NativePdfArm"]

View file

@ -0,0 +1,104 @@
"""SurfSense arm: per-question fresh thread + ``/api/v1/new_chat`` stream.
For every question:
* Create a fresh ``NewChatThread`` on the suite's pinned SearchSpace.
This sidesteps the per-thread ``THREAD_BUSY`` 409 (a single thread
serialises turns, see ``surfsense_backend/app/routes/new_chat_routes.py:191-220``).
* POST ``/api/v1/new_chat`` with the prompt and the per-question
``mentioned_document_ids`` (``surfsense_backend/app/schemas/new_chat.py:241-243``).
* Consume the SSE stream via ``NewChatClient.ask`` which accumulates
text deltas and returns ``StreamedAnswer``.
* Optionally delete the thread (default ON for ephemeral runs).
Citations are parsed from the streamed assistant text via the
canonical regex port; chunk ids are returned in ``ArmResult.citations``
for the runner to map back to corpus ids.
"""
from __future__ import annotations
import logging
from ..clients import NewChatClient
from ..parse.answer_letter import extract_answer_letter
from .base import Arm, ArmRequest, ArmResult
logger = logging.getLogger(__name__)
class SurfSenseArm(Arm):
"""``Arm`` implementation backed by ``NewChatClient``."""
name: str = "surfsense"
def __init__(
self,
*,
client: NewChatClient,
search_space_id: int,
ephemeral_threads: bool = True,
thread_title_prefix: str = "eval",
) -> None:
self._client = client
self._search_space_id = search_space_id
self._ephemeral = ephemeral_threads
self._title_prefix = thread_title_prefix
async def answer(self, request: ArmRequest) -> ArmResult:
thread_id: int | None = None
try:
thread_id = await self._client.create_thread(
search_space_id=self._search_space_id,
title=f"{self._title_prefix}:{request.question_id}",
)
answer = await self._client.ask(
thread_id=thread_id,
search_space_id=self._search_space_id,
user_query=request.prompt,
mentioned_document_ids=request.mentioned_document_ids,
disabled_tools=request.options.get("disabled_tools"),
)
except Exception as exc: # noqa: BLE001
return ArmResult(
arm=self.name,
question_id=request.question_id,
raw_text="",
error=f"{type(exc).__name__}: {exc}",
extra={"thread_id": thread_id},
)
finally:
if self._ephemeral and thread_id is not None:
try:
await self._client.delete_thread(thread_id)
except Exception as exc: # noqa: BLE001
logger.debug(
"Failed to delete thread %s: %s", thread_id, exc
)
letter = extract_answer_letter(answer.text)
return ArmResult(
arm=self.name,
question_id=request.question_id,
raw_text=answer.text,
answer_letter=letter.letter,
citations=answer.citations,
latency_ms=answer.latency_ms,
# SurfSense doesn't surface input/output token counts in the
# SSE stream today; leaving the cost / token fields at 0
# documents that gap. Estimating from the raw text would
# bias the comparison against the SurfSense arm.
extra={
"thread_id": thread_id,
"search_space_id": self._search_space_id,
"answer_letter_strategy": letter.strategy,
"user_message_id": answer.user_message_id,
"assistant_message_id": answer.assistant_message_id,
"finished_normally": answer.finished_normally,
"n_raw_events": len(answer.raw_events),
"n_mentioned_documents": len(request.mentioned_document_ids or []),
},
)
__all__ = ["SurfSenseArm"]

View file

@ -0,0 +1,273 @@
"""Dual-mode credential resolver + httpx client factory with 401 auto-refresh.
SurfSense supports ``AUTH_TYPE=LOCAL`` (email + password) and
``AUTH_TYPE=GOOGLE`` (Google OAuth frontend stores JWT in ``localStorage``).
There is no headless equivalent of the Google flow, so the harness handles
both modes by treating the JWT as the universal credential:
* **LOCAL**: harness POSTs form-encoded ``username`` + ``password`` to
``/auth/jwt/login``, reads ``{access_token, refresh_token}``.
* **GOOGLE / pre-issued JWT**: operator pastes their existing JWT (and
optionally refresh token) into ``SURFSENSE_JWT`` /
``SURFSENSE_REFRESH_TOKEN``; harness skips login.
Either way ``client_with_auth`` returns one shared
``httpx.AsyncClient`` with ``Authorization: Bearer <jwt>`` set and an
event hook that, on a 401 with a refresh token in scope, calls
``POST /auth/jwt/refresh`` and retries the original request once. JWT
lifetime defaults to one day backend-side, so this matters for long
MIRAGE runs.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
import httpx
from .config import Config
logger = logging.getLogger(__name__)
class CredentialError(RuntimeError):
"""Raised when no credential mode is configured."""
_NO_CREDENTIALS_MESSAGE = (
"No SurfSense credentials configured. Set ONE of:\n"
" (LOCAL) SURFSENSE_USER_EMAIL + SURFSENSE_USER_PASSWORD\n"
" (GOOGLE) SURFSENSE_JWT (and optionally SURFSENSE_REFRESH_TOKEN)\n"
"For GOOGLE: log in to SurfSense in your browser, open DevTools → "
"Application → Local Storage → copy `surfsense_bearer_token` and "
"`surfsense_refresh_token` into those env vars."
)
@dataclass
class TokenBundle:
"""Mutable token state — refresh hook updates ``access_token`` in place."""
access_token: str
refresh_token: str | None = None
# ``mode`` is informational only ("local" or "jwt"); used in error messages.
mode: str = "jwt"
# ---------------------------------------------------------------------------
# Token acquisition
# ---------------------------------------------------------------------------
async def acquire_token(config: Config, *, http: httpx.AsyncClient | None = None) -> TokenBundle:
"""Resolve credentials → ``TokenBundle``.
Precedence:
1. ``SURFSENSE_JWT`` set use it directly. Refresh token captured if
supplied.
2. ``SURFSENSE_USER_EMAIL`` + ``SURFSENSE_USER_PASSWORD`` set
form-encoded POST to ``/auth/jwt/login``.
3. Neither raise ``CredentialError``.
The optional ``http`` argument lets tests inject a mocked client; if
omitted a one-shot client is created for the login call only.
"""
if config.has_jwt_mode():
return TokenBundle(
access_token=config.surfsense_jwt or "",
refresh_token=config.surfsense_refresh_token,
mode="jwt",
)
if config.has_local_mode():
async def _login(client: httpx.AsyncClient) -> TokenBundle:
response = await client.post(
f"{config.surfsense_api_base}/auth/jwt/login",
data={
"username": config.surfsense_user_email,
"password": config.surfsense_user_password,
},
headers={"Accept": "application/json"},
)
if response.status_code != 200:
raise CredentialError(
f"LOCAL login failed (HTTP {response.status_code}): "
f"{_safe_text(response)}"
)
payload = response.json()
access = payload.get("access_token")
if not access:
raise CredentialError(
f"LOCAL login response missing access_token: {payload!r}"
)
return TokenBundle(
access_token=access,
refresh_token=payload.get("refresh_token") or None,
mode="local",
)
if http is not None:
return await _login(http)
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, connect=10.0)) as client:
return await _login(client)
raise CredentialError(_NO_CREDENTIALS_MESSAGE)
def _safe_text(response: httpx.Response, *, limit: int = 200) -> str:
body = response.text or ""
if len(body) > limit:
return body[:limit] + ""
return body
# ---------------------------------------------------------------------------
# httpx client + 401 auto-refresh
# ---------------------------------------------------------------------------
class _AuthState:
"""Shared mutable holder closed over by the auth event hook.
Kept private so callers can't accidentally mutate the access token
out-of-band; ``client_with_auth`` returns the client directly.
"""
def __init__(self, config: Config, tokens: TokenBundle) -> None:
self.config = config
self.tokens = tokens
self._refresh_in_flight: bool = False
def _build_auth_request(state: _AuthState, request: httpx.Request) -> None:
"""Stamp the current bearer onto ``request`` (request-event hook)."""
request.headers["Authorization"] = f"Bearer {state.tokens.access_token}"
async def _refresh_access_token(
state: _AuthState, transport: httpx.AsyncBaseTransport | None = None
) -> bool:
"""POST ``/auth/jwt/refresh`` with the current refresh token.
Returns ``True`` on success and updates ``state.tokens`` in place.
Returns ``False`` if no refresh token is configured or the call fails.
Recursive 401s are avoided by using a *new* client without the auth
hook.
"""
refresh = state.tokens.refresh_token
if not refresh:
return False
try:
async with httpx.AsyncClient(
timeout=httpx.Timeout(15.0, connect=5.0),
transport=transport,
) as inner:
response = await inner.post(
f"{state.config.surfsense_api_base}/auth/jwt/refresh",
json={"refresh_token": refresh},
headers={"Accept": "application/json"},
)
except httpx.HTTPError as exc:
logger.warning("Token refresh transport error: %s", exc)
return False
if response.status_code != 200:
logger.warning(
"Token refresh rejected (HTTP %s): %s",
response.status_code,
_safe_text(response),
)
return False
payload = response.json()
new_access = payload.get("access_token")
if not new_access:
logger.warning("Refresh response missing access_token: %r", payload)
return False
state.tokens.access_token = new_access
new_refresh = payload.get("refresh_token")
if new_refresh:
state.tokens.refresh_token = new_refresh
return True
def client_with_auth(
config: Config,
tokens: TokenBundle,
*,
timeout: float = 60.0,
transport: httpx.AsyncBaseTransport | None = None,
base_url: str | None = None,
) -> httpx.AsyncClient:
"""Build a single shared ``httpx.AsyncClient`` for the SurfSense API.
* Stamps ``Authorization: Bearer <jwt>`` on every outgoing request.
* On any 401 response, attempts a single refresh (if a refresh token
is configured) and retries the original request once. The retry
uses a fresh stamping of the bearer header, so a successful
refresh transparently unblocks long runs.
* The retry is best-effort repeated 401s after a refresh attempt
are surfaced to the caller so they can re-auth manually.
Pass ``base_url`` to scope a sub-client (e.g. tests). The default
keeps full URLs in calling code, which makes route-spec citations in
the codebase easier to grep.
"""
state = _AuthState(config, tokens)
async def _request_hook(request: httpx.Request) -> None:
_build_auth_request(state, request)
# ``send`` is overridden in ``_AuthAwareClient`` to retry once on 401
# after refreshing the bearer. httpx's response event-hook can't
# *replace* a response, so we need a subclass to do the replay.
client = _AuthAwareClient(
state=state,
transport=transport,
timeout=httpx.Timeout(timeout, connect=10.0),
base_url=base_url or "",
event_hooks={"request": [_request_hook]},
)
return client
class _AuthAwareClient(httpx.AsyncClient):
"""``AsyncClient`` that retries once on 401 after refreshing the token."""
def __init__(self, *, state: _AuthState, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._auth_state = state
async def send( # type: ignore[override]
self, request: httpx.Request, **kwargs: Any
) -> httpx.Response:
response = await super().send(request, **kwargs)
if response.status_code != 401:
return response
# Don't refresh while a refresh is itself in flight.
if self._auth_state._refresh_in_flight:
return response
self._auth_state._refresh_in_flight = True
try:
refreshed = await _refresh_access_token(self._auth_state)
finally:
self._auth_state._refresh_in_flight = False
if not refreshed:
return response
# Re-stamp and replay once. ``request`` is reusable.
await response.aclose()
request.headers["Authorization"] = f"Bearer {self._auth_state.tokens.access_token}"
return await super().send(request, **kwargs)
__all__ = [
"CredentialError",
"TokenBundle",
"acquire_token",
"client_with_auth",
]

View file

@ -0,0 +1,790 @@
"""Argparse CLI for ``python -m surfsense_evals``.
Subcommands:
* ``setup --suite <name> --provider-model <slug> [--agent-llm-id <int>]``
* ``teardown --suite <name>``
* ``models list [--provider openrouter] [--grep <s>]``
* ``suites list``
* ``benchmarks list [--suite <name>]``
* ``ingest <suite> <benchmark> [benchmark flags]``
* ``run <suite> <benchmark> [benchmark flags]``
* ``report --suite <name> [--benchmark <name>]``
The ``ingest`` / ``run`` subparsers are built dynamically from the
registry adding a new benchmark only requires registering it; the
CLI surface comes for free. ``add_run_args`` lets each benchmark
publish its own flags.
Design choices worth flagging:
* ``setup`` rejects ``agent_llm_id == 0`` (Auto / LiteLLM router) so
per-question accuracy is reproducible.
* ``setup`` validates that the picked LLM config has
``provider == "OPENROUTER"`` and ``model_name == --provider-model``
before declaring success both arms of the head-to-head must hit
the same OpenRouter slug.
* Lifecycle state is keyed by suite, so ``setup --suite legal`` does
not touch ``medical``'s SearchSpace, and vice versa.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import sys
from dataclasses import dataclass
from typing import Any
import sys
import httpx
from rich.console import Console
from rich.table import Table
# Windows' legacy console (cp1252) crashes when Rich tries to write characters
# outside the active codepage (e.g. '->', em-dashes, box-drawing). Force UTF-8
# on stdout/stderr and disable Rich's legacy_windows render path so the file
# stream is used directly. Modern Windows (>=10, VS Code terminal, Windows
# Terminal, PowerShell, cmd) all interpret ANSI escapes natively.
if sys.platform == "win32":
for _stream in (sys.stdout, sys.stderr):
try:
_stream.reconfigure(encoding="utf-8", errors="replace")
except (AttributeError, ValueError):
pass
from . import registry
from .auth import CredentialError, acquire_token, client_with_auth
from .clients import SearchSpaceClient
from .clients.search_space import LlmPreferences
from .config import (
DEFAULT_SCENARIO,
SCENARIOS,
Config,
SuiteState,
clear_suite_state,
get_suite_state,
load_config,
set_suite_state,
utc_iso_timestamp,
)
from .vision_llm import VisionConfigError, resolve_vision_llm
logger = logging.getLogger("surfsense_evals")
console = Console(legacy_windows=False)
# ---------------------------------------------------------------------------
# Discovery
# ---------------------------------------------------------------------------
def _discover_suites() -> list[str]:
"""Trigger ``register(...)`` for every benchmark.
Imported lazily so ``models list`` (which doesn't need any
benchmark) still runs fast.
"""
from surfsense_evals.suites import discover_suites
return discover_suites()
# ---------------------------------------------------------------------------
# Global LLM config fetcher (used by setup + models list)
# ---------------------------------------------------------------------------
@dataclass
class LlmConfigEntry:
id: int
name: str
provider: str
model_name: str
raw: dict[str, Any]
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> LlmConfigEntry:
return cls(
id=int(payload["id"]),
name=str(payload.get("name", "")),
provider=str(payload.get("provider", "")).upper(),
model_name=str(payload.get("model_name", "")),
raw=payload,
)
async def _list_global_llm_configs(http: httpx.AsyncClient, base: str) -> list[LlmConfigEntry]:
response = await http.get(
f"{base}/api/v1/global-new-llm-configs",
headers={"Accept": "application/json"},
)
response.raise_for_status()
payload = response.json()
if not isinstance(payload, list):
raise RuntimeError(f"Unexpected /global-new-llm-configs payload: {payload!r}")
return [LlmConfigEntry.from_payload(item) for item in payload]
def _resolve_openrouter_id(
candidates: list[LlmConfigEntry],
provider_model: str,
*,
explicit_id: int | None,
) -> int:
"""Resolve the SurfSense LLM id for ``provider_model``.
Behaviour:
* If ``explicit_id`` is given: return it directly. The caller is
then expected to GET-validate that the row's
``provider == "OPENROUTER"`` and ``model_name`` matches the slug.
That branch supports positive BYOK ``NewLLMConfig`` rows whose
slugs may overlap with global OpenRouter virtuals.
* Otherwise: filter to ``provider == "OPENROUTER"`` and
``model_name == provider_model``. Expect exactly one match
raise with a friendly message otherwise.
"""
if explicit_id is not None:
return explicit_id
matches = [
c for c in candidates if c.provider == "OPENROUTER" and c.model_name == provider_model
]
if not matches:
sample = ", ".join(
f"{c.model_name} (id={c.id})" for c in candidates if c.provider == "OPENROUTER"
)[:600]
raise RuntimeError(
f"No OpenRouter config found for slug '{provider_model}'. "
"Make sure `openrouter_integration.enabled: true` in "
"global_llm_config.yaml and that the Celery worker has "
"finished its first refresh (the catalogue is fetched at "
"Celery startup per `app/celery_app.py`). "
f"Available OpenRouter slugs (sample): {sample or '<none>'}.\n"
"Browse with: python -m surfsense_evals models list --grep <substring>"
)
if len(matches) > 1:
listing = "\n".join(f" id={c.id} name={c.name!r}" for c in matches)
raise RuntimeError(
f"Multiple OpenRouter configs for slug '{provider_model}':\n{listing}\n"
"Pass --agent-llm-id <id> to disambiguate."
)
return matches[0].id
# ---------------------------------------------------------------------------
# Subcommand implementations
# ---------------------------------------------------------------------------
async def _cmd_setup(args: argparse.Namespace) -> int:
suite = args.suite
provider_model: str = args.provider_model
explicit_id: int | None = args.agent_llm_id
scenario: str = args.scenario
vision_llm_slug: str | None = args.vision_llm
native_arm_model: str | None = args.native_arm_model
skip_vision_setup: bool = args.no_vision_llm_setup
if explicit_id == 0:
console.print(
"[red]agent_llm_id == 0 (Auto / LiteLLM router) is not allowed — "
"results would not be reproducible.[/red]"
)
return 2
if scenario not in SCENARIOS:
console.print(
f"[red]Unknown scenario {scenario!r}. Pick one of: "
f"{', '.join(SCENARIOS)}[/red]"
)
return 2
# Scenario-specific validation. Each branch documents WHY the rule
# exists so the operator's mental model matches what the runner does.
if scenario == "cost-arbitrage":
if not native_arm_model:
console.print(
"[red]--scenario cost-arbitrage requires --native-arm-model "
"<vision-capable slug>.[/red] The native arm needs a vision "
"model to fairly answer image-bearing questions; SurfSense "
"answers from already-extracted text via --provider-model."
)
return 2
if native_arm_model == provider_model:
console.print(
"[yellow]--native-arm-model equals --provider-model in "
"cost-arbitrage; that's degenerate (same as head-to-head). "
"Pick a different slug or switch to --scenario head-to-head.[/yellow]"
)
elif scenario in ("head-to-head", "symmetric-cheap"):
if native_arm_model:
console.print(
f"[yellow]--native-arm-model is ignored for --scenario {scenario} "
f"(both arms answer with --provider-model={provider_model!r}).[/yellow]"
)
native_arm_model = None # don't persist a stale value
config = load_config()
try:
token = await acquire_token(config)
except CredentialError as exc:
console.print(f"[red]{exc}[/red]")
return 2
async with client_with_auth(config, token) as http:
candidates = await _list_global_llm_configs(http, config.surfsense_api_base)
try:
agent_llm_id = _resolve_openrouter_id(
candidates, provider_model, explicit_id=explicit_id
)
except RuntimeError as exc:
console.print(f"[red]{exc}[/red]")
return 2
ss_client = SearchSpaceClient(http, config.surfsense_api_base)
existing = get_suite_state(config, suite)
if existing is not None:
try:
row = await ss_client.get(existing.search_space_id)
console.print(
f"Reusing existing SearchSpace [cyan]{row.name}[/cyan] "
f"(id={row.id}) for suite [bold]{suite}[/bold]."
)
search_space_id = row.id
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 404:
console.print(
f"[yellow]state.json pointed at SearchSpace id={existing.search_space_id} "
f"but backend returned 404; creating a fresh one.[/yellow]"
)
existing = None
else:
raise
if existing is None:
ss_name = f"eval-{suite}-{utc_iso_timestamp()}"
row = await ss_client.create(
ss_name, description=f"surfsense-evals lifecycle ({suite})"
)
console.print(
f"Created SearchSpace [cyan]{row.name}[/cyan] (id={row.id}) "
f"for suite [bold]{suite}[/bold]."
)
search_space_id = row.id
# Resolve + attach the vision LLM config (unless explicitly skipped).
# Asymmetric scenarios make the vision LLM at ingest a hard
# requirement — without it, SurfSense's chunks have no image
# content and the entire framing collapses.
vision_required = scenario in ("symmetric-cheap", "cost-arbitrage")
vision_config_id: int | None = None
vision_provider_model: str | None = None
if not skip_vision_setup and (vision_required or vision_llm_slug is not None):
try:
vision_candidates = await ss_client.list_global_vision_llm_configs()
resolved = resolve_vision_llm(
vision_candidates, explicit_slug=vision_llm_slug
)
except VisionConfigError as exc:
console.print(f"[red]{exc}[/red]")
return 2
vision_config_id = resolved.config_id
vision_provider_model = resolved.provider_model
console.print(
f"Vision LLM at ingest: [cyan]{vision_provider_model}[/cyan] "
f"(id={vision_config_id}, selected_via={resolved.selected_via})."
)
pref_kwargs: dict[str, Any] = {"agent_llm_id": agent_llm_id}
if vision_config_id is not None:
pref_kwargs["vision_llm_config_id"] = vision_config_id
await ss_client.set_llm_preferences(search_space_id, **pref_kwargs)
prefs = await ss_client.get_llm_preferences(search_space_id)
if not _validate_pin(prefs, provider_model):
agent = prefs.agent_llm or {}
console.print(
f"[red]LLM pin validation FAILED.[/red] After PUT, "
f"agent_llm.provider={agent.get('provider')!r}, "
f"model_name={agent.get('model_name')!r}; expected "
f"provider=OPENROUTER, model_name={provider_model!r}."
)
return 2
if vision_config_id is not None and prefs.vision_llm_config_id != vision_config_id:
console.print(
f"[red]Vision LLM pin validation FAILED.[/red] After PUT, "
f"vision_llm_config_id={prefs.vision_llm_config_id!r}; "
f"expected {vision_config_id!r}."
)
return 2
suite_state = SuiteState(
search_space_id=search_space_id,
agent_llm_id=agent_llm_id,
provider_model=provider_model,
created_at=utc_iso_timestamp(),
ingestion_maps=existing.ingestion_maps if existing else {},
scenario=scenario,
vision_llm_config_id=vision_config_id,
vision_provider_model=vision_provider_model,
native_arm_model=native_arm_model,
)
set_suite_state(config, suite, suite_state)
summary_bits = [
f"suite={suite!r}",
f"scenario={scenario!r}",
f"search_space_id={suite_state.search_space_id}",
f"agent_llm_id={suite_state.agent_llm_id}",
f"provider_model={suite_state.provider_model!r}",
]
if suite_state.vision_provider_model:
summary_bits.append(f"vision_provider_model={suite_state.vision_provider_model!r}")
if suite_state.native_arm_model:
summary_bits.append(f"native_arm_model={suite_state.native_arm_model!r}")
console.print(f"[green]setup OK[/green] {' '.join(summary_bits)}")
return 0
def _validate_pin(prefs: LlmPreferences, provider_model: str) -> bool:
agent = prefs.agent_llm or {}
return (
str(agent.get("provider", "")).upper() == "OPENROUTER"
and str(agent.get("model_name", "")) == provider_model
)
async def _cmd_teardown(args: argparse.Namespace) -> int:
suite = args.suite
config = load_config()
state = get_suite_state(config, suite)
if state is None:
console.print(f"[yellow]No state for suite {suite!r}; nothing to tear down.[/yellow]")
return 0
try:
token = await acquire_token(config)
except CredentialError as exc:
console.print(f"[red]{exc}[/red]")
return 2
async with client_with_auth(config, token) as http:
ss_client = SearchSpaceClient(http, config.surfsense_api_base)
try:
await ss_client.delete(state.search_space_id)
except httpx.HTTPStatusError as exc:
console.print(
f"[yellow]DELETE failed (HTTP {exc.response.status_code}); "
"clearing state.json anyway.[/yellow]"
)
clear_suite_state(config, suite)
console.print(
f"[green]teardown OK[/green] suite={suite!r} "
f"(SearchSpace soft-deleted, state.json slot cleared)."
)
return 0
async def _cmd_models_list(args: argparse.Namespace) -> int:
config = load_config()
try:
token = await acquire_token(config)
except CredentialError as exc:
console.print(f"[red]{exc}[/red]")
return 2
async with client_with_auth(config, token) as http:
entries = await _list_global_llm_configs(http, config.surfsense_api_base)
grep = (args.grep or "").lower()
provider_filter = (args.provider or "").upper()
rows: list[LlmConfigEntry] = []
for e in entries:
if provider_filter and e.provider != provider_filter:
continue
if grep and grep not in e.model_name.lower() and grep not in e.name.lower():
continue
rows.append(e)
table = Table(
title=f"Global LLM configs ({len(rows)} of {len(entries)})",
show_lines=False,
)
table.add_column("id", justify="right", style="cyan")
table.add_column("provider", style="magenta")
table.add_column("model_name", style="green")
table.add_column("name")
for e in sorted(rows, key=lambda x: (x.provider, x.model_name)):
table.add_row(str(e.id), e.provider, e.model_name, e.name)
console.print(table)
return 0
def _cmd_suites_list(_args: argparse.Namespace) -> int:
_discover_suites()
suites = registry.list_suites()
if not suites:
console.print(
"[yellow]No suites registered. Drop a benchmark under "
"src/surfsense_evals/suites/<domain>/<benchmark>/.[/yellow]"
)
return 0
table = Table(title=f"Registered suites ({len(suites)})")
table.add_column("suite", style="bold")
table.add_column("benchmarks", style="green")
for suite in suites:
names = [b.name for b in registry.list_benchmarks(suite)]
table.add_row(suite, ", ".join(names) or "<none>")
console.print(table)
return 0
def _cmd_benchmarks_list(args: argparse.Namespace) -> int:
_discover_suites()
benchmarks = registry.list_benchmarks(args.suite)
if not benchmarks:
console.print("[yellow]No benchmarks registered.[/yellow]")
return 0
table = Table(title=f"Benchmarks ({len(benchmarks)})")
table.add_column("suite", style="bold")
table.add_column("name", style="cyan")
table.add_column("headline", justify="center")
table.add_column("description")
for b in benchmarks:
table.add_row(
b.suite,
b.name,
"yes" if b.headline else "no",
getattr(b, "description", ""),
)
console.print(table)
return 0
async def _cmd_ingest(args: argparse.Namespace) -> int:
benchmark = registry.get(args.suite, args.benchmark)
config = load_config()
state = get_suite_state(config, args.suite)
if state is None:
console.print(
f"[red]No setup for suite {args.suite!r}. Run "
f"`python -m surfsense_evals setup --suite {args.suite} "
f"--provider-model <slug>` first.[/red]"
)
return 2
try:
token = await acquire_token(config)
except CredentialError as exc:
console.print(f"[red]{exc}[/red]")
return 2
# Forward parsed CLI flags into ingest() so a benchmark can honour
# its own flags (e.g. MIRAGE's --skip-snippet-filter / --corpus).
extra_kwargs = {
k: v
for k, v in vars(args).items()
if k not in {"_func", "_async", "command", "subcommand", "suite", "benchmark", "log_level"}
}
async with client_with_auth(config, token) as http:
ctx = registry.RunContext(
suite=args.suite,
benchmark=args.benchmark,
config=config,
suite_state=state,
http=http,
)
await benchmark.ingest(ctx, **extra_kwargs)
console.print(f"[green]ingest OK[/green] {args.suite}/{args.benchmark}")
return 0
async def _cmd_run(args: argparse.Namespace) -> int:
benchmark = registry.get(args.suite, args.benchmark)
config = load_config()
state = get_suite_state(config, args.suite)
if state is None:
console.print(
f"[red]No setup for suite {args.suite!r}. Run "
f"`python -m surfsense_evals setup --suite {args.suite} "
f"--provider-model <slug>` first.[/red]"
)
return 2
try:
token = await acquire_token(config)
except CredentialError as exc:
console.print(f"[red]{exc}[/red]")
return 2
extra_kwargs = {
k: v
for k, v in vars(args).items()
if k not in {"_func", "_async", "command", "subcommand", "suite", "benchmark", "log_level"}
}
async with client_with_auth(config, token) as http:
ctx = registry.RunContext(
suite=args.suite,
benchmark=args.benchmark,
config=config,
suite_state=state,
http=http,
)
artifact = await benchmark.run(ctx, **extra_kwargs)
console.print(
f"[green]run OK[/green] {args.suite}/{args.benchmark}"
f"{artifact.raw_path}"
)
return 0
async def _cmd_report(args: argparse.Namespace) -> int:
from .report import write_report
benchmark_filter = args.benchmark
config = load_config()
state = get_suite_state(config, args.suite)
if state is None:
console.print(f"[red]No setup for suite {args.suite!r}.[/red]")
return 2
benchmarks = registry.list_benchmarks(args.suite)
if benchmark_filter:
benchmarks = [b for b in benchmarks if b.name == benchmark_filter]
if not benchmarks:
console.print(
f"[red]No registered benchmark named {benchmark_filter!r} in suite {args.suite!r}.[/red]"
)
return 2
artifacts = _collect_artifacts(config, args.suite, [b.name for b in benchmarks])
if not artifacts:
console.print(
"[yellow]No run artifacts found under "
f"{config.suite_runs_dir(args.suite)}. Run a benchmark first.[/yellow]"
)
return 1
grouped: dict[str, list[registry.RunArtifact]] = {}
for art in artifacts:
grouped.setdefault(art.benchmark, []).append(art)
sections: list[registry.ReportSection] = []
for benchmark in benchmarks:
if benchmark.name not in grouped:
continue
sections.append(benchmark.report_section(grouped[benchmark.name]))
summary_path = write_report(
config=config,
suite=args.suite,
sections=sections,
run_timestamp=utc_iso_timestamp(),
)
console.print(f"[green]report OK[/green] → {summary_path}")
return 0
def _collect_artifacts(
config: Config, suite: str, benchmark_names: list[str]
) -> list[registry.RunArtifact]:
"""Walk ``data/<suite>/runs/*/<benchmark>/`` for the latest artifacts.
Reads any ``run_artifact.json`` written by a benchmark runner. The
runner is responsible for writing this manifest alongside its raw
JSONL so the report writer doesn't have to know benchmark-specific
metric shapes.
"""
runs_dir = config.suite_runs_dir(suite)
if not runs_dir.exists():
return []
artifacts: list[registry.RunArtifact] = []
by_bench: dict[str, registry.RunArtifact] = {}
for ts_dir in sorted(runs_dir.iterdir()):
if not ts_dir.is_dir():
continue
for bench_name in benchmark_names:
bench_dir = ts_dir / bench_name
manifest = bench_dir / "run_artifact.json"
if not manifest.exists():
continue
try:
with manifest.open("r", encoding="utf-8") as fh:
payload = json.load(fh)
except (OSError, json.JSONDecodeError):
continue
artifact = registry.RunArtifact(
suite=suite,
benchmark=bench_name,
run_timestamp=ts_dir.name,
raw_path=bench_dir / payload.get("raw_path", "raw.jsonl"),
metrics=payload.get("metrics", {}),
extra=payload.get("extra", {}),
)
# Latest run wins per benchmark.
by_bench[bench_name] = artifact
artifacts = list(by_bench.values())
return artifacts
# ---------------------------------------------------------------------------
# Argparse wiring
# ---------------------------------------------------------------------------
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="surfsense-evals",
description="SurfSense evaluation harness — domain-agnostic core + pluggable suites.",
)
parser.add_argument(
"--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
)
sub = parser.add_subparsers(dest="command", required=True)
p_setup = sub.add_parser("setup", help="Create per-suite SearchSpace + pin LLM.")
p_setup.add_argument("--suite", required=True)
p_setup.add_argument(
"--provider-model",
required=True,
help=(
"OpenRouter slug for the SurfSense answer LLM (and the native arm "
"too unless --native-arm-model is set), e.g. "
"'anthropic/claude-sonnet-4.5'."
),
)
p_setup.add_argument(
"--agent-llm-id",
type=int,
default=None,
help="Optional override for BYOK NewLLMConfig rows.",
)
p_setup.add_argument(
"--scenario",
choices=SCENARIOS,
default=DEFAULT_SCENARIO,
help=(
"head-to-head (default): both arms answer with --provider-model; "
"symmetric-cheap: both arms use the same cheap text-only slug, "
"SurfSense pre-extracted images at ingest with a vision LLM; "
"cost-arbitrage: native arm uses --native-arm-model (vision), "
"SurfSense uses --provider-model (cheap, text-only) over chunks "
"the vision LLM already extracted at ingest."
),
)
p_setup.add_argument(
"--vision-llm",
default=None,
metavar="SLUG",
help=(
"OpenRouter slug for the vision LLM SurfSense uses at ingest "
"when --use-vision-llm is on. If omitted in symmetric-cheap / "
"cost-arbitrage, the strongest registered vision config is "
"auto-picked (priority: claude-sonnet-4.5 > claude-opus-4.7 > "
"gpt-5 > gemini-2.5-pro)."
),
)
p_setup.add_argument(
"--native-arm-model",
default=None,
metavar="SLUG",
help=(
"Required for --scenario cost-arbitrage. OpenRouter slug used "
"by the native_pdf arm only; SurfSense answers with "
"--provider-model. Ignored for head-to-head / symmetric-cheap."
),
)
p_setup.add_argument(
"--no-vision-llm-setup",
action="store_true",
help=(
"Skip attaching a vision LLM config to the SearchSpace even if "
"the scenario would normally require one. Use when you want to "
"keep whatever is already attached (e.g. a per-user config)."
),
)
p_setup.set_defaults(_func=_cmd_setup, _async=True)
p_teardown = sub.add_parser("teardown", help="Soft-delete the suite SearchSpace + clear state slot.")
p_teardown.add_argument("--suite", required=True)
p_teardown.set_defaults(_func=_cmd_teardown, _async=True)
p_models = sub.add_parser("models", help="LLM-config discovery helpers.")
models_sub = p_models.add_subparsers(dest="subcommand", required=True)
p_models_list = models_sub.add_parser("list", help="List global LLM configs.")
p_models_list.add_argument("--provider", default=None, help="Filter by provider, e.g. openrouter")
p_models_list.add_argument("--grep", default=None, help="Substring filter on name / model_name.")
p_models_list.set_defaults(_func=_cmd_models_list, _async=True)
p_suites = sub.add_parser("suites", help="List registered suites.")
suites_sub = p_suites.add_subparsers(dest="subcommand", required=True)
p_suites_list = suites_sub.add_parser("list", help="List suites.")
p_suites_list.set_defaults(_func=_cmd_suites_list, _async=False)
p_benchmarks = sub.add_parser("benchmarks", help="List registered benchmarks.")
bench_sub = p_benchmarks.add_subparsers(dest="subcommand", required=True)
p_bench_list = bench_sub.add_parser("list", help="List benchmarks.")
p_bench_list.add_argument("--suite", default=None)
p_bench_list.set_defaults(_func=_cmd_benchmarks_list, _async=False)
# Dynamic ingest / run subcommands need the registry populated, so
# discover up-front (cheap on import — modules just register).
_discover_suites()
p_ingest = sub.add_parser("ingest", help="Ingest a benchmark's corpus.")
ingest_sub = p_ingest.add_subparsers(dest="suite", required=True)
for suite in registry.list_suites():
suite_parser = ingest_sub.add_parser(suite, help=f"Ingest a {suite} benchmark.")
suite_bench = suite_parser.add_subparsers(dest="benchmark", required=True)
for benchmark in registry.list_benchmarks(suite):
bp = suite_bench.add_parser(benchmark.name, help=getattr(benchmark, "description", benchmark.name))
if hasattr(benchmark, "add_run_args"):
benchmark.add_run_args(bp)
bp.set_defaults(_func=_cmd_ingest, _async=True)
p_run = sub.add_parser("run", help="Run a benchmark.")
run_sub = p_run.add_subparsers(dest="suite", required=True)
for suite in registry.list_suites():
suite_parser = run_sub.add_parser(suite, help=f"Run a {suite} benchmark.")
suite_bench = suite_parser.add_subparsers(dest="benchmark", required=True)
for benchmark in registry.list_benchmarks(suite):
bp = suite_bench.add_parser(benchmark.name, help=getattr(benchmark, "description", benchmark.name))
if hasattr(benchmark, "add_run_args"):
benchmark.add_run_args(bp)
bp.set_defaults(_func=_cmd_run, _async=True)
p_report = sub.add_parser("report", help="Aggregate latest run artifacts into a summary.")
p_report.add_argument("--suite", required=True)
p_report.add_argument("--benchmark", default=None, help="Optional: report only this benchmark.")
p_report.set_defaults(_func=_cmd_report, _async=True)
return parser
def main(argv: list[str] | None = None) -> int:
parser = _build_parser()
args = parser.parse_args(argv)
logging.basicConfig(
level=getattr(logging, args.log_level),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
func = getattr(args, "_func", None)
if func is None:
parser.print_help()
return 2
is_async = getattr(args, "_async", False)
try:
if is_async:
return asyncio.run(func(args))
return func(args)
except KeyboardInterrupt:
console.print("[yellow]Interrupted.[/yellow]")
return 130
except Exception as exc: # noqa: BLE001
logger.exception("CLI command failed")
console.print(f"[red]Command failed: {exc}[/red]")
return 1
if __name__ == "__main__": # pragma: no cover
sys.exit(main())

View file

@ -0,0 +1,14 @@
"""HTTP clients for the SurfSense API. All share one ``httpx.AsyncClient``."""
from __future__ import annotations
from .documents import DocumentsClient
from .new_chat import NewChatClient, StreamedAnswer
from .search_space import SearchSpaceClient
__all__ = [
"DocumentsClient",
"NewChatClient",
"SearchSpaceClient",
"StreamedAnswer",
]

View file

@ -0,0 +1,277 @@
"""Client for ``/api/v1/documents/{fileupload,status,{id}/chunks}``.
Verified against:
* ``surfsense_backend/app/routes/documents_routes.py:122-292`` (POST fileupload)
* ``surfsense_backend/app/routes/documents_routes.py:806-871`` (GET status batch)
* ``surfsense_backend/app/routes/documents_routes.py:1062-1128`` (GET {id}/chunks paginated)
Document processing is asynchronous:
* ``POST /documents/fileupload`` returns immediately with
``document_ids`` in ``pending``;
* a Celery worker moves each through ``processing ready/failed``;
* the harness polls ``GET /documents/status?document_ids=...`` until
every doc is ``ready`` (otherwise the retriever sees an empty corpus
and accuracy numbers are meaningless).
"""
from __future__ import annotations
import asyncio
import logging
import mimetypes
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import httpx
logger = logging.getLogger(__name__)
@dataclass
class FileUploadResult:
"""Mirrors the JSON returned by ``POST /documents/fileupload``."""
document_ids: list[int]
duplicate_document_ids: list[int]
total_files: int
pending_files: int
skipped_duplicates: int
message: str = ""
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> FileUploadResult:
return cls(
document_ids=[int(x) for x in payload.get("document_ids", [])],
duplicate_document_ids=[int(x) for x in payload.get("duplicate_document_ids", [])],
total_files=int(payload.get("total_files", 0)),
pending_files=int(payload.get("pending_files", 0)),
skipped_duplicates=int(payload.get("skipped_duplicates", 0)),
message=str(payload.get("message", "")),
)
@dataclass
class DocumentStatus:
document_id: int
title: str
document_type: str
state: str
reason: str | None = None
@property
def is_ready(self) -> bool:
return self.state == "ready"
@property
def is_failed(self) -> bool:
return self.state == "failed"
@dataclass
class ChunkRow:
id: int
document_id: int
content: str = ""
raw: dict[str, Any] = field(default_factory=dict)
class DocumentProcessingFailed(RuntimeError):
"""Raised when a polled document lands in ``failed``."""
def __init__(self, statuses: Sequence[DocumentStatus]) -> None:
details = ", ".join(
f"id={s.document_id} ({s.title!r}): {s.reason or 'unknown'}"
for s in statuses
)
super().__init__(f"Document(s) failed to process: {details}")
self.statuses = list(statuses)
class DocumentProcessingTimeout(RuntimeError):
"""Raised when polling exceeds the per-doc timeout budget."""
class DocumentsClient:
"""Document upload + status polling + chunk listing."""
def __init__(self, http: httpx.AsyncClient, base_url: str) -> None:
self._http = http
self._base = base_url.rstrip("/")
# ------------------------------------------------------------------
# upload
# ------------------------------------------------------------------
async def upload(
self,
files: Iterable[Path],
*,
search_space_id: int,
should_summarize: bool = False,
use_vision_llm: bool = False,
processing_mode: str = "basic",
) -> FileUploadResult:
"""Upload files to ``/api/v1/documents/fileupload``.
``files`` is materialised to a list because we may need to
re-read on retry. Caller is responsible for ensuring each path
exists and respects the per-file size cap (50 MB backend default).
"""
materialised = [Path(p) for p in files]
if not materialised:
return FileUploadResult(
document_ids=[],
duplicate_document_ids=[],
total_files=0,
pending_files=0,
skipped_duplicates=0,
message="No files supplied",
)
opened: list[tuple[str, Any]] = []
try:
for path in materialised:
# ``open`` directly — httpx wraps it in MultipartStream.
file_obj = path.open("rb")
mime, _ = mimetypes.guess_type(path.name)
opened.append(
(
"files",
(path.name, file_obj, mime or "application/octet-stream"),
)
)
response = await self._http.post(
f"{self._base}/api/v1/documents/fileupload",
data={
"search_space_id": str(search_space_id),
"should_summarize": "true" if should_summarize else "false",
"use_vision_llm": "true" if use_vision_llm else "false",
"processing_mode": processing_mode,
},
files=opened,
# Multipart uploads can be slow for big PDFs; bump per-call.
timeout=httpx.Timeout(120.0, connect=10.0),
)
finally:
for _, (_, file_obj, _) in opened:
try:
file_obj.close()
except Exception: # noqa: BLE001
pass
response.raise_for_status()
return FileUploadResult.from_payload(response.json())
# ------------------------------------------------------------------
# status polling
# ------------------------------------------------------------------
async def get_status(
self, *, search_space_id: int, document_ids: Sequence[int]
) -> list[DocumentStatus]:
if not document_ids:
return []
response = await self._http.get(
f"{self._base}/api/v1/documents/status",
params={
"search_space_id": search_space_id,
"document_ids": ",".join(str(d) for d in document_ids),
},
headers={"Accept": "application/json"},
)
response.raise_for_status()
payload = response.json()
return [
DocumentStatus(
document_id=int(item["id"]),
title=str(item.get("title", "")),
document_type=str(item.get("document_type", "")),
state=str((item.get("status") or {}).get("state", "ready")),
reason=(item.get("status") or {}).get("reason"),
)
for item in payload.get("items", [])
]
async def wait_until_ready(
self,
*,
search_space_id: int,
document_ids: Sequence[int],
timeout_s: float = 300.0,
initial_poll_s: float = 1.0,
max_poll_s: float = 10.0,
) -> list[DocumentStatus]:
"""Poll ``GET /documents/status`` until every doc is ``ready``.
Exponential backoff from ``initial_poll_s`` up to ``max_poll_s``.
Raises ``DocumentProcessingFailed`` if any doc lands in
``failed`` (with the offending document ids), or
``DocumentProcessingTimeout`` if the budget is exhausted.
"""
if not document_ids:
return []
deadline = asyncio.get_event_loop().time() + timeout_s
poll = initial_poll_s
while True:
statuses = await self.get_status(
search_space_id=search_space_id, document_ids=document_ids
)
failed = [s for s in statuses if s.is_failed]
if failed:
raise DocumentProcessingFailed(failed)
ready = [s for s in statuses if s.is_ready]
if len(ready) == len(document_ids):
return statuses
now = asyncio.get_event_loop().time()
if now >= deadline:
pending = [s for s in statuses if not s.is_ready and not s.is_failed]
pending_ids = [s.document_id for s in pending]
raise DocumentProcessingTimeout(
f"Timed out after {timeout_s:.0f}s waiting for documents "
f"(still pending/processing: {pending_ids})"
)
await asyncio.sleep(min(poll, max(0.1, deadline - now)))
poll = min(poll * 1.5, max_poll_s)
# ------------------------------------------------------------------
# chunks (chunk_id -> document_id map)
# ------------------------------------------------------------------
async def list_chunks(
self, document_id: int, *, page_size: int = 100
) -> list[ChunkRow]:
"""Walk ``GET /documents/{id}/chunks`` until ``has_more=False``.
Used by ingestion to materialise the ``chunk_id -> document_id``
map needed for retrieval scoring (CUREv1).
"""
rows: list[ChunkRow] = []
page = 0
while True:
response = await self._http.get(
f"{self._base}/api/v1/documents/{document_id}/chunks",
params={"page": page, "page_size": page_size},
headers={"Accept": "application/json"},
)
response.raise_for_status()
payload = response.json()
for item in payload.get("items", []):
rows.append(
ChunkRow(
id=int(item["id"]),
document_id=document_id,
content=str(item.get("content", "")),
raw=item,
)
)
if not payload.get("has_more"):
break
page += 1
return rows

View file

@ -0,0 +1,280 @@
"""Client for ``/api/v1/threads`` and ``/api/v1/new_chat`` (SSE).
Verified against:
* ``surfsense_backend/app/routes/new_chat_routes.py:793-848`` (POST /threads)
* ``surfsense_backend/app/routes/new_chat_routes.py:1073-1142`` (DELETE /threads/{id})
* ``surfsense_backend/app/routes/new_chat_routes.py:1689-1800`` (POST /new_chat SSE)
* ``surfsense_backend/app/routes/new_chat_routes.py:191-220`` (THREAD_BUSY / TURN_CANCELLING 409)
* ``surfsense_backend/app/services/streaming/envelope/sse.py`` (wire framing)
* ``surfsense_backend/app/services/streaming/events/text.py`` (text-delta events)
* ``surfsense_backend/app/schemas/new_chat.py:234-288`` (NewChatRequest body)
The wire format is "Vercel AI SDK"-flavoured SSE with one event per
``data: <json>\n\n`` block (or the literal ``data: [DONE]\n\n``
terminator). Text deltas arrive as ``{"type":"text-delta","id":...,"delta":...}``
events; we accumulate them per ``id`` and emit the final concatenated
text plus parsed citations.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from collections.abc import AsyncIterator, Sequence
from dataclasses import dataclass, field
from typing import Any
import httpx
from ..parse import iter_sse_events, parse_citations
logger = logging.getLogger(__name__)
@dataclass
class StreamedAnswer:
"""Result of a single ``/new_chat`` turn."""
text: str
raw_events: list[dict[str, Any]] = field(default_factory=list)
latency_ms: int = 0
user_message_id: str | None = None
assistant_message_id: str | None = None
finished_normally: bool = False
@property
def citations(self) -> list[dict[str, Any]]:
"""Parsed citation tokens (lazy; small enough to recompute)."""
return [token.to_dict() for token in parse_citations(self.text)]
class ThreadBusyError(RuntimeError):
"""Raised after exhausting retries on a 409 ``THREAD_BUSY`` / ``TURN_CANCELLING``."""
def __init__(self, error_code: str, message: str) -> None:
super().__init__(f"{error_code}: {message}")
self.error_code = error_code
class NewChatClient:
"""Thread create / delete / SSE ask."""
def __init__(self, http: httpx.AsyncClient, base_url: str) -> None:
self._http = http
self._base = base_url.rstrip("/")
# ------------------------------------------------------------------
# threads
# ------------------------------------------------------------------
async def create_thread(
self,
*,
search_space_id: int,
title: str = "eval",
archived: bool = False,
visibility: str = "PRIVATE",
) -> int:
response = await self._http.post(
f"{self._base}/api/v1/threads",
json={
"search_space_id": search_space_id,
"title": title,
"archived": archived,
"visibility": visibility,
},
headers={"Accept": "application/json"},
)
response.raise_for_status()
payload = response.json()
return int(payload["id"])
async def delete_thread(self, thread_id: int) -> None:
response = await self._http.delete(
f"{self._base}/api/v1/threads/{thread_id}",
headers={"Accept": "application/json"},
)
if response.status_code == 404:
return # idempotent
response.raise_for_status()
# ------------------------------------------------------------------
# /new_chat SSE
# ------------------------------------------------------------------
async def ask(
self,
*,
thread_id: int,
search_space_id: int,
user_query: str,
mentioned_document_ids: Sequence[int] | None = None,
disabled_tools: Sequence[str] | None = None,
max_busy_retries: int = 4,
timeout_s: float = 600.0,
) -> StreamedAnswer:
"""Stream a single turn and return the accumulated answer.
Honours backend ``THREAD_BUSY`` / ``TURN_CANCELLING`` 409
responses by sleeping for the ``Retry-After`` header (or the
``retry-after-ms`` header if present) and replaying. Bounded
by ``max_busy_retries`` so a stuck thread never blocks the
whole run.
"""
body: dict[str, Any] = {
"chat_id": thread_id,
"search_space_id": search_space_id,
"user_query": user_query,
}
if mentioned_document_ids:
body["mentioned_document_ids"] = list(mentioned_document_ids)
if disabled_tools:
body["disabled_tools"] = list(disabled_tools)
attempt = 0
while True:
try:
return await self._stream_once(body=body, timeout_s=timeout_s)
except ThreadBusyError as exc:
attempt += 1
if attempt > max_busy_retries:
raise
# Cap wait at 30s; backend retry hint is exponential anyway.
wait = min(30.0, 0.5 * (2 ** attempt))
logger.info(
"thread_id=%s busy (%s); retry %d/%d after %.1fs",
thread_id,
exc.error_code,
attempt,
max_busy_retries,
wait,
)
await asyncio.sleep(wait)
async def _stream_once(
self,
*,
body: dict[str, Any],
timeout_s: float,
) -> StreamedAnswer:
# Per-call timeout — the connect should be quick, the read needs
# to outlive the longest LLM completion.
timeout = httpx.Timeout(timeout_s, connect=10.0)
started = time.monotonic()
async with self._http.stream(
"POST",
f"{self._base}/api/v1/new_chat",
json=body,
headers={"Accept": "text/event-stream"},
timeout=timeout,
) as response:
if response.status_code == 409:
detail = await self._extract_busy_detail(response)
raise ThreadBusyError(
error_code=detail.get("errorCode", "THREAD_BUSY"),
message=detail.get("message", "Thread is busy"),
)
response.raise_for_status()
answer = await self._consume_sse(response)
answer.latency_ms = int((time.monotonic() - started) * 1000)
return answer
@staticmethod
async def _extract_busy_detail(response: httpx.Response) -> dict[str, Any]:
try:
payload = json.loads(await response.aread())
except (json.JSONDecodeError, ValueError):
return {"errorCode": "THREAD_BUSY", "message": response.text}
if isinstance(payload, dict) and isinstance(payload.get("detail"), dict):
return payload["detail"]
return payload if isinstance(payload, dict) else {}
@staticmethod
async def _consume_sse(response: httpx.Response) -> StreamedAnswer:
"""Walk SSE events, accumulate text-delta payloads.
Backend events of interest:
* ``{"type": "text-start", "id": ...}``
* ``{"type": "text-delta", "id": ..., "delta": ...}``
* ``{"type": "text-end", "id": ...}``
* ``{"type": "start", "messageId": ...}`` (top-level message id)
* ``{"type": "finish"}``
* literal ``[DONE]`` sentinel
Multiple ``text-start`` blocks can interleave each gets its
own ``id`` and we concatenate them in arrival order. That
mirrors the AI SDK client behaviour: one continuous assistant
message visible to the user.
"""
ordered_text_ids: list[str] = []
text_buffers: dict[str, list[str]] = {}
raw_events: list[dict[str, Any]] = []
user_message_id: str | None = None
assistant_message_id: str | None = None
finished = False
async for event in iter_sse_events(_aiter_lines(response)):
data = event.data
if data == "[DONE]":
finished = True
continue
try:
payload = json.loads(data)
except (json.JSONDecodeError, ValueError):
logger.debug("Skipping non-JSON SSE payload: %r", data[:120])
continue
if not isinstance(payload, dict):
continue
raw_events.append(payload)
ev_type = payload.get("type")
if ev_type == "text-delta":
tid = str(payload.get("id", ""))
delta = payload.get("delta", "")
if not isinstance(delta, str):
continue
if tid not in text_buffers:
text_buffers[tid] = []
ordered_text_ids.append(tid)
text_buffers[tid].append(delta)
elif ev_type == "text-start":
tid = str(payload.get("id", ""))
if tid and tid not in text_buffers:
text_buffers[tid] = []
ordered_text_ids.append(tid)
elif ev_type == "start":
msg_id = payload.get("messageId")
if isinstance(msg_id, str):
user_message_id = user_message_id or msg_id
elif ev_type == "data-user-message-id":
msg_id = (payload.get("data") or {}).get("id") or payload.get("id")
if isinstance(msg_id, str):
user_message_id = msg_id
elif ev_type == "data-assistant-message-id":
msg_id = (payload.get("data") or {}).get("id") or payload.get("id")
if isinstance(msg_id, str):
assistant_message_id = msg_id
elif ev_type == "finish":
finished = True
text = "".join("".join(text_buffers.get(tid, [])) for tid in ordered_text_ids)
return StreamedAnswer(
text=text,
raw_events=raw_events,
user_message_id=user_message_id,
assistant_message_id=assistant_message_id,
finished_normally=finished,
)
async def _aiter_lines(response: httpx.Response) -> AsyncIterator[str]:
"""Adapter so the parser can consume any line iterator (mockable in tests)."""
async for line in response.aiter_lines():
yield line

View file

@ -0,0 +1,207 @@
"""Client for ``/api/v1/searchspaces`` and ``/api/v1/search-spaces/{id}/llm-preferences``.
Verified against:
* ``surfsense_backend/app/routes/search_spaces_routes.py:116`` (POST create)
* ``surfsense_backend/app/routes/search_spaces_routes.py:234`` (GET by id)
* ``surfsense_backend/app/routes/search_spaces_routes.py:422`` (DELETE soft-delete)
* ``surfsense_backend/app/routes/search_spaces_routes.py:698-849`` (GET/PUT llm-preferences)
* ``surfsense_backend/app/schemas/search_space.py:14`` (SearchSpaceCreate body)
* ``surfsense_backend/app/routes/vision_llm_routes.py:60`` (GET global vision configs)
Note the inconsistent pluralisation in the backend: ``/searchspaces``
(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for the
``llm-preferences`` sub-resource. Both are mirrored verbatim here.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import httpx
@dataclass
class SearchSpaceRow:
"""Subset of the SearchSpace row we care about."""
id: int
name: str
description: str | None
user_id: str
citations_enabled: bool
qna_custom_instructions: str | None
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> SearchSpaceRow:
return cls(
id=int(payload["id"]),
name=str(payload["name"]),
description=payload.get("description"),
user_id=str(payload.get("user_id", "")),
citations_enabled=bool(payload.get("citations_enabled", True)),
qna_custom_instructions=payload.get("qna_custom_instructions"),
)
@dataclass
class VisionLlmConfigEntry:
"""Subset of one ``GET /global-vision-llm-configs`` row.
The backend returns negative ids for global / OpenRouter-derived
vision configs and positive ids for per-user BYOK rows. Either is
accepted by ``set_llm_preferences(vision_llm_config_id=...)``.
"""
id: int
name: str
provider: str
model_name: str
is_auto_mode: bool
raw: dict[str, Any]
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> VisionLlmConfigEntry:
return cls(
id=int(payload.get("id", 0)),
name=str(payload.get("name", "")),
provider=str(payload.get("provider", "")).upper(),
model_name=str(payload.get("model_name", "")),
is_auto_mode=bool(payload.get("is_auto_mode", False)),
raw=payload,
)
@dataclass
class LlmPreferences:
"""Resolved LLM preferences with the embedded full config row.
Mirrors ``LLMPreferencesRead`` from the backend so the lifecycle
command can introspect ``provider`` / ``model_name`` to validate the
OpenRouter pin.
"""
agent_llm_id: int | None
document_summary_llm_id: int | None
image_generation_config_id: int | None
vision_llm_config_id: int | None
agent_llm: dict[str, Any] | None
raw: dict[str, Any]
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> LlmPreferences:
return cls(
agent_llm_id=payload.get("agent_llm_id"),
document_summary_llm_id=payload.get("document_summary_llm_id"),
image_generation_config_id=payload.get("image_generation_config_id"),
vision_llm_config_id=payload.get("vision_llm_config_id"),
agent_llm=payload.get("agent_llm"),
raw=payload,
)
class SearchSpaceClient:
"""Thin wrapper around the SearchSpace + LLM preferences endpoints."""
def __init__(self, http: httpx.AsyncClient, base_url: str) -> None:
self._http = http
self._base = base_url.rstrip("/")
async def create(self, name: str, *, description: str | None = None) -> SearchSpaceRow:
body: dict[str, Any] = {"name": name}
if description is not None:
body["description"] = description
# citations_enabled defaults to True backend-side; keep that default.
response = await self._http.post(
f"{self._base}/api/v1/searchspaces",
json=body,
headers={"Accept": "application/json"},
)
response.raise_for_status()
return SearchSpaceRow.from_payload(response.json())
async def get(self, search_space_id: int) -> SearchSpaceRow:
response = await self._http.get(
f"{self._base}/api/v1/searchspaces/{search_space_id}",
headers={"Accept": "application/json"},
)
response.raise_for_status()
return SearchSpaceRow.from_payload(response.json())
async def delete(self, search_space_id: int) -> None:
"""Soft-delete: backend prefixes name with ``[DELETING]`` and dispatches a Celery cascade."""
response = await self._http.delete(
f"{self._base}/api/v1/searchspaces/{search_space_id}",
headers={"Accept": "application/json"},
)
# 404 means it's already gone — treat as success (idempotent teardown).
if response.status_code == 404:
return
response.raise_for_status()
async def get_llm_preferences(self, search_space_id: int) -> LlmPreferences:
response = await self._http.get(
f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences",
headers={"Accept": "application/json"},
)
response.raise_for_status()
return LlmPreferences.from_payload(response.json())
async def set_llm_preferences(
self,
search_space_id: int,
*,
agent_llm_id: int | None = None,
document_summary_llm_id: int | None = None,
image_generation_config_id: int | None = None,
vision_llm_config_id: int | None = None,
) -> LlmPreferences:
"""PUT a partial update to ``/search-spaces/{id}/llm-preferences``.
Backend uses ``model_dump(exclude_unset=True)`` so omitted fields
are left unchanged.
"""
body: dict[str, Any] = {}
if agent_llm_id is not None:
body["agent_llm_id"] = agent_llm_id
if document_summary_llm_id is not None:
body["document_summary_llm_id"] = document_summary_llm_id
if image_generation_config_id is not None:
body["image_generation_config_id"] = image_generation_config_id
if vision_llm_config_id is not None:
body["vision_llm_config_id"] = vision_llm_config_id
response = await self._http.put(
f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences",
json=body,
headers={"Accept": "application/json"},
)
response.raise_for_status()
return LlmPreferences.from_payload(response.json())
async def list_global_vision_llm_configs(self) -> list[VisionLlmConfigEntry]:
"""List the registered global vision LLM configs.
Used by ``setup`` to (a) resolve an explicit ``--vision-llm <slug>``
to a config id and (b) auto-pick the strongest registered vision
config when the operator doesn't pass one. The ``Auto (Fastest)``
entry (``id=0``) is filtered out accuracy must be reproducible.
"""
response = await self._http.get(
f"{self._base}/api/v1/global-vision-llm-configs",
headers={"Accept": "application/json"},
)
response.raise_for_status()
payload = response.json()
if not isinstance(payload, list):
raise RuntimeError(
f"Unexpected /global-vision-llm-configs payload: {payload!r}"
)
return [
VisionLlmConfigEntry.from_payload(item)
for item in payload
if not bool(item.get("is_auto_mode", False))
]

View file

@ -0,0 +1,279 @@
"""Environment + filesystem configuration for the harness.
Two responsibilities:
1. Load env vars (with sensible defaults) into a single immutable ``Config``
so that every other module reads it from one place.
2. Read / write ``data/state.json``. State is keyed by suite name so multiple
suites can be set up in parallel and torn down independently.
The pinned ``search_space_id`` lives in ``state.json`` (not env) so re-runs
are idempotent without forcing the operator to remember an integer.
"""
from __future__ import annotations
import json
import os
from collections.abc import Mapping
from dataclasses import dataclass, field
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
# Resolve once at import time. ``find_dotenv`` walks up; an explicit ``.env``
# at the package root or in CWD wins. Silent-no-op if neither exists.
load_dotenv()
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
"""Resolves to ``surfsense_evals/`` (the package root, not ``src/``)."""
def _project_root() -> Path:
"""Return the ``surfsense_evals/`` project root.
Computed from this file's path: ``src/surfsense_evals/core/config.py`` →
walk up four levels. Kept as a function so tests can monkeypatch.
"""
return _PROJECT_ROOT
@dataclass(frozen=True)
class Config:
"""Immutable runtime configuration."""
surfsense_api_base: str
openrouter_api_key: str | None
openrouter_base_url: str
# Credentials — exactly ONE mode must be supplied.
surfsense_jwt: str | None
surfsense_refresh_token: str | None
surfsense_user_email: str | None
surfsense_user_password: str | None
# Filesystem paths.
data_dir: Path
reports_dir: Path
@property
def state_path(self) -> Path:
return self.data_dir / "state.json"
def has_jwt_mode(self) -> bool:
return bool(self.surfsense_jwt)
def has_local_mode(self) -> bool:
return bool(self.surfsense_user_email and self.surfsense_user_password)
def credential_mode(self) -> str:
"""Return ``"jwt"``, ``"local"``, or ``"none"`` (no credentials supplied)."""
if self.has_jwt_mode():
return "jwt"
if self.has_local_mode():
return "local"
return "none"
def suite_data_dir(self, suite: str) -> Path:
return self.data_dir / suite
def suite_reports_dir(self, suite: str) -> Path:
return self.reports_dir / suite
def suite_runs_dir(self, suite: str) -> Path:
return self.suite_data_dir(suite) / "runs"
def suite_maps_dir(self, suite: str) -> Path:
return self.suite_data_dir(suite) / "maps"
def load_config() -> Config:
"""Read the current process env into a ``Config``.
No validation is performed here; callers (e.g. ``auth.acquire_token``,
``cli`` subcommands) decide which fields they require. This keeps
``models list`` and ``suites list`` runnable without OpenRouter creds.
"""
project_root = _project_root()
data_dir = Path(os.environ.get("EVAL_DATA_DIR") or (project_root / "data")).resolve()
reports_dir = Path(os.environ.get("EVAL_REPORTS_DIR") or (project_root / "reports")).resolve()
return Config(
surfsense_api_base=os.environ.get("SURFSENSE_API_BASE", "http://localhost:8000").rstrip("/"),
openrouter_api_key=os.environ.get("OPENROUTER_API_KEY") or None,
openrouter_base_url=os.environ.get(
"OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"
).rstrip("/"),
surfsense_jwt=os.environ.get("SURFSENSE_JWT") or None,
surfsense_refresh_token=os.environ.get("SURFSENSE_REFRESH_TOKEN") or None,
surfsense_user_email=os.environ.get("SURFSENSE_USER_EMAIL") or None,
surfsense_user_password=os.environ.get("SURFSENSE_USER_PASSWORD") or None,
data_dir=data_dir,
reports_dir=reports_dir,
)
# ---------------------------------------------------------------------------
# state.json — per-suite slots
# ---------------------------------------------------------------------------
# Scenario names — chosen at ``setup`` time, persisted in ``state.json``.
#
# * ``head-to-head`` (default, current behaviour): both arms answer with the
# SAME slug pinned via ``--provider-model``. Vision LLM at ingest is
# optional but recommended for image-bearing benchmarks.
# * ``symmetric-cheap``: both arms answer with the SAME (cheap, text-only)
# slug; SurfSense pre-extracted images at ingest with a vision LLM.
# Measures whether vision-RAG ingestion lets a cheap downstream model
# match a vision one. Native arm structurally loses on image questions —
# that's the point, and the report labels it accordingly.
# * ``cost-arbitrage``: native arm answers with an EXPENSIVE vision slug
# (``--native-arm-model``), SurfSense answers with a CHEAP text-only slug
# (``--provider-model``) over chunks the vision LLM already extracted at
# ingest. Measures how close SurfSense gets to native at a fraction of
# the per-query cost. The most compelling "shines" framing.
SCENARIOS: tuple[str, ...] = ("head-to-head", "symmetric-cheap", "cost-arbitrage")
DEFAULT_SCENARIO: str = "head-to-head"
@dataclass
class SuiteState:
"""Per-suite persisted state.
``provider_model`` is the slug pinned to the SearchSpace's
``agent_llm`` what answers SurfSense queries (and what the native
arm uses too, unless ``native_arm_model`` is set for cost-arbitrage).
``vision_provider_model`` is the slug of the OpenRouter vision LLM
config attached to the SearchSpace's ``vision_llm_config_id`` — what
SurfSense uses to extract image content at ingest time when
``use_vision_llm=True``. ``None`` means no vision config was attached
at setup (legacy or text-only suite).
"""
search_space_id: int
agent_llm_id: int
provider_model: str
created_at: str
ingestion_maps: dict[str, str] = field(default_factory=dict)
scenario: str = DEFAULT_SCENARIO
vision_llm_config_id: int | None = None
vision_provider_model: str | None = None
native_arm_model: str | None = None
def to_dict(self) -> dict[str, Any]:
return {
"search_space_id": self.search_space_id,
"agent_llm_id": self.agent_llm_id,
"provider_model": self.provider_model,
"created_at": self.created_at,
"ingestion_maps": dict(self.ingestion_maps),
"scenario": self.scenario,
"vision_llm_config_id": self.vision_llm_config_id,
"vision_provider_model": self.vision_provider_model,
"native_arm_model": self.native_arm_model,
}
@classmethod
def from_dict(cls, payload: Mapping[str, Any]) -> SuiteState:
# ``scenario`` / vision / native fields default for back-compat with
# ``state.json`` written before scenarios shipped.
scenario = str(payload.get("scenario") or DEFAULT_SCENARIO)
if scenario not in SCENARIOS:
scenario = DEFAULT_SCENARIO
raw_vision_id = payload.get("vision_llm_config_id")
return cls(
search_space_id=int(payload["search_space_id"]),
agent_llm_id=int(payload["agent_llm_id"]),
provider_model=str(payload["provider_model"]),
created_at=str(payload.get("created_at") or ""),
ingestion_maps=dict(payload.get("ingestion_maps") or {}),
scenario=scenario,
vision_llm_config_id=int(raw_vision_id) if raw_vision_id is not None else None,
vision_provider_model=(
str(payload["vision_provider_model"])
if payload.get("vision_provider_model")
else None
),
native_arm_model=(
str(payload["native_arm_model"])
if payload.get("native_arm_model")
else None
),
)
@property
def effective_native_arm_model(self) -> str:
"""Slug the native arm should use; falls back to ``provider_model``."""
return self.native_arm_model or self.provider_model
def _load_state(config: Config) -> dict[str, Any]:
if not config.state_path.exists():
return {"suites": {}}
try:
with config.state_path.open("r", encoding="utf-8") as fh:
data = json.load(fh)
except (OSError, json.JSONDecodeError) as exc:
raise RuntimeError(
f"Failed to read state file {config.state_path}: {exc!s}. "
"Delete it if you want to start fresh."
) from exc
if not isinstance(data, dict) or "suites" not in data:
return {"suites": {}}
return data
def _write_state(config: Config, payload: Mapping[str, Any]) -> None:
config.data_dir.mkdir(parents=True, exist_ok=True)
tmp = config.state_path.with_suffix(".json.tmp")
with tmp.open("w", encoding="utf-8") as fh:
json.dump(dict(payload), fh, indent=2, sort_keys=True)
fh.write("\n")
tmp.replace(config.state_path)
def get_suite_state(config: Config, suite: str) -> SuiteState | None:
"""Return ``SuiteState`` for ``suite`` or ``None`` if not set up."""
state = _load_state(config)
raw = (state.get("suites") or {}).get(suite)
if not raw:
return None
return SuiteState.from_dict(raw)
def set_suite_state(config: Config, suite: str, suite_state: SuiteState) -> None:
"""Persist ``suite_state`` under the suite slot. Other suites are untouched."""
state = _load_state(config)
suites = dict(state.get("suites") or {})
suites[suite] = suite_state.to_dict()
state["suites"] = suites
_write_state(config, state)
def clear_suite_state(config: Config, suite: str) -> bool:
"""Remove the slot for ``suite``. Returns ``True`` if removal happened."""
state = _load_state(config)
suites = dict(state.get("suites") or {})
if suite not in suites:
return False
del suites[suite]
state["suites"] = suites
_write_state(config, state)
return True
def utc_iso_timestamp() -> str:
"""Filesystem-safe UTC ISO timestamp, e.g. ``2026-05-11T20-30-00Z``."""
return datetime.now(UTC).strftime("%Y-%m-%dT%H-%M-%SZ")

View file

@ -0,0 +1,311 @@
"""Per-upload ingestion settings shared across every benchmark.
The SurfSense ``POST /api/v1/documents/fileupload`` endpoint exposes
exactly three knobs (verified at
``surfsense_backend/app/routes/documents_routes.py`` and
``surfsense_backend/app/etl_pipeline/etl_document.py``):
* ``processing_mode`` ``"basic"`` (default) | ``"premium"``
* ``use_vision_llm`` ``bool`` (run vision LLM during ingest to
extract image content / captions / tables)
* ``should_summarize`` ``bool`` (generate document summary)
This module gives every benchmark a uniform way to:
1. Receive sensible per-benchmark defaults (text-only benchmarks
default vision off; image-bearing benchmarks default vision on).
2. Accept CLI overrides (``--use-vision-llm`` / ``--no-vision-llm``,
``--processing-mode {basic,premium}``,
``--should-summarize`` / ``--no-summarize``).
3. Persist the *actual* settings used into the doc-map manifest and
the run artifact so reports can show "vision=ON, mode=premium →
65% accuracy" head-to-head with "vision=OFF, mode=basic 52%".
A/B testing on the same corpus
------------------------------
SurfSense dedupes uploads by ``(filename, search_space_id)`` NOT by
content hash and NOT by ingestion settings. Re-uploading the same
filename to the same SearchSpace with a different ``use_vision_llm``
flag will hit the duplicate branch and *not* re-process. To compare
two settings combos head-to-head on the same corpus you must give
each combo its own SearchSpace, which today means:
teardown --suite <s>
setup --suite <s> ...
ingest <s> <bench> --no-vision-llm # baseline run
run <s> <bench>
teardown --suite <s>
setup --suite <s> ...
ingest <s> <bench> --use-vision-llm # vision arm
run <s> <bench>
The runs land in different timestamped subdirectories under
``data/<suite>/runs/`` and ``report --suite <s>`` aggregates whichever
manifest is currently latest per benchmark.
"""
from __future__ import annotations
import argparse
import json
from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any
# Keep the constant list of valid processing modes here so benchmarks
# don't have to re-import from the backend (they don't have access to
# the backend package anyway).
PROCESSING_MODES: tuple[str, ...] = ("basic", "premium")
@dataclass(frozen=True)
class IngestSettings:
"""Resolved per-upload knobs handed to ``DocumentsClient.upload``.
Use ``IngestSettings(...)`` directly to define benchmark defaults,
or ``IngestSettings.merge(defaults, opts)`` to apply CLI overrides
on top of those defaults.
"""
use_vision_llm: bool = False
processing_mode: str = "basic"
should_summarize: bool = False
def to_dict(self) -> dict[str, Any]:
return {
"use_vision_llm": self.use_vision_llm,
"processing_mode": self.processing_mode,
"should_summarize": self.should_summarize,
}
@classmethod
def merge(cls, defaults: IngestSettings, opts: Mapping[str, Any]) -> IngestSettings:
"""Apply CLI overrides on top of ``defaults``.
``opts`` is the kwargs dict built by ``core.cli`` from the
argparse namespace (see ``_cmd_ingest`` / ``_cmd_run``). Keys
we look for: ``use_vision_llm`` (bool or None), ``processing_mode``
(str or None), ``should_summarize`` (bool or None). Anything
else is ignored so benchmarks can pass through their own opts.
"""
return cls(
use_vision_llm=_coerce_bool(opts.get("use_vision_llm"), defaults.use_vision_llm),
processing_mode=_coerce_mode(opts.get("processing_mode"), defaults.processing_mode),
should_summarize=_coerce_bool(opts.get("should_summarize"), defaults.should_summarize),
)
def render_label(self) -> str:
"""Human-readable single-line label for reports / log lines."""
return (
f"vision={'on' if self.use_vision_llm else 'off'}, "
f"mode={self.processing_mode}, "
f"summarize={'on' if self.should_summarize else 'off'}"
)
def _coerce_bool(value: Any, default: bool) -> bool:
"""Argparse with ``BooleanOptionalAction`` yields True/False/None.
``None`` means the operator didn't pass the flag → fall back to
the benchmark default.
"""
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "on"}
return bool(value)
def _coerce_mode(value: Any, default: str) -> str:
if value is None or value == "":
return default
val = str(value).strip().lower()
if val not in PROCESSING_MODES:
raise ValueError(
f"Invalid processing_mode {val!r}; must be one of {PROCESSING_MODES}"
)
return val
# ---------------------------------------------------------------------------
# Argparse helper
# ---------------------------------------------------------------------------
def _add_bool_pair(
parser: argparse.ArgumentParser,
*,
dest: str,
on_flag: str,
off_flag: str,
on_help: str,
off_help: str,
) -> None:
"""Add a mutually exclusive ``--foo`` / ``--no-foo`` pair.
We don't use ``argparse.BooleanOptionalAction`` because it would
auto-generate ``--no-use-vision-llm`` rather than the friendlier
``--no-vision-llm`` that operators reach for. Default is ``None``
so ``IngestSettings.merge`` can distinguish "silent" from
"explicit false".
"""
group = parser.add_mutually_exclusive_group()
group.add_argument(
on_flag,
dest=dest,
action="store_true",
default=None,
help=on_help,
)
group.add_argument(
off_flag,
dest=dest,
action="store_false",
default=None,
help=off_help,
)
def add_ingest_settings_args(
parser: argparse.ArgumentParser,
*,
defaults: IngestSettings,
) -> None:
"""Attach the three ingest-settings flag pairs to ``parser``.
Each bool exposes a mutually exclusive ``--foo`` / ``--no-foo``
pair so an operator can flip either direction without restating
every flag. Default is ``None`` so that "operator didn't pass the
flag" is distinguishable from "operator explicitly passed false"
``IngestSettings.merge`` then folds in the benchmark default
only when the operator was silent.
"""
settings_group = parser.add_argument_group(
"ingest settings",
f"Per-upload knobs (forwarded to /documents/fileupload). "
f"Defaults for this benchmark: {defaults.render_label()}.",
)
_add_bool_pair(
settings_group,
dest="use_vision_llm",
on_flag="--use-vision-llm",
off_flag="--no-vision-llm",
on_help=(
"Run vision LLM during ingest to extract image content "
f"(default for this benchmark: "
f"{'on' if defaults.use_vision_llm else 'off'})."
),
off_help="Skip vision LLM during ingest (text-only ETL).",
)
settings_group.add_argument(
"--processing-mode",
dest="processing_mode",
choices=PROCESSING_MODES,
default=None,
help=(
"SurfSense ETL processing mode (premium uses a 10x page "
f"multiplier and typically routes to a stronger ETL). "
f"Default for this benchmark: {defaults.processing_mode!r}."
),
)
_add_bool_pair(
settings_group,
dest="should_summarize",
on_flag="--should-summarize",
off_flag="--no-summarize",
on_help=(
"Have SurfSense generate a document summary at ingest "
f"(default for this benchmark: "
f"{'on' if defaults.should_summarize else 'off'})."
),
off_help="Skip per-document summary generation.",
)
# ---------------------------------------------------------------------------
# Doc-map manifest helpers
# ---------------------------------------------------------------------------
#
# Every benchmark writes a doc-map JSONL under ``data/<suite>/maps/`` that
# pairs source identifiers (case_id, snippet_id, doc_path, …) to the
# SurfSense document_ids returned by the upload. To make the report
# self-describing we also write a header line:
#
# {"__settings__": {"use_vision_llm": ..., "processing_mode": ..., ...}}
#
# These two helpers centralise that protocol so each benchmark only has to
# call ``write_settings_header`` and ``read_settings_header``.
SETTINGS_HEADER_KEY = "__settings__"
def settings_header_line(settings: IngestSettings) -> str:
"""Return the JSON-serialised header line (no trailing newline)."""
return json.dumps({SETTINGS_HEADER_KEY: settings.to_dict()})
def is_settings_header(row: Mapping[str, Any]) -> bool:
return SETTINGS_HEADER_KEY in row
def read_settings_header(map_path: Path) -> dict[str, Any]:
"""Read the ``__settings__`` header out of a doc-map JSONL.
Returns ``{}`` on a missing file, an empty file, an unreadable
file, or a file whose first non-blank line is not a settings
header (e.g. a corpus ingested before this feature existed).
Callers use this purely to surface settings in the report; it
must never fail the run.
"""
if not map_path.exists():
return {}
try:
with map_path.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
row = json.loads(line)
if isinstance(row, dict) and SETTINGS_HEADER_KEY in row:
return dict(row[SETTINGS_HEADER_KEY])
return {}
except (OSError, json.JSONDecodeError):
return {}
return {}
def format_ingest_settings_md(settings: Any) -> str:
"""Render the resolved settings as a single Markdown bullet line."""
if not isinstance(settings, Mapping) or not settings:
return "- SurfSense ingest settings: (not recorded — re-ingest to capture)"
vision = "on" if settings.get("use_vision_llm") else "off"
mode = settings.get("processing_mode") or "basic"
summarize = "on" if settings.get("should_summarize") else "off"
return (
f"- SurfSense ingest settings: vision_llm=`{vision}`, "
f"processing_mode=`{mode}`, summarize=`{summarize}`"
)
__all__ = [
"PROCESSING_MODES",
"SETTINGS_HEADER_KEY",
"IngestSettings",
"add_ingest_settings_args",
"format_ingest_settings_md",
"is_settings_header",
"read_settings_header",
"settings_header_line",
]

View file

@ -0,0 +1,50 @@
"""Pure-function metric primitives. Lazy imports."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover
from .comparison import McnemarResult, bootstrap_delta_ci, mcnemar_test, paired_aggregate
from .mc_accuracy import AccuracyResult, accuracy_with_wilson_ci, wilson_ci
from .retrieval import RetrievalScores, mrr, ndcg_at_k, recall_at_k, score_run
__all__ = [
"AccuracyResult",
"McnemarResult",
"RetrievalScores",
"accuracy_with_wilson_ci",
"bootstrap_delta_ci",
"mcnemar_test",
"mrr",
"ndcg_at_k",
"paired_aggregate",
"recall_at_k",
"score_run",
"wilson_ci",
]
_MODULE_FOR = {
"AccuracyResult": "mc_accuracy",
"accuracy_with_wilson_ci": "mc_accuracy",
"wilson_ci": "mc_accuracy",
"RetrievalScores": "retrieval",
"mrr": "retrieval",
"ndcg_at_k": "retrieval",
"recall_at_k": "retrieval",
"score_run": "retrieval",
"McnemarResult": "comparison",
"bootstrap_delta_ci": "comparison",
"mcnemar_test": "comparison",
"paired_aggregate": "comparison",
}
def __getattr__(name: str):
if name in _MODULE_FOR:
from importlib import import_module
mod = import_module(f".{_MODULE_FOR[name]}", __name__)
return getattr(mod, name)
raise AttributeError(f"module 'surfsense_evals.core.metrics' has no attribute {name!r}")

View file

@ -0,0 +1,258 @@
"""Paired comparison statistics for head-to-head benchmarks.
In every head-to-head benchmark (currently MedXpertQA-MM and
MMLongBench-Doc) each question is answered by both arms (Native PDF
and SurfSense). That makes per-question outcomes paired, so
``McNemar's test`` on the discordant pairs is the right significance
test for "are the two arms different?". We also expose a bootstrap
delta CI for visualising effect size.
Aggregate cost / latency / token deltas are mean-based; the runner
slices them by arm before passing them in.
"""
from __future__ import annotations
import math
import statistics
from collections.abc import Sequence
from dataclasses import dataclass
import numpy as np
@dataclass(frozen=True)
class McnemarResult:
"""Discordant pair counts + the test statistics."""
n_total: int
b: int # native correct, surfsense wrong
c: int # native wrong, surfsense correct
statistic: float
p_value: float
method: str
def to_dict(self) -> dict[str, float | int | str]:
return {
"n_total": self.n_total,
"b_native_correct_only": self.b,
"c_surfsense_correct_only": self.c,
"statistic": self.statistic,
"p_value": self.p_value,
"method": self.method,
}
def mcnemar_test(
arm_a_correct: Sequence[bool],
arm_b_correct: Sequence[bool],
*,
use_exact_below: int = 11,
) -> McnemarResult:
"""Paired McNemar's test on per-question correctness.
``arm_a_correct`` is treated as the reference arm (typically the
"native" arm); ``arm_b_correct`` is the challenger (typically
"surfsense"). The test statistic only depends on discordant pairs.
Default switch-over (``b + c < 11``): for very small discordant
samples the exact binomial test is preferred; above that the
continuity-corrected chi-square is well-behaved (Edwards 1948).
Callers can raise ``use_exact_below`` if they prefer the more
conservative ``b + c < 25`` rule.
No external statistical package is required: scipy is a heavy dep
and we only need binomial CDFs / chi-square sf, both implementable
in stdlib + numpy without surprises.
"""
if len(arm_a_correct) != len(arm_b_correct):
raise ValueError(
f"Length mismatch: arm_a={len(arm_a_correct)}, arm_b={len(arm_b_correct)}"
)
n = len(arm_a_correct)
b = sum(1 for a, c in zip(arm_a_correct, arm_b_correct) if a and not c)
c = sum(1 for a, cc in zip(arm_a_correct, arm_b_correct) if (not a) and cc)
discordant = b + c
if discordant == 0:
return McnemarResult(
n_total=n, b=b, c=c, statistic=0.0, p_value=1.0, method="degenerate"
)
if discordant < use_exact_below:
# Exact binomial: under H0 each discordant pair is a Bernoulli(0.5).
# p-value = 2 * P(X <= min(b,c) | n=discordant, p=0.5), capped at 1.
k = min(b, c)
cdf = sum(_binom_pmf(discordant, i) for i in range(k + 1))
p_value = min(1.0, 2.0 * cdf)
return McnemarResult(
n_total=n, b=b, c=c, statistic=float(k), p_value=p_value, method="exact"
)
# Chi-square with continuity correction (McNemar-Edwards).
chi = ((abs(b - c) - 1) ** 2) / discordant
p_value = _chi2_sf(chi, df=1)
return McnemarResult(
n_total=n, b=b, c=c, statistic=chi, p_value=p_value, method="chi2_cc"
)
def _binom_pmf(n: int, k: int) -> float:
return math.comb(n, k) * (0.5 ** n)
def _chi2_sf(x: float, *, df: int) -> float:
"""Survival function (1 - CDF) of chi-square; df=1 closed form."""
if x <= 0:
return 1.0
if df == 1:
# Chi^2(1) = N(0,1)^2; sf(x) = 2 * Phi_complement(sqrt(x))
return math.erfc(math.sqrt(x / 2.0))
# General fallback via regularized upper incomplete gamma.
a = df / 2.0
z = x / 2.0
return _gammaincc(a, z)
def _gammaincc(a: float, x: float, *, max_iter: int = 200, tol: float = 1e-12) -> float:
"""Regularised upper incomplete gamma Q(a, x). Series + continued fraction."""
if x < 0 or a <= 0:
return float("nan")
if x == 0:
return 1.0
if x < a + 1.0:
# Series for P(a, x); subtract from 1.
p_series = _gammainc_series(a, x, max_iter=max_iter, tol=tol)
return 1.0 - p_series
return _gammaincc_cf(a, x, max_iter=max_iter, tol=tol)
def _gammainc_series(a: float, x: float, *, max_iter: int, tol: float) -> float:
term = 1.0 / a
summation = term
for n in range(1, max_iter):
term *= x / (a + n)
summation += term
if abs(term) < abs(summation) * tol:
break
log_pre = -x + a * math.log(x) - math.lgamma(a)
return summation * math.exp(log_pre)
def _gammaincc_cf(a: float, x: float, *, max_iter: int, tol: float) -> float:
b = x + 1.0 - a
c_val = 1.0 / 1e-300
d = 1.0 / b
h = d
for i in range(1, max_iter):
an = -i * (i - a)
b += 2.0
d = an * d + b
if abs(d) < 1e-300:
d = 1e-300
c_val = b + an / c_val
if abs(c_val) < 1e-300:
c_val = 1e-300
d = 1.0 / d
delta = d * c_val
h *= delta
if abs(delta - 1.0) < tol:
break
log_pre = -x + a * math.log(x) - math.lgamma(a)
return h * math.exp(log_pre)
# ---------------------------------------------------------------------------
# Bootstrap delta CI
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class BootstrapDelta:
delta: float
ci_low: float
ci_high: float
n_resamples: int
def to_dict(self) -> dict[str, float | int]:
return {
"delta": self.delta,
"ci_low": self.ci_low,
"ci_high": self.ci_high,
"n_resamples": self.n_resamples,
}
def bootstrap_delta_ci(
arm_a_correct: Sequence[bool],
arm_b_correct: Sequence[bool],
*,
n_resamples: int = 5000,
level: float = 0.95,
random_state: int | None = 0,
) -> BootstrapDelta:
"""Paired-sample bootstrap CI for ``mean(arm_b) - mean(arm_a)``.
Resamples *paired indices* with replacement so the dependency
between arms is preserved.
"""
if len(arm_a_correct) != len(arm_b_correct):
raise ValueError("paired arms must have the same length")
n = len(arm_a_correct)
if n == 0:
return BootstrapDelta(0.0, 0.0, 0.0, 0)
a = np.asarray(arm_a_correct, dtype=np.int8)
b = np.asarray(arm_b_correct, dtype=np.int8)
delta = float(b.mean() - a.mean())
rng = np.random.default_rng(random_state)
deltas = np.empty(n_resamples, dtype=np.float64)
for i in range(n_resamples):
idx = rng.integers(0, n, size=n)
deltas[i] = b[idx].mean() - a[idx].mean()
alpha = (1.0 - level) / 2.0
ci_low, ci_high = float(np.quantile(deltas, alpha)), float(np.quantile(deltas, 1 - alpha))
return BootstrapDelta(delta=delta, ci_low=ci_low, ci_high=ci_high, n_resamples=n_resamples)
# ---------------------------------------------------------------------------
# Simple aggregate helpers (cost / latency / tokens)
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class Aggregate:
mean: float
median: float
p95: float
n: int
def to_dict(self) -> dict[str, float | int]:
return {"mean": self.mean, "median": self.median, "p95": self.p95, "n": self.n}
def paired_aggregate(values: Sequence[float]) -> Aggregate:
"""Mean / median / p95 of a list of numbers (e.g. cost-per-question)."""
if not values:
return Aggregate(0.0, 0.0, 0.0, 0)
arr = np.asarray(values, dtype=np.float64)
return Aggregate(
mean=float(arr.mean()),
median=float(statistics.median(values)),
p95=float(np.quantile(arr, 0.95)),
n=len(values),
)
__all__ = [
"Aggregate",
"BootstrapDelta",
"McnemarResult",
"bootstrap_delta_ci",
"mcnemar_test",
"paired_aggregate",
]

View file

@ -0,0 +1,130 @@
"""Multiple-choice accuracy + Wilson 95% confidence intervals.
Wilson CI is preferred over normal-approximation because MIRAGE's
per-task subsets can be small (PubMedQA* and BioASQ-Y/N have a few
hundred questions each) and Wilson handles n0 / p{0,1} edges
gracefully.
Reference for the closed form: Wilson (1927); identical to the
``statsmodels.stats.proportion.proportion_confint(method='wilson')``
output and what scikit-learn implements internally for its bounded
estimators.
"""
from __future__ import annotations
import math
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
@dataclass(frozen=True)
class AccuracyResult:
"""Per-task accuracy with Wilson CI."""
n_correct: int
n_total: int
accuracy: float
ci_low: float
ci_high: float
def to_dict(self) -> dict[str, float | int]:
return {
"n_correct": self.n_correct,
"n_total": self.n_total,
"accuracy": self.accuracy,
"ci_low": self.ci_low,
"ci_high": self.ci_high,
}
# Two-sided Wilson z values. 1.959964 ≈ z_{0.975}.
_Z_FOR_LEVEL: dict[float, float] = {
0.90: 1.6448536269514722,
0.95: 1.959963984540054,
0.99: 2.5758293035489004,
}
def wilson_ci(
n_correct: int, n_total: int, *, level: float = 0.95
) -> tuple[float, float]:
"""Two-sided Wilson score confidence interval for a proportion.
Returns ``(low, high)``. ``n_total == 0`` returns ``(0.0, 1.0)``
the maximally uncertain interval.
"""
if n_total <= 0:
return 0.0, 1.0
if level not in _Z_FOR_LEVEL:
raise ValueError(f"Unsupported confidence level {level!r}")
z = _Z_FOR_LEVEL[level]
p = n_correct / n_total
n = n_total
denom = 1.0 + (z * z) / n
centre = (p + (z * z) / (2 * n)) / denom
half = (z / denom) * math.sqrt((p * (1 - p) / n) + (z * z) / (4 * n * n))
low = max(0.0, centre - half)
high = min(1.0, centre + half)
return low, high
def accuracy_with_wilson_ci(
n_correct: int, n_total: int, *, level: float = 0.95
) -> AccuracyResult:
if n_total < 0:
raise ValueError(f"n_total must be >= 0, got {n_total}")
if n_correct < 0 or n_correct > n_total:
raise ValueError(
f"n_correct must be in [0, n_total]; got n_correct={n_correct}, n_total={n_total}"
)
accuracy = (n_correct / n_total) if n_total > 0 else 0.0
low, high = wilson_ci(n_correct, n_total, level=level)
return AccuracyResult(
n_correct=n_correct,
n_total=n_total,
accuracy=accuracy,
ci_low=low,
ci_high=high,
)
def per_task_accuracy(
rows: Sequence[Mapping[str, object]],
*,
task_key: str = "task",
correct_key: str = "is_correct",
level: float = 0.95,
) -> dict[str, AccuracyResult]:
"""Group ``rows`` by ``task_key`` and compute per-task ``AccuracyResult``.
``rows[i][correct_key]`` must be truthy iff the answer was correct.
"""
counts: dict[str, list[int]] = {}
for row in rows:
task = str(row.get(task_key, ""))
bucket = counts.setdefault(task, [0, 0])
bucket[1] += 1
if row.get(correct_key):
bucket[0] += 1
return {
task: accuracy_with_wilson_ci(c[0], c[1], level=level)
for task, c in counts.items()
}
def macro_accuracy(per_task: Mapping[str, AccuracyResult]) -> float:
if not per_task:
return 0.0
return sum(r.accuracy for r in per_task.values()) / len(per_task)
__all__ = [
"AccuracyResult",
"accuracy_with_wilson_ci",
"macro_accuracy",
"per_task_accuracy",
"wilson_ci",
]

View file

@ -0,0 +1,132 @@
"""Retrieval metrics: Recall@k, MRR, nDCG@k.
Used by CUREv1's runner to score the SurfSense arm against the
benchmark's qrels. ``corpus_id`` is the canonical CUREv1 passage id
(string); the runner maps SurfSense ``chunk_id`` ``document_id``
``corpus_id`` before calling these.
Graded relevance (CUREv1 uses 0/1/2 grades) is honoured by ``ndcg_at_k``;
``recall_at_k`` and ``mrr`` flatten anything > 0 to "relevant".
"""
from __future__ import annotations
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
@dataclass(frozen=True)
class RetrievalScores:
"""Aggregated retrieval scores."""
recall_at_k: dict[int, float]
mrr: float
ndcg_at_10: float
n_queries: int
def to_dict(self) -> dict:
return {
"recall_at_k": dict(self.recall_at_k),
"mrr": self.mrr,
"ndcg_at_10": self.ndcg_at_10,
"n_queries": self.n_queries,
}
def recall_at_k(retrieved: Sequence[str], relevant: Iterable[str], k: int) -> float:
"""Fraction of ``relevant`` documents found in ``retrieved[:k]``."""
if not relevant:
return 0.0
relevant_set = set(relevant)
if not relevant_set:
return 0.0
top_k = list(retrieved)[:k]
hits = sum(1 for doc in top_k if doc in relevant_set)
return hits / len(relevant_set)
def mrr(retrieved: Sequence[str], relevant: Iterable[str]) -> float:
"""Reciprocal rank of the first relevant doc, 0 if none found."""
relevant_set = set(relevant)
for rank, doc in enumerate(retrieved, start=1):
if doc in relevant_set:
return 1.0 / rank
return 0.0
def _dcg_at_k(grades: Sequence[float], k: int) -> float:
s = 0.0
for i, grade in enumerate(grades[:k], start=1):
# Standard log-base-2 discount; gain = 2^grade - 1 for graded relevance.
s += (2.0 ** grade - 1.0) / math.log2(i + 1)
return s
def ndcg_at_k(
retrieved: Sequence[str],
qrels: Mapping[str, float],
k: int,
) -> float:
"""nDCG@k against graded ``qrels`` (``{doc_id: grade}``).
Unjudged documents in ``retrieved`` contribute zero gain. The
ideal ordering is ``qrels`` sorted by grade descending.
"""
if not qrels:
return 0.0
grades = [float(qrels.get(doc, 0.0)) for doc in retrieved]
dcg = _dcg_at_k(grades, k)
ideal = sorted(qrels.values(), reverse=True)
idcg = _dcg_at_k([float(g) for g in ideal], k)
if idcg == 0.0:
return 0.0
return dcg / idcg
def score_run(
*,
per_query_retrieved: Mapping[str, Sequence[str]],
per_query_qrels: Mapping[str, Mapping[str, float]],
ks: Sequence[int] = (1, 5, 10, 32),
ndcg_k: int = 10,
) -> RetrievalScores:
"""Aggregate Recall@k, MRR, nDCG@k across a run.
``per_query_retrieved`` maps ``query_id -> ordered list of doc ids``.
``per_query_qrels`` maps ``query_id -> {doc_id: grade}`` (grade > 0
is relevant).
Queries present in retrieved but not in qrels are skipped. Queries
in qrels but missing from retrieved contribute zeros.
"""
qids = set(per_query_qrels.keys()) & set(per_query_retrieved.keys())
if not qids:
return RetrievalScores(recall_at_k={k: 0.0 for k in ks}, mrr=0.0, ndcg_at_10=0.0, n_queries=0)
recall_totals = {k: 0.0 for k in ks}
mrr_total = 0.0
ndcg_total = 0.0
for qid in qids:
retrieved = list(per_query_retrieved[qid])
qrels = per_query_qrels[qid]
relevant_docs = [d for d, g in qrels.items() if g > 0]
for k in ks:
recall_totals[k] += recall_at_k(retrieved, relevant_docs, k)
mrr_total += mrr(retrieved, relevant_docs)
ndcg_total += ndcg_at_k(retrieved, qrels, ndcg_k)
n = len(qids)
return RetrievalScores(
recall_at_k={k: v / n for k, v in recall_totals.items()},
mrr=mrr_total / n,
ndcg_at_10=ndcg_total / n,
n_queries=n,
)
__all__ = ["RetrievalScores", "mrr", "ndcg_at_k", "recall_at_k", "score_run"]

View file

@ -0,0 +1,21 @@
"""Parsers shared across suites: citations, MCQ envelopes, AI-SDK SSE."""
from __future__ import annotations
from .answer_letter import AnswerLetterResult, extract_answer_letter
from .citations import CITATION_REGEX, CitationToken, ChunkCitation, UrlCitation, parse_citations
from .freeform_answer import extract_freeform_answer
from .sse import SseEvent, iter_sse_events
__all__ = [
"CITATION_REGEX",
"CitationToken",
"ChunkCitation",
"UrlCitation",
"parse_citations",
"AnswerLetterResult",
"extract_answer_letter",
"extract_freeform_answer",
"SseEvent",
"iter_sse_events",
]

View file

@ -0,0 +1,122 @@
"""Robust extractor for MCQ answer letters.
Handles three answer shapes seen in the wild:
1. **MedRAG envelope** ``{"step_by_step_thinking": "...", "answer_choice": "A"}``
embedded somewhere in the assistant message (often inside ```` ```json ```` /
``` ``` ``` fences). The regex grabs the JSON object and reads the
``answer_choice`` field.
2. **Final-line letter** e.g. ``Answer: B`` or ``The correct answer is (C).``.
Falls back to a permissive regex over the last few lines.
3. **Bare letter** single uppercase letter at the end of the message.
The function returns the parsed letter (uppercased) plus a discriminator
of which strategy fired so the runner / report can flag suspicious
parses (typically zero-confidence parses indicate the model didn't
follow the prompt).
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from typing import Literal
ParserStrategy = Literal["json_envelope", "answer_line", "bare_letter", "none"]
@dataclass(frozen=True)
class AnswerLetterResult:
letter: str | None
strategy: ParserStrategy
@property
def found(self) -> bool:
return self.letter is not None
# ---------------------------------------------------------------------------
# Strategies
# ---------------------------------------------------------------------------
_JSON_BLOCK = re.compile(r"\{[^{}]*\"answer_choice\"\s*:\s*\"([A-Za-z])\"[^{}]*\}", re.DOTALL)
_FENCED_JSON = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL | re.IGNORECASE)
_ANSWER_LINE = re.compile(
r"(?:final\s*answer|answer\s*choice|the\s+correct\s+answer\s+is|answer)\s*[:=\-]?\s*"
r"\(?\s*([A-Za-z])\s*[\)\.]*\s*$",
re.IGNORECASE | re.MULTILINE,
)
_BARE_LETTER = re.compile(r"^\s*\(?\s*([A-Za-z])\s*[\)\.]*\s*$", re.MULTILINE)
def _from_json_envelope(text: str) -> str | None:
# Try fenced code blocks first (most likely to contain the JSON).
for fence in _FENCED_JSON.finditer(text):
try:
obj = json.loads(fence.group(1))
except (json.JSONDecodeError, ValueError):
continue
if isinstance(obj, dict):
choice = obj.get("answer_choice")
if isinstance(choice, str) and choice.strip():
return choice.strip()[:1].upper()
# Fall back to a tolerant regex over the whole text (handles
# responses that drop the fences).
match = _JSON_BLOCK.search(text)
if match:
return match.group(1).upper()
return None
def _from_answer_line(text: str) -> str | None:
# Walk lines bottom-up; the answer is almost always near the end.
for match in reversed(list(_ANSWER_LINE.finditer(text))):
letter = match.group(1).upper()
if letter.isalpha():
return letter
return None
def _from_bare_letter(text: str) -> str | None:
# Inspect only the final non-empty lines (avoid grabbing in-prose
# mentions of "A" or "I").
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
for ln in reversed(lines[-3:]):
match = _BARE_LETTER.match(ln)
if match:
return match.group(1).upper()
return None
def extract_answer_letter(text: str) -> AnswerLetterResult:
"""Run strategies in order and return the first hit.
Order: JSON envelope final-answer-line regex bare-letter
fallback. Empty / whitespace-only text returns
``AnswerLetterResult(None, "none")``.
"""
if not text or not text.strip():
return AnswerLetterResult(None, "none")
letter = _from_json_envelope(text)
if letter:
return AnswerLetterResult(letter, "json_envelope")
letter = _from_answer_line(text)
if letter:
return AnswerLetterResult(letter, "answer_line")
letter = _from_bare_letter(text)
if letter:
return AnswerLetterResult(letter, "bare_letter")
return AnswerLetterResult(None, "none")
__all__ = ["AnswerLetterResult", "ParserStrategy", "extract_answer_letter"]

View file

@ -0,0 +1,110 @@
"""Python port of the canonical citation parser.
Source of truth: ``surfsense_web/lib/citations/citation-parser.ts:20-21``.
The pattern is byte-for-byte identical to the TS export ``CITATION_REGEX``
so a SurfSense user reading the web client and a CUREv1 retrieval scorer
running here see the same chunk_ids extracted from the same answer.
The TS reference also handles a ``urlcite{N}`` placeholder produced by
``preprocessCitationMarkdown`` that pre-processing step is web-only
(GFM autolink workaround), so the harness sees raw ``[citation:URL]``
tokens and ``parse_citations`` returns them as ``UrlCitation`` directly.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any, Union
# Pattern preserves the TS source verbatim:
# /[\[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g
#
# Notes:
# * Matches both ASCII ``[]`` and Chinese fullwidth ``【】`` brackets.
# * Allows an optional ZWSP (``\u200B``) just inside each bracket.
# * ``citation:`` then EITHER a URL (anything not ``]``, ``】``, or ZWSP),
# OR a ``urlcite\d+`` placeholder, OR one or more comma-separated
# chunk ids (each optionally prefixed with ``doc-`` and optionally
# negative).
# * URL char class deliberately excludes the closing brackets so a
# ``[citation:https://x.com]`` doesn't swallow the ``]``.
# The ZWSP must be the actual code-point — the original TS source uses
# the regex literal ``\u200B`` which the JS engine interprets as the
# character. Python's ``re`` doesn't process the ``\u`` escape inside
# the pattern source, so we splice the literal character in via an
# f-string. This keeps our pattern functionally identical to the TS
# reference and lets ``"\u200B" in CITATION_REGEX.pattern`` succeed.
_ZWSP = "\u200B"
CITATION_REGEX = re.compile(
rf"[\[【]{_ZWSP}?citation:\s*("
rf"https?://[^\]】{_ZWSP}]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*"
rf")\s*{_ZWSP}?[\]】]"
)
@dataclass(frozen=True)
class ChunkCitation:
chunk_id: int
is_docs_chunk: bool
def to_dict(self) -> dict[str, Any]:
return {
"kind": "chunk",
"chunk_id": self.chunk_id,
"is_docs_chunk": self.is_docs_chunk,
}
@dataclass(frozen=True)
class UrlCitation:
url: str
def to_dict(self) -> dict[str, Any]:
return {"kind": "url", "url": self.url}
CitationToken = Union[ChunkCitation, UrlCitation]
def parse_citations(text: str, *, url_map: dict[str, str] | None = None) -> list[CitationToken]:
"""Return the citation tokens found in ``text`` in document order.
``url_map`` is the optional ``urlciteN -> URL`` lookup that the web
client builds in its preprocessing step. The harness ordinarily
doesn't preprocess (we don't render the markdown, we score it), so
the default empty map means ``urlciteN`` placeholders are dropped
rather than mis-resolved to a missing URL.
Multi-id payloads like ``[citation:1, doc-2, -3]`` are flattened
into separate ``ChunkCitation`` entries same as the TS reference.
"""
out: list[CitationToken] = []
for match in CITATION_REGEX.finditer(text):
captured = match.group(1)
if captured.startswith("http://") or captured.startswith("https://"):
out.append(UrlCitation(url=captured.strip()))
continue
if captured.startswith("urlcite"):
if url_map and captured in url_map:
out.append(UrlCitation(url=url_map[captured]))
continue
for raw_id in (s.strip() for s in captured.split(",")):
is_docs_chunk = raw_id.startswith("doc-")
number_part = raw_id[4:] if is_docs_chunk else raw_id
try:
chunk_id = int(number_part)
except ValueError:
continue
out.append(ChunkCitation(chunk_id=chunk_id, is_docs_chunk=is_docs_chunk))
return out
__all__ = [
"CITATION_REGEX",
"ChunkCitation",
"UrlCitation",
"CitationToken",
"parse_citations",
]

View file

@ -0,0 +1,85 @@
"""Extract free-form answers from open-ended LLM responses.
Used by benchmarks that don't have a fixed letter set (MMLongBench-Doc,
DocVQA-style benchmarks, future legal/finance suites). The contract:
* Strip leading "Answer:" / "Final answer:" markers if present.
* Drop fenced code blocks if the model wrapped its answer in one.
* Trim leading/trailing whitespace.
* Return the *last* meaningful chunk models often think out loud
before stating the answer.
If the message is empty or only contains a fence, return ``""``.
"""
from __future__ import annotations
import re
_ANSWER_PREFIX = re.compile(
r"^\s*(?:final\s*answer|the\s+answer\s+is|answer)\s*[:=\-]\s*",
re.IGNORECASE,
)
# Marker-only regex (no capture group) used to find every "Answer:"
# token position. We then slice from the LAST marker's end to the
# next newline ourselves — robust to multiple inline answers because
# we never let the engine greedy-capture across markers.
_ANSWER_MARKER = re.compile(
r"(?:final\s*answer|the\s+answer\s+is|answer)\s*[:=\-]\s*",
re.IGNORECASE,
)
_FENCED_BLOCK = re.compile(r"```[a-zA-Z0-9]*\s*([\s\S]*?)\s*```")
def extract_freeform_answer(text: str) -> str:
"""Pull the model's final answer out of a possibly-verbose response."""
if not text or not text.strip():
return ""
# 1. Find the last line that starts with an Answer: marker. If
# nothing matches, walk back to the last non-empty line.
lines = [ln.rstrip() for ln in text.strip().splitlines()]
candidate = ""
for ln in reversed(lines):
if not ln.strip():
continue
if _ANSWER_PREFIX.search(ln):
candidate = _ANSWER_PREFIX.sub("", ln, count=1).strip()
break
if not candidate:
# 2. Inline match: find every "Answer:" marker position and
# slice from the LAST marker's end to the next newline. Robust
# to "preamble.Answer: 42" one-liners and multiple inline
# markers (we always pick the final, freshest one).
marker_matches = list(_ANSWER_MARKER.finditer(text))
if marker_matches:
last = marker_matches[-1]
tail = text[last.end():]
nl = tail.find("\n")
if nl >= 0:
tail = tail[:nl]
candidate = tail.strip()
if not candidate:
# 3. No "Answer:" marker — try fenced blocks.
fences = _FENCED_BLOCK.findall(text)
if fences:
candidate = fences[-1].strip()
else:
# Last non-empty line as a fallback.
for ln in reversed(lines):
if ln.strip():
candidate = ln.strip()
break
# 2. Strip wrapping quotes / parens / trailing punctuation that
# confuse the grader without changing meaning.
candidate = candidate.strip().strip("`").strip()
if candidate.startswith(("\"", "'")) and candidate.endswith(("\"", "'")):
candidate = candidate[1:-1].strip()
return candidate
__all__ = ["extract_freeform_answer"]

View file

@ -0,0 +1,72 @@
"""Minimal SSE consumer compatible with SurfSense's wire format.
SurfSense uses ``app/services/streaming/envelope/sse.py`` to frame events:
* ``data: <single-line-string>\\n\\n``
* ``data: <json-string>\\n\\n`` (most events)
* ``data: [DONE]\\n\\n`` (terminator)
There is no ``event:``, ``id:``, or ``retry:`` framing in production
``format_sse(payload)`` only emits the ``data:`` line. This implementation
is therefore intentionally smaller than ``httpx-sse`` (which we still
list as a dep so callers who want richer parsing can opt in): one event
per ``data:`` line, separated by blank lines.
We accept any line iterator (an ``httpx.Response.aiter_lines`` adapter
in production, a list in tests) so this is unit-testable without a
network mock.
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass
@dataclass(frozen=True)
class SseEvent:
"""A parsed SSE event. Only the ``data`` field is populated.
Multi-line payloads (``data: a\\ndata: b``) are joined with ``\\n``
per the SSE spec, even though SurfSense doesn't currently emit them.
"""
data: str
async def iter_sse_events(lines: AsyncIterator[str]) -> AsyncIterator[SseEvent]:
"""Yield one ``SseEvent`` per blank-line-terminated frame.
Lines that are empty or whitespace flush the buffer. ``data:`` lines
are accumulated into the buffer; everything else is ignored
(matches the lenient browser EventSource behaviour).
"""
buffer: list[str] = []
async for raw in lines:
if raw is None:
continue
line = raw.rstrip("\r")
if line == "":
if buffer:
yield SseEvent(data="\n".join(buffer))
buffer.clear()
continue
if line.startswith(":"):
# comment / heartbeat
continue
if line.startswith("data:"):
# spec: optional single space after the colon.
payload = line[5:]
if payload.startswith(" "):
payload = payload[1:]
buffer.append(payload)
continue
# Any other field (event:, id:, retry:) is currently unused.
continue
if buffer:
yield SseEvent(data="\n".join(buffer))
__all__ = ["SseEvent", "iter_sse_events"]

View file

@ -0,0 +1,31 @@
"""Domain-agnostic PDF rendering helper. Lazy import."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover
from .render import (
PdfImage,
render_pdf,
render_pdf_with_images,
render_text_files_to_pdf,
)
__all__ = [
"PdfImage",
"render_pdf",
"render_pdf_with_images",
"render_text_files_to_pdf",
]
_LAZY = {"PdfImage", "render_pdf", "render_pdf_with_images", "render_text_files_to_pdf"}
def __getattr__(name: str):
if name in _LAZY:
from . import render as _mod
return getattr(_mod, name)
raise AttributeError(f"module 'surfsense_evals.core.pdf' has no attribute {name!r}")

View file

@ -0,0 +1,351 @@
"""Deterministic ``.txt`` / ``.md`` → single PDF via reportlab.
Used wherever a benchmark needs the same source bytes fed to both the
native-PDF arm and the SurfSense ingestion arm. The head-to-head
comparison is fair only if the *same* PDF is the input to both arms,
which is why we go to lengths to make the rendering deterministic.
Determinism notes:
* We pin the PDF metadata to a fixed creation date and producer
(``reportlab`` accepts neither directly, but ``Canvas.setAuthor`` and
the absence of an ``info`` mutator means the bytes only differ by
``CreationDate`` / ``ModDate``). We post-process the PDF to scrub
those if ``deterministic=True`` is passed.
* Page size, font, margins, and tab handling are fixed in code so the
same input yields the same byte output across machines.
* PDF/A is overkill for our use; basic PDF 1.4 is what every model
expects.
"""
from __future__ import annotations
import io
import re
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from reportlab.lib.pagesizes import LETTER
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
from reportlab.lib.units import inch
from reportlab.lib.utils import ImageReader
from reportlab.platypus import (
Image,
KeepTogether,
PageBreak,
Paragraph,
SimpleDocTemplate,
Spacer,
)
@dataclass
class RenderedPdf:
path: Path
n_pages_estimate: int
n_chars: int
_PDF_DATE_KEY = re.compile(rb"/(?:CreationDate|ModDate)\s*\(D:[^)]*\)")
# reportlab also writes a `/ID [<hex1><hex2>]` trailer entry that
# embeds a per-run hash. Scrub it so two renders of the same input
# produce the same bytes.
_PDF_ID_ARRAY = re.compile(rb"/ID\s*\[\s*<[^>]*>\s*<[^>]*>\s*\]")
def _scrub_dates(pdf_bytes: bytes) -> bytes:
"""Remove ``CreationDate`` / ``ModDate`` / trailer ``/ID`` so the
file is byte-deterministic across runs."""
pdf_bytes = _PDF_DATE_KEY.sub(b"/CreationDate (D:19700101000000Z)", pdf_bytes)
pdf_bytes = _PDF_ID_ARRAY.sub(b"/ID [<00><00>]", pdf_bytes)
return pdf_bytes
_DEFAULT_STYLES = getSampleStyleSheet()
def _build_body_style() -> ParagraphStyle:
base = _DEFAULT_STYLES["BodyText"]
style = ParagraphStyle(
"EvalBody",
parent=base,
fontName="Helvetica",
fontSize=10.5,
leading=14,
spaceAfter=6,
spaceBefore=0,
)
return style
def _build_heading_style() -> ParagraphStyle:
base = _DEFAULT_STYLES["Heading2"]
style = ParagraphStyle(
"EvalHeading",
parent=base,
fontName="Helvetica-Bold",
fontSize=14,
leading=18,
spaceAfter=10,
spaceBefore=8,
)
return style
def _normalise_paragraphs(text: str) -> list[str]:
"""Split a text blob into paragraphs while preserving blank-line structure."""
blocks: list[list[str]] = [[]]
for line in text.splitlines():
stripped = line.rstrip()
if stripped == "":
if blocks[-1]:
blocks.append([])
continue
blocks[-1].append(stripped)
paragraphs: list[str] = []
for block in blocks:
if not block:
continue
# Join lines within a paragraph with spaces (text-from-PDF style).
paragraphs.append(" ".join(block))
return paragraphs
def _escape_html(text: str) -> str:
return (
text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
def render_pdf(
*,
title: str,
sections: Sequence[tuple[str | None, str]],
output_path: Path,
deterministic: bool = True,
) -> RenderedPdf:
"""Render one PDF from a list of ``(section_heading, section_text)`` tuples.
``section_heading`` may be ``None`` for an unnamed section. Each
section is followed by a page break so the model's PDF parser sees
a clean structural boundary between source files.
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=LETTER,
leftMargin=0.75 * inch,
rightMargin=0.75 * inch,
topMargin=0.75 * inch,
bottomMargin=0.75 * inch,
title=title,
author="surfsense-evals",
subject="Eval input",
creator="surfsense-evals",
)
body_style = _build_body_style()
heading_style = _build_heading_style()
title_style = ParagraphStyle(
"EvalTitle",
parent=_DEFAULT_STYLES["Title"],
fontName="Helvetica-Bold",
fontSize=18,
leading=22,
spaceAfter=14,
)
flow: list = [Paragraph(_escape_html(title), title_style)]
total_chars = 0
for index, (heading, text) in enumerate(sections):
if index > 0:
flow.append(PageBreak())
if heading:
flow.append(Paragraph(_escape_html(heading), heading_style))
for paragraph in _normalise_paragraphs(text):
total_chars += len(paragraph)
flow.append(Paragraph(_escape_html(paragraph), body_style))
flow.append(Spacer(1, 4))
doc.build(flow)
pdf_bytes = buffer.getvalue()
if deterministic:
pdf_bytes = _scrub_dates(pdf_bytes)
output_path.write_bytes(pdf_bytes)
# Conservative page estimate: ~3000 chars per LETTER page at 10.5pt.
n_pages = max(1, total_chars // 3000 + len(sections))
return RenderedPdf(path=output_path, n_pages_estimate=n_pages, n_chars=total_chars)
@dataclass
class PdfImage:
"""One image to embed inside a section.
``caption`` is rendered below the image (italic). ``max_width_in``
caps the rendered width in inches; height auto-scales to preserve
aspect ratio (read with PIL).
"""
path: Path
caption: str = ""
max_width_in: float = 5.5 # default leaves margin for LETTER 8.5"
def _make_image_flowable(image: PdfImage) -> Image:
"""Build a reportlab Image flowable scaled to fit page width."""
reader = ImageReader(str(image.path))
iw, ih = reader.getSize()
if iw <= 0 or ih <= 0:
raise ValueError(f"Invalid image dimensions for {image.path}: {iw}x{ih}")
target_w = image.max_width_in * inch
target_h = target_w * (ih / iw)
# Cap height too — some medical images are extreme portrait.
max_h = 7.0 * inch
if target_h > max_h:
target_h = max_h
target_w = target_h * (iw / ih)
return Image(str(image.path), width=target_w, height=target_h)
def render_pdf_with_images(
*,
title: str,
sections: Sequence[tuple[str | None, str, Sequence[PdfImage] | None]],
output_path: Path,
deterministic: bool = True,
page_break_between_sections: bool = False,
) -> RenderedPdf:
"""Render a PDF that mixes text and embedded images.
Each section is ``(heading, body_text, images)``. Images render
inline after the body text, each followed by an italic caption.
Set ``page_break_between_sections=True`` if you want explicit
structural boundaries (mostly useful for multi-case PDFs); the
default keeps everything on one page when possible (so a single
MedXpertQA case is one PDF page with case + images + options).
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=LETTER,
leftMargin=0.75 * inch,
rightMargin=0.75 * inch,
topMargin=0.75 * inch,
bottomMargin=0.75 * inch,
title=title,
author="surfsense-evals",
subject="Eval input",
creator="surfsense-evals",
)
body_style = _build_body_style()
heading_style = _build_heading_style()
caption_style = ParagraphStyle(
"EvalCaption",
parent=body_style,
fontSize=9,
leading=11,
textColor="#444",
spaceBefore=2,
spaceAfter=10,
)
title_style = ParagraphStyle(
"EvalTitle",
parent=_DEFAULT_STYLES["Title"],
fontName="Helvetica-Bold",
fontSize=18,
leading=22,
spaceAfter=14,
)
flow: list = [Paragraph(_escape_html(title), title_style)]
total_chars = 0
for index, (heading, text, images) in enumerate(sections):
if index > 0 and page_break_between_sections:
flow.append(PageBreak())
if heading:
flow.append(Paragraph(_escape_html(heading), heading_style))
for paragraph in _normalise_paragraphs(text):
total_chars += len(paragraph)
flow.append(Paragraph(_escape_html(paragraph), body_style))
flow.append(Spacer(1, 4))
for image in images or []:
try:
img_flow = _make_image_flowable(image)
except Exception: # noqa: BLE001 — bad image shouldn't kill PDF
continue
grouped = [img_flow]
if image.caption:
grouped.append(Paragraph(_escape_html(image.caption), caption_style))
else:
grouped.append(Spacer(1, 8))
flow.append(KeepTogether(grouped))
doc.build(flow)
pdf_bytes = buffer.getvalue()
if deterministic:
pdf_bytes = _scrub_dates(pdf_bytes)
output_path.write_bytes(pdf_bytes)
n_pages = max(1, total_chars // 3000 + len(sections))
return RenderedPdf(path=output_path, n_pages_estimate=n_pages, n_chars=total_chars)
def render_text_files_to_pdf(
*,
title: str,
files: Iterable[Path],
output_path: Path,
deterministic: bool = True,
) -> RenderedPdf:
"""Convenience wrapper: read a list of text files, render to one PDF.
The heading of each section is the file's name (no extension), so
e.g. ``admission_note.txt`` becomes a section header ``admission_note``
in the rendered PDF. Useful for any text-only benchmark that ships
a corpus as separate ``.txt`` / ``.md`` shards per logical document.
"""
sections: list[tuple[str | None, str]] = []
for path in files:
path = Path(path)
text = path.read_text(encoding="utf-8")
sections.append((path.stem, text))
return render_pdf(
title=title,
sections=sections,
output_path=output_path,
deterministic=deterministic,
)
# Tiny self-check — handy when debugging.
def _self_test() -> None: # pragma: no cover
out = Path("./_render_self_test.pdf")
sections = [
("intro", "Hello world.\n\nThis is a test."),
("body", "Line one.\nLine two."),
]
rendered = render_pdf(title="Self test", sections=sections, output_path=out)
print(f"wrote {rendered.path} ({rendered.n_chars} chars)")
# Importing ``datetime`` keeps the timezone helper handy if a future
# benchmark wants to embed a real timestamp without losing determinism.
_NOW_FROZEN = datetime(2026, 5, 11, tzinfo=UTC)

View file

@ -0,0 +1,22 @@
"""External LLM providers (used by the native arm).
Lazy imports so the SurfSense-only path doesn't transitively load the
OpenRouter client until something actually constructs ``OpenRouterPdfProvider``.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover
from .openrouter_pdf import OpenRouterPdfProvider, OpenRouterResponse
__all__ = ["OpenRouterPdfProvider", "OpenRouterResponse"]
def __getattr__(name: str):
if name in {"OpenRouterPdfProvider", "OpenRouterResponse"}:
from . import openrouter_pdf as _mod
return getattr(_mod, name)
raise AttributeError(f"module 'surfsense_evals.core.providers' has no attribute {name!r}")

View file

@ -0,0 +1,118 @@
"""Bare OpenRouter ``chat/completions`` provider — no PDF, no plugins.
Used by ``BareLlmArm`` to measure "what does the model answer with
zero retrieval context?". Same wire shape as ``OpenRouterPdfProvider``
minus the file-parser plugin and the ``file`` content part:
```json
{
"model": "openai/gpt-5.4-mini",
"messages": [
{"role": "system", "content": "<optional>"},
{"role": "user", "content": "<prompt>"}
]
}
```
The response shape is identical to the PDF provider's, so we re-use
``_parse_chat_completion`` from ``openrouter_pdf`` and only specialise
the request builder. That keeps cost-extraction, token-counting, and
content-array handling in one place.
"""
from __future__ import annotations
import logging
import time
from typing import Any
import httpx
from .openrouter_pdf import (
OpenRouterResponse,
_DEFAULT_HEADERS,
_parse_chat_completion,
)
logger = logging.getLogger(__name__)
class OpenRouterChatProvider:
"""Stateless bare-chat client. No PDF, no file-parser plugin."""
def __init__(
self,
*,
api_key: str,
base_url: str = "https://openrouter.ai/api/v1",
model: str,
timeout_s: float = 600.0,
) -> None:
if not api_key:
raise ValueError("OPENROUTER_API_KEY is required for the bare-LLM arm.")
self._api_key = api_key
self._base = base_url.rstrip("/")
self._model = model
self._timeout = httpx.Timeout(timeout_s, connect=15.0)
@property
def model(self) -> str:
return self._model
def _build_payload(
self,
*,
prompt: str,
system_prompt: str | None,
max_tokens: int | None,
) -> dict[str, Any]:
messages: list[dict[str, Any]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
body: dict[str, Any] = {"model": self._model, "messages": messages}
if max_tokens:
body["max_tokens"] = max_tokens
return body
async def complete(
self,
*,
prompt: str,
system_prompt: str | None = None,
max_tokens: int | None = None,
http: httpx.AsyncClient | None = None,
) -> OpenRouterResponse:
"""Single chat completion. Errors are raised verbatim — caller decides retries."""
payload = self._build_payload(
prompt=prompt,
system_prompt=system_prompt,
max_tokens=max_tokens,
)
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
**_DEFAULT_HEADERS,
}
url = f"{self._base}/chat/completions"
started = time.monotonic()
if http is not None:
response = await http.post(url, json=payload, headers=headers, timeout=self._timeout)
else:
async with httpx.AsyncClient(timeout=self._timeout) as client:
response = await client.post(
url, json=payload, headers=headers, timeout=self._timeout
)
latency_ms = int((time.monotonic() - started) * 1000)
if response.status_code >= 400:
raise httpx.HTTPStatusError(
f"OpenRouter HTTP {response.status_code}: {response.text[:300]}",
request=response.request,
response=response,
)
return _parse_chat_completion(response.json(), latency_ms=latency_ms)
__all__ = ["OpenRouterChatProvider"]

View file

@ -0,0 +1,231 @@
"""Native-PDF arm provider: OpenRouter ``chat/completions`` with PDF input.
Per `<https://openrouter.ai/docs/features/multimodal/pdfs>`__ the wire
shape is OpenAI-compatible with one PDF-specific extra:
```json
{
"model": "anthropic/claude-sonnet-4.5",
"messages": [{
"role": "user",
"content": [
{"type": "file", "file": {"filename": "case.pdf",
"file_data": "data:application/pdf;base64,<b64>"}},
{"type": "text", "text": "<prompt>"}
]
}],
"plugins": [{"id": "file-parser", "pdf": {"engine": "native"}}]
}
```
``engine: "native"`` is the only engine that doesn't pre-OCR the
PDF it forwards raw bytes to PDF-native models (Claude, Gemini),
matching what a human user does when "dropping the PDF into Claude".
``mistral-ocr`` and ``cloudflare-ai`` are exposed as enum options for
non-native models.
Headers ``HTTP-Referer`` and ``X-Title`` make spend show up cleanly on
the OpenRouter dashboard.
"""
from __future__ import annotations
import base64
import logging
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
import httpx
logger = logging.getLogger(__name__)
class PdfEngine(str, Enum):
NATIVE = "native"
MISTRAL_OCR = "mistral-ocr"
CLOUDFLARE_AI = "cloudflare-ai"
@dataclass
class OpenRouterResponse:
"""Subset of the OpenRouter response we care about for scoring."""
text: str
input_tokens: int
output_tokens: int
total_tokens: int
cost_micros: int
latency_ms: int
finish_reason: str | None
raw: dict[str, Any]
_DEFAULT_HEADERS = {
"HTTP-Referer": "https://github.com/MODSetter/SurfSense",
"X-Title": "SurfSense-evals",
}
class OpenRouterPdfProvider:
"""Thin httpx-based client. Stateless; safe to reuse per arm instance."""
def __init__(
self,
*,
api_key: str,
base_url: str = "https://openrouter.ai/api/v1",
model: str,
engine: PdfEngine = PdfEngine.NATIVE,
timeout_s: float = 600.0,
) -> None:
if not api_key:
raise ValueError("OPENROUTER_API_KEY is required for the native arm.")
self._api_key = api_key
self._base = base_url.rstrip("/")
self._model = model
self._engine = engine
self._timeout = httpx.Timeout(timeout_s, connect=15.0)
@property
def model(self) -> str:
return self._model
@property
def engine(self) -> PdfEngine:
return self._engine
def _build_payload(
self,
*,
prompt: str,
pdf_path: Path,
max_tokens: int | None,
extra_messages: list[dict[str, Any]] | None,
) -> dict[str, Any]:
b64 = base64.b64encode(pdf_path.read_bytes()).decode("ascii")
user_content: list[dict[str, Any]] = [
{
"type": "file",
"file": {
"filename": pdf_path.name,
"file_data": f"data:application/pdf;base64,{b64}",
},
},
{"type": "text", "text": prompt},
]
messages: list[dict[str, Any]] = list(extra_messages or [])
messages.append({"role": "user", "content": user_content})
body: dict[str, Any] = {
"model": self._model,
"messages": messages,
"plugins": [
{"id": "file-parser", "pdf": {"engine": self._engine.value}}
],
}
if max_tokens:
body["max_tokens"] = max_tokens
return body
async def complete(
self,
*,
prompt: str,
pdf_path: Path,
max_tokens: int | None = None,
extra_messages: list[dict[str, Any]] | None = None,
http: httpx.AsyncClient | None = None,
) -> OpenRouterResponse:
"""Single chat completion. Errors are raised verbatim — runner decides retries."""
payload = self._build_payload(
prompt=prompt,
pdf_path=pdf_path,
max_tokens=max_tokens,
extra_messages=extra_messages,
)
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
**_DEFAULT_HEADERS,
}
url = f"{self._base}/chat/completions"
started = time.monotonic()
if http is not None:
response = await http.post(url, json=payload, headers=headers, timeout=self._timeout)
else:
async with httpx.AsyncClient(timeout=self._timeout) as client:
response = await client.post(
url, json=payload, headers=headers, timeout=self._timeout
)
latency_ms = int((time.monotonic() - started) * 1000)
if response.status_code >= 400:
raise httpx.HTTPStatusError(
f"OpenRouter HTTP {response.status_code}: {response.text[:300]}",
request=response.request,
response=response,
)
data = response.json()
return _parse_chat_completion(data, latency_ms=latency_ms)
def _parse_chat_completion(payload: dict[str, Any], *, latency_ms: int) -> OpenRouterResponse:
"""Tolerant parser for OpenRouter / OpenAI chat-completions JSON.
OpenRouter passes through any provider-specific extras, but the
canonical shape is ``choices[0].message.content`` (string OR array
of content parts) and ``usage.prompt_tokens / completion_tokens / total_tokens``.
Cost lives at the top level (``payload["usage"]["cost"]`` or
``payload["x-or-cost"]``) depending on routing.
"""
text = ""
finish_reason: str | None = None
choices = payload.get("choices") or []
if choices:
message = (choices[0] or {}).get("message") or {}
content = message.get("content")
if isinstance(content, str):
text = content
elif isinstance(content, list):
chunks: list[str] = []
for part in content:
if isinstance(part, dict) and part.get("type") in {"text", "output_text"}:
chunks.append(str(part.get("text", "")))
text = "".join(chunks)
finish_reason = (choices[0] or {}).get("finish_reason") or None
usage = payload.get("usage") or {}
input_tokens = int(usage.get("prompt_tokens") or 0)
output_tokens = int(usage.get("completion_tokens") or 0)
total_tokens = int(usage.get("total_tokens") or (input_tokens + output_tokens))
# OpenRouter exposes cost in dollars on `usage.cost` or `cost`. We
# convert to integer micros to avoid float-summing surprises across
# 7,663 MIRAGE questions.
raw_cost = usage.get("cost")
if raw_cost is None:
raw_cost = payload.get("cost")
cost_micros = 0
if raw_cost is not None:
try:
cost_micros = int(round(float(raw_cost) * 1_000_000))
except (TypeError, ValueError):
cost_micros = 0
return OpenRouterResponse(
text=text,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
latency_ms=latency_ms,
finish_reason=finish_reason,
raw=payload,
)
__all__ = ["OpenRouterPdfProvider", "OpenRouterResponse", "PdfEngine"]

View file

@ -0,0 +1,265 @@
"""Suite + Benchmark protocols and the global registry.
The extensibility seam: ``core.cli`` walks ``surfsense_evals.suites`` on
import, which auto-imports every benchmark subpackage, which calls
``register(<benchmark>)`` at module bottom. The CLI then iterates the
populated registry to build subcommand groups dynamically.
Adding a new domain = drop a folder under ``suites/<domain>/<bench>/``
that ends in ``register(MyBenchmark())``. No edits anywhere in
``core/`` are required.
"""
from __future__ import annotations
import argparse
from collections.abc import Mapping
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol, runtime_checkable
import httpx
from .clients import DocumentsClient, NewChatClient, SearchSpaceClient
from .config import Config, SuiteState
# ---------------------------------------------------------------------------
# Run context — what every benchmark.ingest/run receives
# ---------------------------------------------------------------------------
@dataclass
class RunContext:
"""Per-invocation environment threaded into ``ingest`` and ``run``.
A benchmark uses this to read pinned suite state, build new HTTP
clients on the shared ``http`` session, find the right data /
reports paths, and discover the active OpenRouter model + key.
``http`` is the authenticated SurfSense client (auth event hook
attached). It is **not** an OpenRouter client providers create
their own short-lived clients because OpenRouter doesn't share the
SurfSense bearer.
"""
suite: str
benchmark: str
config: Config
suite_state: SuiteState
http: httpx.AsyncClient
@property
def search_space_id(self) -> int:
return self.suite_state.search_space_id
@property
def agent_llm_id(self) -> int:
return self.suite_state.agent_llm_id
@property
def provider_model(self) -> str:
"""Slug used by the SurfSense agent (and the native arm by default).
For ``cost-arbitrage`` scenarios this is the *cheap, text-only*
slug SurfSense answers from the chunks the vision LLM already
extracted at ingest. The native arm should use
``native_arm_model`` instead in that scenario.
"""
return self.suite_state.provider_model
@property
def native_arm_model(self) -> str:
"""Slug the native_pdf arm should use.
Defaults to ``provider_model`` (head-to-head / symmetric-cheap);
for ``cost-arbitrage`` it returns the explicit
``--native-arm-model`` so the native arm can fairly answer
image-bearing questions.
"""
return self.suite_state.effective_native_arm_model
@property
def vision_provider_model(self) -> str | None:
"""Slug of the OpenRouter vision LLM SurfSense used at ingest.
``None`` if no vision config was attached at setup (legacy or
text-only suite). Used by runners purely to record what was
actually used in ``RunArtifact.extra`` and to label reports.
"""
return self.suite_state.vision_provider_model
@property
def scenario(self) -> str:
"""Scenario name pinned at setup time (see ``config.SCENARIOS``)."""
return self.suite_state.scenario
def search_space_client(self) -> SearchSpaceClient:
return SearchSpaceClient(self.http, self.config.surfsense_api_base)
def documents_client(self) -> DocumentsClient:
return DocumentsClient(self.http, self.config.surfsense_api_base)
def new_chat_client(self) -> NewChatClient:
return NewChatClient(self.http, self.config.surfsense_api_base)
def maps_dir(self) -> Path:
path = self.config.suite_maps_dir(self.suite)
path.mkdir(parents=True, exist_ok=True)
return path
def runs_dir(self, *, run_timestamp: str) -> Path:
path = self.config.suite_runs_dir(self.suite) / run_timestamp / self.benchmark
path.mkdir(parents=True, exist_ok=True)
return path
def benchmark_data_dir(self) -> Path:
path = self.config.suite_data_dir(self.suite) / self.benchmark
path.mkdir(parents=True, exist_ok=True)
return path
# ---------------------------------------------------------------------------
# Run artifact + report section
# ---------------------------------------------------------------------------
@dataclass
class RunArtifact:
"""Everything a runner persists for the report writer to consume.
``raw_path`` points at the JSONL of per-question ``ArmResult``
rows. ``metrics`` is a free-form dict the benchmark fills in (e.g.
``{"native": {...}, "surfsense": {...}, "delta": {...}}``).
"""
suite: str
benchmark: str
run_timestamp: str
raw_path: Path
metrics: dict[str, Any] = field(default_factory=dict)
extra: dict[str, Any] = field(default_factory=dict)
@dataclass
class ReportSection:
"""One benchmark's slice of the final summary."""
title: str
headline: bool
body_md: str
body_json: dict[str, Any] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# Benchmark protocol + registry
# ---------------------------------------------------------------------------
@runtime_checkable
class Benchmark(Protocol):
"""The contract every benchmark module ends with ``register(<x>)``."""
suite: str
name: str
headline: bool
description: str
async def ingest(self, ctx: RunContext, **opts: Any) -> None: # pragma: no cover - protocol
...
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact: # pragma: no cover - protocol
...
def add_run_args(self, parser: argparse.ArgumentParser) -> None: # pragma: no cover - protocol
"""Add benchmark-specific flags to ``run <suite> <benchmark>``."""
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection: # pragma: no cover - protocol
...
# ---------------------------------------------------------------------------
# Registry storage
# ---------------------------------------------------------------------------
_REGISTRY: dict[tuple[str, str], Benchmark] = {}
def register(benchmark: Benchmark) -> None:
"""Add ``benchmark`` to the registry. Last-wins on duplicate keys.
Duplicate registrations log a warning rather than raising so a
benchmark module imported twice (once via auto-discovery, once via
a test directly importing it) doesn't blow up the CLI.
"""
key = (benchmark.suite, benchmark.name)
if key in _REGISTRY:
import logging
logging.getLogger(__name__).warning(
"Benchmark %s/%s re-registered (overwriting prior)", *key
)
_REGISTRY[key] = benchmark
def unregister(suite: str, name: str) -> None:
"""Test helper: drop a single benchmark from the registry."""
_REGISTRY.pop((suite, name), None)
def reset() -> None:
"""Test helper: wipe the registry (use with monkeypatched discovery)."""
_REGISTRY.clear()
def get(suite: str, name: str) -> Benchmark:
try:
return _REGISTRY[(suite, name)]
except KeyError as exc:
available = ", ".join(f"{s}/{n}" for s, n in sorted(_REGISTRY)) or "<none>"
raise KeyError(
f"Unknown benchmark '{suite}/{name}'. Registered: {available}"
) from exc
def list_suites() -> list[str]:
return sorted({s for s, _ in _REGISTRY})
def list_benchmarks(suite: str | None = None) -> list[Benchmark]:
if suite is None:
return [_REGISTRY[k] for k in sorted(_REGISTRY)]
return [_REGISTRY[k] for k in sorted(_REGISTRY) if k[0] == suite]
def snapshot() -> Mapping[tuple[str, str], Benchmark]:
"""Read-only view for diagnostics (e.g. ``benchmarks list`` rendering)."""
return dict(_REGISTRY)
__all__ = [
"Arm",
"Benchmark",
"ReportSection",
"RunArtifact",
"RunContext",
"get",
"list_benchmarks",
"list_suites",
"register",
"reset",
"snapshot",
"unregister",
]
# Re-export Arm from arms.base so suites can `from core.registry import Arm`.
from .arms.base import Arm # noqa: E402, F401 (deliberate re-export at bottom)

View file

@ -0,0 +1,18 @@
"""Report writer + section composition primitives. Lazy import."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover
from .writer import write_report
__all__ = ["write_report"]
def __getattr__(name: str):
if name == "write_report":
from .writer import write_report
return write_report
raise AttributeError(f"module 'surfsense_evals.core.report' has no attribute {name!r}")

View file

@ -0,0 +1,89 @@
"""Report writer — composes per-benchmark sections into one summary.
Output:
* ``reports/<suite>/<run-timestamp>/summary.md`` human-readable.
Bullet lists only (no tables) per project's coding-standards.
* ``reports/<suite>/<run-timestamp>/summary.json`` same content as
structured JSON for downstream tooling (CI dashboards, regressions).
Headline benchmarks come first in both outputs.
"""
from __future__ import annotations
import json
from collections.abc import Iterable
from pathlib import Path
from ..config import Config
from ..registry import ReportSection
def write_report(
*,
config: Config,
suite: str,
sections: Iterable[ReportSection],
run_timestamp: str,
) -> Path:
"""Write ``summary.md`` + ``summary.json``. Returns the path of the .md file."""
sections_list = list(sections)
sections_list.sort(key=lambda s: (not s.headline, s.title.lower()))
out_dir = config.suite_reports_dir(suite) / run_timestamp
out_dir.mkdir(parents=True, exist_ok=True)
md_path = out_dir / "summary.md"
json_path = out_dir / "summary.json"
md_lines: list[str] = [
f"# SurfSense evals — suite `{suite}`",
"",
f"- Run timestamp: `{run_timestamp}`",
f"- Sections: {len(sections_list)}",
"",
]
headline = [s for s in sections_list if s.headline]
secondary = [s for s in sections_list if not s.headline]
if headline:
md_lines.append("## Headline")
md_lines.append("")
for section in headline:
md_lines.append(f"### {section.title}")
md_lines.append("")
md_lines.append(section.body_md.rstrip())
md_lines.append("")
if secondary:
md_lines.append("## Secondary measurements")
md_lines.append("")
for section in secondary:
md_lines.append(f"### {section.title}")
md_lines.append("")
md_lines.append(section.body_md.rstrip())
md_lines.append("")
md_path.write_text("\n".join(md_lines).rstrip() + "\n", encoding="utf-8")
json_payload = {
"suite": suite,
"run_timestamp": run_timestamp,
"sections": [
{
"title": s.title,
"headline": s.headline,
"body_md": s.body_md,
"body_json": s.body_json,
}
for s in sections_list
],
}
json_path.write_text(
json.dumps(json_payload, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
return md_path
__all__ = ["ReportSection", "write_report"]

View file

@ -0,0 +1,58 @@
"""Shared scenario formatting helpers for head-to-head benchmark reports.
The scenario chosen at ``setup`` time (``head-to-head``, ``symmetric-cheap``,
``cost-arbitrage``) materially changes how a head-to-head report should be
read. This module produces the one-bullet summary every head-to-head
runner stamps near the top of its ``report_section`` body so reviewers
immediately see the framing no need to dig into ``run_artifact.json``.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
def format_scenario_md(extra: Mapping[str, Any] | None) -> str:
"""Render a scenario-aware bullet for a benchmark report.
Reads ``extra["scenario"]`` plus the runtime LLM slugs the runner
recorded. Falls back to a sensible "head-to-head" line if the artifact
pre-dates scenarios so old runs still render cleanly.
"""
extra = dict(extra or {})
scenario = str(extra.get("scenario") or "head-to-head")
surf_slug = str(extra.get("provider_model") or "?")
native_slug = str(extra.get("native_arm_model") or surf_slug)
vision_slug = extra.get("vision_provider_model")
if scenario == "cost-arbitrage":
body = (
f"- Scenario: **cost-arbitrage** — native arm answers with "
f"`{native_slug}` (vision); SurfSense answers with `{surf_slug}` "
f"over chunks vision-extracted at ingest"
f"{f' by `{vision_slug}`' if vision_slug else ''}. "
"Measures how close SurfSense gets to native at a fraction of "
"the per-query cost."
)
elif scenario == "symmetric-cheap":
body = (
f"- Scenario: **symmetric-cheap** — both arms answer with "
f"`{surf_slug}`; SurfSense pre-extracted images at ingest"
f"{f' via `{vision_slug}`' if vision_slug else ''}. "
"Native arm structurally loses on image-bearing questions "
"(text-only model can't see images) — that's the point."
)
else:
body = (
f"- Scenario: head-to-head — both arms answer with `{surf_slug}` "
"via OpenRouter."
)
if vision_slug:
body += f" SurfSense ingest VLM: `{vision_slug}`."
return body
__all__ = ["format_scenario_md"]

View file

@ -0,0 +1,127 @@
"""Vision LLM resolution + auto-pick logic for the harness's ``setup`` command.
Two responsibilities:
1. Resolve an explicit ``--vision-llm <slug>`` to a global OpenRouter
vision LLM config id that ``set_llm_preferences(vision_llm_config_id=...)``
can accept.
2. Auto-pick the strongest registered vision config when the operator
doesn't pass ``--vision-llm`` but the scenario / benchmark needs one.
The priority list mirrors the recommended slugs in the README so the
auto-pick is deterministic and reviewable.
"""
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from .clients.search_space import VisionLlmConfigEntry
# Order matters — first match wins when auto-picking. Keep these in sync
# with the "Recommended vision slugs" table in the README so the
# auto-pick story is the same one users read about.
RECOMMENDED_VISION_PRIORITY: tuple[str, ...] = (
"anthropic/claude-sonnet-4.5",
"anthropic/claude-opus-4.7",
"openai/gpt-5",
"google/gemini-2.5-pro",
)
class VisionConfigError(RuntimeError):
"""Raised when no vision config can be resolved (explicit or auto)."""
@dataclass(frozen=True)
class ResolvedVisionConfig:
"""Result of ``resolve_vision_llm`` — what to attach + a label for logs."""
config_id: int
provider_model: str
selected_via: str # "explicit" | "auto-priority" | "auto-fallback"
def _openrouter_only(entries: Iterable[VisionLlmConfigEntry]) -> list[VisionLlmConfigEntry]:
return [e for e in entries if e.provider == "OPENROUTER" and not e.is_auto_mode]
def resolve_vision_llm(
candidates: list[VisionLlmConfigEntry],
*,
explicit_slug: str | None,
) -> ResolvedVisionConfig:
"""Resolve a vision LLM config id from a slug or by auto-picking.
* If ``explicit_slug`` is given: must match exactly one OpenRouter
vision config's ``model_name``. Raises ``VisionConfigError`` with a
friendly listing if zero / many match.
* Otherwise: walk ``RECOMMENDED_VISION_PRIORITY`` in order and return
the first registered one. If none of the recommended slugs are
registered, fall back to the first OpenRouter vision config in the
list (deterministic by listing order). Raises ``VisionConfigError``
if zero are registered at all.
"""
or_vision = _openrouter_only(candidates)
if explicit_slug is not None:
matches = [e for e in or_vision if e.model_name == explicit_slug]
if not matches:
sample = ", ".join(e.model_name for e in or_vision[:8]) or "<none>"
raise VisionConfigError(
f"No OpenRouter vision config found for slug '{explicit_slug}'. "
"Make sure `openrouter_integration.vision_enabled: true` in "
"global_llm_config.yaml and that the Celery worker has finished "
"its first refresh. "
f"Available OpenRouter vision slugs (sample): {sample}."
)
if len(matches) > 1:
listing = "\n".join(f" id={e.id} name={e.name!r}" for e in matches)
raise VisionConfigError(
f"Multiple OpenRouter vision configs match '{explicit_slug}':\n{listing}"
)
only = matches[0]
return ResolvedVisionConfig(
config_id=only.id,
provider_model=only.model_name,
selected_via="explicit",
)
if not or_vision:
raise VisionConfigError(
"No OpenRouter vision LLM configs are registered with this "
"SurfSense backend. Either pass `--no-vision-llm` to the ingest "
"step (text-only ingestion), or enable "
"`openrouter_integration.vision_enabled: true` in "
"global_llm_config.yaml so the Celery worker syncs vision-capable "
"OpenRouter models on next refresh."
)
by_slug = {e.model_name: e for e in or_vision}
for preferred in RECOMMENDED_VISION_PRIORITY:
match = by_slug.get(preferred)
if match is not None:
return ResolvedVisionConfig(
config_id=match.id,
provider_model=match.model_name,
selected_via="auto-priority",
)
# Fallback: first registered OpenRouter vision config. Deterministic
# because the backend returns them in a stable order.
fallback = or_vision[0]
return ResolvedVisionConfig(
config_id=fallback.id,
provider_model=fallback.model_name,
selected_via="auto-fallback",
)
__all__ = [
"RECOMMENDED_VISION_PRIORITY",
"ResolvedVisionConfig",
"VisionConfigError",
"resolve_vision_llm",
]

View file

@ -0,0 +1,66 @@
"""Suite registry auto-discovery.
Importing ``surfsense_evals.suites`` walks every subpackage one level deep
(domain like ``medical``) AND its benchmark subpackages
(``medical/medxpertqa``, ``medical/mirage``, ``medical/cure``). Each
benchmark's ``__init__.py`` is expected to call
``core.registry.register(<Benchmark>)`` at module bottom; merely importing
the module is enough to populate the registry.
Adding a new domain is therefore: drop a folder under ``suites/`` with the
right structure. No edits anywhere else.
Subpackages whose name starts with ``_`` are skipped that's reserved for
test fixtures (e.g. ``suites/_demo/``) so they don't accidentally show up
in ``benchmarks list``.
"""
from __future__ import annotations
import importlib
import logging
import pkgutil
from typing import Iterable
logger = logging.getLogger(__name__)
def _iter_subpackages(package) -> Iterable[str]:
"""Yield fully-qualified subpackage names one level deep, skipping ``_*``."""
for module_info in pkgutil.iter_modules(package.__path__, prefix=f"{package.__name__}."):
if not module_info.ispkg:
continue
leaf = module_info.name.rsplit(".", 1)[-1]
if leaf.startswith("_"):
continue
yield module_info.name
def discover_suites() -> list[str]:
"""Import every domain + benchmark subpackage so registrations fire.
Returns the list of fully-qualified benchmark module names that were
successfully imported. Failures are logged (not raised) so a single
broken benchmark doesn't take down the whole CLI — the operator still
sees the working benchmarks via ``benchmarks list``.
"""
import surfsense_evals.suites as _suites # self-import for __path__
imported: list[str] = []
for domain_name in _iter_subpackages(_suites):
try:
domain_pkg = importlib.import_module(domain_name)
except Exception as exc: # noqa: BLE001
logger.warning("Failed to import suite domain %s: %s", domain_name, exc)
continue
for benchmark_name in _iter_subpackages(domain_pkg):
try:
importlib.import_module(benchmark_name)
imported.append(benchmark_name)
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to import benchmark %s: %s", benchmark_name, exc
)
return imported

View file

@ -0,0 +1,8 @@
"""Test fixture suite — skipped by the auto-discovery walker (name starts with ``_``).
Imported explicitly by ``tests/core/test_registry.py`` to prove the
register-on-import contract works without polluting the production
benchmark list.
"""
from __future__ import annotations

View file

@ -0,0 +1,46 @@
"""Demo benchmark — registers on import, used only by the registry tests."""
from __future__ import annotations
import argparse
from typing import Any
from ....core.registry import (
Benchmark,
ReportSection,
RunArtifact,
RunContext,
register,
)
class HelloBenchmark:
suite: str = "_demo"
name: str = "hello"
headline: bool = False
description: str = "Demo benchmark used by the registry test."
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument("--echo", default="hi")
async def ingest(self, ctx: RunContext, **_opts: Any) -> None: # pragma: no cover
return None
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact: # pragma: no cover
return RunArtifact(
suite=self.suite,
benchmark=self.name,
run_timestamp="0",
raw_path=ctx.benchmark_data_dir() / "raw.jsonl",
metrics={"echo": opts.get("echo")},
)
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
return ReportSection(
title="Hello demo",
headline=False,
body_md="- runs: " + str(len(artifacts)),
)
register(HelloBenchmark())

View file

@ -0,0 +1,7 @@
"""Medical RAG benchmarks (MedXpertQA-MM headline + MIRAGE/CUREv1 secondary).
Subpackages register themselves with ``core.registry`` on import. The
``suites/__init__.py`` discovery walker imports them automatically.
"""
from __future__ import annotations

View file

@ -0,0 +1,18 @@
"""CUREv1 — secondary single-arm SurfSense retrieval measurement.
Source: https://huggingface.co/datasets/clinia/CUREv1
Paper: https://arxiv.org/html/2412.06954v4
Pure retrieval benchmark 10 medical disciplines, English/French/Spanish
queries, expert-curated qrels (graded 0/1/2). The harness ingests the
corpus, runs each query via SurfSense's ``/api/v1/new_chat``, parses
chunk citations, maps them back to CUREv1 ``corpus-id``, and scores
Recall@k / MRR / nDCG@10 against qrels.
"""
from __future__ import annotations
from .runner import CureBenchmark
from ....core import registry as _registry
_registry.register(CureBenchmark())

View file

@ -0,0 +1,239 @@
"""CUREv1 ingestion.
For each (lang, discipline) requested, downloads the corpus split via
``datasets.load_dataset(path="clinia/CUREv1", name="corpus", split=<discipline>)``,
batches passages into ~5 MB markdown bundles, uploads them to
SurfSense, polls until ``ready``, and persists the
``corpus_id -> document_id`` map under
``data/medical/maps/cure_corpus_map_<discipline>.jsonl``. A union map
``cure_corpus_map.jsonl`` is also written so the runner can resolve
citations across disciplines without juggling per-file paths.
"""
from __future__ import annotations
import io
import json
import logging
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from ....core.config import set_suite_state
from ....core.ingest_settings import IngestSettings, settings_header_line
from ....core.registry import RunContext
logger = logging.getLogger(__name__)
_BATCH_SIZE_BYTES = 5 * 1024 * 1024
# 10 disciplines covered by the dataset card. We exhaustively list
# them so a smoke test can default to one.
DISCIPLINES = (
"anesthesiology",
"cardiology",
"dermatology",
"endocrinology",
"gastroenterology",
"hematology",
"nephrology",
"neurology",
"obstetrics_gynecology",
"psychiatry",
)
@dataclass
class CorpusPassage:
corpus_id: str
title: str
text: str
def to_markdown(self) -> str:
title = (self.title or "").strip() or "Untitled"
body = (self.text or "").strip()
return f"# {title}\n\n_id: `{self.corpus_id}`_\n\n{body}\n"
@dataclass
class PassageBatch:
path: Path
corpus_ids: list[str]
def _stream_corpus(discipline: str) -> Iterable[CorpusPassage]:
"""Stream corpus rows for one discipline via the ``datasets`` library."""
from datasets import load_dataset # noqa: PLC0415
logger.info("Loading CUREv1 corpus for discipline=%s", discipline)
ds = load_dataset(path="clinia/CUREv1", name="corpus", split=discipline)
for row in ds:
cid = str(row.get("_id") or "")
if not cid:
continue
yield CorpusPassage(
corpus_id=cid,
title=str(row.get("title") or ""),
text=str(row.get("text") or ""),
)
def _write_batches(
passages: Iterable[CorpusPassage],
*,
out_dir: Path,
discipline: str,
batch_bytes: int = _BATCH_SIZE_BYTES,
) -> list[PassageBatch]:
out_dir.mkdir(parents=True, exist_ok=True)
batches: list[PassageBatch] = []
current_buffer = io.StringIO()
current_ids: list[str] = []
current_bytes = 0
batch_idx = 0
def _flush() -> None:
nonlocal current_buffer, current_ids, current_bytes, batch_idx
if not current_ids:
return
path = out_dir / f"cure_{discipline}_{batch_idx:04d}.md"
path.write_text(current_buffer.getvalue(), encoding="utf-8")
batches.append(PassageBatch(path=path, corpus_ids=current_ids))
batch_idx += 1
current_buffer = io.StringIO()
current_ids = []
current_bytes = 0
for passage in passages:
chunk = passage.to_markdown() + "\n---\n\n"
chunk_bytes = len(chunk.encode("utf-8"))
if current_bytes + chunk_bytes > batch_bytes and current_ids:
_flush()
current_buffer.write(chunk)
current_ids.append(passage.corpus_id)
current_bytes += chunk_bytes
_flush()
return batches
async def run_ingest(
ctx: RunContext,
*,
disciplines: list[str] | None = None,
max_per_discipline: int | None = None,
settings: IngestSettings | None = None,
) -> None:
disciplines = disciplines or list(DISCIPLINES)
settings = settings or IngestSettings(use_vision_llm=False, processing_mode="basic")
bench_dir = ctx.benchmark_data_dir()
batches_root = bench_dir / "batches"
batches_root.mkdir(parents=True, exist_ok=True)
docs_client = ctx.documents_client()
union_map_path = ctx.maps_dir() / "cure_corpus_map.jsonl"
union_map_fh = union_map_path.open("w", encoding="utf-8")
# Header row records the ingest-time settings so the runner can
# surface them in the report (see core/ingest_settings.py).
union_map_fh.write(settings_header_line(settings) + "\n")
try:
for discipline in disciplines:
try:
passages_iter = _stream_corpus(discipline)
if max_per_discipline is not None:
passages_iter = _take(passages_iter, max_per_discipline)
batches = _write_batches(
passages_iter,
out_dir=batches_root / discipline,
discipline=discipline,
)
except Exception as exc: # noqa: BLE001
logger.warning("Skipping discipline %s: %s", discipline, exc)
continue
if not batches:
logger.warning("Discipline %s produced 0 batches; skipping upload", discipline)
continue
logger.info(
"Uploading %d batches for discipline %s", len(batches), discipline
)
upload_result = await docs_client.upload(
files=[b.path for b in batches],
search_space_id=ctx.search_space_id,
should_summarize=settings.should_summarize,
use_vision_llm=settings.use_vision_llm,
processing_mode=settings.processing_mode,
)
new_doc_ids = list(upload_result.document_ids)
if new_doc_ids:
await docs_client.wait_until_ready(
search_space_id=ctx.search_space_id,
document_ids=new_doc_ids,
timeout_s=3600.0,
max_poll_s=15.0,
)
statuses = await docs_client.get_status(
search_space_id=ctx.search_space_id,
document_ids=new_doc_ids + upload_result.duplicate_document_ids,
)
title_to_doc = {s.title: s.document_id for s in statuses}
per_discipline_path = (
ctx.maps_dir() / f"cure_corpus_map_{discipline}.jsonl"
)
with per_discipline_path.open("w", encoding="utf-8") as fh:
fh.write(settings_header_line(settings) + "\n")
for batch in batches:
doc_id = title_to_doc.get(batch.path.name)
if doc_id is None:
logger.warning("No document_id for batch %s", batch.path.name)
continue
for cid in batch.corpus_ids:
record = {
"corpus_id": cid,
"document_id": doc_id,
"discipline": discipline,
}
fh.write(json.dumps(record) + "\n")
union_map_fh.write(json.dumps(record) + "\n")
chunks_map_path = ctx.maps_dir() / f"cure_chunk_map_{discipline}.jsonl"
with chunks_map_path.open("w", encoding="utf-8") as fh:
for doc_id in {title_to_doc.get(b.path.name) for b in batches} - {None}:
try:
chunks = await docs_client.list_chunks(int(doc_id))
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to list chunks for doc_id=%s: %s", doc_id, exc
)
continue
for chunk in chunks:
fh.write(
json.dumps(
{
"chunk_id": chunk.id,
"document_id": doc_id,
"discipline": discipline,
}
)
+ "\n"
)
finally:
union_map_fh.close()
new_state = ctx.suite_state
new_state.ingestion_maps["cure"] = str(union_map_path)
set_suite_state(ctx.config, ctx.suite, new_state)
logger.info("CUREv1 ingestion complete; union map at %s", union_map_path)
def _take(it: Iterable, n: int) -> Iterable:
yielded = 0
for x in it:
if yielded >= n:
return
yield x
yielded += 1
__all__ = ["DISCIPLINES", "CorpusPassage", "PassageBatch", "run_ingest"]

View file

@ -0,0 +1,397 @@
"""CUREv1 runner — single-arm SurfSense retrieval scoring.
For each query we ask SurfSense via ``/api/v1/new_chat`` (no
``mentioned_document_ids``) and parse chunk citations from the
streamed answer. Cited ``chunk_id`` ``document_id`` (chunk map)
``corpus_id`` (corpus map). The resulting ranked list is scored
against the dataset's qrels.
The prompt nudges the model to surface its supporting passages via
SurfSense's standard ``[citation:CHUNK_ID]`` format (already required
by the agent system prompt), so we recover retrieval ordering from
the answer text without needing a separate retrieval API.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from ....core.arms import ArmRequest, ArmResult, SurfSenseArm
from ....core.config import utc_iso_timestamp
from ....core.ingest_settings import (
IngestSettings,
add_ingest_settings_args,
format_ingest_settings_md,
is_settings_header,
read_settings_header,
)
from ....core.metrics.retrieval import score_run
from ....core.registry import (
Benchmark,
ReportSection,
RunArtifact,
RunContext,
)
logger = logging.getLogger(__name__)
_PROMPT = """\
You are a medical literature retrieval assistant for the question
below. Identify the top passages from the knowledge base that best
answer it and cite each one in the standard format
[citation:CHUNK_ID]. List as many citations as are useful, ordered
from most to least relevant. Provide a one-sentence justification
for each citation.
Query: {query}
"""
_DESCRIPTION = "CUREv1 retrieval (single-arm SurfSense): Recall@k / MRR / nDCG@10."
# CUREv1 corpus is text-only markdown bundles; vision LLM at ingest
# is wasted by default but the operator can flip it via CLI for an
# A/B comparison.
_DEFAULT_INGEST_SETTINGS = IngestSettings(
use_vision_llm=False,
processing_mode="basic",
should_summarize=False,
)
@dataclass
class CureQuery:
qid: str
text: str
discipline: str
def _load_chunk_map(maps_dir: Path) -> dict[int, int]:
"""Union all ``cure_chunk_map_<discipline>.jsonl`` into one dict."""
out: dict[int, int] = {}
for path in sorted(maps_dir.glob("cure_chunk_map_*.jsonl")):
with path.open("r", encoding="utf-8") as fh:
for line in fh:
if not line.strip():
continue
row = json.loads(line)
if is_settings_header(row):
continue
try:
out[int(row["chunk_id"])] = int(row["document_id"])
except (KeyError, TypeError, ValueError):
continue
return out
def _load_doc_to_corpus(maps_dir: Path) -> dict[int, list[str]]:
"""Map ``document_id -> [corpus_id, ...]`` from the union map.
Multiple corpus passages may live in one batched markdown
document, so each doc_id maps to a list. Citation ordering of the
first occurrence is preserved.
"""
out: dict[int, list[str]] = defaultdict(list)
union_path = maps_dir / "cure_corpus_map.jsonl"
if not union_path.exists():
return out
with union_path.open("r", encoding="utf-8") as fh:
for line in fh:
if not line.strip():
continue
row = json.loads(line)
if is_settings_header(row):
continue
try:
out[int(row["document_id"])].append(str(row["corpus_id"]))
except (KeyError, TypeError, ValueError):
continue
return out
def _load_queries(*, lang: str, disciplines: list[str], sample_n: int | None) -> list[CureQuery]:
from datasets import load_dataset # noqa: PLC0415
out: list[CureQuery] = []
for discipline in disciplines:
try:
ds = load_dataset(path="clinia/CUREv1", name=f"queries-{lang}", split=discipline)
except Exception as exc: # noqa: BLE001
logger.warning("Skipping queries for %s/%s: %s", lang, discipline, exc)
continue
for row in ds:
qid = str(row.get("_id") or "")
text = str(row.get("text") or "")
if not qid or not text:
continue
out.append(CureQuery(qid=qid, text=text, discipline=discipline))
out.sort(key=lambda q: (q.discipline, q.qid))
if sample_n is not None and sample_n > 0:
# Stratified-by-discipline slice.
per_d = max(1, sample_n // max(1, len(disciplines)))
sliced: list[CureQuery] = []
counter: dict[str, int] = defaultdict(int)
for q in out:
if counter[q.discipline] >= per_d:
continue
sliced.append(q)
counter[q.discipline] += 1
if len(sliced) >= sample_n:
break
out = sliced
return out
def _load_qrels(*, disciplines: list[str]) -> dict[str, dict[str, float]]:
from datasets import load_dataset # noqa: PLC0415
out: dict[str, dict[str, float]] = defaultdict(dict)
for discipline in disciplines:
try:
ds = load_dataset(path="clinia/CUREv1", name="qrels", split=discipline)
except Exception as exc: # noqa: BLE001
logger.warning("Skipping qrels for %s: %s", discipline, exc)
continue
for row in ds:
qid = str(row.get("query-id") or row.get("query_id") or "")
cid = str(row.get("corpus-id") or row.get("corpus_id") or "")
score = row.get("score")
if not qid or not cid or score is None:
continue
try:
out[qid][cid] = float(score)
except (TypeError, ValueError):
continue
return out
async def _gather_with_limit(coros, *, concurrency: int) -> list[Any]:
sem = asyncio.Semaphore(max(1, concurrency))
async def _wrap(c):
async with sem:
return await c
return await asyncio.gather(*(_wrap(c) for c in coros))
class CureBenchmark:
suite: str = "medical"
name: str = "cure"
headline: bool = False
description: str = _DESCRIPTION
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument("--lang", default="en", choices=("en", "es", "fr"))
parser.add_argument("--discipline", default=None,
help="Restrict to one discipline (default: all ingested).")
parser.add_argument("--n", dest="sample_n", type=int, default=None)
parser.add_argument("--concurrency", type=int, default=4)
parser.add_argument(
"--max-passages-per-discipline", type=int, default=None,
help="(ingest only) cap corpus rows per discipline for smoke testing.",
)
# Per-upload knobs forwarded to /documents/fileupload at ingest;
# ignored at run-time (runner reads resolved settings from the
# union-map header).
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
from .ingest import DISCIPLINES, run_ingest
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
await run_ingest(
ctx,
disciplines=list(DISCIPLINES),
max_per_discipline=opts.get("max_passages_per_discipline"),
settings=settings,
)
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
lang = opts.get("lang") or "en"
discipline_filter = opts.get("discipline")
sample_n = opts.get("sample_n")
concurrency = int(opts.get("concurrency") or 4)
maps_dir = ctx.maps_dir()
chunk_to_doc = _load_chunk_map(maps_dir)
doc_to_corpus = _load_doc_to_corpus(maps_dir)
ingest_settings = read_settings_header(maps_dir / "cure_corpus_map.jsonl")
if not chunk_to_doc or not doc_to_corpus:
raise RuntimeError(
"CUREv1 not ingested for this suite. Run "
"`python -m surfsense_evals ingest medical cure` first."
)
# Disciplines to query are determined by the per-discipline maps
# actually present (either user-filtered or whatever was ingested).
ingested_disciplines = sorted({
row_disc
for path in maps_dir.glob("cure_corpus_map_*.jsonl")
for row_disc in [path.stem[len("cure_corpus_map_"):]]
})
if discipline_filter:
disciplines = [discipline_filter]
else:
disciplines = ingested_disciplines or ["dermatology"]
queries = _load_queries(lang=lang, disciplines=disciplines, sample_n=sample_n)
if not queries:
raise RuntimeError(
f"No CUREv1 queries matched lang={lang!r} disciplines={disciplines!r}."
)
qrels = _load_qrels(disciplines=disciplines)
logger.info(
"CUREv1: %d queries / %d qrels across disciplines %s",
len(queries),
len(qrels),
disciplines,
)
arm = SurfSenseArm(
client=ctx.new_chat_client(),
search_space_id=ctx.search_space_id,
ephemeral_threads=True,
)
async def _ask(q: CureQuery) -> ArmResult:
return await arm.answer(
ArmRequest(
question_id=f"{q.discipline}::{q.qid}",
prompt=_PROMPT.format(query=q.text.strip()),
)
)
results: list[ArmResult] = await _gather_with_limit(
(_ask(q) for q in queries), concurrency=concurrency
)
per_query_retrieved: dict[str, list[str]] = {}
for q, res in zip(queries, results):
chunk_ids: list[int] = []
seen: set[int] = set()
for citation in res.citations:
if citation.get("kind") != "chunk":
continue
cid = int(citation.get("chunk_id"))
if cid in seen:
continue
chunk_ids.append(cid)
seen.add(cid)
corpus_ids: list[str] = []
seen_corpus: set[str] = set()
for cid in chunk_ids:
doc_id = chunk_to_doc.get(cid)
if doc_id is None:
continue
for corpus_id in doc_to_corpus.get(doc_id, []):
if corpus_id in seen_corpus:
continue
corpus_ids.append(corpus_id)
seen_corpus.add(corpus_id)
per_query_retrieved[q.qid] = corpus_ids
scores = score_run(
per_query_retrieved=per_query_retrieved,
per_query_qrels=qrels,
ks=(1, 5, 10, 32),
ndcg_k=10,
)
run_timestamp = utc_iso_timestamp()
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
raw_path = run_dir / "raw.jsonl"
with raw_path.open("w", encoding="utf-8") as fh:
for q, res in zip(queries, results):
fh.write(
json.dumps(
{
"discipline": q.discipline,
"qid": q.qid,
"lang": lang,
"retrieved_corpus_ids": per_query_retrieved.get(q.qid, []),
**res.to_jsonl(),
}
)
+ "\n"
)
metrics = scores.to_dict()
metrics["lang"] = lang
metrics["disciplines"] = disciplines
artifact = RunArtifact(
suite=self.suite,
benchmark=self.name,
run_timestamp=run_timestamp,
raw_path=raw_path,
metrics=metrics,
extra={
"n_queries": len(queries),
"lang": lang,
"disciplines": disciplines,
"concurrency": concurrency,
"provider_model": ctx.provider_model,
"ingest_settings": ingest_settings,
},
)
manifest_path = run_dir / "run_artifact.json"
manifest_path.write_text(
json.dumps(
{
"suite": self.suite,
"benchmark": self.name,
"raw_path": "raw.jsonl",
"metrics": metrics,
"extra": artifact.extra,
},
indent=2,
sort_keys=True,
)
+ "\n",
encoding="utf-8",
)
return artifact
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
if not artifacts:
return ReportSection(
title="CUREv1 — single-arm SurfSense retrieval",
headline=False,
body_md="(no run artifacts found)",
body_json={},
)
latest = max(artifacts, key=lambda a: a.run_timestamp)
m = latest.metrics
recall = m.get("recall_at_k", {})
lines: list[str] = [
format_ingest_settings_md(latest.extra.get("ingest_settings")),
f"- Language: {m.get('lang', '?')}",
f"- Disciplines: {', '.join(m.get('disciplines', []) or ['?'])}",
f"- n_queries (after qrels intersection): {m.get('n_queries', 0)}",
]
for k in (1, 5, 10, 32):
v = recall.get(str(k), recall.get(k))
if v is not None:
lines.append(f"- Recall@{k}: {float(v):.3f}")
lines.append(f"- MRR: {float(m.get('mrr', 0.0)):.3f}")
lines.append(f"- nDCG@10: {float(m.get('ndcg_at_10', 0.0)):.3f}")
return ReportSection(
title="CUREv1 — single-arm SurfSense retrieval",
headline=False,
body_md="\n".join(lines),
body_json=m,
)
__all__ = ["CureBenchmark", "CureQuery"]

View file

@ -0,0 +1,25 @@
"""MedXpertQA-MM — multimodal medical exam head-to-head (medical suite headline).
Source: https://huggingface.co/datasets/TsinghuaC3I/MedXpertQA
Paper: https://arxiv.org/abs/2501.18362 (ICML 2025)
* MM subset: ~2,000 expert-level exam questions with diverse medical
images (radiology, dermatology, pathology, ECGs, gross specimens,
fundus photos) and structured patient information embedded in the
question stem.
* 5 answer choices per MM question (AE).
* USMLE / COMLEX / 17 specialty board sources; rigorously filtered
and reviewed by physicians.
Real diagnostic images carry signal that text-only patient charts
cannot (e.g. CT scans, dermoscopy), so this benchmark exercises the
full vision RAG pipeline end-to-end against a vision-capable model
fed the same PDF natively.
"""
from __future__ import annotations
from ....core import registry as _registry
from .runner import MedXpertQAMMBenchmark
_registry.register(MedXpertQAMMBenchmark())

View file

@ -0,0 +1,394 @@
"""MedXpertQA-MM ingestion.
Steps:
1. Pull ``MM/test.jsonl`` (and optionally ``MM/dev.jsonl``) plus
``images.zip`` from
``hf://datasets/TsinghuaC3I/MedXpertQA``. Cache under
``<data_dir>/medical/medxpertqa/``.
2. Extract ``images.zip`` once into ``<data_dir>/medical/medxpertqa/images/``.
3. Render one PDF per MM question (text question + structured patient
info embedded in the question stem + each image flowable + answer
options). Output: ``<data_dir>/medical/medxpertqa/pdfs/<id>.pdf``.
4. Upload each PDF to SurfSense with ``use_vision_llm=True``; persist
``id -> document_id`` in
``<data_dir>/medical/maps/medxpertqa_doc_map.jsonl``.
Both arms then receive byte-identical PDFs. The native arm sends the
PDF directly to OpenRouter; SurfSense ingests via its own vision
pipeline and the runner queries with ``mentioned_document_ids=[...]``
to scope retrieval to the question's PDF.
"""
from __future__ import annotations
import json
import logging
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from ....core.config import set_suite_state
from ....core.ingest_settings import IngestSettings, settings_header_line
from ....core.pdf import PdfImage, render_pdf_with_images
from ....core.registry import RunContext
from .prompt import format_options
logger = logging.getLogger(__name__)
HF_REPO_ID = "TsinghuaC3I/MedXpertQA"
HF_REPO_TYPE = "dataset"
def _hf_hub_download(*args, **kwargs):
from huggingface_hub import hf_hub_download
return hf_hub_download(*args, **kwargs)
# ---------------------------------------------------------------------------
# Question shape
# ---------------------------------------------------------------------------
@dataclass
class MedXpertQuestion:
qid: str # e.g. "MM-26"
question: str # full question text (case + ask)
options: dict[str, str] # A-E
label: str # "A".."E"
image_files: list[str] # filenames inside images.zip
medical_task: str
body_system: str
question_type: str
split: str # "test" or "dev"
def _load_jsonl(path: Path, *, split: str) -> list[MedXpertQuestion]:
out: list[MedXpertQuestion] = []
with path.open("r", encoding="utf-8") as fh:
for raw_line in fh:
line = raw_line.strip()
if not line:
continue
row = json.loads(line)
qid = str(row.get("id") or "").strip()
question = str(row.get("question") or "").strip()
options = row.get("options") or {}
label = str(row.get("label") or "").strip().upper()
if not qid or not question or not isinstance(options, dict) or not label:
continue
opts = {str(k).strip().upper(): str(v).strip() for k, v in options.items()}
images = row.get("images") or []
if not isinstance(images, list):
images = []
out.append(MedXpertQuestion(
qid=qid,
question=question,
options=opts,
label=label,
image_files=[str(x).strip() for x in images if str(x).strip()],
medical_task=str(row.get("medical_task") or "").strip(),
body_system=str(row.get("body_system") or "").strip(),
question_type=str(row.get("question_type") or "").strip(),
split=split,
))
return out
# ---------------------------------------------------------------------------
# Image archive helpers
# ---------------------------------------------------------------------------
def _ensure_images_extracted(images_zip: Path, images_dir: Path) -> None:
"""Extract images.zip once, tolerantly handle re-runs."""
marker = images_dir / ".extracted_ok"
if marker.exists():
return
images_dir.mkdir(parents=True, exist_ok=True)
logger.info("Extracting MedXpertQA images.zip -> %s", images_dir)
with zipfile.ZipFile(images_zip) as zf:
zf.extractall(images_dir)
marker.write_text("ok\n", encoding="utf-8")
def _resolve_image_path(image_filename: str, images_dir: Path) -> Path | None:
"""Find a question's image in the (possibly nested) extract directory.
The zip layout sometimes nests under ``images/`` and sometimes
flat handle both.
"""
direct = images_dir / image_filename
if direct.exists():
return direct
nested = images_dir / "images" / image_filename
if nested.exists():
return nested
# Last-ditch: glob recursively (slow but correct for unusual layouts).
matches = list(images_dir.rglob(image_filename))
return matches[0] if matches else None
# ---------------------------------------------------------------------------
# PDF rendering
# ---------------------------------------------------------------------------
def _render_question_pdf(
q: MedXpertQuestion,
*,
images_dir: Path,
pdfs_dir: Path,
) -> tuple[Path, list[str]]:
"""Render one MedXpertQA question into a PDF.
Layout:
Title: MedXpertQA <qid> (medical_task / body_system)
Section 1 (case): full question text
Section 1 images: each image flowable + caption
Section 2 (options): A) ... B) ... C) ... D) ... E) ...
Returns (pdf_path, missing_images) so the caller can warn on
questions where some image files weren't found.
"""
out_path = pdfs_dir / f"{q.qid}.pdf"
images: list[PdfImage] = []
missing: list[str] = []
for fname in q.image_files:
resolved = _resolve_image_path(fname, images_dir)
if resolved is None:
missing.append(fname)
continue
images.append(PdfImage(path=resolved, caption=f"Image: {fname}", max_width_in=5.5))
title_meta_parts = []
if q.medical_task:
title_meta_parts.append(q.medical_task)
if q.body_system:
title_meta_parts.append(q.body_system)
if q.question_type:
title_meta_parts.append(q.question_type)
title_suffix = f" ({' / '.join(title_meta_parts)})" if title_meta_parts else ""
sections = [
("Clinical case", q.question, images),
("Answer choices", format_options(q.options), None),
]
render_pdf_with_images(
title=f"MedXpertQA-MM {q.qid}{title_suffix}",
sections=sections,
output_path=out_path,
)
return out_path, missing
# ---------------------------------------------------------------------------
# Upload helper
# ---------------------------------------------------------------------------
async def _upload_pdfs(
ctx: RunContext,
pdf_paths: Iterable[Path],
*,
batch_size: int,
settings: IngestSettings,
) -> dict[str, int]:
docs_client = ctx.documents_client()
name_to_id: dict[str, int] = {}
pdf_list = list(pdf_paths)
for batch_start in range(0, len(pdf_list), batch_size):
batch = pdf_list[batch_start:batch_start + batch_size]
result = await docs_client.upload(
files=batch,
search_space_id=ctx.search_space_id,
should_summarize=settings.should_summarize,
use_vision_llm=settings.use_vision_llm,
processing_mode=settings.processing_mode,
)
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
if all_ids:
await docs_client.wait_until_ready(
search_space_id=ctx.search_space_id,
document_ids=result.document_ids,
timeout_s=1800.0,
)
statuses = await docs_client.get_status(
search_space_id=ctx.search_space_id,
document_ids=all_ids,
)
for s in statuses:
name_to_id[s.title] = s.document_id
logger.info(
"Uploaded MedXpertQA batch %d-%d: %d new, %d duplicate",
batch_start, batch_start + len(batch),
len(result.document_ids), len(result.duplicate_document_ids),
)
return name_to_id
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
async def run_ingest(
ctx: RunContext,
*,
split: str = "test",
max_questions: int | None = None,
upload_batch_size: int = 8,
skip_upload: bool = False,
include_dev: bool = False,
settings: IngestSettings | None = None,
) -> None:
"""Ingest MedXpertQA-MM into the medical suite.
Parameters
----------
split : 'test' (default), 'dev', or 'both'
Which subset to render + upload.
max_questions : int | None
Cap on number of questions ingested (handy for fast iteration).
upload_batch_size : int
PDFs per ``fileupload`` call.
skip_upload : bool
Render PDFs locally but don't push to SurfSense.
include_dev : bool
Convenience: equivalent to ``split='both'``.
"""
settings = settings or IngestSettings(use_vision_llm=True, processing_mode="basic")
bench_dir = ctx.benchmark_data_dir()
images_zip_local = bench_dir / "images.zip"
images_dir = bench_dir / "images"
pdfs_dir = bench_dir / "pdfs"
pdfs_dir.mkdir(parents=True, exist_ok=True)
hf_cache = bench_dir / ".hf_cache"
hf_cache.mkdir(parents=True, exist_ok=True)
# Step 1: download jsonl(s)
splits_to_load: list[str] = []
if split == "both" or include_dev:
splits_to_load = ["dev", "test"]
elif split in {"dev", "test"}:
splits_to_load = [split]
else:
raise ValueError(f"Unknown split {split!r}; use 'test' / 'dev' / 'both'")
questions: list[MedXpertQuestion] = []
for sp in splits_to_load:
rel = f"MM/{sp}.jsonl"
local = _hf_hub_download(
repo_id=HF_REPO_ID,
filename=rel,
repo_type=HF_REPO_TYPE,
cache_dir=str(hf_cache),
)
loaded = _load_jsonl(Path(local), split=sp)
questions.extend(loaded)
logger.info("Loaded %d MedXpertQA-MM questions from %s split", len(loaded), sp)
if max_questions is not None and max_questions > 0:
questions = questions[:max_questions]
if not questions:
raise RuntimeError("No MedXpertQA-MM questions loaded; check the split argument.")
# Step 2: download images.zip + extract once
if not images_zip_local.exists():
local_zip = _hf_hub_download(
repo_id=HF_REPO_ID,
filename="images.zip",
repo_type=HF_REPO_TYPE,
cache_dir=str(hf_cache),
)
# Materialise into bench_dir so the path is stable.
try:
from os import link as _link
_link(local_zip, images_zip_local)
except OSError:
from shutil import copy2
copy2(local_zip, images_zip_local)
_ensure_images_extracted(images_zip_local, images_dir)
# Step 3: render PDFs
pdf_paths: dict[str, Path] = {}
missing_image_count = 0
for i, q in enumerate(questions, start=1):
try:
pdf, missing = _render_question_pdf(q, images_dir=images_dir, pdfs_dir=pdfs_dir)
pdf_paths[q.qid] = pdf
if missing:
missing_image_count += len(missing)
logger.debug("qid=%s missing %d images: %s", q.qid, len(missing), missing)
except Exception as exc: # noqa: BLE001
logger.warning("Failed to render MedXpertQA PDF for %s: %s", q.qid, exc)
if i % 50 == 0:
logger.info(" ... rendered %d / %d PDFs", i, len(questions))
if missing_image_count:
logger.warning(
"MedXpertQA: %d image references could not be resolved on disk "
"(rendered PDFs may be missing some images).",
missing_image_count,
)
# Step 4: upload
name_to_id: dict[str, int] = {}
if skip_upload:
logger.info("MedXpertQA: --skip-upload set; skipping SurfSense ingestion")
else:
logger.info("MedXpertQA upload settings: %s", settings.render_label())
name_to_id = await _upload_pdfs(
ctx,
pdf_paths.values(),
batch_size=upload_batch_size,
settings=settings,
)
# Step 5: persist manifest + questions
questions_jsonl = bench_dir / "questions.jsonl"
with questions_jsonl.open("w", encoding="utf-8") as fh:
for q in questions:
fh.write(json.dumps({
"qid": q.qid,
"question": q.question,
"options": q.options,
"label": q.label,
"image_files": q.image_files,
"medical_task": q.medical_task,
"body_system": q.body_system,
"question_type": q.question_type,
"split": q.split,
}) + "\n")
logger.info("Wrote %d MedXpertQA questions to %s", len(questions), questions_jsonl)
map_path = ctx.maps_dir() / "medxpertqa_doc_map.jsonl"
with map_path.open("w", encoding="utf-8") as fh:
# Header line records the resolved ingest settings
# (see core/ingest_settings.py).
fh.write(settings_header_line(settings) + "\n")
for q in questions:
local = pdf_paths.get(q.qid)
if local is None:
continue
fh.write(json.dumps({
"qid": q.qid,
"document_id": name_to_id.get(local.name),
"pdf_path": str(local),
"n_images": len(q.image_files),
"split": q.split,
}) + "\n")
logger.info("Wrote MedXpertQA doc map to %s", map_path)
new_state = ctx.suite_state
new_state.ingestion_maps["medxpertqa"] = str(map_path)
set_suite_state(ctx.config, ctx.suite, new_state)
__all__ = ["MedXpertQuestion", "run_ingest"]

View file

@ -0,0 +1,54 @@
"""MedXpertQA-MM prompt.
Mirrors the upstream paper's evaluation prompt (Zuo et al., ICML 2025
§3.4): present case + 5 options A-E, ask for a single letter answer.
We also instruct the model to use the embedded images explicitly,
since the whole point of the MM subset is that the answer depends on
visual evidence (radiology / dermoscopy / pathology / ECG, etc.).
"""
from __future__ import annotations
from collections.abc import Mapping
ANSWER_LETTERS = ("A", "B", "C", "D", "E")
_PROMPT = """\
You are a board-certified physician. The following exam question
includes a clinical case and one or more medical images (radiology,
dermatology, pathology, ECG, etc.). Use BOTH the text and the images
to choose the best answer. Do not rely on memorisation of the case;
read the images carefully they often determine the correct answer.
Case + question:
{question}
Answer choices:
{options_block}
Respond on a single line in the format `Answer: X` where X is one of
A, B, C, D, or E.
"""
def format_options(options: Mapping[str, str]) -> str:
"""Render the ``A) ... E) ...`` options block."""
parts: list[str] = []
for letter in ANSWER_LETTERS:
text = options.get(letter)
if text is None or str(text).strip() == "":
continue
parts.append(f"{letter}) {str(text).strip()}")
return "\n".join(parts)
def build_prompt(question: str, options: Mapping[str, str]) -> str:
return _PROMPT.format(
question=question.strip(),
options_block=format_options(options),
)
__all__ = ["ANSWER_LETTERS", "build_prompt", "format_options"]

View file

@ -0,0 +1,681 @@
"""MedXpertQA-MM runner — Native PDF (vision) vs SurfSense (vision RAG).
Headline benchmark for the medical suite.
* Native arm reads the rendered PDF (case + images + options) via
OpenRouter ``chat/completions`` + the file-parser plugin.
* SurfSense arm queries ``POST /api/v1/new_chat`` scoped via
``mentioned_document_ids=[doc_id]`` to the same per-question PDF.
Operational notes:
* PDFs contain real images (radiology, dermoscopy, pathology, ECGs).
Operator must pin a vision-capable model via
``setup --provider-model anthropic/claude-sonnet-4.5`` (or similar);
the runner emits a warning if a known text-only slug is pinned.
* MedXpertQA tags ``medical_task`` (Diagnosis / Treatment / Basic
Medicine) and ``body_system`` (Cardiovascular / Lymphatic / )
directly on every row; we slice the report by both.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import os
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from ....core.arms import ArmRequest, ArmResult, NativePdfArm, SurfSenseArm
from ....core.config import utc_iso_timestamp
from ....core.ingest_settings import (
IngestSettings,
add_ingest_settings_args,
format_ingest_settings_md,
is_settings_header,
)
from ....core.metrics.comparison import (
bootstrap_delta_ci,
mcnemar_test,
paired_aggregate,
)
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci
from ....core.providers.openrouter_pdf import OpenRouterPdfProvider, PdfEngine
from ....core.registry import (
ReportSection,
RunArtifact,
RunContext,
)
from ....core.scenarios import format_scenario_md
from .prompt import ANSWER_LETTERS, build_prompt
logger = logging.getLogger(__name__)
_TEXT_ONLY_HINTS = ("gpt-5.4-mini", "gpt-3.5", "text-only", "instruct-")
@dataclass
class MXQuestion:
qid: str
question: str
options: dict[str, str]
label: str
medical_task: str
body_system: str
question_type: str
split: str
n_images: int
pdf_path: Path
document_id: int | None
def _load_doc_map(map_path: Path) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
"""Read the doc map JSONL.
Returns ``(rows, settings)`` where ``settings`` is the
``__settings__`` header blob (or ``{}`` for legacy maps).
"""
rows: dict[str, dict[str, Any]] = {}
settings: dict[str, Any] = {}
with map_path.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
row = json.loads(line)
if is_settings_header(row):
settings = dict(row["__settings__"])
continue
rows[str(row["qid"])] = row
return rows, settings
def _load_questions(
questions_jsonl: Path,
doc_map: dict[str, dict[str, Any]],
*,
split_filter: str | None,
task_filter: str | None,
body_filter: str | None,
require_images: bool,
sample_n: int | None,
) -> list[MXQuestion]:
out: list[MXQuestion] = []
with questions_jsonl.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
row = json.loads(line)
qid = str(row.get("qid") or "").strip()
if not qid:
continue
if split_filter and split_filter != "all" and row.get("split") != split_filter:
continue
if task_filter and task_filter != "all" and row.get("medical_task") != task_filter:
continue
if body_filter and body_filter != "all" and row.get("body_system") != body_filter:
continue
map_row = doc_map.get(qid)
if map_row is None:
logger.debug("No doc-map entry for %s; skipping", qid)
continue
n_images = int(map_row.get("n_images", 0))
if require_images and n_images <= 0:
continue
out.append(MXQuestion(
qid=qid,
question=str(row.get("question") or ""),
options={str(k).upper(): str(v) for k, v in (row.get("options") or {}).items()},
label=str(row.get("label") or "").strip().upper(),
medical_task=str(row.get("medical_task") or "").strip(),
body_system=str(row.get("body_system") or "").strip(),
question_type=str(row.get("question_type") or "").strip(),
split=str(row.get("split") or ""),
n_images=n_images,
pdf_path=Path(map_row["pdf_path"]),
document_id=map_row.get("document_id"),
))
out.sort(key=lambda q: (q.split, q.qid))
if sample_n is not None and sample_n > 0:
out = out[:sample_n]
return out
async def _gather_with_limit(coros: Iterable, *, concurrency: int) -> list[Any]:
sem = asyncio.Semaphore(max(1, concurrency))
async def _wrap(coro):
async with sem:
return await coro
return await asyncio.gather(*(_wrap(c) for c in coros))
_DESCRIPTION = (
"MedXpertQA-MM (~2,000 multimodal medical exam questions, 5 options, with images) — "
"Native PDF (vision) vs SurfSense (vision RAG) head-to-head."
)
# MedXpertQA-MM PDFs embed clinical images; vision LLM at ingest is
# the whole point. Operators can flip ``--no-vision-llm`` to measure
# how much we degrade without it (likely material).
_DEFAULT_INGEST_SETTINGS = IngestSettings(
use_vision_llm=True,
processing_mode="basic",
should_summarize=False,
)
class MedXpertQAMMBenchmark:
"""Multimodal medical exam head-to-head."""
suite: str = "medical"
name: str = "medxpertqa"
headline: bool = True # The medical suite headline.
description: str = _DESCRIPTION
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--split", default="test", choices=["test", "dev", "all"],
help="Which MedXpertQA-MM split to run (default: test).",
)
parser.add_argument(
"--task", default="all",
help="Filter by medical_task value (e.g. Diagnosis, Treatment, Basic Medicine).",
)
parser.add_argument(
"--body-system", dest="body_filter", default="all",
help="Filter by body_system value (e.g. Cardiovascular, Lymphatic).",
)
parser.add_argument(
"--require-images", dest="require_images", action="store_true",
help="Skip rare MM rows that ended up with zero resolvable images.",
)
parser.add_argument("--n", dest="sample_n", type=int, default=None,
help="Run only the first N questions after filters apply.")
parser.add_argument("--concurrency", type=int, default=4,
help="Parallel question workers per arm.")
parser.add_argument("--no-mentions", dest="no_mentions", action="store_true",
help="SurfSense arm: skip mentioned_document_ids (unscoped retrieval).")
parser.add_argument(
"--pdf-engine", default="native",
choices=[e.value for e in PdfEngine],
help="OpenRouter file-parser engine for the native arm.",
)
parser.add_argument(
"--max-output-tokens", type=int, default=512,
help="Cap on completion length for both arms.",
)
# Ingest-only knobs (forwarded by the CLI to ingest.run_ingest).
parser.add_argument(
"--max-questions", dest="max_questions", type=int, default=None,
help="(ingest only) cap on number of MM questions to render + upload.",
)
parser.add_argument(
"--upload-batch-size", dest="upload_batch_size", type=int, default=8,
help="(ingest only) PDFs per fileupload call.",
)
parser.add_argument(
"--skip-upload", dest="skip_upload", action="store_true",
help="(ingest only) render PDFs locally but don't push to SurfSense.",
)
parser.add_argument(
"--include-dev", dest="include_dev", action="store_true",
help="(ingest only) shorthand for --split all.",
)
# Per-upload knobs forwarded to /documents/fileupload at ingest;
# ignored at run-time (runner reads the resolved settings out of
# the doc-map manifest header).
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
from .ingest import run_ingest
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
await run_ingest(
ctx,
split=opts.get("split") or "test",
max_questions=opts.get("max_questions"),
upload_batch_size=int(opts.get("upload_batch_size") or 8),
skip_upload=bool(opts.get("skip_upload", False)),
include_dev=bool(opts.get("include_dev", False)),
settings=settings,
)
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
split_filter = opts.get("split") or "test"
task_filter = opts.get("task") or "all"
body_filter = opts.get("body_filter") or "all"
require_images = bool(opts.get("require_images"))
sample_n = opts.get("sample_n")
concurrency = int(opts.get("concurrency") or 4)
no_mentions = bool(opts.get("no_mentions"))
pdf_engine_name = opts.get("pdf_engine") or "native"
max_output_tokens = int(opts.get("max_output_tokens") or 512)
bench_dir = ctx.benchmark_data_dir()
questions_jsonl = bench_dir / "questions.jsonl"
map_path = ctx.maps_dir() / "medxpertqa_doc_map.jsonl"
if not questions_jsonl.exists() or not map_path.exists():
raise RuntimeError(
"MedXpertQA-MM not ingested for this suite. Run "
"`python -m surfsense_evals ingest medical medxpertqa` first."
)
doc_map, ingest_settings = _load_doc_map(map_path)
questions = _load_questions(
questions_jsonl, doc_map,
split_filter=split_filter,
task_filter=task_filter if task_filter != "all" else None,
body_filter=body_filter if body_filter != "all" else None,
require_images=require_images,
sample_n=sample_n,
)
if not questions:
raise RuntimeError(
"No MedXpertQA-MM questions matched the filters; broaden --split/--task/--body-system/--n."
)
logger.info("MedXpertQA-MM: scheduled %d questions", len(questions))
api_key = os.environ.get("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError("OPENROUTER_API_KEY env var is required for the native arm.")
# Native arm slug differs from SurfSense slug only in cost-arbitrage
# scenario; otherwise both arms answer with provider_model.
native_arm_model = ctx.native_arm_model
if any(hint in native_arm_model.lower() for hint in _TEXT_ONLY_HINTS):
if ctx.scenario == "symmetric-cheap":
logger.info(
"symmetric-cheap: native arm pinned to text-only %r as "
"intended; expect it to lose on image-bearing questions "
"(SurfSense answers from vision-extracted chunks).",
native_arm_model,
)
else:
logger.warning(
"Native arm slug %r looks text-only; image content in "
"MedXpertQA PDFs will be ignored. Re-pin via "
"`setup --provider-model anthropic/claude-sonnet-4.5` "
"(or pass --native-arm-model and --scenario cost-arbitrage "
"to make this asymmetry explicit).",
native_arm_model,
)
provider = OpenRouterPdfProvider(
api_key=api_key,
base_url=ctx.config.openrouter_base_url,
model=native_arm_model,
engine=PdfEngine(pdf_engine_name),
)
native_arm = NativePdfArm(provider=provider, max_output_tokens=max_output_tokens)
surf_arm = SurfSenseArm(
client=ctx.new_chat_client(),
search_space_id=ctx.search_space_id,
ephemeral_threads=True,
)
run_timestamp = utc_iso_timestamp()
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
raw_path = run_dir / "raw.jsonl"
async def _native_one(q: MXQuestion) -> ArmResult:
return await native_arm.answer(_make_native_request(q, max_output_tokens))
async def _surf_one(q: MXQuestion) -> ArmResult:
return await surf_arm.answer(_make_surfsense_request(q, no_mentions=no_mentions))
native_results, surf_results = await asyncio.gather(
_gather_with_limit((_native_one(q) for q in questions), concurrency=concurrency),
_gather_with_limit((_surf_one(q) for q in questions), concurrency=concurrency),
)
with raw_path.open("w", encoding="utf-8") as fh:
for q, n_res, s_res in zip(questions, native_results, surf_results, strict=False):
meta = {
"qid": q.qid,
"split": q.split,
"medical_task": q.medical_task,
"body_system": q.body_system,
"question_type": q.question_type,
"n_images": q.n_images,
"correct": q.label,
"document_id": q.document_id,
}
fh.write(json.dumps({**meta, **n_res.to_jsonl()}) + "\n")
fh.write(json.dumps({**meta, **s_res.to_jsonl()}) + "\n")
metrics = _compute_metrics(questions, native_results, surf_results)
artifact = RunArtifact(
suite=self.suite,
benchmark=self.name,
run_timestamp=run_timestamp,
raw_path=raw_path,
metrics=metrics,
extra={
"n_questions": len(questions),
"concurrency": concurrency,
"split_filter": split_filter,
"task_filter": task_filter,
"body_filter": body_filter,
"require_images": require_images,
"no_mentions": no_mentions,
"pdf_engine": pdf_engine_name,
"scenario": ctx.scenario,
"provider_model": ctx.provider_model,
"native_arm_model": native_arm_model,
"vision_provider_model": ctx.vision_provider_model,
"agent_llm_id": ctx.agent_llm_id,
"ingest_settings": ingest_settings,
},
)
manifest_path = run_dir / "run_artifact.json"
manifest_path.write_text(
json.dumps({
"suite": self.suite,
"benchmark": self.name,
"raw_path": "raw.jsonl",
"metrics": metrics,
"extra": artifact.extra,
}, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
return artifact
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
if not artifacts:
return ReportSection(
title="MedXpertQA-MM — Native PDF (vision) vs SurfSense (vision RAG)",
headline=False,
body_md="(no run artifacts found)",
body_json={},
)
latest = max(artifacts, key=lambda a: a.run_timestamp)
m = latest.metrics
native = m.get("native", {})
surf = m.get("surfsense", {})
delta = m.get("delta", {})
per_task = m.get("per_task", {})
per_body = m.get("per_body_system", {})
extra = latest.extra
body_lines: list[str] = []
body_lines.append(
f"- Sample size: {extra.get('n_questions', '?')} questions "
f"(split: `{extra.get('split_filter', 'test')}`, "
f"task: `{extra.get('task_filter', 'all')}`, "
f"body: `{extra.get('body_filter', 'all')}`, "
f"engine: `{extra.get('pdf_engine', 'native')}`)."
)
body_lines.append(format_scenario_md(extra))
body_lines.append(format_ingest_settings_md(extra.get("ingest_settings")))
body_lines.append(
"- Native arm (OpenRouter `chat/completions` + file plugin, "
f"`{extra.get('native_arm_model') or extra.get('provider_model', '?')}`):"
)
body_lines.append(_arm_summary_lines(native, indent=" "))
body_lines.append(
"- SurfSense arm (`POST /api/v1/new_chat`, vision RAG over chunks, "
f"`{extra.get('provider_model', '?')}`):"
)
body_lines.append(_arm_summary_lines(surf, indent=" "))
body_lines.append("- Delta (paired):")
body_lines.append(
f" - Accuracy: SurfSense {_pp(delta.get('accuracy_pp'))} pp "
f"(McNemar p={_fmt(delta.get('mcnemar_p_value'), 4)}, "
f"method={delta.get('mcnemar_method')})"
)
body_lines.append(
f" - Bootstrap 95% CI on delta: "
f"[{_pp(delta.get('bootstrap_ci_low'))}pp, {_pp(delta.get('bootstrap_ci_high'))}pp]"
)
body_lines.append(
f" - Cost / question: native ${_dollars(native.get('cost_micros_mean'))}, "
f"surfsense ${_dollars(surf.get('cost_micros_mean'))} "
f"(SurfSense delta {_pct_change(delta.get('cost_micros_pct'))})"
)
body_lines.append(
f" - Latency p50: native {_ms_to_s(native.get('latency_ms_median'))}, "
f"surfsense {_ms_to_s(surf.get('latency_ms_median'))} "
f"(SurfSense delta {_pct_change(delta.get('latency_ms_pct'))})"
)
if per_task:
body_lines.append("- Per-medical_task split:")
for task_name, vals in sorted(per_task.items()):
body_lines.append(
f" - {task_name}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
f"(n={vals.get('n')})"
)
if per_body:
body_lines.append("- Per-body_system split (top 5 by sample size):")
top = sorted(per_body.items(), key=lambda kv: -kv[1].get("n", 0))[:5]
for body_name, vals in top:
body_lines.append(
f" - {body_name}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
f"(n={vals.get('n')})"
)
return ReportSection(
title="MedXpertQA-MM — Native PDF (vision) vs SurfSense (vision RAG)",
headline=False,
body_md="\n".join(body_lines),
body_json=m,
)
# ---------------------------------------------------------------------------
# Per-question helpers
# ---------------------------------------------------------------------------
def _make_native_request(q: MXQuestion, max_tokens: int) -> ArmRequest:
prompt = build_prompt(q.question, q.options)
return ArmRequest(
question_id=q.qid,
prompt=prompt,
pdf_paths=[q.pdf_path],
options={"max_tokens": max_tokens},
)
def _make_surfsense_request(q: MXQuestion, *, no_mentions: bool) -> ArmRequest:
prompt = build_prompt(q.question, q.options)
mentions: list[int] | None = None
if not no_mentions and q.document_id is not None:
mentions = [int(q.document_id)]
return ArmRequest(
question_id=q.qid,
prompt=prompt,
mentioned_document_ids=mentions,
)
# ---------------------------------------------------------------------------
# Metrics
# ---------------------------------------------------------------------------
def _compute_metrics(
questions: list[MXQuestion],
native_results: list[ArmResult],
surf_results: list[ArmResult],
) -> dict[str, Any]:
native_correct: list[bool] = []
surf_correct: list[bool] = []
for q, n_res, s_res in zip(questions, native_results, surf_results, strict=False):
gold = q.label
n_ok = (n_res.answer_letter or "").upper() == gold and gold in ANSWER_LETTERS
s_ok = (s_res.answer_letter or "").upper() == gold and gold in ANSWER_LETTERS
native_correct.append(n_ok)
surf_correct.append(s_ok)
native_costs = [float(r.cost_micros) for r in native_results]
surf_costs = [float(r.cost_micros) for r in surf_results]
native_lats = [float(r.latency_ms) for r in native_results]
surf_lats = [float(r.latency_ms) for r in surf_results]
native_in = [float(r.input_tokens) for r in native_results]
native_out = [float(r.output_tokens) for r in native_results]
native_acc = accuracy_with_wilson_ci(sum(native_correct), len(native_correct))
surf_acc = accuracy_with_wilson_ci(sum(surf_correct), len(surf_correct))
mc = mcnemar_test(native_correct, surf_correct)
boot = bootstrap_delta_ci(native_correct, surf_correct, n_resamples=2000)
native_cost_agg = paired_aggregate(native_costs)
surf_cost_agg = paired_aggregate(surf_costs)
native_lat_agg = paired_aggregate(native_lats)
surf_lat_agg = paired_aggregate(surf_lats)
cost_pct = _safe_pct(surf_cost_agg.mean, native_cost_agg.mean)
lat_pct = _safe_pct(surf_lat_agg.median, native_lat_agg.median)
per_task = _per_field(questions, native_correct, surf_correct, key=lambda q: q.medical_task or "unknown")
per_body = _per_field(questions, native_correct, surf_correct, key=lambda q: q.body_system or "unknown")
return {
"native": {
**native_acc.to_dict(),
"cost_micros_mean": native_cost_agg.mean,
"cost_micros_median": native_cost_agg.median,
"latency_ms_mean": native_lat_agg.mean,
"latency_ms_median": native_lat_agg.median,
"latency_ms_p95": native_lat_agg.p95,
"input_tokens_mean": (sum(native_in) / len(native_in)) if native_in else 0.0,
"output_tokens_mean": (sum(native_out) / len(native_out)) if native_out else 0.0,
},
"surfsense": {
**surf_acc.to_dict(),
"cost_micros_mean": surf_cost_agg.mean,
"cost_micros_median": surf_cost_agg.median,
"latency_ms_mean": surf_lat_agg.mean,
"latency_ms_median": surf_lat_agg.median,
"latency_ms_p95": surf_lat_agg.p95,
},
"delta": {
"accuracy_pp": 100.0 * (surf_acc.accuracy - native_acc.accuracy),
"mcnemar_p_value": mc.p_value,
"mcnemar_method": mc.method,
"mcnemar_b_native_only": mc.b,
"mcnemar_c_surfsense_only": mc.c,
"bootstrap_ci_low": 100.0 * boot.ci_low,
"bootstrap_ci_high": 100.0 * boot.ci_high,
"cost_micros_pct": cost_pct,
"latency_ms_pct": lat_pct,
},
"per_task": per_task,
"per_body_system": per_body,
}
def _per_field(
questions: list[MXQuestion],
native_correct: list[bool],
surf_correct: list[bool],
*,
key,
) -> dict[str, dict[str, Any]]:
bucket: dict[str, list[tuple[bool, bool]]] = {}
for q, n_ok, s_ok in zip(questions, native_correct, surf_correct, strict=False):
bucket.setdefault(key(q), []).append((n_ok, s_ok))
out: dict[str, dict[str, Any]] = {}
for k, pairs in bucket.items():
n_correct = [a for a, _ in pairs]
s_correct = [b for _, b in pairs]
out[k] = {
"n": len(pairs),
"native_accuracy": (sum(n_correct) / len(pairs)) if pairs else 0.0,
"surfsense_accuracy": (sum(s_correct) / len(pairs)) if pairs else 0.0,
"delta_accuracy_pp": (
100.0 * (sum(s_correct) - sum(n_correct)) / len(pairs)
if pairs else 0.0
),
}
return out
def _safe_pct(numerator: float, denominator: float) -> float | None:
if denominator == 0:
return None
return 100.0 * (numerator - denominator) / denominator
# ---------------------------------------------------------------------------
# Formatters
# ---------------------------------------------------------------------------
def _arm_summary_lines(d: dict[str, Any], *, indent: str) -> str:
if not d:
return f"{indent}(no data)"
acc = d.get("accuracy", 0.0)
low = d.get("ci_low", 0.0)
high = d.get("ci_high", 0.0)
lines = [
f"{indent}- Accuracy: {acc * 100:.1f}% (Wilson 95% CI: {low * 100:.1f}% {high * 100:.1f}%)",
f"{indent}- Cost / question: ${_dollars(d.get('cost_micros_mean'))} (mean), "
f"${_dollars(d.get('cost_micros_median'))} (median)",
f"{indent}- Latency: p50 {_ms_to_s(d.get('latency_ms_median'))}, "
f"p95 {_ms_to_s(d.get('latency_ms_p95'))}",
]
if "input_tokens_mean" in d:
lines.append(
f"{indent}- Mean tokens / question: in {d.get('input_tokens_mean', 0):.0f}, "
f"out {d.get('output_tokens_mean', 0):.0f}"
)
return "\n".join(lines)
def _dollars(micros: Any) -> str:
if micros is None:
return "?"
try:
return f"{(float(micros) / 1_000_000):.4f}"
except (TypeError, ValueError):
return "?"
def _ms_to_s(ms: Any) -> str:
if ms is None:
return "?"
try:
return f"{float(ms) / 1000:.1f}s"
except (TypeError, ValueError):
return "?"
def _pp(value: Any) -> str:
if value is None:
return "?"
try:
return f"{float(value):+.1f}"
except (TypeError, ValueError):
return "?"
def _pct_change(value: Any) -> str:
if value is None:
return "?"
try:
return f"{float(value):+.0f}%"
except (TypeError, ValueError):
return "?"
def _fmt(value: Any, ndigits: int) -> str:
if value is None:
return "?"
try:
return f"{float(value):.{ndigits}f}"
except (TypeError, ValueError):
return "?"
__all__ = ["MedXpertQAMMBenchmark", "MXQuestion"]

View file

@ -0,0 +1,17 @@
"""MIRAGE — secondary single-arm SurfSense MCQ measurement.
Source: https://github.com/Teddy-XiongGZ/MIRAGE, paper
https://aclanthology.org/2024.findings-acl.372/. 7,663 questions
across MMLU-Med, MedQA-US, MedMCQA, PubMedQA*, BioASQ-Y/N.
This is a SurfSense-only measurement (not a head-to-head); native
PDF-in-LLM doesn't apply because there is no per-question discrete
document the corpus is millions of biomedical snippets.
"""
from __future__ import annotations
from .runner import MirageBenchmark
from ....core import registry as _registry
_registry.register(MirageBenchmark())

View file

@ -0,0 +1,548 @@
"""MIRAGE ingestion.
Downloads:
* ``benchmark.json`` ( 4 MB; questions for the 5 sub-tasks).
* ``retrieved_snippets_10k.zip`` (the union of top-10k snippet ids
retrieved by every retriever in the MedRAG paper, per task a
recall ceiling that avoids needing the full 23.9M-doc PubMed mirror).
Snippet *content* lives in the MedRAG HF mirrors
(``MedRAG/textbooks``, ``MedRAG/pubmed``, ``MedRAG/statpearls``,
``MedRAG/wikipedia``). We default to ``MedRAG/textbooks`` (212 MB,
125k snippets) which is the smallest and covers the majority of
``MedQA-US`` and the medical examination subsets. Operators can
opt into larger corpora with ``--corpus``.
Each snippet is written as one markdown file then batched into
``~5 MB`` markdown bundles for SurfSense's file upload (smaller
than backend default ``MAX_FILE_SIZE_BYTES`` and avoids the per-call
overhead of one HTTP request per snippet).
The ingestion produces two maps under ``data/medical/maps/``:
* ``mirage_snippet_map.jsonl`` ``{snippet_id, document_id, batch_path}``
* ``mirage_chunk_map.jsonl`` ``{chunk_id, document_id, snippet_id?}``
(best-effort; chunk text is heuristically attributed to the
snippet it overlaps when the SurfSense chunker splits a batched
markdown).
"""
from __future__ import annotations
import asyncio
import io
import json
import logging
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
import httpx
from ....core.config import set_suite_state
from ....core.ingest_settings import IngestSettings, settings_header_line
from ....core.registry import RunContext
logger = logging.getLogger(__name__)
MIRAGE_BENCHMARK_URL = (
"https://raw.githubusercontent.com/Teddy-XiongGZ/MIRAGE/main/benchmark.json"
)
# Upstream only ships ONE zip — top-10k retrievals across 5 retrievers,
# ~16 GB. We default to skipping it (see `--skip-snippet-filter`) and
# ingesting the chosen corpus in full; this URL is only fetched when
# the operator explicitly opts in.
MIRAGE_SNIPPETS_ZIP_URL = (
"https://virginia.box.com/shared/static/cxq17th6eisl2pn04vp0x723zczlvlzc.zip"
)
_DEFAULT_CORPUS = "MedRAG/textbooks"
_BATCH_SIZE_BYTES = 5 * 1024 * 1024
# 2 GB safety cap. Anything larger requires --allow-large-download.
# Set high enough that ``benchmark.json`` and small zips pass through
# untouched but the 16 GB MIRAGE retrievals zip trips the guard.
_LARGE_DOWNLOAD_BYTES = 2 * 1024 * 1024 * 1024
_DOWNLOAD_RETRIES = 5
_RETRYABLE_NET_EXC: tuple[type[BaseException], ...] = (
httpx.RemoteProtocolError,
httpx.ReadError,
httpx.ReadTimeout,
httpx.ConnectError,
httpx.ConnectTimeout,
)
@dataclass
class SnippetRow:
snippet_id: str
title: str
content: str
def to_markdown(self) -> str:
title = (self.title or "").strip() or "Untitled"
body = (self.content or "").strip()
return f"# {title}\n\n_id: `{self.snippet_id}`_\n\n{body}\n"
# ---------------------------------------------------------------------------
# Download helpers
# ---------------------------------------------------------------------------
async def _fetch_to_path(
url: str,
*,
dest: Path,
label: str,
timeout_s: float = 600.0,
allow_large_download: bool = False,
expect_zip: bool = False,
) -> Path:
"""Download ``url`` to ``dest`` with retry, atomic-rename, and
HTTP ``Range`` resume.
Operational properties:
* If ``dest`` already exists *and* (when ``expect_zip`` is True) the
cached file is a valid ZIP, returns it immediately. A corrupt ZIP
is removed and re-downloaded this is the safety net for the
`box.com truncated 16 GB zip` failure mode where the previous
run wrote a half-completed file then exited with an exception.
* Bytes are written to ``<dest>.partial`` and renamed only after the
stream completes cleanly (and, for zips, only after a quick
central-directory check). A failure mid-download leaves the
``.partial`` file in place so the next attempt can resume from
where it stopped via an HTTP ``Range`` header.
* Retries on transient network errors (``RemoteProtocolError``,
``ReadError``, ``ReadTimeout``, ``ConnectError``,
``ConnectTimeout``) with exponential backoff, up to
``_DOWNLOAD_RETRIES``.
* Aborts before downloading if the ``Content-Length`` (or already-
downloaded ``.partial`` size) is over ``_LARGE_DOWNLOAD_BYTES``
and ``allow_large_download`` is False, to keep an operator from
surprise-grabbing 16 GB on a slow link.
"""
if dest.exists():
if expect_zip and not _is_valid_zip(dest):
logger.warning(
"Cached %s at %s failed ZIP validation (size=%d B); deleting "
"and re-downloading.",
label,
dest,
dest.stat().st_size,
)
dest.unlink(missing_ok=True)
else:
logger.info("Using cached %s at %s", label, dest)
return dest
dest.parent.mkdir(parents=True, exist_ok=True)
partial = dest.with_suffix(dest.suffix + ".partial")
last_exc: BaseException | None = None
for attempt in range(1, _DOWNLOAD_RETRIES + 1):
existing_bytes = partial.stat().st_size if partial.exists() else 0
headers: dict[str, str] = {}
if existing_bytes:
headers["Range"] = f"bytes={existing_bytes}-"
logger.info(
"Resuming %s from byte %d (attempt %d/%d)",
label,
existing_bytes,
attempt,
_DOWNLOAD_RETRIES,
)
else:
logger.info(
"Downloading %s from %s (attempt %d/%d)",
label,
url,
attempt,
_DOWNLOAD_RETRIES,
)
try:
async with httpx.AsyncClient(
timeout=httpx.Timeout(timeout_s, connect=20.0),
follow_redirects=True,
) as client:
async with client.stream("GET", url, headers=headers) as response:
if existing_bytes and response.status_code == 200:
logger.warning(
"Server ignored Range header for %s; restarting from 0.",
label,
)
partial.unlink(missing_ok=True)
existing_bytes = 0
elif response.status_code == 416:
# Range not satisfiable — the .partial is at or
# past the end. Treat as "already downloaded";
# validate by closing and re-opening for atomic
# rename below.
logger.info(
"Server reports %s already complete (HTTP 416).",
label,
)
elif response.status_code not in (200, 206):
response.raise_for_status()
total_size = _planned_total_size(response, existing_bytes)
if (
total_size is not None
and total_size > _LARGE_DOWNLOAD_BYTES
and not allow_large_download
):
raise _LargeDownloadAbort(label, total_size)
mode = "ab" if existing_bytes else "wb"
with partial.open(mode) as fh:
async for chunk in response.aiter_bytes(chunk_size=1 << 18):
fh.write(chunk)
# Optional content sanity check before promoting to dest.
if expect_zip and not _is_valid_zip(partial):
raise zipfile.BadZipFile(
f"{label} downloaded to {partial} but failed central-"
"directory check; will retry."
)
partial.replace(dest)
return dest
except _LargeDownloadAbort:
raise
except _RETRYABLE_NET_EXC as exc:
last_exc = exc
wait = min(60.0, 2.0 ** attempt)
logger.warning(
"Network error fetching %s (%s: %s); retrying in %.0fs.",
label,
type(exc).__name__,
exc,
wait,
)
await asyncio.sleep(wait)
except zipfile.BadZipFile as exc:
last_exc = exc
# Truncated body — drop the partial and retry from scratch.
partial.unlink(missing_ok=True)
wait = min(60.0, 2.0 ** attempt)
logger.warning(
"Truncated ZIP for %s; restarting from byte 0 in %.0fs.",
label,
wait,
)
await asyncio.sleep(wait)
raise RuntimeError(
f"Failed to download {label} after {_DOWNLOAD_RETRIES} attempts: {last_exc!s}"
)
def _planned_total_size(response: httpx.Response, existing_bytes: int) -> int | None:
"""Best-effort total size including any already-buffered .partial bytes."""
cl = response.headers.get("Content-Length")
if not cl:
return None
try:
remaining = int(cl)
except ValueError:
return None
return existing_bytes + remaining
def _is_valid_zip(path: Path) -> bool:
"""Cheap ZIP validity check via central-directory parse."""
try:
with zipfile.ZipFile(path) as zf:
# ``namelist`` forces the central directory to be parsed.
zf.namelist()
return True
except (zipfile.BadZipFile, OSError):
return False
class _LargeDownloadAbort(RuntimeError):
"""Raised when a download exceeds the safety threshold without opt-in."""
def __init__(self, label: str, size_bytes: int) -> None:
gb = size_bytes / (1024 ** 3)
super().__init__(
f"{label} would download ~{gb:.1f} GB, above the {_LARGE_DOWNLOAD_BYTES / (1024 ** 3):.0f} GB safety cap. "
"Re-run with `--allow-large-download` to acknowledge, or use "
"`--skip-snippet-filter` to bypass this download entirely and "
"ingest the full corpus instead."
)
def _read_snippet_ids(zip_path: Path, *, tasks: list[str]) -> dict[str, set[str]]:
"""Walk the ZIP for files whose path contains any task name.
Each MedRAG retriever produces one JSON file per task in the zip;
we union all retrievers' top-K ids. The exact directory layout has
historically been ``<retriever>/<task>.json`` mapping
``question_id -> [snippet_id, ...]``.
"""
out: dict[str, set[str]] = {t: set() for t in tasks}
with zipfile.ZipFile(zip_path, "r") as zf:
for member in zf.namelist():
if not member.lower().endswith(".json"):
continue
stem = Path(member).stem.lower()
for task in tasks:
if task.lower() in stem:
try:
with zf.open(member) as fh:
payload = json.loads(fh.read().decode("utf-8"))
except (json.JSONDecodeError, KeyError):
continue
for ids in payload.values():
if isinstance(ids, list):
for sid in ids:
if isinstance(sid, str):
out[task].add(sid)
elif isinstance(sid, dict) and "id" in sid:
out[task].add(str(sid["id"]))
break
return out
def _load_corpus(
corpus_name: str, snippet_ids: set[str] | None
) -> Iterable[SnippetRow]:
"""Stream rows from a MedRAG HF corpus.
* ``snippet_ids=None`` yield every row (full-corpus ingestion path).
* ``snippet_ids={...}`` filter to the requested ids.
Imported lazily ``datasets`` is a heavyweight dep.
"""
if snippet_ids is not None and not snippet_ids:
return iter(())
from datasets import load_dataset # noqa: PLC0415
logger.info("Loading corpus %s (this may take a while)", corpus_name)
ds = load_dataset(corpus_name, split="train", streaming=True)
for row in ds:
sid = str(row.get("id") or "")
if snippet_ids is not None and sid not in snippet_ids:
continue
yield SnippetRow(
snippet_id=sid,
title=str(row.get("title") or ""),
content=str(row.get("content") or row.get("contents") or ""),
)
# ---------------------------------------------------------------------------
# Batching + upload
# ---------------------------------------------------------------------------
@dataclass
class SnippetBatch:
path: Path
snippet_ids: list[str]
def _write_batches(
snippets: Iterable[SnippetRow],
*,
out_dir: Path,
batch_bytes: int = _BATCH_SIZE_BYTES,
prefix: str = "mirage",
) -> list[SnippetBatch]:
out_dir.mkdir(parents=True, exist_ok=True)
batches: list[SnippetBatch] = []
current_buffer = io.StringIO()
current_ids: list[str] = []
current_bytes = 0
batch_idx = 0
def _flush() -> None:
nonlocal current_buffer, current_ids, current_bytes, batch_idx
if not current_ids:
return
path = out_dir / f"{prefix}_{batch_idx:04d}.md"
path.write_text(current_buffer.getvalue(), encoding="utf-8")
batches.append(SnippetBatch(path=path, snippet_ids=current_ids))
batch_idx += 1
current_buffer = io.StringIO()
current_ids = []
current_bytes = 0
for snippet in snippets:
chunk = snippet.to_markdown() + "\n---\n\n"
chunk_bytes = len(chunk.encode("utf-8"))
if current_bytes + chunk_bytes > batch_bytes and current_ids:
_flush()
current_buffer.write(chunk)
current_ids.append(snippet.snippet_id)
current_bytes += chunk_bytes
_flush()
return batches
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
async def run_ingest(
ctx: RunContext,
*,
tasks: list[str] | None = None,
corpus: str = _DEFAULT_CORPUS,
max_snippets_per_task: int | None = None,
skip_snippet_filter: bool = True,
allow_large_download: bool = False,
settings: IngestSettings | None = None,
) -> None:
"""Ingest a MedRAG corpus into the suite SearchSpace.
By default (``skip_snippet_filter=True``) we ingest the **entire**
chosen corpus and let SurfSense's own retriever do the work. The
upstream MIRAGE retrieval zip is ~16 GB and only useful when you
want to pre-filter the corpus to the set of snippets some other
retriever surfaced; for ``MedRAG/textbooks`` (212 MB / 125k snippets)
that pre-filter is unnecessary overhead and routinely fails to
download (box.com truncates the stream). Set
``skip_snippet_filter=False`` (CLI: ``--use-snippet-filter``) only
if you specifically want the upstream filter and budget the
16 GB zip transfer.
"""
tasks = tasks or ["mmlu", "medqa", "medmcqa", "pubmedqa", "bioasq"]
settings = settings or IngestSettings(use_vision_llm=False, processing_mode="basic")
bench_path = ctx.benchmark_data_dir() / "benchmark.json"
await _fetch_to_path(MIRAGE_BENCHMARK_URL, dest=bench_path, label="MIRAGE benchmark.json")
if skip_snippet_filter:
logger.info(
"Skipping retrieved_snippets_10k.zip (skip_snippet_filter=True); "
"ingesting entire corpus %s.",
corpus,
)
snippets = list(_load_corpus(corpus, snippet_ids=None))
else:
zip_path = ctx.benchmark_data_dir() / "retrieved_snippets_10k.zip"
await _fetch_to_path(
MIRAGE_SNIPPETS_ZIP_URL,
dest=zip_path,
label="MIRAGE retrieved_snippets_10k.zip",
allow_large_download=allow_large_download,
expect_zip=True,
)
by_task = _read_snippet_ids(zip_path, tasks=tasks)
if max_snippets_per_task is not None:
by_task = {k: set(list(v)[:max_snippets_per_task]) for k, v in by_task.items()}
union_ids = set().union(*by_task.values())
logger.info(
"MIRAGE: tasks=%s, snippet ids per task: %s, union=%d",
tasks,
{k: len(v) for k, v in by_task.items()},
len(union_ids),
)
if not union_ids:
raise RuntimeError(
f"No snippet ids parsed for tasks {tasks!r} from {zip_path}. "
"Check the zip layout (the upstream archive may have changed)."
)
snippets = list(_load_corpus(corpus, snippet_ids=union_ids))
logger.info(
"Loaded %d / %d requested snippets from corpus %s",
len(snippets),
len(union_ids),
corpus,
)
if not snippets:
raise RuntimeError(
f"Corpus {corpus} returned 0 matching rows. Either the snippet "
"ids reference a different corpus (e.g. PubMed) or the HF mirror "
"is unavailable. Pass --corpus to override."
)
batches_dir = ctx.benchmark_data_dir() / "batches"
batches = _write_batches(snippets, out_dir=batches_dir)
logger.info("Wrote %d snippet batches to %s", len(batches), batches_dir)
docs_client = ctx.documents_client()
upload_result = await docs_client.upload(
files=[b.path for b in batches],
search_space_id=ctx.search_space_id,
should_summarize=settings.should_summarize,
use_vision_llm=settings.use_vision_llm,
processing_mode=settings.processing_mode,
)
logger.info("MIRAGE upload settings: %s", settings.render_label())
new_doc_ids = list(upload_result.document_ids)
if new_doc_ids:
await docs_client.wait_until_ready(
search_space_id=ctx.search_space_id,
document_ids=new_doc_ids,
timeout_s=3600.0,
max_poll_s=15.0,
)
statuses = await docs_client.get_status(
search_space_id=ctx.search_space_id,
document_ids=new_doc_ids + upload_result.duplicate_document_ids,
)
title_to_doc = {s.title: s.document_id for s in statuses}
snippet_map_path = ctx.maps_dir() / "mirage_snippet_map.jsonl"
chunk_map_path = ctx.maps_dir() / "mirage_chunk_map.jsonl"
with snippet_map_path.open("w", encoding="utf-8") as fh:
# Header line records the ingest-time settings (see
# core/ingest_settings.py for the protocol).
fh.write(settings_header_line(settings) + "\n")
for batch in batches:
doc_id = title_to_doc.get(batch.path.name)
if doc_id is None:
logger.warning("No document_id for batch %s", batch.path.name)
continue
for sid in batch.snippet_ids:
fh.write(
json.dumps(
{
"snippet_id": sid,
"document_id": doc_id,
"batch_path": str(batch.path),
}
)
+ "\n"
)
# Best-effort chunk map. SurfSense doesn't expose snippet attribution
# per chunk, so we just record (chunk_id -> document_id) here; the
# MIRAGE runner only needs document_id for accuracy scoring.
with chunk_map_path.open("w", encoding="utf-8") as fh:
for doc_id in {b.path.name and title_to_doc.get(b.path.name) for b in batches} - {None}:
try:
chunks = await docs_client.list_chunks(int(doc_id))
except Exception as exc: # noqa: BLE001
logger.warning("Failed to list chunks for doc_id=%s: %s", doc_id, exc)
continue
for chunk in chunks:
fh.write(
json.dumps({"chunk_id": chunk.id, "document_id": doc_id})
+ "\n"
)
new_state = ctx.suite_state
new_state.ingestion_maps["mirage"] = str(snippet_map_path)
set_suite_state(ctx.config, ctx.suite, new_state)
logger.info("Wrote MIRAGE maps to %s and %s", snippet_map_path, chunk_map_path)
__all__ = ["run_ingest", "SnippetRow", "SnippetBatch"]

View file

@ -0,0 +1,44 @@
"""MedRAG ``{step_by_step_thinking, answer_choice}`` MCQ prompt.
Mirrors the MedRAG paper's prompt format so accuracy numbers are
comparable to the published MIRAGE leaderboard.
"""
from __future__ import annotations
from collections.abc import Mapping
_PROMPT_TEMPLATE = """\
You are a helpful medical expert. Answer the following multiple-choice
question using the relevant medical knowledge available to you (and any
retrieved context, if provided).
Respond with a JSON object on a single line:
{{"step_by_step_thinking": "<your reasoning>", "answer_choice": "<letter>"}}
Question: {question}
Options:
{options_block}
"""
def _options_block(options: Mapping[str, str]) -> str:
parts: list[str] = []
for letter in sorted(options.keys()):
text = options.get(letter)
if text is None or text == "":
continue
parts.append(f"{letter}) {text}")
return "\n".join(parts)
def build_prompt(question: str, options: Mapping[str, str]) -> str:
return _PROMPT_TEMPLATE.format(
question=question.strip(),
options_block=_options_block(options),
)
__all__ = ["build_prompt"]

View file

@ -0,0 +1,332 @@
"""MIRAGE runner: SurfSense-only per-task accuracy.
The benchmark file format is one top-level dict per task (``mmlu``,
``medqa``, ``medmcqa``, ``pubmedqa``, ``bioasq``); each task value is
``{question_id: {question, options, answer}}``.
We restrict retrieval to the suite SearchSpace's full corpus (no
``mentioned_document_ids`` MIRAGE has no per-question ground-truth
document; retrieval *is* the test). Accuracy is paired against the
``answer`` letter from the dataset.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
from dataclasses import dataclass
from typing import Any
from ....core.arms import ArmRequest, ArmResult, SurfSenseArm
from ....core.config import utc_iso_timestamp
from ....core.ingest_settings import (
IngestSettings,
add_ingest_settings_args,
format_ingest_settings_md,
read_settings_header,
)
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci, macro_accuracy
from ....core.registry import (
Benchmark,
ReportSection,
RunArtifact,
RunContext,
)
from .prompt import build_prompt
logger = logging.getLogger(__name__)
_TASKS = ("mmlu", "medqa", "medmcqa", "pubmedqa", "bioasq")
_DESCRIPTION = "MIRAGE (7,663 medical MCQs) — single-arm SurfSense per-task accuracy."
# MIRAGE corpus is text-only (textbook + abstract markdown). Vision
# LLM at ingest is wasted compute by default; flip ``--use-vision-llm``
# to measure cost.
_DEFAULT_INGEST_SETTINGS = IngestSettings(
use_vision_llm=False,
processing_mode="basic",
should_summarize=False,
)
@dataclass
class MirageQuestion:
task: str
qid: str
question: str
options: dict[str, str]
correct: str
@property
def question_id(self) -> str:
return f"{self.task}::{self.qid}"
def _load_questions(
benchmark: dict[str, Any],
*,
tasks: list[str],
sample_n: int | None,
) -> list[MirageQuestion]:
out: list[MirageQuestion] = []
for task in tasks:
rows = benchmark.get(task) or {}
if not isinstance(rows, dict):
continue
for qid, raw in rows.items():
if not isinstance(raw, dict):
continue
options = raw.get("options") or {}
if not isinstance(options, dict):
continue
answer_raw = str(raw.get("answer") or "").strip()
if not answer_raw:
continue
answer_letter = answer_raw[:1].upper()
out.append(
MirageQuestion(
task=task,
qid=str(qid),
question=str(raw.get("question", "")),
options={str(k): str(v) for k, v in options.items() if v},
correct=answer_letter,
)
)
out.sort(key=lambda q: (q.task, q.qid))
if sample_n is not None and sample_n > 0:
# Stratified-by-task slice so smoke runs cover every task.
per_task = max(1, sample_n // max(1, len(tasks)))
sliced: list[MirageQuestion] = []
per_task_counter: dict[str, int] = {}
for q in out:
n = per_task_counter.get(q.task, 0)
if n >= per_task:
continue
sliced.append(q)
per_task_counter[q.task] = n + 1
if len(sliced) >= sample_n:
break
out = sliced
return out
async def _gather_with_limit(coros, *, concurrency: int) -> list[Any]:
sem = asyncio.Semaphore(max(1, concurrency))
async def _wrap(c):
async with sem:
return await c
return await asyncio.gather(*(_wrap(c) for c in coros))
class MirageBenchmark:
suite: str = "medical"
name: str = "mirage"
headline: bool = False
description: str = _DESCRIPTION
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--task",
default="all",
choices=("all", *_TASKS),
help="Run a single task or all (default: all).",
)
parser.add_argument("--n", dest="sample_n", type=int, default=None,
help="Stratified sample size across tasks.")
parser.add_argument("--concurrency", type=int, default=4)
parser.add_argument(
"--corpus", default="MedRAG/textbooks",
help="HF MedRAG corpus to ingest from (default: MedRAG/textbooks).",
)
parser.add_argument(
"--max-snippets-per-task", type=int, default=None,
help="Cap the per-task ingestion to N snippets (smoke).",
)
# Mutually exclusive: by default we skip the upstream 16 GB
# retrievals zip and ingest the entire corpus. Operators who
# want the upstream pre-filter pass --use-snippet-filter (and,
# if their corpus mismatch warrants the 16 GB transfer,
# --allow-large-download).
snippet_group = parser.add_mutually_exclusive_group()
snippet_group.add_argument(
"--use-snippet-filter", dest="use_snippet_filter", action="store_true",
default=False,
help="Download retrieved_snippets_10k.zip (~16 GB) and "
"filter the corpus to those ids before ingest. "
"Default: skip and ingest entire corpus.",
)
snippet_group.add_argument(
"--skip-snippet-filter", dest="use_snippet_filter", action="store_false",
help="(Default) Skip the 16 GB upstream zip; ingest entire corpus.",
)
parser.add_argument(
"--allow-large-download", action="store_true", default=False,
help="Permit downloads larger than 2 GB (e.g. retrieved_snippets_10k.zip).",
)
# Per-upload knobs; ignored at run-time (runner reads the
# resolved settings out of the snippet-map manifest header).
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
from .ingest import run_ingest
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
await run_ingest(
ctx,
corpus=str(opts.get("corpus") or "MedRAG/textbooks"),
max_snippets_per_task=opts.get("max_snippets_per_task"),
skip_snippet_filter=not bool(opts.get("use_snippet_filter")),
allow_large_download=bool(opts.get("allow_large_download")),
settings=settings,
)
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
task_filter = opts.get("task") or "all"
tasks = list(_TASKS) if task_filter == "all" else [task_filter]
sample_n = opts.get("sample_n")
concurrency = int(opts.get("concurrency") or 4)
bench_path = ctx.benchmark_data_dir() / "benchmark.json"
if not bench_path.exists():
raise RuntimeError(
"MIRAGE benchmark.json missing. Run "
"`python -m surfsense_evals ingest medical mirage` first."
)
benchmark = json.loads(bench_path.read_text(encoding="utf-8"))
ingest_settings = read_settings_header(
ctx.maps_dir() / "mirage_snippet_map.jsonl"
)
questions = _load_questions(benchmark, tasks=tasks, sample_n=sample_n)
if not questions:
raise RuntimeError(
f"No MIRAGE questions matched task={task_filter!r} sample_n={sample_n!r}."
)
logger.info("MIRAGE: scheduled %d questions across tasks %s",
len(questions), tasks)
arm = SurfSenseArm(
client=ctx.new_chat_client(),
search_space_id=ctx.search_space_id,
ephemeral_threads=True,
)
async def _ask(q: MirageQuestion) -> ArmResult:
request = ArmRequest(
question_id=q.question_id,
prompt=build_prompt(q.question, q.options),
)
return await arm.answer(request)
results: list[ArmResult] = await _gather_with_limit(
(_ask(q) for q in questions), concurrency=concurrency
)
run_timestamp = utc_iso_timestamp()
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
raw_path = run_dir / "raw.jsonl"
with raw_path.open("w", encoding="utf-8") as fh:
for q, res in zip(questions, results):
fh.write(
json.dumps(
{
"task": q.task,
"qid": q.qid,
"correct": q.correct,
**res.to_jsonl(),
}
)
+ "\n"
)
per_task_acc: dict[str, dict[str, Any]] = {}
for task in tasks:
n_correct = 0
n_total = 0
for q, res in zip(questions, results):
if q.task != task:
continue
n_total += 1
if (res.answer_letter or "").upper() == q.correct:
n_correct += 1
acc = accuracy_with_wilson_ci(n_correct, n_total)
per_task_acc[task] = acc.to_dict()
macro = macro_accuracy(
{t: accuracy_with_wilson_ci(d["n_correct"], d["n_total"]) for t, d in per_task_acc.items()}
)
metrics = {"per_task": per_task_acc, "macro_accuracy": macro}
artifact = RunArtifact(
suite=self.suite,
benchmark=self.name,
run_timestamp=run_timestamp,
raw_path=raw_path,
metrics=metrics,
extra={
"n_questions": len(questions),
"task_filter": task_filter,
"concurrency": concurrency,
"provider_model": ctx.provider_model,
"ingest_settings": ingest_settings,
},
)
manifest_path = run_dir / "run_artifact.json"
manifest_path.write_text(
json.dumps(
{
"suite": self.suite,
"benchmark": self.name,
"raw_path": "raw.jsonl",
"metrics": metrics,
"extra": artifact.extra,
},
indent=2,
sort_keys=True,
)
+ "\n",
encoding="utf-8",
)
return artifact
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
if not artifacts:
return ReportSection(
title="MIRAGE — single-arm SurfSense per-task accuracy",
headline=False,
body_md="(no run artifacts found)",
body_json={},
)
latest = max(artifacts, key=lambda a: a.run_timestamp)
per_task = latest.metrics.get("per_task", {})
macro = latest.metrics.get("macro_accuracy", 0.0)
lines: list[str] = []
lines.append(format_ingest_settings_md(latest.extra.get("ingest_settings")))
for task in _TASKS:
row = per_task.get(task)
if not row:
continue
acc = row.get("accuracy", 0.0)
low = row.get("ci_low", 0.0)
high = row.get("ci_high", 0.0)
lines.append(
f"- {task}: {acc * 100:.1f}% "
f"(Wilson 95% CI: {low * 100:.1f}% {high * 100:.1f}%, "
f"n={row.get('n_total', '?')})"
)
if not lines:
lines.append("- (no per-task results)")
lines.append(f"- Macro accuracy: {macro * 100:.2f}%")
return ReportSection(
title="MIRAGE — single-arm SurfSense per-task accuracy",
headline=False,
body_md="\n".join(lines),
body_json=latest.metrics,
)
__all__ = ["MirageBenchmark", "MirageQuestion"]

View file

@ -0,0 +1,14 @@
"""Multimodal long-document benchmarks (PDFs with embedded images/charts/tables).
Distinct from the medical suite because these documents are domain-mixed
(research reports, financials, manuals, government, brochures, papers).
The hypothesis being tested here is *general*: does SurfSense's
chunking-based vision RAG preserve information that lives in pixels
across long PDFs, across pages versus feeding the same PDF directly
to a vision-capable model?
Subpackages register themselves with ``core.registry`` on import. The
``suites/__init__.py`` discovery walker imports them automatically.
"""
from __future__ import annotations

View file

@ -0,0 +1,19 @@
"""MMLongBench-Doc — head-to-head Native PDF (vision) vs SurfSense (vision RAG).
Source: https://huggingface.co/datasets/yubo2333/MMLongBench-Doc
Paper: https://arxiv.org/abs/2407.01523 (NeurIPS 2024 D&B Track)
* 135 long PDFs (avg 47 pages, multi-modal: text, images, charts, tables)
* 1,091 expert-annotated questions
* 33% require evidence from multiple pages
* ~22% intentionally unanswerable (tests hallucination resistance)
* 7 document types: research report, tutorial/workshop, academic paper,
financial report, brochure, government, manuals
"""
from __future__ import annotations
from ....core import registry as _registry
from .runner import MMLongBenchDocBenchmark
_registry.register(MMLongBenchDocBenchmark())

View file

@ -0,0 +1,236 @@
"""Format-aware grader for MMLongBench-Doc answers.
The dataset ships with five ``answer_format`` values per question:
* ``Str`` short factoid string
* ``Int`` integer count / year
* ``Float`` decimal number (often with units stripped)
* ``List`` comma- or semicolon-separated bag of items
* ``None`` gold answer is literally "Not answerable" (hallucination probe)
The official MMLongBench-Doc paper grades with GPT-4 as judge. We
implement a *deterministic* rule-based grader as the default (so two
researchers running the same harness get the same number); an
LLM-judge mode is exposed via ``--judge gpt5`` and routed through the
same OpenRouter key the arms use, but is opt-in to keep cost down.
Returned by every grading call:
* ``correct: bool`` final pass/fail used for accuracy + McNemar
* ``f1: float`` token-level F1 (continuous credit, useful when
comparing arms that get *most* of a list right)
* ``method: str`` which path graded the row (one of
``str_norm`` / ``int_eq`` / ``float_tol`` / ``list_set`` /
``none_match`` / ``llm_judge``).
"""
from __future__ import annotations
import re
import string
from collections import Counter
from dataclasses import dataclass
# ---------------------------------------------------------------------------
# Public types
# ---------------------------------------------------------------------------
@dataclass
class GradeResult:
correct: bool
f1: float
method: str
normalised_pred: str = ""
normalised_gold: str = ""
# ---------------------------------------------------------------------------
# Normalisation helpers (shared)
# ---------------------------------------------------------------------------
_PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.IGNORECASE)
_WS = re.compile(r"\s+")
_NOT_ANSWERABLE_TOKENS = {
"not answerable",
"cannot be answered",
"cannot answer",
"no answer",
"unknown",
"none",
"not specified",
"not mentioned",
"not provided",
"the answer is not in the document",
}
# Abbreviations that should be matched literally on the lowercased
# prediction (because normalisation strips their punctuation and
# leaves them too short to be safe as substring tokens).
_NOT_ANSWERABLE_LITERAL = {"n/a", "na/", "n.a.", "n a"}
def _normalise_text(s: str) -> str:
"""SQuAD-style normalisation: lowercase, drop punctuation/articles, squash whitespace."""
s = s.lower()
s = s.translate(_PUNCT_TABLE)
s = _ARTICLES.sub(" ", s)
s = _WS.sub(" ", s).strip()
return s
# ---------------------------------------------------------------------------
# Per-format graders
# ---------------------------------------------------------------------------
def _grade_str(pred: str, gold: str) -> GradeResult:
p = _normalise_text(pred)
g = _normalise_text(gold)
if not p:
return GradeResult(False, 0.0, "str_norm", p, g)
if p == g:
return GradeResult(True, 1.0, "str_norm", p, g)
# Substring match in either direction = correct (handles the common
# "model emits a fuller sentence containing the gold" case).
if g and (g in p or p in g):
return GradeResult(True, _f1_tokens(p, g), "str_norm", p, g)
return GradeResult(False, _f1_tokens(p, g), "str_norm", p, g)
_INT_RE = re.compile(r"-?\d[\d,]*")
def _grade_int(pred: str, gold: str) -> GradeResult:
g_match = _INT_RE.search(gold)
if g_match is None:
return _grade_str(pred, gold)
g_val = int(g_match.group(0).replace(",", ""))
p_match = _INT_RE.search(pred)
if p_match is None:
return GradeResult(False, 0.0, "int_eq", str(p_match), str(g_val))
p_val = int(p_match.group(0).replace(",", ""))
return GradeResult(p_val == g_val, 1.0 if p_val == g_val else 0.0,
"int_eq", str(p_val), str(g_val))
_FLOAT_RE = re.compile(r"-?\d+(?:[.,]\d+)?")
def _grade_float(pred: str, gold: str, *, rel_tol: float = 1e-2) -> GradeResult:
g_match = _FLOAT_RE.search(gold)
if g_match is None:
return _grade_str(pred, gold)
g_val = float(g_match.group(0).replace(",", "."))
p_match = _FLOAT_RE.search(pred)
if p_match is None:
return GradeResult(False, 0.0, "float_tol", "", str(g_val))
p_val = float(p_match.group(0).replace(",", "."))
# Tolerance: 1% relative or 0.01 absolute, whichever is looser.
abs_diff = abs(p_val - g_val)
tol = max(abs(g_val) * rel_tol, 0.01)
ok = abs_diff <= tol
return GradeResult(ok, 1.0 if ok else 0.0, "float_tol", str(p_val), str(g_val))
_LIST_SPLIT = re.compile(r"[;,\n]")
def _grade_list(pred: str, gold: str) -> GradeResult:
g_items = {_normalise_text(x) for x in _LIST_SPLIT.split(gold) if x.strip()}
p_items = {_normalise_text(x) for x in _LIST_SPLIT.split(pred) if x.strip()}
if not g_items:
return _grade_str(pred, gold)
inter = g_items & p_items
if not inter:
return GradeResult(False, 0.0, "list_set",
", ".join(sorted(p_items)),
", ".join(sorted(g_items)))
precision = len(inter) / len(p_items) if p_items else 0.0
recall = len(inter) / len(g_items)
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
return GradeResult(f1 >= 0.999, f1, "list_set",
", ".join(sorted(p_items)),
", ".join(sorted(g_items)))
def _grade_none(pred: str, gold: str) -> GradeResult:
"""Gold == 'Not answerable'. The arm earns credit if its prediction
expresses inability to answer.
Two passes:
1. Literal-substring check on the lowercased+stripped pred for
ambiguous abbreviations like ``n/a`` (since normalisation
strips the punctuation and would over-match).
2. Word-boundary substring check on the normalised pred for the
multi-word phrases (``cannot answer``, ``not specified`` etc.).
"""
raw_lower = (pred or "").strip().lower()
p = _normalise_text(pred)
expressed_unknown = False
# Pass 1: literal abbreviation hits on the raw lowercased text.
if any(lit in raw_lower for lit in _NOT_ANSWERABLE_LITERAL):
expressed_unknown = True
# Pass 2: word-boundary check on normalised tokens.
if not expressed_unknown:
p_padded = f" {p} "
for tok_raw in _NOT_ANSWERABLE_TOKENS:
tok = _normalise_text(tok_raw)
if not tok or len(tok) < 3:
continue
if f" {tok} " in p_padded:
expressed_unknown = True
break
return GradeResult(
expressed_unknown, 1.0 if expressed_unknown else 0.0,
"none_match", p, _normalise_text(gold),
)
def _f1_tokens(pred: str, gold: str) -> float:
p_tok = pred.split()
g_tok = gold.split()
if not p_tok or not g_tok:
return 0.0
common = Counter(p_tok) & Counter(g_tok)
overlap = sum(common.values())
if overlap == 0:
return 0.0
precision = overlap / len(p_tok)
recall = overlap / len(g_tok)
return 2 * precision * recall / (precision + recall)
# ---------------------------------------------------------------------------
# Public dispatcher
# ---------------------------------------------------------------------------
_FORMAT_DISPATCH = {
"str": _grade_str,
"int": _grade_int,
"float": _grade_float,
"list": _grade_list,
"none": _grade_none,
}
def grade(*, pred: str, gold: str, answer_format: str) -> GradeResult:
"""Grade a single (prediction, gold) pair.
``answer_format`` is the dataset's ``answer_format`` column value.
Unknown / blank values fall through to string grading.
"""
fmt = (answer_format or "").strip().lower()
fn = _FORMAT_DISPATCH.get(fmt, _grade_str)
return fn(pred or "", gold or "")
__all__ = ["GradeResult", "grade"]

View file

@ -0,0 +1,365 @@
"""MMLongBench-Doc ingestion.
Steps:
1. Pull the questions parquet from
``hf://datasets/yubo2333/MMLongBench-Doc/data/`` and cache locally.
2. Resolve the unique set of ``doc_id`` referenced by questions, and
download each PDF from
``hf://datasets/yubo2333/MMLongBench-Doc/documents/<doc_id>``.
``huggingface_hub.hf_hub_download`` is resumable + content-hash
verifying; we cache PDFs under ``<data_dir>/multimodal_doc/mmlongbench/pdfs/``.
3. Upload every PDF to SurfSense via ``DocumentsClient.upload`` with
``use_vision_llm=True`` so SurfSense's Pillow + LiteLLM vision
pipeline extracts captions / OCR for embedded images, charts, and
tables.
4. Wait for ``processed`` status and persist
``doc_id -> document_id`` in
``<data_dir>/multimodal_doc/maps/mmlongbench_doc_map.jsonl``.
By default we ingest **all** 135 PDFs (~660 MB, totally manageable).
Operators can scope to a subset with ``--max-docs N`` if iterating on
a slow vision pipeline.
"""
from __future__ import annotations
import json
import logging
import os
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from ....core.config import set_suite_state
from ....core.ingest_settings import IngestSettings, settings_header_line
from ....core.registry import RunContext
logger = logging.getLogger(__name__)
HF_REPO_ID = "yubo2333/MMLongBench-Doc"
HF_REPO_TYPE = "dataset"
# Lazy import: huggingface_hub + pyarrow are heavyweight; keep the
# benchmark module importable on machines that have only the core
# install (e.g. CI lint jobs).
def _hf_hub_download(*args, **kwargs):
from huggingface_hub import hf_hub_download
return hf_hub_download(*args, **kwargs)
def _list_repo_files() -> list[str]:
from huggingface_hub import list_repo_files
return list_repo_files(repo_id=HF_REPO_ID, repo_type=HF_REPO_TYPE)
# ---------------------------------------------------------------------------
# Question parquet -> Python rows
# ---------------------------------------------------------------------------
@dataclass
class MMLongBenchQuestion:
doc_id: str # filename inside the documents/ folder
doc_type: str
question: str
answer: str
answer_format: str # Str / Int / Float / List / None
evidence_pages: list[int]
evidence_sources: list[str]
def _load_questions_from_parquet(parquet_path: Path) -> list[MMLongBenchQuestion]:
import pyarrow.parquet as pq
table = pq.read_table(parquet_path)
rows = table.to_pylist()
out: list[MMLongBenchQuestion] = []
for row in rows:
doc_id = str(row.get("doc_id") or "").strip()
if not doc_id:
continue
question = str(row.get("question") or "").strip()
if not question:
continue
out.append(
MMLongBenchQuestion(
doc_id=doc_id,
doc_type=str(row.get("doc_type") or "").strip(),
question=question,
answer=str(row.get("answer") or "").strip(),
answer_format=str(row.get("answer_format") or "").strip(),
evidence_pages=_parse_int_list(row.get("evidence_pages")),
evidence_sources=_parse_str_list(row.get("evidence_sources")),
)
)
return out
def _parse_int_list(raw) -> list[int]:
if raw is None:
return []
if isinstance(raw, list):
out = []
for x in raw:
try:
out.append(int(x))
except (TypeError, ValueError):
continue
return out
text = str(raw).strip().strip("[]")
if not text:
return []
out: list[int] = []
for tok in text.split(","):
tok = tok.strip().strip("'\"")
if tok.isdigit():
out.append(int(tok))
return out
def _parse_str_list(raw) -> list[str]:
if raw is None:
return []
if isinstance(raw, list):
return [str(x).strip().strip("'\"") for x in raw if str(x).strip()]
text = str(raw).strip().strip("[]")
if not text:
return []
return [tok.strip().strip("'\"") for tok in text.split(",") if tok.strip()]
# ---------------------------------------------------------------------------
# Download helpers
# ---------------------------------------------------------------------------
def _download_questions_parquet(cache_dir: Path) -> Path:
"""Download every parquet under ``data/`` and concatenate.
The HF dataset usually publishes a single ``train`` split, but we
enumerate to be robust to repo restructuring.
"""
parquet_paths: list[Path] = []
files = _list_repo_files()
data_files = [f for f in files if f.startswith("data/") and f.endswith(".parquet")]
if not data_files:
raise RuntimeError(
f"No parquet files found under data/ in {HF_REPO_ID}; "
f"upstream repo may have been restructured."
)
for rel in sorted(data_files):
local = _hf_hub_download(
repo_id=HF_REPO_ID,
filename=rel,
repo_type=HF_REPO_TYPE,
cache_dir=str(cache_dir),
)
parquet_paths.append(Path(local))
logger.info("Cached MMLongBench parquet shard %s -> %s", rel, local)
return parquet_paths[0] if len(parquet_paths) == 1 else _merge_parquets(parquet_paths, cache_dir)
def _merge_parquets(paths: list[Path], cache_dir: Path) -> Path:
"""Combine multiple parquet shards into one (rare branch, but correct)."""
import pyarrow as pa
import pyarrow.parquet as pq
tables = [pq.read_table(p) for p in paths]
merged = pa.concat_tables(tables, promote_options="default")
out = cache_dir / "merged_questions.parquet"
pq.write_table(merged, out)
return out
def _download_pdf(doc_id: str, cache_dir: Path, pdfs_dir: Path) -> Path:
"""Download a single PDF (resumable via huggingface_hub cache)."""
rel = f"documents/{doc_id}"
local = _hf_hub_download(
repo_id=HF_REPO_ID,
filename=rel,
repo_type=HF_REPO_TYPE,
cache_dir=str(cache_dir),
)
# Materialise to a stable path inside our data/ tree so the runner
# has a deterministic location regardless of HF cache internals.
dest = pdfs_dir / doc_id
if not dest.exists() or dest.stat().st_size != Path(local).stat().st_size:
# Use a hardlink when possible (cheap), fall back to copy.
try:
if dest.exists():
dest.unlink()
os.link(local, dest)
except OSError:
from shutil import copy2
copy2(local, dest)
return dest
# ---------------------------------------------------------------------------
# Upload helpers
# ---------------------------------------------------------------------------
async def _upload_pdfs(
ctx: RunContext,
pdf_paths: Iterable[Path],
*,
batch_size: int,
settings: IngestSettings,
) -> dict[str, int]:
"""Upload PDFs in batches, return ``filename -> document_id`` map."""
docs_client = ctx.documents_client()
name_to_id: dict[str, int] = {}
pdf_list = list(pdf_paths)
for batch_start in range(0, len(pdf_list), batch_size):
batch = pdf_list[batch_start:batch_start + batch_size]
result = await docs_client.upload(
files=batch,
search_space_id=ctx.search_space_id,
should_summarize=settings.should_summarize,
use_vision_llm=settings.use_vision_llm,
processing_mode=settings.processing_mode,
)
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
if all_ids:
await docs_client.wait_until_ready(
search_space_id=ctx.search_space_id,
document_ids=result.document_ids, # only newly added need polling
timeout_s=1800.0, # vision pipeline is slow on long PDFs
)
statuses = await docs_client.get_status(
search_space_id=ctx.search_space_id,
document_ids=all_ids,
)
for s in statuses:
name_to_id[s.title] = s.document_id
logger.info(
"Uploaded MMLongBench batch %d-%d: %d new, %d duplicate",
batch_start, batch_start + len(batch),
len(result.document_ids), len(result.duplicate_document_ids),
)
return name_to_id
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
async def run_ingest(
ctx: RunContext,
*,
max_docs: int | None = None,
upload_batch_size: int = 8,
skip_upload: bool = False,
settings: IngestSettings | None = None,
) -> None:
"""Ingest MMLongBench-Doc into the multimodal_doc suite.
Parameters
----------
max_docs : int | None
Cap the number of PDFs to download + upload. ``None`` = all 135.
Useful when iterating on the runner without paying for the full
vision pipeline pass each time.
upload_batch_size : int
How many PDFs to send per ``fileupload`` call. Smaller batches
recover faster from individual failures; larger batches reduce
round-trip overhead.
skip_upload : bool
Download + cache PDFs locally but skip SurfSense ingestion.
Useful for testing the native arm in isolation.
"""
settings = settings or IngestSettings(use_vision_llm=True, processing_mode="basic")
bench_dir = ctx.benchmark_data_dir()
pdfs_dir = bench_dir / "pdfs"
pdfs_dir.mkdir(parents=True, exist_ok=True)
hf_cache = bench_dir / ".hf_cache"
hf_cache.mkdir(parents=True, exist_ok=True)
# Step 1: questions
parquet_path = _download_questions_parquet(hf_cache)
questions = _load_questions_from_parquet(parquet_path)
if not questions:
raise RuntimeError(
"MMLongBench-Doc parquet contains no parseable questions. "
"Upstream may have changed schema."
)
# Persist a copy alongside the PDFs so the runner has one place to read.
questions_jsonl = bench_dir / "questions.jsonl"
with questions_jsonl.open("w", encoding="utf-8") as fh:
for q in questions:
fh.write(json.dumps({
"doc_id": q.doc_id,
"doc_type": q.doc_type,
"question": q.question,
"answer": q.answer,
"answer_format": q.answer_format,
"evidence_pages": q.evidence_pages,
"evidence_sources": q.evidence_sources,
}) + "\n")
logger.info("Wrote %d MMLongBench questions to %s", len(questions), questions_jsonl)
# Step 2: download unique PDFs
unique_doc_ids = sorted({q.doc_id for q in questions})
if max_docs is not None and max_docs > 0:
unique_doc_ids = unique_doc_ids[:max_docs]
logger.info("MMLongBench: downloading %d unique PDFs", len(unique_doc_ids))
pdf_paths: dict[str, Path] = {}
for i, doc_id in enumerate(unique_doc_ids, start=1):
try:
pdf_paths[doc_id] = _download_pdf(doc_id, hf_cache, pdfs_dir)
if i % 10 == 0:
logger.info(" ... %d / %d PDFs cached", i, len(unique_doc_ids))
except Exception as exc: # noqa: BLE001
logger.warning("Failed to download MMLongBench PDF %s: %s", doc_id, exc)
# Step 3: upload to SurfSense
name_to_id: dict[str, int] = {}
if skip_upload:
logger.info("MMLongBench: --skip-upload set; skipping SurfSense ingestion")
else:
logger.info("MMLongBench upload settings: %s", settings.render_label())
name_to_id = await _upload_pdfs(
ctx,
pdf_paths.values(),
batch_size=upload_batch_size,
settings=settings,
)
# Step 4: persist doc_id -> document_id manifest
map_path = ctx.maps_dir() / "mmlongbench_doc_map.jsonl"
with map_path.open("w", encoding="utf-8") as fh:
# Header line records the resolved ingest settings
# (see core/ingest_settings.py).
fh.write(settings_header_line(settings) + "\n")
for doc_id in unique_doc_ids:
local = pdf_paths.get(doc_id)
if local is None:
continue
fh.write(json.dumps({
"doc_id": doc_id,
"document_id": name_to_id.get(local.name),
"pdf_path": str(local),
"n_questions": sum(1 for q in questions if q.doc_id == doc_id),
}) + "\n")
logger.info("Wrote MMLongBench doc map to %s", map_path)
new_state = ctx.suite_state
new_state.ingestion_maps["mmlongbench"] = str(map_path)
set_suite_state(ctx.config, ctx.suite, new_state)
__all__ = ["MMLongBenchQuestion", "run_ingest"]

View file

@ -0,0 +1,60 @@
"""MMLongBench-Doc prompt template.
Both arms get the same prompt only the document delivery channel
differs (native PDF embedded in the OpenRouter request vs SurfSense
RAG retrieval). The format hint in the prompt mirrors what the
upstream paper uses so the grader's regex can reliably extract the
answer.
"""
from __future__ import annotations
# ---------------------------------------------------------------------------
# Per-format hint blocks
# ---------------------------------------------------------------------------
_FORMAT_HINTS: dict[str, str] = {
"str": (
"Respond with the answer as a short phrase, no full sentence. "
"Format your final line as `Answer: <text>`."
),
"int": (
"Respond with a single integer only. "
"Format your final line as `Answer: <integer>`."
),
"float": (
"Respond with a single decimal number only (no units). "
"Format your final line as `Answer: <number>`."
),
"list": (
"Respond with a comma-separated list of items, no extra text. "
"Format your final line as `Answer: item1, item2, item3`."
),
"none": (
"If the answer cannot be determined from the document, say so explicitly. "
"Format your final line as `Answer: Not answerable`."
),
}
_PROMPT = """\
You are a document-understanding assistant. Use ONLY the provided
document to answer the question. The document may contain text,
tables, charts, figures, and images. If the answer is in a chart or
image, read it carefully. Do not use external knowledge.
Question: {question}
{format_hint}
"""
def build_prompt(question: str, *, answer_format: str) -> str:
"""Assemble the full prompt for one MMLongBench question."""
fmt = (answer_format or "str").strip().lower()
hint = _FORMAT_HINTS.get(fmt, _FORMAT_HINTS["str"])
return _PROMPT.format(question=question.strip(), format_hint=hint)
__all__ = ["build_prompt"]

View file

@ -0,0 +1,704 @@
"""MMLongBench-Doc runner — head-to-head Native PDF (vision) vs SurfSense (vision RAG).
Differences from a typical MCQ head-to-head:
* Open-ended answers (Str / Int / Float / List / Not-answerable) uses
``extract_freeform_answer`` instead of ``extract_answer_letter``.
* Format-aware grader (see ``.grader``) returns both binary correctness
(for accuracy / McNemar) and continuous F1 (for nuanced reporting).
* Native arm requires a vision-capable model we don't enforce this
in code (operator's choice via ``setup --provider-model``) but we
emit a warning if the pinned slug looks text-only.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import os
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from ....core.arms import ArmRequest, ArmResult, NativePdfArm, SurfSenseArm
from ....core.config import utc_iso_timestamp
from ....core.ingest_settings import (
IngestSettings,
add_ingest_settings_args,
format_ingest_settings_md,
is_settings_header,
)
from ....core.metrics.comparison import (
bootstrap_delta_ci,
mcnemar_test,
paired_aggregate,
)
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci
from ....core.parse.freeform_answer import extract_freeform_answer
from ....core.providers.openrouter_pdf import OpenRouterPdfProvider, PdfEngine
from ....core.registry import (
ReportSection,
RunArtifact,
RunContext,
)
from ....core.scenarios import format_scenario_md
from .grader import GradeResult, grade
from .prompt import build_prompt
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Question + map row shapes
# ---------------------------------------------------------------------------
@dataclass
class MMLBQuestion:
qid: str # synthesised from doc_id + index
doc_id: str # filename inside the documents/ folder
doc_type: str
question: str
gold_answer: str
answer_format: str
evidence_pages: list[int]
evidence_sources: list[str]
pdf_path: Path
document_id: int | None # SurfSense doc id (None if upload skipped)
def _load_doc_map(map_path: Path) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
"""Read the doc map JSONL.
Returns ``(rows, settings)`` where ``settings`` is the
``__settings__`` header blob (or ``{}`` for legacy maps).
"""
rows: dict[str, dict[str, Any]] = {}
settings: dict[str, Any] = {}
with map_path.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
row = json.loads(line)
if is_settings_header(row):
settings = dict(row["__settings__"])
continue
rows[str(row["doc_id"])] = row
return rows, settings
def _load_questions(
questions_jsonl: Path,
doc_map: dict[str, dict[str, Any]],
*,
doc_filter: list[str] | None,
format_filter: str | None,
sample_n: int | None,
skip_unanswerable: bool,
) -> list[MMLBQuestion]:
out: list[MMLBQuestion] = []
per_doc_counter: dict[str, int] = {}
with questions_jsonl.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
row = json.loads(line)
doc_id = str(row.get("doc_id") or "").strip()
if not doc_id:
continue
if doc_filter and doc_id not in doc_filter:
continue
map_row = doc_map.get(doc_id)
if map_row is None:
logger.debug("No doc-map entry for %s; skipping", doc_id)
continue
answer_format = str(row.get("answer_format") or "").strip().lower()
if format_filter and format_filter != "all" and format_filter != answer_format:
continue
gold = str(row.get("answer") or "").strip()
if skip_unanswerable and answer_format == "none":
continue
idx = per_doc_counter.get(doc_id, 0)
per_doc_counter[doc_id] = idx + 1
out.append(MMLBQuestion(
qid=f"{doc_id}::Q{idx:03d}",
doc_id=doc_id,
doc_type=str(row.get("doc_type") or "").strip(),
question=str(row.get("question") or "").strip(),
gold_answer=gold,
answer_format=answer_format,
evidence_pages=list(row.get("evidence_pages") or []),
evidence_sources=list(row.get("evidence_sources") or []),
pdf_path=Path(map_row["pdf_path"]),
document_id=map_row.get("document_id"),
))
out.sort(key=lambda q: (q.doc_id, q.qid))
if sample_n is not None and sample_n > 0:
out = out[:sample_n]
return out
# ---------------------------------------------------------------------------
# Bounded concurrency helper
# ---------------------------------------------------------------------------
async def _gather_with_limit(coros: Iterable, *, concurrency: int) -> list[Any]:
sem = asyncio.Semaphore(max(1, concurrency))
async def _wrap(coro):
async with sem:
return await coro
return await asyncio.gather(*(_wrap(c) for c in coros))
# ---------------------------------------------------------------------------
# Benchmark
# ---------------------------------------------------------------------------
_DESCRIPTION = (
"MMLongBench-Doc (135 long PDFs, 1,091 multimodal questions) — "
"Native PDF (vision) vs SurfSense (vision RAG) head-to-head."
)
_TEXT_ONLY_HINTS = ("gpt-5.4-mini", "gpt-3.5", "text-only", "instruct-")
# MMLongBench-Doc PDFs are long documents with figures, charts, and
# tables. Vision LLM at ingest is the whole point; flip --no-vision-llm
# to measure how much SurfSense degrades on real document images.
_DEFAULT_INGEST_SETTINGS = IngestSettings(
use_vision_llm=True,
processing_mode="basic",
should_summarize=False,
)
class MMLongBenchDocBenchmark:
"""Long-document multimodal RAG vs native vision."""
suite: str = "multimodal_doc"
name: str = "mmlongbench"
headline: bool = True
description: str = _DESCRIPTION
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--docs",
default=None,
help="Comma-separated doc_ids (filenames) to run (default: all).",
)
parser.add_argument(
"--format",
default="all",
choices=["all", "str", "int", "float", "list", "none"],
help="Filter to one answer format. 'none' = unanswerable probes only.",
)
parser.add_argument(
"--n", dest="sample_n", type=int, default=None,
help="Run only the first N questions after filters apply.",
)
parser.add_argument(
"--skip-unanswerable", dest="skip_unanswerable", action="store_true",
help="Drop ~22%% unanswerable questions (use to compare against baselines that don't include them).",
)
parser.add_argument(
"--concurrency", type=int, default=4,
help="Parallel question workers per arm.",
)
parser.add_argument(
"--no-mentions", dest="no_mentions", action="store_true",
help="SurfSense arm: skip mentioned_document_ids (unscoped retrieval).",
)
parser.add_argument(
"--pdf-engine", default="native",
choices=[e.value for e in PdfEngine],
help="OpenRouter file-parser engine for the native arm.",
)
parser.add_argument(
"--max-output-tokens", type=int, default=512,
help="Cap on completion length for both arms.",
)
# Ingest-only knobs (forwarded by the CLI to ingest.run_ingest).
parser.add_argument(
"--max-docs", dest="max_docs", type=int, default=None,
help="(ingest only) cap on number of unique PDFs to download + upload.",
)
parser.add_argument(
"--upload-batch-size", dest="upload_batch_size", type=int, default=8,
help="(ingest only) PDFs per fileupload call.",
)
parser.add_argument(
"--skip-upload", dest="skip_upload", action="store_true",
help="(ingest only) cache PDFs locally but don't push to SurfSense.",
)
# Per-upload knobs forwarded to /documents/fileupload at ingest;
# ignored at run-time (runner reads the resolved settings out of
# the doc-map manifest header).
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
from .ingest import run_ingest
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
await run_ingest(
ctx,
max_docs=opts.get("max_docs"),
upload_batch_size=int(opts.get("upload_batch_size") or 8),
skip_upload=bool(opts.get("skip_upload", False)),
settings=settings,
)
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
docs_raw: str | None = opts.get("docs")
doc_filter = [d.strip() for d in docs_raw.split(",")] if docs_raw else None
format_filter = opts.get("format") or "all"
sample_n = opts.get("sample_n")
skip_unanswerable = bool(opts.get("skip_unanswerable"))
concurrency = int(opts.get("concurrency") or 4)
no_mentions = bool(opts.get("no_mentions"))
pdf_engine_name = opts.get("pdf_engine") or "native"
max_output_tokens = int(opts.get("max_output_tokens") or 512)
bench_dir = ctx.benchmark_data_dir()
questions_jsonl = bench_dir / "questions.jsonl"
map_path = ctx.maps_dir() / "mmlongbench_doc_map.jsonl"
if not questions_jsonl.exists() or not map_path.exists():
raise RuntimeError(
"MMLongBench-Doc not ingested for this suite. Run "
"`python -m surfsense_evals ingest multimodal_doc mmlongbench` first."
)
doc_map, ingest_settings = _load_doc_map(map_path)
questions = _load_questions(
questions_jsonl, doc_map,
doc_filter=doc_filter,
format_filter=None if format_filter == "all" else format_filter,
sample_n=sample_n,
skip_unanswerable=skip_unanswerable,
)
if not questions:
raise RuntimeError(
"No MMLongBench questions matched the filters; broaden --docs/--format/--n."
)
logger.info("MMLongBench-Doc: scheduled %d questions", len(questions))
api_key = os.environ.get("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError(
"OPENROUTER_API_KEY env var is required for the native arm."
)
# Native arm slug differs from SurfSense slug only in cost-arbitrage
# scenario; otherwise both arms answer with provider_model.
native_arm_model = ctx.native_arm_model
if any(hint in native_arm_model.lower() for hint in _TEXT_ONLY_HINTS):
if ctx.scenario == "symmetric-cheap":
logger.info(
"symmetric-cheap: native arm pinned to text-only %r as "
"intended; expect it to lose on image-bearing pages "
"(SurfSense answers from vision-extracted chunks).",
native_arm_model,
)
else:
logger.warning(
"Native arm slug %r looks text-only; image content in "
"PDFs will be ignored. Re-pin via "
"`setup --provider-model anthropic/claude-sonnet-4.5` "
"(or pass --native-arm-model and --scenario cost-arbitrage "
"to make this asymmetry explicit).",
native_arm_model,
)
provider = OpenRouterPdfProvider(
api_key=api_key,
base_url=ctx.config.openrouter_base_url,
model=native_arm_model,
engine=PdfEngine(pdf_engine_name),
)
native_arm = NativePdfArm(provider=provider, max_output_tokens=max_output_tokens)
surf_arm = SurfSenseArm(
client=ctx.new_chat_client(),
search_space_id=ctx.search_space_id,
ephemeral_threads=True,
)
run_timestamp = utc_iso_timestamp()
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
raw_path = run_dir / "raw.jsonl"
async def _native_one(q: MMLBQuestion) -> ArmResult:
return await native_arm.answer(_make_native_request(q, max_output_tokens))
async def _surf_one(q: MMLBQuestion) -> ArmResult:
return await surf_arm.answer(_make_surfsense_request(q, no_mentions=no_mentions))
native_results, surf_results = await asyncio.gather(
_gather_with_limit((_native_one(q) for q in questions), concurrency=concurrency),
_gather_with_limit((_surf_one(q) for q in questions), concurrency=concurrency),
)
native_grades = [_grade_one(q, r) for q, r in zip(questions, native_results, strict=False)]
surf_grades = [_grade_one(q, r) for q, r in zip(questions, surf_results, strict=False)]
with raw_path.open("w", encoding="utf-8") as fh:
for q, n_res, s_res, n_g, s_g in zip(
questions, native_results, surf_results, native_grades, surf_grades, strict=False
):
meta = {
"qid": q.qid,
"doc_id": q.doc_id,
"doc_type": q.doc_type,
"answer_format": q.answer_format,
"gold": q.gold_answer,
"evidence_pages": q.evidence_pages,
"evidence_sources": q.evidence_sources,
"document_id": q.document_id,
}
fh.write(json.dumps({
**meta,
**n_res.to_jsonl(),
"graded": _grade_to_jsonl(n_g),
}) + "\n")
fh.write(json.dumps({
**meta,
**s_res.to_jsonl(),
"graded": _grade_to_jsonl(s_g),
}) + "\n")
metrics = _compute_metrics(questions, native_results, surf_results, native_grades, surf_grades)
artifact = RunArtifact(
suite=self.suite,
benchmark=self.name,
run_timestamp=run_timestamp,
raw_path=raw_path,
metrics=metrics,
extra={
"n_questions": len(questions),
"concurrency": concurrency,
"format_filter": format_filter,
"skip_unanswerable": skip_unanswerable,
"no_mentions": no_mentions,
"pdf_engine": pdf_engine_name,
"scenario": ctx.scenario,
"provider_model": ctx.provider_model,
"native_arm_model": native_arm_model,
"vision_provider_model": ctx.vision_provider_model,
"agent_llm_id": ctx.agent_llm_id,
"ingest_settings": ingest_settings,
},
)
manifest_path = run_dir / "run_artifact.json"
manifest_path.write_text(
json.dumps({
"suite": self.suite,
"benchmark": self.name,
"raw_path": "raw.jsonl",
"metrics": metrics,
"extra": artifact.extra,
}, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
return artifact
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
if not artifacts:
return ReportSection(
title="MMLongBench-Doc — Native PDF (vision) vs SurfSense (vision RAG)",
headline=True,
body_md="(no run artifacts found)",
body_json={},
)
latest = max(artifacts, key=lambda a: a.run_timestamp)
m = latest.metrics
native = m.get("native", {})
surf = m.get("surfsense", {})
delta = m.get("delta", {})
per_format = m.get("per_format", {})
extra = latest.extra
body_lines: list[str] = []
body_lines.append(
f"- Sample size: {extra.get('n_questions', '?')} questions "
f"(format filter: `{extra.get('format_filter', 'all')}`, "
f"skip-unanswerable: `{extra.get('skip_unanswerable', False)}`, "
f"engine: `{extra.get('pdf_engine', 'native')}`)."
)
body_lines.append(format_scenario_md(extra))
body_lines.append(format_ingest_settings_md(extra.get("ingest_settings")))
body_lines.append(
"- Native arm (OpenRouter `chat/completions` + file plugin, "
f"`{extra.get('native_arm_model') or extra.get('provider_model', '?')}`):"
)
body_lines.append(_arm_summary_lines(native, indent=" "))
body_lines.append(
"- SurfSense arm (`POST /api/v1/new_chat`, vision RAG over chunks, "
f"`{extra.get('provider_model', '?')}`):"
)
body_lines.append(_arm_summary_lines(surf, indent=" "))
body_lines.append("- Delta (paired):")
body_lines.append(
f" - Accuracy: SurfSense {_pp(delta.get('accuracy_pp'))} pp "
f"(McNemar p={_fmt(delta.get('mcnemar_p_value'), 4)}, "
f"method={delta.get('mcnemar_method')})"
)
body_lines.append(
f" - F1 (mean): SurfSense {_pp(delta.get('f1_pp'))} pp"
)
body_lines.append(
f" - Bootstrap 95% CI on accuracy delta: "
f"[{_pp(delta.get('bootstrap_ci_low'))}pp, {_pp(delta.get('bootstrap_ci_high'))}pp]"
)
body_lines.append(
f" - Cost / question: native ${_dollars(native.get('cost_micros_mean'))}, "
f"surfsense ${_dollars(surf.get('cost_micros_mean'))} "
f"(SurfSense delta {_pct_change(delta.get('cost_micros_pct'))})"
)
body_lines.append(
f" - Latency p50: native {_ms_to_s(native.get('latency_ms_median'))}, "
f"surfsense {_ms_to_s(surf.get('latency_ms_median'))} "
f"(SurfSense delta {_pct_change(delta.get('latency_ms_pct'))})"
)
if per_format:
body_lines.append("- Per-format split (accuracy delta in pp):")
for fmt, vals in sorted(per_format.items()):
body_lines.append(
f" - {fmt}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
f"(n={vals.get('n')}, native acc={vals.get('native_accuracy', 0)*100:.1f}%, "
f"surf acc={vals.get('surfsense_accuracy', 0)*100:.1f}%)"
)
return ReportSection(
title="MMLongBench-Doc — Native PDF (vision) vs SurfSense (vision RAG)",
headline=True,
body_md="\n".join(body_lines),
body_json=m,
)
# ---------------------------------------------------------------------------
# Per-question helpers
# ---------------------------------------------------------------------------
def _make_native_request(q: MMLBQuestion, max_tokens: int) -> ArmRequest:
prompt = build_prompt(q.question, answer_format=q.answer_format)
return ArmRequest(
question_id=q.qid,
prompt=prompt,
pdf_paths=[q.pdf_path],
options={"max_tokens": max_tokens},
)
def _make_surfsense_request(q: MMLBQuestion, *, no_mentions: bool) -> ArmRequest:
prompt = build_prompt(q.question, answer_format=q.answer_format)
mentions: list[int] | None = None
if not no_mentions and q.document_id is not None:
mentions = [int(q.document_id)]
return ArmRequest(
question_id=q.qid,
prompt=prompt,
mentioned_document_ids=mentions,
)
def _grade_one(q: MMLBQuestion, result: ArmResult) -> GradeResult:
pred_text = extract_freeform_answer(result.raw_text or "")
return grade(pred=pred_text, gold=q.gold_answer, answer_format=q.answer_format)
def _grade_to_jsonl(g: GradeResult) -> dict[str, Any]:
return {
"correct": g.correct,
"f1": g.f1,
"method": g.method,
"normalised_pred": g.normalised_pred,
"normalised_gold": g.normalised_gold,
}
# ---------------------------------------------------------------------------
# Metrics aggregation
# ---------------------------------------------------------------------------
def _compute_metrics(
questions: list[MMLBQuestion],
native_results: list[ArmResult],
surf_results: list[ArmResult],
native_grades: list[GradeResult],
surf_grades: list[GradeResult],
) -> dict[str, Any]:
native_correct = [g.correct for g in native_grades]
surf_correct = [g.correct for g in surf_grades]
native_f1 = [g.f1 for g in native_grades]
surf_f1 = [g.f1 for g in surf_grades]
native_costs = [float(r.cost_micros) for r in native_results]
surf_costs = [float(r.cost_micros) for r in surf_results]
native_latencies = [float(r.latency_ms) for r in native_results]
surf_latencies = [float(r.latency_ms) for r in surf_results]
native_in_tokens = [float(r.input_tokens) for r in native_results]
native_out_tokens = [float(r.output_tokens) for r in native_results]
native_acc = accuracy_with_wilson_ci(sum(native_correct), len(native_correct))
surf_acc = accuracy_with_wilson_ci(sum(surf_correct), len(surf_correct))
mc = mcnemar_test(native_correct, surf_correct)
boot = bootstrap_delta_ci(native_correct, surf_correct, n_resamples=2000)
native_cost_agg = paired_aggregate(native_costs)
surf_cost_agg = paired_aggregate(surf_costs)
native_latency_agg = paired_aggregate(native_latencies)
surf_latency_agg = paired_aggregate(surf_latencies)
cost_pct = _safe_pct(surf_cost_agg.mean, native_cost_agg.mean)
latency_pct = _safe_pct(surf_latency_agg.median, native_latency_agg.median)
per_format_pairs: dict[str, list[tuple[bool, bool]]] = {}
for q, n_ok, s_ok in zip(questions, native_correct, surf_correct, strict=False):
per_format_pairs.setdefault(q.answer_format or "unknown", []).append((n_ok, s_ok))
per_format: dict[str, dict[str, Any]] = {}
for fmt, pairs in per_format_pairs.items():
n_correct = [a for a, _ in pairs]
s_correct = [b for _, b in pairs]
per_format[fmt] = {
"n": len(pairs),
"native_accuracy": (sum(n_correct) / len(pairs)) if pairs else 0.0,
"surfsense_accuracy": (sum(s_correct) / len(pairs)) if pairs else 0.0,
"delta_accuracy_pp": (
100.0 * (sum(s_correct) - sum(n_correct)) / len(pairs)
if pairs else 0.0
),
}
native_f1_mean = sum(native_f1) / len(native_f1) if native_f1 else 0.0
surf_f1_mean = sum(surf_f1) / len(surf_f1) if surf_f1 else 0.0
return {
"native": {
**native_acc.to_dict(),
"f1_mean": native_f1_mean,
"cost_micros_mean": native_cost_agg.mean,
"cost_micros_median": native_cost_agg.median,
"latency_ms_mean": native_latency_agg.mean,
"latency_ms_median": native_latency_agg.median,
"latency_ms_p95": native_latency_agg.p95,
"input_tokens_mean": (sum(native_in_tokens) / len(native_in_tokens)) if native_in_tokens else 0.0,
"output_tokens_mean": (sum(native_out_tokens) / len(native_out_tokens)) if native_out_tokens else 0.0,
},
"surfsense": {
**surf_acc.to_dict(),
"f1_mean": surf_f1_mean,
"cost_micros_mean": surf_cost_agg.mean,
"cost_micros_median": surf_cost_agg.median,
"latency_ms_mean": surf_latency_agg.mean,
"latency_ms_median": surf_latency_agg.median,
"latency_ms_p95": surf_latency_agg.p95,
},
"delta": {
"accuracy_pp": 100.0 * (surf_acc.accuracy - native_acc.accuracy),
"f1_pp": 100.0 * (surf_f1_mean - native_f1_mean),
"mcnemar_p_value": mc.p_value,
"mcnemar_method": mc.method,
"mcnemar_b_native_only": mc.b,
"mcnemar_c_surfsense_only": mc.c,
"bootstrap_ci_low": 100.0 * boot.ci_low,
"bootstrap_ci_high": 100.0 * boot.ci_high,
"cost_micros_pct": cost_pct,
"latency_ms_pct": latency_pct,
},
"per_format": per_format,
}
def _safe_pct(numerator: float, denominator: float) -> float | None:
if denominator == 0:
return None
return 100.0 * (numerator - denominator) / denominator
# ---------------------------------------------------------------------------
# Tiny formatting helpers used by report_section
# ---------------------------------------------------------------------------
def _arm_summary_lines(d: dict[str, Any], *, indent: str) -> str:
if not d:
return f"{indent}(no data)"
acc = d.get("accuracy", 0.0)
low = d.get("ci_low", 0.0)
high = d.get("ci_high", 0.0)
f1 = d.get("f1_mean", 0.0)
lines = [
f"{indent}- Accuracy: {acc * 100:.1f}% (Wilson 95% CI: {low * 100:.1f}% {high * 100:.1f}%)",
f"{indent}- F1 (token-level mean): {f1 * 100:.1f}%",
f"{indent}- Cost / question: ${_dollars(d.get('cost_micros_mean'))} (mean), "
f"${_dollars(d.get('cost_micros_median'))} (median)",
f"{indent}- Latency: p50 {_ms_to_s(d.get('latency_ms_median'))}, "
f"p95 {_ms_to_s(d.get('latency_ms_p95'))}",
]
if "input_tokens_mean" in d:
lines.append(
f"{indent}- Mean tokens / question: in {d.get('input_tokens_mean', 0):.0f}, "
f"out {d.get('output_tokens_mean', 0):.0f}"
)
return "\n".join(lines)
def _dollars(micros: Any) -> str:
if micros is None:
return "?"
try:
return f"{(float(micros) / 1_000_000):.4f}"
except (TypeError, ValueError):
return "?"
def _ms_to_s(ms: Any) -> str:
if ms is None:
return "?"
try:
return f"{float(ms) / 1000:.1f}s"
except (TypeError, ValueError):
return "?"
def _pp(value: Any) -> str:
if value is None:
return "?"
try:
return f"{float(value):+.1f}"
except (TypeError, ValueError):
return "?"
def _pct_change(value: Any) -> str:
if value is None:
return "?"
try:
return f"{float(value):+.0f}%"
except (TypeError, ValueError):
return "?"
def _fmt(value: Any, ndigits: int) -> str:
if value is None:
return "?"
try:
return f"{float(value):.{ndigits}f}"
except (TypeError, ValueError):
return "?"
__all__ = ["MMLBQuestion", "MMLongBenchDocBenchmark"]

View file

@ -0,0 +1,18 @@
"""Research / multi-document RAG benchmarks.
Distinct from ``multimodal_doc`` (PDF-bound) and ``medical`` (one
question = one source PDF). Benchmarks here put *retrieval and
reasoning across many documents* in the critical path the regime
where SurfSense's chunk-level RAG should shine versus "pour the
entire document into the LLM" or "ask the LLM cold".
* ``frames`` (google/frames-benchmark) 824 multi-hop Wikipedia
questions; tests bare-LLM vs SurfSense over a shared ~330-doc
corpus.
* ``crag`` (facebookresearch/CRAG, KDD Cup 2024) 2,706 web QA
pairs with 5 pre-retrieved HTML pages each; tests bare-LLM vs
long-context-stuffed LLM vs SurfSense over the question's 5
scoped pages the closest comparison to a competing RAG product.
"""
from __future__ import annotations

View file

@ -0,0 +1,57 @@
"""CRAG — Comprehensive RAG Benchmark (Yang et al., Meta, KDD Cup 2024).
Source: https://github.com/facebookresearch/CRAG (Tasks 1, 2, and 3)
Paper: https://arxiv.org/abs/2406.04744
This package registers two siblings:
* ``crag`` Tasks 1 & 2: 5 candidate pages per question.
* ``crag_t3`` Task 3: 50 candidate pages per question. The
long-context arm is capped to the top-5 (the realistic "naive
RAG = pick top-K results" baseline); SurfSense retrieves over
all 50, where its rerank becomes the entire contribution.
Both share the grader, prompt, runner, and report code; only the
ingest path differs (single bz2 vs 4-part tar.bz2 streamed).
CRAG ships ~2,706 factual QA pairs, each paired with **5 full HTML
pages** retrieved as the top-5 of a real web search at ``query_time``
(50 in Task 3).
The benchmark spans 5 domains (finance, music, movie, sports, open)
and 8 question types (simple, comparison, aggregation, set, multi-hop,
post-processing, false_premise, simple_w_condition) heads/torsos/
tails of entity popularity and an explicit staticreal-time
freshness axis.
Why CRAG demonstrates SurfSense more clearly than FRAMES
--------------------------------------------------------
FRAMES tested SurfSense vs. *no retrieval at all* a fair "naive
prompting" baseline (the published 40.8% number) but not a competing
RAG product. CRAG enables a three-way comparison:
* ``bare_llm`` chat completion with the question only. CRAG
paper: 34% accuracy ("LLM cold").
* ``long_context`` stuff all 5 extracted page texts straight into
the prompt (the "naive RAG" / "straightforward RAG" arm in the
paper). Published baseline: ~44%.
* ``surfsense`` POST ``/api/v1/new_chat`` with retrieval scoped
to the question's 5 ingested pages (``mentioned_document_ids``).
So the headline becomes "SurfSense vs. context-stuffed long-context
LLM, both fed the same 5 pages" — a head-to-head against the simplest
realistic RAG strategy, not against an unarmed model.
Scoring follows the CRAG paper: each prediction is graded as
**correct** (+1), **missing/I-don't-know** (0), or **incorrect** (-1),
and the headline metric is the *Truthfulness Score*:
``(#correct - #incorrect) / total`` — penalising hallucinations
relative to refusals.
"""
from __future__ import annotations
from ....core import registry as _registry
from .runner import CragBenchmark, CragTask3Benchmark
_registry.register(CragBenchmark())
_registry.register(CragTask3Benchmark())

View file

@ -0,0 +1,335 @@
"""CRAG dataset loader — download ``crag_task_1_and_2_dev_v4.jsonl.bz2`` and parse.
The CRAG repo (``facebookresearch/CRAG``) ships Tasks 1 & 2 as a
single bzip2-compressed JSONL on GitHub raw. Each row carries:
* ``interaction_id`` opaque per-question id (we keep verbatim)
* ``query_time`` wall clock of the original web search
* ``domain`` finance | music | movie | sports | open
* ``question_type`` simple | comparison | aggregation | set |
multi-hop | post-processing | false_premise |
simple_w_condition
* ``static_or_dynamic`` static | slow-changing | fast-changing | real-time
* ``query`` the question
* ``answer`` gold short answer
* ``alt_ans`` list[str] of alternative valid answers
(paraphrases / synonyms / unit variants)
* ``split`` 0 = validation, 1 = public test
* ``popularity`` head | torso | tail (KG questions); empty for web
* ``search_results`` list of up to 5 ``{page_name, page_url,
page_snippet, page_result, page_last_modified}``;
``page_result`` is full HTML.
We materialise this into ``CragQuestion`` objects keeping ``pages`` as
a list of ``CragPage`` so downstream ingest can save each as its own
file and SurfSense can dedupe on filename.
"""
from __future__ import annotations
import bz2
import hashlib
import io
import json
import logging
import urllib.request
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# Tasks 1 & 2 share the same JSONL on the public CRAG repo.
CRAG_TASK_1_2_URL = (
"https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/"
"crag_task_1_and_2_dev_v4.jsonl.bz2"
)
CRAG_TASK_1_2_FILENAME = "crag_task_1_and_2_dev_v4.jsonl.bz2"
# ---------------------------------------------------------------------------
# Question / page dataclasses
# ---------------------------------------------------------------------------
@dataclass
class CragPage:
"""One of the up-to-5 pre-retrieved web pages for a CRAG question."""
page_name: str
page_url: str
page_snippet: str
page_html: str
page_last_modified: str | None = None
@property
def url_hash(self) -> str:
"""Stable 12-hex digest of the page URL for filename keys.
We can't use the raw URL as a filename (slashes, query strings,
unicode), and we *do* want collision-safety across the whole
ingest sample. ``sha1[:12]`` gives us 48 bits of namespace
which is overkill for a corpus capped at a few thousand pages.
"""
return hashlib.sha1(self.page_url.encode("utf-8")).hexdigest()[:12]
@dataclass
class CragQuestion:
"""One row of CRAG (Tasks 1 & 2)."""
qid: str # synthesised "C00000".."C02705"
interaction_id: str
query_time: str
query: str
gold_answer: str
alt_answers: list[str]
domain: str
question_type: str
static_or_dynamic: str
popularity: str # may be "" for web-sourced questions
split: int # 0=validation, 1=public_test
raw_index: int # row index in the source JSONL
pages: list[CragPage] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"qid": self.qid,
"interaction_id": self.interaction_id,
"query_time": self.query_time,
"query": self.query,
"gold_answer": self.gold_answer,
"alt_answers": list(self.alt_answers),
"domain": self.domain,
"question_type": self.question_type,
"static_or_dynamic": self.static_or_dynamic,
"popularity": self.popularity,
"split": self.split,
"raw_index": self.raw_index,
"n_pages": len(self.pages),
"page_urls": [p.page_url for p in self.pages],
}
# ---------------------------------------------------------------------------
# Download + decompress
# ---------------------------------------------------------------------------
def download_task_1_2(cache_dir: Path) -> Path:
"""Download the bz2 archive into ``cache_dir`` (skip if cached).
Returns the path to the local ``.jsonl.bz2``. We use stdlib
``urllib`` rather than ``httpx`` to keep the download synchronous
and trivially resumable (re-running the function is a no-op once
the file is on disk and non-empty).
"""
cache_dir.mkdir(parents=True, exist_ok=True)
dest = cache_dir / CRAG_TASK_1_2_FILENAME
if dest.exists() and dest.stat().st_size > 0:
logger.debug("CRAG bz2 already cached at %s", dest)
return dest
logger.info("Downloading CRAG (Tasks 1 & 2) from %s ...", CRAG_TASK_1_2_URL)
tmp = dest.with_suffix(dest.suffix + ".part")
req = urllib.request.Request(
CRAG_TASK_1_2_URL,
headers={"User-Agent": "SurfSense-Evals/0.1 (CRAG dataset fetch)"},
)
with urllib.request.urlopen(req, timeout=600) as response, tmp.open("wb") as fh:
chunk = response.read(1 << 20)
while chunk:
fh.write(chunk)
chunk = response.read(1 << 20)
tmp.replace(dest)
logger.info("CRAG bz2 downloaded: %s (%.1f MiB)", dest, dest.stat().st_size / 1024 / 1024)
return dest
# ---------------------------------------------------------------------------
# Parse
# ---------------------------------------------------------------------------
def _parse_pages(raw_search_results: Any) -> list[CragPage]:
if not isinstance(raw_search_results, list):
return []
pages: list[CragPage] = []
for entry in raw_search_results:
if not isinstance(entry, dict):
continue
url = str(entry.get("page_url") or "").strip()
html = str(entry.get("page_result") or "")
if not url or not html.strip():
# No URL or empty HTML => useless for retrieval.
continue
pages.append(CragPage(
page_name=str(entry.get("page_name") or "").strip(),
page_url=url,
page_snippet=str(entry.get("page_snippet") or "").strip(),
page_html=html,
page_last_modified=(
str(entry.get("page_last_modified")).strip()
if entry.get("page_last_modified") else None
),
))
return pages
def _parse_alt_answers(raw: Any) -> list[str]:
if isinstance(raw, list):
return [str(x).strip() for x in raw if str(x).strip()]
if isinstance(raw, str) and raw.strip():
return [raw.strip()]
return []
def iter_questions(jsonl_bz2_path: Path) -> list[CragQuestion]:
"""Stream-decompress + parse the CRAG JSONL into ``CragQuestion`` objects.
The bz2 expansion ratio is ~10x and the decompressed file is
multi-GB; we therefore decompress *line by line* via
``bz2.open(..., "rt")``. Each row is a single (potentially very
large, due to embedded HTML) JSON object. We keep the entire row
in memory because we materialise the pages to disk immediately
after parsing in the ingest pipeline the runner never holds
more than the current sample's worth of HTML.
"""
out: list[CragQuestion] = []
with bz2.open(jsonl_bz2_path, mode="rt", encoding="utf-8") as fh:
for raw_idx, line in enumerate(fh):
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
except json.JSONDecodeError as exc:
logger.warning("Skipping malformed CRAG row %d: %s", raw_idx, exc)
continue
query = str(row.get("query") or "").strip()
answer = str(row.get("answer") or "").strip()
if not query or not answer:
logger.debug("Skipping CRAG row %d with missing query/answer", raw_idx)
continue
interaction_id = str(row.get("interaction_id") or "").strip()
pages = _parse_pages(row.get("search_results"))
out.append(CragQuestion(
qid=f"C{raw_idx:05d}",
interaction_id=interaction_id,
query_time=str(row.get("query_time") or "").strip(),
query=query,
gold_answer=answer,
alt_answers=_parse_alt_answers(row.get("alt_ans")),
domain=str(row.get("domain") or "").strip().lower(),
question_type=str(row.get("question_type") or "").strip().lower(),
static_or_dynamic=str(row.get("static_or_dynamic") or "").strip().lower(),
popularity=str(row.get("popularity") or "").strip().lower(),
split=int(row.get("split") or 0),
raw_index=raw_idx,
pages=pages,
))
return out
def stratified_sample(
questions: list[CragQuestion],
*,
n: int,
seed: int = 17,
) -> list[CragQuestion]:
"""Take ``n`` questions that roughly preserve the domain × question-type mix.
CRAG is only ~2.7k rows so naive head-of-list sampling badly
over-weights ``finance`` (because the dataset isn't shuffled by
domain). We bucket on ``(domain, question_type)`` and round-robin
pick from each bucket until we hit ``n`` this gives every
bucket a fair shot and keeps the sample composition stable across
re-runs (deterministic via the seeded shuffle inside each bucket).
"""
if n <= 0 or n >= len(questions):
return list(questions)
import random
rng = random.Random(seed)
buckets: dict[tuple[str, str], list[CragQuestion]] = {}
for q in questions:
buckets.setdefault((q.domain, q.question_type), []).append(q)
for items in buckets.values():
rng.shuffle(items)
keys = sorted(buckets.keys())
chosen: list[CragQuestion] = []
cursor = 0
while len(chosen) < n and any(buckets[k] for k in keys):
key = keys[cursor % len(keys)]
cursor += 1
if buckets[key]:
chosen.append(buckets[key].pop())
chosen.sort(key=lambda q: q.raw_index)
return chosen
def write_questions_jsonl(questions: list[CragQuestion], dest: Path) -> None:
"""Persist a parsed copy (without page HTML) under the benchmark data dir."""
dest.parent.mkdir(parents=True, exist_ok=True)
with dest.open("w", encoding="utf-8") as fh:
for q in questions:
fh.write(json.dumps(q.to_dict()) + "\n")
# ---------------------------------------------------------------------------
# Reading the lightweight questions.jsonl back
# ---------------------------------------------------------------------------
def load_questions_jsonl(path: Path) -> list[dict[str, Any]]:
"""Re-load the lightweight (no-HTML) questions JSONL from disk."""
out: list[dict[str, Any]] = []
if not path.exists():
return out
with path.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
try:
out.append(json.loads(line))
except json.JSONDecodeError:
continue
return out
# ---------------------------------------------------------------------------
# Convenience: decompress a snippet to memory for tests
# ---------------------------------------------------------------------------
def decompress_to_memory(jsonl_bz2_path: Path) -> io.StringIO:
"""For tests / one-off scripts: read the whole bz2 into a StringIO.
Avoids leaking gigabytes; use ``iter_questions`` in production.
"""
with bz2.open(jsonl_bz2_path, mode="rb") as fh:
return io.StringIO(fh.read().decode("utf-8"))
__all__ = [
"CRAG_TASK_1_2_FILENAME",
"CRAG_TASK_1_2_URL",
"CragPage",
"CragQuestion",
"decompress_to_memory",
"download_task_1_2",
"iter_questions",
"load_questions_jsonl",
"stratified_sample",
"write_questions_jsonl",
]

View file

@ -0,0 +1,263 @@
"""CRAG Task 3 dataset loader — 4-part tar.bz2 → streaming JSONL.
Task 3 ships ~7 GB of compressed data split into 4 parts on GitHub:
crag_task_3_dev_v4.tar.bz2.part1 (2 GB)
crag_task_3_dev_v4.tar.bz2.part2 (2 GB)
crag_task_3_dev_v4.tar.bz2.part3 (2 GB)
crag_task_3_dev_v4.tar.bz2.part4 (1.3 GB)
Concatenated, they form a tar archive containing a single JSONL file.
Decompressed, that JSONL is on the order of 30-50 GB because each row
embeds 50 full HTML pages (vs 5 in Tasks 1 & 2).
Materialising the JSONL would blow the disk budget (we have ~50 GB
free at the time of writing), so we stream the whole thing instead:
1. Download parts (idempotent; ``scripts/download_crag_task3.py``).
2. Concat them into a virtual file via ``_MultiPartReader``.
3. Wrap in ``bz2.BZ2File`` for on-the-fly decompression.
4. Wrap in ``tarfile.open(fileobj=..., mode="r|")`` for streaming
tar member iteration.
5. For the JSONL member inside, ``tar.extractfile()`` returns a
binary file-like; we iterate lines and yield parsed dicts.
The caller can ``break`` out as soon as they have enough samples
nothing past the consumed point is decompressed.
Schema is identical to Tasks 1 & 2 (see ``dataset.py``); only
``search_results`` is bigger (50 entries instead of 5).
"""
from __future__ import annotations
import bz2
import json
import logging
import tarfile
from collections.abc import Iterator
from pathlib import Path
from typing import IO
from .dataset import (
CragPage,
CragQuestion,
_parse_alt_answers,
_parse_pages,
)
logger = logging.getLogger(__name__)
CRAG_TASK_3_PART_URLS: tuple[str, ...] = tuple(
"https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/"
f"crag_task_3_dev_v4.tar.bz2.part{i}"
for i in (1, 2, 3, 4)
)
CRAG_TASK_3_PART_NAMES: tuple[str, ...] = tuple(
f"crag_task_3_dev_v4.tar.bz2.part{i}" for i in (1, 2, 3, 4)
)
# ---------------------------------------------------------------------------
# Multi-part virtual file (concatenates N files transparently)
# ---------------------------------------------------------------------------
class _MultiPartReader:
"""Read N files end-to-end as if they were one big file.
Implements just enough of the file protocol for ``bz2.BZ2File``
to consume it: ``read(n)``, ``readable()``, ``close()``.
Doesn't implement ``seek`` — the bz2 + tarfile streaming path
is forward-only, which is what we want here.
"""
def __init__(self, paths: list[Path]) -> None:
if not paths:
raise ValueError("_MultiPartReader needs at least one path")
for p in paths:
if not p.exists():
raise FileNotFoundError(p)
self._paths = list(paths)
self._idx = 0
self._fh: IO[bytes] | None = self._paths[0].open("rb")
self._closed = False
def read(self, n: int = -1) -> bytes:
if self._closed:
raise ValueError("read of closed _MultiPartReader")
if n is None or n < 0:
chunks: list[bytes] = []
while self._fh is not None:
chunks.append(self._fh.read())
self._advance()
return b"".join(chunks)
out: list[bytes] = []
remaining = n
while remaining > 0 and self._fh is not None:
chunk = self._fh.read(remaining)
if not chunk:
self._advance()
continue
out.append(chunk)
remaining -= len(chunk)
return b"".join(out)
def _advance(self) -> None:
if self._fh is not None:
self._fh.close()
self._fh = None
self._idx += 1
if self._idx < len(self._paths):
self._fh = self._paths[self._idx].open("rb")
def readable(self) -> bool:
return not self._closed
def close(self) -> None:
if self._fh is not None:
self._fh.close()
self._fh = None
self._closed = True
def __enter__(self) -> _MultiPartReader:
return self
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[no-untyped-def]
self.close()
# ---------------------------------------------------------------------------
# Stream the JSONL inside the tar.bz2
# ---------------------------------------------------------------------------
def _is_jsonl_member(name: str) -> bool:
return name.endswith(".jsonl") or name.endswith(".jsonl.txt")
def iter_questions_task3(
parts_dir: Path,
*,
max_questions: int | None = None,
) -> list[CragQuestion]:
"""Stream-parse Task 3 rows into ``CragQuestion`` objects.
The Task 3 archive ships its 2,706 questions sharded across
multiple JSONL files inside the tar (e.g.
``crag_task_3_dev_v4_0.jsonl``, ``..._1.jsonl``, ). We iterate
members in-stream, parse every JSONL one we encounter, and stop
as soon as ``max_questions`` is reached at which point we
don't decompress any further members.
For a typical n=50 sample at ~3 MB per row we touch ~150 MB of
decompressed JSONL almost always inside the first shard.
"""
parts = [parts_dir / name for name in CRAG_TASK_3_PART_NAMES]
multi = _MultiPartReader(parts)
bz = bz2.BZ2File(multi, mode="rb")
tar = tarfile.open(fileobj=bz, mode="r|")
out: list[CragQuestion] = []
raw_idx = 0
found_jsonl = False
try:
for member in tar:
if not member.isfile() or not _is_jsonl_member(member.name):
continue
found_jsonl = True
logger.info(
"CRAG Task 3: streaming JSONL shard %s (size: %d bytes)",
member.name, member.size,
)
fh = tar.extractfile(member)
if fh is None:
logger.warning("tar.extractfile returned None for %s; skipping", member.name)
continue
try:
for raw_line in fh:
line = raw_line.decode("utf-8", errors="replace").strip()
if not line:
continue
try:
row = json.loads(line)
except json.JSONDecodeError as exc:
logger.warning(
"Skipping malformed CRAG Task 3 row %d in %s: %s",
raw_idx, member.name, exc,
)
raw_idx += 1
continue
query = str(row.get("query") or "").strip()
answer = str(row.get("answer") or "").strip()
if not query or not answer:
raw_idx += 1
continue
out.append(CragQuestion(
qid=f"T3_{raw_idx:05d}",
interaction_id=str(row.get("interaction_id") or "").strip(),
query_time=str(row.get("query_time") or "").strip(),
query=query,
gold_answer=answer,
alt_answers=_parse_alt_answers(row.get("alt_ans")),
domain=str(row.get("domain") or "").strip().lower(),
question_type=str(row.get("question_type") or "").strip().lower(),
static_or_dynamic=str(row.get("static_or_dynamic") or "").strip().lower(),
popularity=str(row.get("popularity") or "").strip().lower(),
split=int(row.get("split") or 0),
raw_index=raw_idx,
pages=_parse_pages(row.get("search_results")),
))
raw_idx += 1
if max_questions is not None and len(out) >= max_questions:
return out
finally:
try:
fh.close()
except Exception: # noqa: BLE001
pass
if not found_jsonl:
raise RuntimeError(
"No JSONL member found inside Task 3 tar.bz2 archive; "
"schema may have changed upstream."
)
finally:
try:
tar.close()
except Exception: # noqa: BLE001
pass
try:
bz.close()
except Exception: # noqa: BLE001
pass
try:
multi.close()
except Exception: # noqa: BLE001
pass
return out
def parts_present(parts_dir: Path) -> bool:
"""``True`` iff all 4 parts exist on disk and are non-empty."""
for name in CRAG_TASK_3_PART_NAMES:
p = parts_dir / name
if not p.exists() or p.stat().st_size == 0:
return False
return True
# ---------------------------------------------------------------------------
# Re-exports for convenience
# ---------------------------------------------------------------------------
__all__ = [
"CRAG_TASK_3_PART_NAMES",
"CRAG_TASK_3_PART_URLS",
"CragPage",
"CragQuestion",
"iter_questions_task3",
"parts_present",
]

View file

@ -0,0 +1,540 @@
"""CRAG 3-class grader: ``correct`` (+1) / ``missing`` (0) / ``incorrect`` (-1).
The CRAG paper's headline metric is the **Truthfulness Score**:
score = (#correct - #incorrect) / total
which rewards calibrated abstention refusing to answer is neutral
(0), guessing wrong is negative (-1). Grading is therefore a 3-class
problem rather than the 2-class accuracy used for FRAMES.
Pipeline per (pred, gold, alt_ans, question_type):
1. Detect refusal first (``Answer: I don't know`` / "I don't know" /
"no information") ``missing`` (deterministic, never billed).
2. ``false_premise`` questions: gold is canonically "the question
contains a false premise" — reward any answer that flags the
false premise (substring "false premise" / "incorrect premise" /
"no such") as correct.
3. Run the FRAMES-style deterministic shortcut (exact / numeric /
substring) on ``pred`` against ``gold alt_ans``. Hit correct.
4. Fall through to the LLM judge (if configured), which returns one
of ``{correct, missing, incorrect}`` verbatim CRAG protocol.
5. No judge configured record ``incorrect`` (pessimistic but at
least monotone with the deterministic grader).
The judge is throttled by an asyncio.Semaphore so it doesn't outrun
the OpenRouter rate limit; the pre-judge deterministic pass keeps
the bill bounded (most easy "Beyoncé"-vs-"Beyoncé Knowles" cases
short-circuit before we burn judge tokens).
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import string
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Literal
from ....core.providers.openrouter_chat import OpenRouterChatProvider
logger = logging.getLogger(__name__)
GradeClass = Literal["correct", "missing", "incorrect"]
# ---------------------------------------------------------------------------
# Public type
# ---------------------------------------------------------------------------
@dataclass
class CragGradeResult:
"""One graded (pred, gold) pair under CRAG's 3-class rubric."""
grade: GradeClass
score: int # +1 / 0 / -1
method: str # exact, numeric, substring, refusal,
# false_premise_correct, false_premise_miss,
# llm_judge, lexical_miss, ...
normalised_pred: str = ""
normalised_gold: str = ""
judge_rationale: str = ""
@property
def correct(self) -> bool:
return self.grade == "correct"
@property
def missing(self) -> bool:
return self.grade == "missing"
@property
def incorrect(self) -> bool:
return self.grade == "incorrect"
def to_dict(self) -> dict[str, Any]:
return {
"grade": self.grade,
"score": self.score,
"method": self.method,
"normalised_pred": self.normalised_pred,
"normalised_gold": self.normalised_gold,
"judge_rationale": self.judge_rationale,
}
def _grade_to_score(grade: GradeClass) -> int:
return {"correct": 1, "missing": 0, "incorrect": -1}[grade]
# ---------------------------------------------------------------------------
# Normalisation
# ---------------------------------------------------------------------------
_PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.IGNORECASE)
_WS = re.compile(r"\s+")
def _normalise(s: str) -> str:
s = (s or "").lower()
s = s.translate(_PUNCT_TABLE)
s = _ARTICLES.sub(" ", s)
s = _WS.sub(" ", s).strip()
return s
_WORD_NUMBERS = {
"zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11,
"twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15, "sixteen": 16,
"seventeen": 17, "eighteen": 18, "nineteen": 19, "twenty": 20,
}
_NUMERIC_RE = re.compile(r"-?\d+(?:[.,]\d+)?")
def _maybe_number(s: str) -> float | None:
"""Extract a single numeric value from raw lowercased text."""
raw = (s or "").strip().lower()
if not raw:
return None
match = _NUMERIC_RE.search(raw)
if match:
try:
return float(match.group(0).replace(",", ""))
except ValueError:
pass
for tok in _normalise(s).split():
if tok in _WORD_NUMBERS:
return float(_WORD_NUMBERS[tok])
return None
def _whole_word_substring(haystack: str, needle: str) -> bool:
if not needle:
return False
return f" {needle} " in f" {haystack} "
# ---------------------------------------------------------------------------
# Refusal detection
# ---------------------------------------------------------------------------
_REFUSAL_PATTERNS = [
re.compile(r"\bi\s+don'?t\s+know\b", re.IGNORECASE),
re.compile(r"\bi\s+do\s+not\s+know\b", re.IGNORECASE),
re.compile(r"\bnot\s+enough\s+information\b", re.IGNORECASE),
re.compile(r"\binsufficient\s+information\b", re.IGNORECASE),
re.compile(r"\bcannot\s+(?:be\s+)?(?:answered|determined)\b", re.IGNORECASE),
re.compile(r"\bunable\s+to\s+(?:answer|determine)\b", re.IGNORECASE),
re.compile(r"\bno\s+(?:information|data|evidence)\b", re.IGNORECASE),
]
def _is_refusal(pred: str) -> bool:
"""Cheap deterministic check for "I don't know" -shaped responses."""
if not pred or not pred.strip():
return True # empty answer is a de facto refusal
return any(p.search(pred) for p in _REFUSAL_PATTERNS)
# ---------------------------------------------------------------------------
# False-premise handling
# ---------------------------------------------------------------------------
_FALSE_PREMISE_PATTERNS = [
re.compile(r"false\s+premise", re.IGNORECASE),
re.compile(r"incorrect\s+premise", re.IGNORECASE),
re.compile(r"premise\s+(?:is|of)\s+the\s+question", re.IGNORECASE),
re.compile(r"\bno\s+such\b", re.IGNORECASE),
re.compile(r"never\s+(?:happened|occurred|existed)", re.IGNORECASE),
re.compile(r"\bdid\s+not\s+(?:happen|occur|exist)\b", re.IGNORECASE),
re.compile(r"\bdoes\s+not\s+exist\b", re.IGNORECASE),
re.compile(r"is\s+not\s+(?:true|correct|accurate)", re.IGNORECASE),
re.compile(r"\bisn'?t\s+(?:true|correct|accurate)\b", re.IGNORECASE),
re.compile(r"\binvalid\s+(?:premise|question|assumption)\b", re.IGNORECASE),
]
def _flags_false_premise(pred: str) -> bool:
return any(p.search(pred) for p in _FALSE_PREMISE_PATTERNS)
# ---------------------------------------------------------------------------
# Deterministic grader
# ---------------------------------------------------------------------------
def grade_deterministic(
*,
pred: str,
gold: str,
alt_answers: Sequence[str] = (),
question_type: str = "",
) -> CragGradeResult:
"""Try to grade without the LLM judge. Returns a final result.
Always returns *some* result the caller checks ``method`` to
decide whether the LLM judge should overturn it. ``lexical_miss``
and ``false_premise_unclear`` are the two methods that trigger the
judge fallback.
"""
qtype = (question_type or "").lower()
n_pred = _normalise(pred)
n_gold = _normalise(gold)
if _is_refusal(pred):
# CRAG protocol: refusal is *missing* (0), even on false-premise
# questions where one might argue refusal == correct. We
# follow the paper's grading literally.
return CragGradeResult(
grade="missing",
score=0,
method="refusal",
normalised_pred=n_pred,
normalised_gold=n_gold,
)
# Empty-gold guard (shouldn't happen, but defensively):
if not n_gold:
return CragGradeResult(
grade="incorrect",
score=-1,
method="empty_gold",
normalised_pred=n_pred,
normalised_gold=n_gold,
)
# False-premise questions: gold is typically "the question contains
# a false premise" / "no such X" / similar. Any answer that
# explicitly flags the false premise is correct.
if qtype == "false_premise":
if _flags_false_premise(pred):
return CragGradeResult(
grade="correct",
score=1,
method="false_premise_flagged",
normalised_pred=n_pred,
normalised_gold=n_gold,
)
# If the model commits to *any* concrete answer on a false-
# premise question without flagging the premise, it is wrong.
# But we don't classify ourselves — let the judge decide on
# the off chance the gold itself is e.g. "no" and the pred
# is "no" without explicit "false premise" wording.
return CragGradeResult(
grade="incorrect",
score=-1,
method="false_premise_unclear",
normalised_pred=n_pred,
normalised_gold=n_gold,
)
# All non-false-premise questions: try the standard chain against
# gold and each alt answer. First match wins.
candidates = [gold, *list(alt_answers)]
for candidate in candidates:
if not candidate or not str(candidate).strip():
continue
cand_norm = _normalise(candidate)
if not cand_norm:
continue
if n_pred == cand_norm:
return CragGradeResult(
grade="correct", score=1, method="exact",
normalised_pred=n_pred, normalised_gold=cand_norm,
)
p_num = _maybe_number(pred)
c_num = _maybe_number(candidate)
if p_num is not None and c_num is not None:
# Pure 1% relative tolerance for CRAG (currency, counts,
# ratios). Unlike FRAMES (which uses a 0.5 absolute floor
# for year-shaped answers), CRAG's numeric questions are
# often small-value (stock prices, percentages) where a
# 0.5 floor would let "$2.05" match "$2.17". The judge is
# the safety net for borderline rounding cases.
tol = abs(c_num) * 0.01
if abs(p_num - c_num) <= tol:
return CragGradeResult(
grade="correct", score=1, method="numeric",
normalised_pred=n_pred, normalised_gold=cand_norm,
)
# Numeric question with different numbers — keep looking
# at other candidates rather than declaring miss now;
# alt answers may include word forms that pass.
if _whole_word_substring(n_pred, cand_norm):
return CragGradeResult(
grade="correct", score=1, method="substring",
normalised_pred=n_pred, normalised_gold=cand_norm,
)
if _whole_word_substring(cand_norm, n_pred) and len(n_pred) >= 3:
return CragGradeResult(
grade="correct", score=1, method="substring_reverse",
normalised_pred=n_pred, normalised_gold=cand_norm,
)
return CragGradeResult(
grade="incorrect",
score=-1,
method="lexical_miss",
normalised_pred=n_pred,
normalised_gold=n_gold,
)
# ---------------------------------------------------------------------------
# LLM-as-judge (3-class)
# ---------------------------------------------------------------------------
_JUDGE_SYSTEM = (
"You are an impartial grader for short-answer factual questions, "
"following the CRAG benchmark rubric. Given a question, the gold "
"answer (and any alternative valid answers), and a model's "
"prediction, classify the prediction into exactly one of three "
"categories:\n\n"
"* \"correct\" — the prediction expresses the same factual "
"content as the gold answer (paraphrasing OK; numbers as words "
"OK; partial-but-correct names OK; non-contradictory extra "
"detail OK).\n"
"* \"missing\" — the prediction explicitly refuses, says \"I "
"don't know\", says there is insufficient information, or hedges "
"without committing.\n"
"* \"incorrect\" — the prediction commits to a fact that is "
"different from the gold answer, or fails to flag a false "
"premise when the question contains one.\n\n"
"Special case: if the question contains a false premise and the "
"gold answer says so, then a prediction that flags the false "
"premise is \"correct\".\n\n"
"Respond with ONLY a JSON object on a single line:\n"
'{\"grade\": \"correct\"|\"missing\"|\"incorrect\", \"rationale\": \"<one short sentence>\"}'
)
_JUDGE_TEMPLATE = """\
Question: {question}
Question type: {question_type}
Gold answer: {gold}
{alt_block}Model prediction: {pred}
Decide whether the prediction is correct, missing, or incorrect.
"""
@dataclass
class CragJudgeConfig:
api_key: str
model: str = "anthropic/claude-sonnet-4.5"
base_url: str = "https://openrouter.ai/api/v1"
max_tokens: int = 200
concurrency: int = 4
class CragLlmJudge:
"""Async LLM judge over OpenRouter chat completions, 3-class output."""
def __init__(self, *, config: CragJudgeConfig) -> None:
self._config = config
self._provider = OpenRouterChatProvider(
api_key=config.api_key,
base_url=config.base_url,
model=config.model,
)
self._sem = asyncio.Semaphore(max(1, config.concurrency))
@property
def model(self) -> str:
return self._config.model
async def judge(
self,
*,
question: str,
gold: str,
alt_answers: Sequence[str],
pred: str,
question_type: str = "",
) -> tuple[GradeClass, str]:
"""Return ``(grade, rationale)``. Errors return incorrect + reason."""
alt_block = ""
if alt_answers:
alt_lines = "\n".join(f" - {a}" for a in alt_answers if a)
if alt_lines:
alt_block = f"Alternative valid answers:\n{alt_lines}\n"
prompt = _JUDGE_TEMPLATE.format(
question=question,
question_type=question_type or "unknown",
gold=gold,
alt_block=alt_block,
pred=pred,
)
try:
async with self._sem:
response = await self._provider.complete(
prompt=prompt,
system_prompt=_JUDGE_SYSTEM,
max_tokens=self._config.max_tokens,
)
except Exception as exc: # noqa: BLE001
return "incorrect", f"judge_error: {type(exc).__name__}: {exc}"
return _parse_judge_response(response.text)
def _parse_judge_response(text: str) -> tuple[GradeClass, str]:
"""Parse the judge reply into a 3-class label + rationale."""
if not text or not text.strip():
return "incorrect", "judge_returned_empty"
match = re.search(r"\{[^{}]*\}", text, flags=re.DOTALL)
candidate = match.group(0) if match else text
try:
data = json.loads(candidate)
except (json.JSONDecodeError, ValueError):
lowered = text.strip().lower()
if "correct" in lowered and "incorrect" not in lowered:
return "correct", "yes (parser_fallback)"
if "missing" in lowered or "i don" in lowered:
return "missing", "missing (parser_fallback)"
return "incorrect", f"unparseable_judge_response: {text[:200]}"
raw_grade = str(data.get("grade") or "").strip().lower()
rationale = str(data.get("rationale", "")).strip()[:280]
if raw_grade in {"correct", "missing", "incorrect"}:
return raw_grade, rationale # type: ignore[return-value]
return "incorrect", f"unknown_grade={raw_grade!r}; {rationale}"
# ---------------------------------------------------------------------------
# Combined grader
# ---------------------------------------------------------------------------
# Methods that should *not* trigger the LLM judge — the deterministic
# verdict is conclusive (refusal, exact match, numeric mismatch, etc.).
_TERMINAL_METHODS = frozenset({
"refusal",
"exact",
"numeric",
"substring",
"substring_reverse",
"false_premise_flagged",
"empty_gold",
})
async def grade_with_judge(
*,
pred: str,
gold: str,
alt_answers: Sequence[str],
question: str,
question_type: str,
judge: CragLlmJudge | None,
) -> CragGradeResult:
"""One row → deterministic shortcut → optional LLM judge fallback."""
det = grade_deterministic(
pred=pred,
gold=gold,
alt_answers=alt_answers,
question_type=question_type,
)
if det.method in _TERMINAL_METHODS:
return det
if judge is None:
return det # ``lexical_miss`` / ``false_premise_unclear`` → keep as-is
grade, rationale = await judge.judge(
question=question,
gold=gold,
alt_answers=alt_answers,
pred=pred,
question_type=question_type,
)
return CragGradeResult(
grade=grade,
score=_grade_to_score(grade),
method="llm_judge",
normalised_pred=det.normalised_pred,
normalised_gold=det.normalised_gold,
judge_rationale=rationale,
)
@dataclass
class CragGradeRow:
"""One row to grade. Mirrors the FRAMES grader's tuple but typed."""
qid: str
question: str
gold: str
alt_answers: list[str]
pred: str
question_type: str = ""
async def grade_many(
*,
rows: Sequence[CragGradeRow],
judge: CragLlmJudge | None,
) -> list[CragGradeResult]:
"""Grade every row concurrently. Judge enforces its own concurrency cap."""
if not rows:
return []
coros = [
grade_with_judge(
pred=r.pred,
gold=r.gold,
alt_answers=r.alt_answers,
question=r.question,
question_type=r.question_type,
judge=judge,
)
for r in rows
]
return list(await asyncio.gather(*coros))
__all__ = [
"CragGradeResult",
"CragGradeRow",
"CragJudgeConfig",
"CragLlmJudge",
"GradeClass",
"grade_deterministic",
"grade_many",
"grade_with_judge",
]

View file

@ -0,0 +1,206 @@
"""HTML → markdown for CRAG pages, with boilerplate removal.
Each CRAG page is a *full* HTML document (nav, ads, recommended-for-
you, footer, ...). Without removing that boilerplate, retrieval over
the chunks would surface menu items and "subscribe to our newsletter"
boxes instead of the actual page content. We use ``trafilatura``,
which is purpose-built for main-content extraction (the same library
Common Crawl downstream pipelines use). It outputs clean prose with
section headers, lists, and tables preserved.
Extraction policy:
1. ``trafilatura.extract`` with ``output_format="markdown"`` main
content only, headers preserved, tables kept.
2. If extraction fails or returns < 200 chars (paywalled / JS-only
page / extraction confused), fall back to a plain stdlib
``HTMLParser`` that strips tags and collapses whitespace. Some
text is better than no text SurfSense's chunker handles noisy
prose.
We *intentionally* keep the page name and URL as visible H1 / link
metadata so the SurfSense chunker preserves doc identity at the top of
the first chunk (mirrors what we do for FRAMES Wikipedia pages).
"""
from __future__ import annotations
import html
import logging
import re
from dataclasses import dataclass
from html.parser import HTMLParser
logger = logging.getLogger(__name__)
_MIN_TRAFILATURA_LENGTH = 200
_MAX_OUTPUT_CHARS = 200_000 # cap to keep upload payloads sane
@dataclass
class ExtractionResult:
"""Outcome of converting one HTML blob to plain markdown."""
text: str
method: str # "trafilatura" | "fallback_strip" | "empty"
n_chars: int
@property
def ok(self) -> bool:
return self.n_chars > 0
# ---------------------------------------------------------------------------
# Trafilatura wrapper (lazy import so tests / small scripts don't pay)
# ---------------------------------------------------------------------------
def _trafilatura_extract(html_text: str, *, url: str) -> str | None:
try:
import trafilatura
except ImportError: # pragma: no cover - dependency is required
logger.warning("trafilatura not installed; falling back to strip-tags only")
return None
try:
text = trafilatura.extract(
html_text,
url=url or None,
output_format="markdown",
include_links=False,
include_images=False,
include_tables=True,
favor_recall=True,
)
except Exception as exc: # noqa: BLE001 - trafilatura raises a zoo
logger.debug("trafilatura.extract crashed for %s: %s", url, exc)
return None
if not text:
return None
return text.strip()
# ---------------------------------------------------------------------------
# Stdlib fallback: strip HTML tags
# ---------------------------------------------------------------------------
class _StripHTMLParser(HTMLParser):
"""Collect text content, treating block tags as paragraph breaks.
We deliberately drop ``<script>``, ``<style>``, ``<nav>``,
``<header>``, ``<footer>``, and ``<aside>`` content these are
almost always boilerplate and they are the dominant source of
noise SurfSense ends up retrieving against if not removed.
"""
_SKIP_TAGS = frozenset({"script", "style", "nav", "header", "footer", "aside", "svg"})
_BLOCK_TAGS = frozenset({
"p", "div", "section", "article", "li", "ul", "ol",
"h1", "h2", "h3", "h4", "h5", "h6", "br", "tr",
"td", "th", "table", "blockquote", "pre",
})
def __init__(self) -> None:
super().__init__(convert_charrefs=True)
self._buffer: list[str] = []
self._skip_depth: int = 0
def handle_starttag(self, tag: str, attrs: list) -> None: # noqa: ARG002
if tag in self._SKIP_TAGS:
self._skip_depth += 1
if tag in self._BLOCK_TAGS:
self._buffer.append("\n")
def handle_endtag(self, tag: str) -> None:
if tag in self._SKIP_TAGS and self._skip_depth > 0:
self._skip_depth -= 1
if tag in self._BLOCK_TAGS:
self._buffer.append("\n")
def handle_data(self, data: str) -> None:
if self._skip_depth:
return
self._buffer.append(data)
def get_text(self) -> str:
text = "".join(self._buffer)
# Decode any leftover entities and collapse whitespace.
text = html.unescape(text)
text = re.sub(r"[ \t]+", " ", text)
text = re.sub(r"\n[ \t]+", "\n", text)
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
def _strip_tags(html_text: str) -> str:
parser = _StripHTMLParser()
try:
parser.feed(html_text)
except Exception as exc: # noqa: BLE001 - HTMLParser is fragile on garbage input
logger.debug("HTMLParser failed; using regex strip: %s", exc)
no_tags = re.sub(r"<[^>]+>", " ", html_text)
return re.sub(r"\s+", " ", html.unescape(no_tags)).strip()
return parser.get_text()
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def extract_main_content(
html_text: str,
*,
url: str = "",
page_name: str = "",
last_modified: str | None = None,
) -> ExtractionResult:
"""Convert one HTML blob into clean markdown for ingest.
The returned ``text`` is prefixed with a small metadata header
(``# {page_name}\\n\\nSource: {url}\\n``) so that:
* SurfSense's chunker has a stable doc-identity anchor at the top
of the first chunk (matches what we do for FRAMES Wikipedia).
* The retrieval-augmented arm sees the URL inline, which the LLM
can surface as a citation if the prompt asks for one.
"""
body = ""
method = "empty"
if html_text and html_text.strip():
body = _trafilatura_extract(html_text, url=url) or ""
if body and len(body) >= _MIN_TRAFILATURA_LENGTH:
method = "trafilatura"
else:
stripped = _strip_tags(html_text)
# Prefer trafilatura output even if short, but only if it
# contained any prose at all — empty trafilatura fall-through
# to the stripped form.
if body and stripped and len(stripped) > len(body) * 1.5:
body = stripped
method = "fallback_strip"
elif not body and stripped:
body = stripped
method = "fallback_strip"
elif body:
method = "trafilatura"
body = body.strip()
if len(body) > _MAX_OUTPUT_CHARS:
body = body[:_MAX_OUTPUT_CHARS] + "\n\n[...truncated...]"
if not body:
return ExtractionResult(text="", method="empty", n_chars=0)
title_line = (page_name or url or "Untitled").strip()
header_lines = [f"# {title_line}"]
if url:
header_lines.append(f"Source: {url}")
if last_modified:
header_lines.append(f"Last modified: {last_modified}")
final = "\n".join(header_lines) + "\n\n" + body + "\n"
return ExtractionResult(text=final, method=method, n_chars=len(final))
__all__ = ["ExtractionResult", "extract_main_content"]

View file

@ -0,0 +1,447 @@
"""CRAG ingestion: download → extract → upload → per-question doc map.
Steps:
1. Download ``crag_task_1_and_2_dev_v4.jsonl.bz2`` from
``facebookresearch/CRAG`` (skip if cached).
2. Stream-parse into ``CragQuestion`` objects.
3. Optionally cap to ``--n-questions N`` (and *stratified* sample
across ``(domain, question_type)`` so the smoke / partial run
isn't dominated by ``finance`` or ``simple``).
4. For each question, extract the 5 web pages to clean markdown via
``trafilatura`` and write them to
``<bench_dir>/pages/<qid>__<page_idx>__<url_hash>.md``. The
filename is unique across the whole sample (so SurfSense's
``(filename, search_space)`` dedup never collides between
questions) and round-trippable (the ``<qid>__`` prefix lets the
ingest infer doc-membership at the title level even before we
land on a stable status response).
5. Upload all extracted pages to SurfSense in batches with text-only
ETL (``use_vision_llm=False, processing_mode="basic"``) these
are extracted plaintext, no images involved.
6. Persist a doc map at
``<suite_data>/maps/crag_doc_map.jsonl`` with one row per question:
{"qid": "C00042",
"interaction_id": "<uuid>",
"question": "<text>",
"gold_answer": "<text>",
"alt_answers": [...],
"domain": "...", "question_type": "...",
"static_or_dynamic": "...", "popularity": "...",
"query_time": "...",
"page_filenames": ["C00042__0__abc123.md", ...],
"document_ids": [42101, 42102, ...],
"missing_pages": [...] # filenames whose upload failed
}
The runner uses ``document_ids`` to scope SurfSense retrieval to
exactly the 5 pages of the question (matches CRAG protocol the
benchmark explicitly hands over its own retrieved pages).
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from ....core.clients.documents import (
DocumentProcessingFailed,
DocumentProcessingTimeout,
)
from ....core.config import set_suite_state
from ....core.ingest_settings import IngestSettings, settings_header_line
from ....core.registry import RunContext
from .dataset import (
CragPage,
CragQuestion,
download_task_1_2,
iter_questions,
stratified_sample,
write_questions_jsonl,
)
from .html_extract import extract_main_content
logger = logging.getLogger(__name__)
_FILENAME_SAFE = re.compile(r"[^A-Za-z0-9._\-]+")
def _page_filename(qid: str, page_idx: int, page: CragPage) -> str:
"""Filesystem-safe, globally unique markdown filename for a CRAG page.
Format: ``<qid>__<idx>__<url_hash>.md``. Both the qid (``C00042``)
and the URL-hash (``[:12]``) are alphanumeric so we don't need to
sanitise them, but we strip anything else just in case.
"""
qid_safe = _FILENAME_SAFE.sub("_", qid)
return f"{qid_safe}__{page_idx:02d}__{page.url_hash}.md"
# ---------------------------------------------------------------------------
# Stats
# ---------------------------------------------------------------------------
@dataclass
class _IngestStats:
n_questions: int
n_pages_total: int
n_pages_extracted: int
n_pages_empty: int
n_uploaded: int
n_existing: int
bench_dir: Path
map_path: Path
# ---------------------------------------------------------------------------
# Page extraction
# ---------------------------------------------------------------------------
def _materialise_pages(
questions: list[CragQuestion],
*,
pages_dir: Path,
overwrite: bool = False,
) -> tuple[dict[str, list[str]], dict[str, str]]:
"""Extract every page in every question to ``pages_dir`` as markdown.
Returns:
* ``qid -> [filename, filename, ...]`` (in page order, only
successful extractions)
* ``filename -> source_url`` for diagnostics
Empty extractions (paywall / JS / parse-fail with no fallback
output) are skipped better to retrieve from 4 pages than feed
SurfSense's chunker an empty file.
"""
pages_dir.mkdir(parents=True, exist_ok=True)
qid_to_files: dict[str, list[str]] = {}
file_to_url: dict[str, str] = {}
method_counts: dict[str, int] = {}
n_empty = 0
for q in questions:
names: list[str] = []
for idx, page in enumerate(q.pages):
filename = _page_filename(q.qid, idx, page)
dest = pages_dir / filename
if dest.exists() and dest.stat().st_size > 0 and not overwrite:
method_counts["cache_hit"] = method_counts.get("cache_hit", 0) + 1
names.append(filename)
file_to_url[filename] = page.page_url
continue
result = extract_main_content(
page.page_html,
url=page.page_url,
page_name=page.page_name,
last_modified=page.page_last_modified,
)
method_counts[result.method] = method_counts.get(result.method, 0) + 1
if not result.ok:
n_empty += 1
continue
dest.write_text(result.text, encoding="utf-8")
names.append(filename)
file_to_url[filename] = page.page_url
qid_to_files[q.qid] = names
logger.info(
"CRAG page extraction: %s; empty=%d, total_files=%d across %d questions",
method_counts, n_empty, len(file_to_url), len(qid_to_files),
)
return qid_to_files, file_to_url
# ---------------------------------------------------------------------------
# Upload
# ---------------------------------------------------------------------------
async def _upload_pages(
ctx: RunContext,
*,
pages_dir: Path,
filenames: list[str],
batch_size: int,
settings: IngestSettings,
) -> dict[str, int]:
"""Upload ``filenames`` (already on disk under ``pages_dir``) and return name → doc_id."""
if not filenames:
return {}
docs_client = ctx.documents_client()
name_to_id: dict[str, int] = {}
paths = [pages_dir / fn for fn in filenames if (pages_dir / fn).exists()]
for batch_start in range(0, len(paths), batch_size):
batch = paths[batch_start : batch_start + batch_size]
result = await docs_client.upload(
files=batch,
search_space_id=ctx.search_space_id,
should_summarize=settings.should_summarize,
use_vision_llm=settings.use_vision_llm,
processing_mode=settings.processing_mode,
)
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
if result.document_ids:
try:
await docs_client.wait_until_ready(
search_space_id=ctx.search_space_id,
document_ids=result.document_ids,
timeout_s=900.0,
)
except (DocumentProcessingFailed, DocumentProcessingTimeout) as exc:
logger.warning("CRAG batch processing issue: %s", exc)
if all_ids:
statuses = await docs_client.get_status(
search_space_id=ctx.search_space_id,
document_ids=all_ids,
)
for s in statuses:
stem = Path(s.title).stem if s.title.endswith(".md") else s.title
name_to_id[stem] = s.document_id
name_to_id[s.title] = s.document_id
if not s.title.endswith(".md"):
name_to_id[f"{s.title}.md"] = s.document_id
logger.info(
"CRAG upload batch %d-%d: %d new, %d duplicate",
batch_start, batch_start + len(batch),
len(result.document_ids), len(result.duplicate_document_ids),
)
return name_to_id
# ---------------------------------------------------------------------------
# Doc map writer
# ---------------------------------------------------------------------------
def _resolve_question_doc_ids(
questions: list[CragQuestion],
qid_to_files: dict[str, list[str]],
name_to_id: dict[str, int],
) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for q in questions:
filenames = qid_to_files.get(q.qid, [])
doc_ids: list[int] = []
missing: list[str] = []
for fn in filenames:
stem = Path(fn).stem
doc_id = name_to_id.get(stem) or name_to_id.get(fn)
if doc_id is not None and doc_id not in doc_ids:
doc_ids.append(doc_id)
else:
missing.append(fn)
rows.append({
"qid": q.qid,
"interaction_id": q.interaction_id,
"raw_index": q.raw_index,
"question": q.query,
"gold_answer": q.gold_answer,
"alt_answers": list(q.alt_answers),
"domain": q.domain,
"question_type": q.question_type,
"static_or_dynamic": q.static_or_dynamic,
"popularity": q.popularity,
"query_time": q.query_time,
"split": q.split,
"page_filenames": filenames,
"document_ids": doc_ids,
"missing_pages": missing,
"n_pages": len(filenames),
})
return rows
# ---------------------------------------------------------------------------
# Public entry
# ---------------------------------------------------------------------------
async def run_ingest(
ctx: RunContext,
*,
n_questions: int | None = None,
upload_batch_size: int = 16,
skip_upload: bool = False,
overwrite_extract: bool = False,
settings: IngestSettings | None = None,
sample_seed: int = 17,
) -> None:
"""Ingest the CRAG benchmark (Tasks 1 & 2) into the research suite.
Parameters
----------
n_questions
Cap on the number of CRAG questions to materialise.
``None`` = all 2,706 (~13,500 pages large; smoke runs
should pass 10-20 and full runs ~200).
upload_batch_size
Markdown files per ``/documents/fileupload`` call.
skip_upload
Extract + cache markdown locally but don't push to SurfSense
(useful for debugging the extraction step).
overwrite_extract
Re-run trafilatura even when a cached markdown file exists.
Default False so re-running ingest is idempotent.
settings
Override per-upload knobs. CRAG defaults to text-only basic
ETL these are *extracted* plaintext, no images.
sample_seed
RNG seed for ``stratified_sample``. Pin this for reproducibility.
"""
settings = settings or IngestSettings(
use_vision_llm=False,
processing_mode="basic",
should_summarize=False,
)
bench_dir = ctx.benchmark_data_dir()
pages_dir = bench_dir / "pages"
raw_cache = bench_dir / ".raw_cache"
raw_cache.mkdir(parents=True, exist_ok=True)
bz2_path = download_task_1_2(raw_cache)
logger.info("CRAG: parsing %s ...", bz2_path.name)
all_questions = iter_questions(bz2_path)
if not all_questions:
raise RuntimeError(
"CRAG JSONL contained no parseable rows; upstream may have changed schema."
)
logger.info("CRAG: parsed %d total questions", len(all_questions))
if n_questions is not None and n_questions > 0:
questions = stratified_sample(all_questions, n=n_questions, seed=sample_seed)
logger.info(
"CRAG: stratified sample of %d questions across %d (domain, qtype) buckets",
len(questions),
len({(q.domain, q.question_type) for q in questions}),
)
else:
questions = all_questions
questions_jsonl = bench_dir / "questions.jsonl"
write_questions_jsonl(questions, questions_jsonl)
n_pages_total = sum(len(q.pages) for q in questions)
logger.info(
"CRAG: extracting up to %d pages across %d questions ...",
n_pages_total, len(questions),
)
qid_to_files, file_to_url = _materialise_pages(
questions, pages_dir=pages_dir, overwrite=overwrite_extract,
)
n_pages_extracted = sum(len(v) for v in qid_to_files.values())
name_to_id: dict[str, int] = {}
if skip_upload:
logger.info("CRAG: --skip-upload; skipping SurfSense ingestion")
else:
all_filenames = sorted({fn for fns in qid_to_files.values() for fn in fns})
logger.info("CRAG: uploading %d unique pages ...", len(all_filenames))
name_to_id = await _upload_pages(
ctx,
pages_dir=pages_dir,
filenames=all_filenames,
batch_size=upload_batch_size,
settings=settings,
)
doc_rows = _resolve_question_doc_ids(questions, qid_to_files, name_to_id)
map_path = ctx.maps_dir() / "crag_doc_map.jsonl"
with map_path.open("w", encoding="utf-8") as fh:
fh.write(settings_header_line(settings) + "\n")
for row in doc_rows:
fh.write(json.dumps(row) + "\n")
logger.info("Wrote CRAG doc map to %s (%d rows)", map_path, len(doc_rows))
new_state = ctx.suite_state
new_state.ingestion_maps["crag"] = str(map_path)
set_suite_state(ctx.config, ctx.suite, new_state)
stats = _IngestStats(
n_questions=len(questions),
n_pages_total=n_pages_total,
n_pages_extracted=n_pages_extracted,
n_pages_empty=n_pages_total - n_pages_extracted,
n_uploaded=len(name_to_id),
n_existing=0,
bench_dir=bench_dir,
map_path=map_path,
)
logger.info("CRAG ingest done: %s", stats)
# ---------------------------------------------------------------------------
# For runner: read extracted page text back from disk
# ---------------------------------------------------------------------------
def read_page_markdown(bench_dir: Path, filename: str) -> str | None:
"""Return the on-disk markdown body for a previously-extracted page.
Used by the long-context runner arm to assemble the prompt at
inference time we don't keep all 5×N pages in memory between
ingest and run.
"""
path = bench_dir / "pages" / filename
if not path.exists():
return None
try:
return path.read_text(encoding="utf-8")
except OSError:
return None
async def _retry_upload_idempotent( # noqa: D401 - hidden helper
ctx: RunContext,
*,
pages_dir: Path,
filenames: list[str],
batch_size: int,
settings: IngestSettings,
max_attempts: int = 2,
) -> dict[str, int]:
"""Future-proofing hook (unused today): retry the ingest upload pass."""
last_exc: Exception | None = None
for attempt in range(max_attempts):
try:
return await _upload_pages(
ctx,
pages_dir=pages_dir,
filenames=filenames,
batch_size=batch_size,
settings=settings,
)
except Exception as exc: # noqa: BLE001
last_exc = exc
logger.warning("CRAG upload attempt %d failed: %s", attempt + 1, exc)
await asyncio.sleep(2.0 * (attempt + 1))
if last_exc is not None:
raise last_exc
return {}
__all__ = [
"_IngestStats",
"_materialise_pages",
"_page_filename",
"_resolve_question_doc_ids",
"_upload_pages",
"read_page_markdown",
"run_ingest",
]

View file

@ -0,0 +1,191 @@
"""CRAG Task 3 ingestion: 4-part download → streaming JSONL → upload.
Same flow as ``ingest.run_ingest`` for Tasks 1 & 2 (extract HTML
upload markdown resolve doc_ids write doc map), but:
* Source: 4 .tar.bz2 parts streamed via ``dataset_task3``.
* Page count: 50 per question instead of 5 the whole point of
Task 3 (the long-context arm now structurally has to choose what
to keep, while SurfSense's retrieval becomes mandatory).
* Stratified sampling re-uses the Task 1 helper since the question
schema is identical.
Doc map lands at ``<suite_data>/maps/crag_t3_doc_map.jsonl`` with the
same row shape as Task 1's map (so the runner only needs to know
which file to load; everything else is shared).
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from ....core.config import set_suite_state
from ....core.ingest_settings import IngestSettings, settings_header_line
from ....core.registry import RunContext
from .dataset import stratified_sample, write_questions_jsonl
from .dataset_task3 import (
CRAG_TASK_3_PART_NAMES,
iter_questions_task3,
parts_present,
)
from .ingest import (
_IngestStats,
_materialise_pages,
_resolve_question_doc_ids,
_upload_pages,
)
logger = logging.getLogger(__name__)
_INSTRUCTIONS_TO_DOWNLOAD = (
"Run `python scripts/download_crag_task3.py` first to fetch the "
"4 tar.bz2 parts (~7 GB total) into "
"data/research/crag_t3/.raw_cache/. The downloader is idempotent "
"and parallel."
)
async def run_ingest_task3(
ctx: RunContext,
*,
n_questions: int | None = None,
upload_batch_size: int = 16,
skip_upload: bool = False,
overwrite_extract: bool = False,
settings: IngestSettings | None = None,
sample_seed: int = 17,
parse_cap: int | None = None,
) -> None:
"""Ingest CRAG Task 3 (50 pages per question) into the research suite.
Parameters
----------
n_questions
Cap on the post-stratified-sample question count. ``None`` =
"use whatever ``parse_cap`` produced". For real runs aim for
50 (~2,500 pages) n=200 (10k pages) is doable but slow.
parse_cap
Hard cap on how many rows we *parse* from the streaming
archive before stratified sampling. Defaults to
``max(400, 6*n_questions)`` enough to cover all (domain,
question_type) buckets ~5x but small enough to fit in the
first shard or two (each shard is 5 GB decompressed and
holds ~300 rows; bz2 throughput is ~50 MB/s). Lowering this
is the only knob that bounds streaming cost since we can
``break`` out of the JSONL stream early without decompressing
the rest of the ~50 GB archive body.
upload_batch_size
Markdown files per ``/documents/fileupload`` call.
skip_upload
Extract markdown locally, don't push to SurfSense.
overwrite_extract
Re-run trafilatura even when a cached markdown is present.
settings
Per-upload knobs override (default: text-only basic ETL).
sample_seed
RNG seed for stratified sampling (deterministic).
"""
settings = settings or IngestSettings(
use_vision_llm=False,
processing_mode="basic",
should_summarize=False,
)
bench_dir = ctx.benchmark_data_dir()
pages_dir = bench_dir / "pages"
raw_cache = bench_dir / ".raw_cache"
raw_cache.mkdir(parents=True, exist_ok=True)
if not parts_present(raw_cache):
missing = [
n for n in CRAG_TASK_3_PART_NAMES
if not (raw_cache / n).exists()
]
raise RuntimeError(
f"CRAG Task 3 parts missing from {raw_cache}: {missing}. "
f"{_INSTRUCTIONS_TO_DOWNLOAD}"
)
# 1. Stream-parse (capped). For n=50 we don't need the full 2,706
# rows — just enough that the stratified sampler can balance.
# Each tar shard ~5 GB / ~300 rows / ~2 min decompress, so
# 400-500 rows = shard 0 + a slice of shard 1 ≈ 3-4 min.
parse_cap = parse_cap or (
max(400, 6 * (n_questions or 50)) if n_questions else None
)
logger.info(
"CRAG Task 3: streaming JSONL (parse_cap=%s) ...",
parse_cap if parse_cap else "no-cap",
)
all_questions = iter_questions_task3(raw_cache, max_questions=parse_cap)
logger.info("CRAG Task 3: parsed %d rows", len(all_questions))
if not all_questions:
raise RuntimeError("CRAG Task 3 streaming returned 0 rows; check archive integrity.")
if n_questions is not None and n_questions > 0:
questions = stratified_sample(all_questions, n=n_questions, seed=sample_seed)
logger.info(
"CRAG Task 3: stratified sample of %d questions across %d (domain, qtype) buckets",
len(questions),
len({(q.domain, q.question_type) for q in questions}),
)
else:
questions = all_questions
questions_jsonl = bench_dir / "questions.jsonl"
write_questions_jsonl(questions, questions_jsonl)
n_pages_total = sum(len(q.pages) for q in questions)
logger.info(
"CRAG Task 3: extracting up to %d pages across %d questions ...",
n_pages_total, len(questions),
)
qid_to_files, _file_to_url = _materialise_pages(
questions, pages_dir=pages_dir, overwrite=overwrite_extract,
)
n_pages_extracted = sum(len(v) for v in qid_to_files.values())
name_to_id: dict[str, int] = {}
if skip_upload:
logger.info("CRAG Task 3: --skip-upload; skipping SurfSense ingestion")
else:
all_filenames = sorted({fn for fns in qid_to_files.values() for fn in fns})
logger.info("CRAG Task 3: uploading %d unique pages ...", len(all_filenames))
name_to_id = await _upload_pages(
ctx,
pages_dir=pages_dir,
filenames=all_filenames,
batch_size=upload_batch_size,
settings=settings,
)
doc_rows = _resolve_question_doc_ids(questions, qid_to_files, name_to_id)
map_path = ctx.maps_dir() / "crag_t3_doc_map.jsonl"
with map_path.open("w", encoding="utf-8") as fh:
fh.write(settings_header_line(settings) + "\n")
for row in doc_rows:
fh.write(json.dumps(row) + "\n")
logger.info("Wrote CRAG Task 3 doc map to %s (%d rows)", map_path, len(doc_rows))
new_state = ctx.suite_state
new_state.ingestion_maps["crag_t3"] = str(map_path)
set_suite_state(ctx.config, ctx.suite, new_state)
stats = _IngestStats(
n_questions=len(questions),
n_pages_total=n_pages_total,
n_pages_extracted=n_pages_extracted,
n_pages_empty=n_pages_total - n_pages_extracted,
n_uploaded=len(name_to_id),
n_existing=0,
bench_dir=bench_dir,
map_path=map_path,
)
logger.info("CRAG Task 3 ingest done: %s", stats)
__all__ = ["run_ingest_task3"]

View file

@ -0,0 +1,146 @@
"""CRAG prompt templates for the three competing arms.
The CRAG paper grades each prediction as one of:
* **correct** answer matches gold (with paraphrasing tolerance)
* **missing** model refuses or says "I don't know"
* **incorrect** model commits to a wrong answer (hallucination)
The truthfulness score `(correct - incorrect) / total` rewards
calibrated abstention, so the prompts below explicitly *invite* the
model to refuse when it isn't confident — otherwise the bare-LLM arm
gets penalised twice (no docs *and* a no-refusal prompt) and the
comparison stops being fair to the LLM-only baseline.
Three templates, byte-identical instructions:
* ``build_bare_prompt(q)`` question-only.
* ``build_long_context_prompt(q, contexts)`` question + concatenated
page extracts, all stuffed into the user message. Mirrors the
paper's "straightforward RAG" baseline.
* ``build_surfsense_prompt(q)`` question + a hint that retrieval
over the question's 5 ingested pages is available; the SurfSense
agent itself owns the retrieval step.
The ``Answer:`` line at the end is parsed by ``extract_freeform_answer``
in the runner, so the format is mandatory.
"""
from __future__ import annotations
_BASE_INSTRUCTIONS = (
"You are a careful question-answering assistant. The question is a "
"real-world factual question that may be about finance, music, "
"movies, sports, or any other domain.\n\n"
"Important rules:\n"
"1. If the question contains a false premise (an assumption that "
"is factually wrong), say so explicitly in your final answer "
"rather than answering as if the premise were true.\n"
"2. If you are not confident in an answer, prefer saying \"I don't "
"know\" over guessing. A wrong commit is penalised more than a "
"refusal.\n"
"3. Keep the final answer short — a name, a number, a date, a "
"phrase. Do not repeat the question.\n\n"
"Format your final line EXACTLY as:\n"
"Answer: <short answer>\n"
"If you don't know, write `Answer: I don't know`."
)
_BARE_TEMPLATE = """\
{instructions}
Question: {question}
Question time: {query_time}
"""
_SURFSENSE_TEMPLATE = """\
{instructions}
You have access to a search index of up to 5 web pages that were
retrieved for this question. Use the retrieval tool to look up any
facts you are not confident about. The pages may be partially or fully
relevant; some may contradict each other (prefer the more authoritative
or more recent source).
Question: {question}
Question time: {query_time}
"""
_LONG_CONTEXT_TEMPLATE = """\
{instructions}
You are given the full text of {n_contexts} web pages that were
retrieved for this question. Read all of them, then answer. The
pages may be partially or fully relevant; some may contradict each
other (prefer the more authoritative or more recent source).
{contexts}
Question: {question}
Question time: {query_time}
"""
def build_bare_prompt(question: str, *, query_time: str = "") -> str:
"""Prompt for the no-retrieval baseline arm."""
return _BARE_TEMPLATE.format(
instructions=_BASE_INSTRUCTIONS,
question=question.strip(),
query_time=query_time.strip() or "unknown",
)
def build_surfsense_prompt(question: str, *, query_time: str = "") -> str:
"""Prompt for the SurfSense arm (agent does retrieval itself)."""
return _SURFSENSE_TEMPLATE.format(
instructions=_BASE_INSTRUCTIONS,
question=question.strip(),
query_time=query_time.strip() or "unknown",
)
def build_long_context_prompt(
question: str,
*,
contexts: list[tuple[str, str]],
query_time: str = "",
per_page_char_cap: int = 12_000,
) -> str:
"""Prompt for the "stuff all pages into the prompt" baseline.
``contexts`` is a list of ``(page_title_or_url, page_text)`` pairs.
Each page is truncated at ``per_page_char_cap`` (default 12k chars
3k tokens) so a 5-page CRAG question fits well under any
modern long-context window with room for the question + reasoning.
"""
blocks: list[str] = []
for idx, (title, text) in enumerate(contexts, start=1):
body = (text or "").strip()
if len(body) > per_page_char_cap:
body = body[:per_page_char_cap].rstrip() + "\n[...truncated...]"
title_clean = (title or f"page_{idx}").strip().replace("\n", " ")
blocks.append(
f"--- PAGE {idx}: {title_clean} ---\n{body}\n"
)
contexts_block = "\n".join(blocks) if blocks else "(no pages retrieved)"
return _LONG_CONTEXT_TEMPLATE.format(
instructions=_BASE_INSTRUCTIONS,
n_contexts=len(contexts),
contexts=contexts_block,
question=question.strip(),
query_time=query_time.strip() or "unknown",
)
__all__ = [
"build_bare_prompt",
"build_long_context_prompt",
"build_surfsense_prompt",
]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,29 @@
"""FRAMES — multi-hop Wikipedia retrieval & reasoning (google/frames-benchmark).
Source: https://huggingface.co/datasets/google/frames-benchmark
Paper: https://arxiv.org/abs/2409.12941 (Krishna et al., 2024)
* 824 multi-hop questions, each requiring 2-15 Wikipedia articles
* 5 reasoning types: numerical, tabular, multiple constraints,
temporal, post-processing
* Published Gemini-Pro-1.5 baselines:
- Naive prompting (no retrieval): 40.8%
- BM25, top-4: 47.4%
- Multi-step retrieval & reasoning: 66.0%
- Oracle retrieval (gold articles): 72.9%
This is the benchmark that *finally* puts SurfSense's strongest claim
on trial: cross-document iterative retrieval. The harness ingests
every Wikipedia article referenced by any question in the run sample
into a single SearchSpace; SurfSense answers without
``mentioned_document_ids`` so its agent has to actually retrieve.
The bare-LLM arm answers from the prompt only (the published 40.8%
baseline number).
"""
from __future__ import annotations
from ....core import registry as _registry
from .runner import FramesBenchmark
_registry.register(FramesBenchmark())

View file

@ -0,0 +1,174 @@
"""FRAMES dataset loader — download ``test.tsv`` from HF and parse rows.
The HF repo (``google/frames-benchmark``) ships a single tab-separated
file at ``test.tsv`` (824 rows). Columns of interest for us:
* unnamed first column row index (``id`` we synthesise as ``Q000``..)
* ``Prompt`` the question (free-text, often multi-clause)
* ``Answer`` gold answer (short string: name, number, year, ...)
* ``wikipedia_link_1`` ... ``wikipedia_link_11+`` sparse per-question
link cells (we ignore in favour of the consolidated column below).
* ``reasoning_types`` pipe-separated tags (``"Numerical reasoning |
Tabular reasoning | Multiple constraints"``)
* ``wiki_links`` Python-list literal of every URL the question relies
on, e.g. ``"['https://en.wikipedia.org/wiki/...', '...']"``
We use ``wiki_links`` (already deduplicated per row) and
``ast.literal_eval`` to materialise it. The legacy
``wikipedia_link_*`` columns are kept around only so a curious
operator can compare cell-vs-list if upstream ever drift apart.
"""
from __future__ import annotations
import ast
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
HF_REPO_ID = "google/frames-benchmark"
HF_REPO_TYPE = "dataset"
HF_TEST_FILE = "test.tsv"
def _hf_hub_download(*args: Any, **kwargs: Any) -> str:
from huggingface_hub import hf_hub_download
return hf_hub_download(*args, **kwargs)
# ---------------------------------------------------------------------------
# Question dataclass
# ---------------------------------------------------------------------------
@dataclass
class FramesQuestion:
"""One row of FRAMES (post-parse)."""
qid: str # synthesised "Q000" .. "Q823"
question: str
gold_answer: str
wiki_urls: list[str] # deduped, in original order
reasoning_types: list[str] # split on "|"
raw_index: int # row index from the TSV (for debugging)
def to_dict(self) -> dict[str, Any]:
return {
"qid": self.qid,
"question": self.question,
"gold_answer": self.gold_answer,
"wiki_urls": list(self.wiki_urls),
"reasoning_types": list(self.reasoning_types),
"raw_index": self.raw_index,
}
# ---------------------------------------------------------------------------
# Download + parse
# ---------------------------------------------------------------------------
def download_test_tsv(cache_dir: Path) -> Path:
"""Resumable download of ``test.tsv`` via ``huggingface_hub``."""
cache_dir.mkdir(parents=True, exist_ok=True)
local = _hf_hub_download(
repo_id=HF_REPO_ID,
filename=HF_TEST_FILE,
repo_type=HF_REPO_TYPE,
cache_dir=str(cache_dir),
)
return Path(local)
def _parse_wiki_links(raw: Any) -> list[str]:
"""Convert the ``wiki_links`` cell (Python list literal) to ``list[str]``."""
if not raw:
return []
if isinstance(raw, list):
return [str(x).strip() for x in raw if str(x).strip()]
text = str(raw).strip()
if not text:
return []
try:
parsed = ast.literal_eval(text)
except (SyntaxError, ValueError):
# Fall back: maybe it's a comma-separated string with no quotes.
return [tok.strip() for tok in text.strip("[]").split(",") if tok.strip()]
if isinstance(parsed, (list, tuple)):
return [str(x).strip() for x in parsed if str(x).strip()]
return [str(parsed).strip()]
def _parse_reasoning_types(raw: Any) -> list[str]:
if not raw:
return []
text = str(raw).strip()
if not text:
return []
return [tok.strip() for tok in text.split("|") if tok.strip()]
def load_questions(tsv_path: Path) -> list[FramesQuestion]:
"""Read FRAMES rows from disk into ``FramesQuestion`` objects.
Uses pandas for robust TSV parsing (tabs inside quoted strings are
rare in this dataset but pandas handles them; the stdlib ``csv``
module is fine too if pandas ever becomes a problem). We pin
``index_col=0`` because the upstream TSV uses the first unnamed
column as the row index.
"""
import pandas as pd
df = pd.read_csv(tsv_path, sep="\t", index_col=0, keep_default_na=False)
out: list[FramesQuestion] = []
for raw_idx, row in df.iterrows():
prompt = str(row.get("Prompt") or "").strip()
answer = str(row.get("Answer") or "").strip()
if not prompt or not answer:
logger.debug("Skipping FRAMES row %s with missing Prompt/Answer", raw_idx)
continue
urls = _parse_wiki_links(row.get("wiki_links"))
if not urls:
# Fall back to the per-cell ``wikipedia_link_*`` columns.
urls = []
for col in row.index:
if col.startswith("wikipedia_link"):
val = str(row.get(col) or "").strip()
if val and val not in urls:
urls.append(val)
reasoning = _parse_reasoning_types(row.get("reasoning_types"))
out.append(FramesQuestion(
qid=f"Q{int(raw_idx):03d}",
question=prompt,
gold_answer=answer,
wiki_urls=urls,
reasoning_types=reasoning,
raw_index=int(raw_idx),
))
return out
def write_questions_jsonl(questions: list[FramesQuestion], dest: Path) -> None:
"""Persist a parsed copy under the benchmark data dir."""
dest.parent.mkdir(parents=True, exist_ok=True)
with dest.open("w", encoding="utf-8") as fh:
for q in questions:
fh.write(json.dumps(q.to_dict()) + "\n")
__all__ = [
"FramesQuestion",
"download_test_tsv",
"load_questions",
"write_questions_jsonl",
]

View file

@ -0,0 +1,341 @@
"""FRAMES grader: deterministic shortcut + LLM-as-judge fallback.
FRAMES gold answers are short factoids (a name, a year, an ordinal,
a count). The published paper uses an LLM judge for grading, citing
the long tail of paraphrasing ("Jane Ballou" vs "Mrs. Ballou (Jane)";
"5" vs "five"; "London, England" vs "London"). We replicate that
faithfully *but* avoid burning judge tokens on the obvious cases.
Pipeline per (pred, gold):
1. Normalise both sides (SQuAD-style).
2. If normalised pred == normalised gold CORRECT (``method=exact``).
3. Numeric path: if both extract to a single number and the values
match within 1% relative tolerance CORRECT (``method=numeric``).
4. Substring path: if normalised gold appears as a *whole-word phrase*
inside normalised pred (or vice versa) CORRECT
(``method=substring``).
5. Otherwise call the LLM judge if a judge is wired; the judge
returns yes/no with a one-line rationale.
6. If no judge is configured, fall through to ``False``
(``method=lexical_miss``).
The judge is called *concurrently* across the run via a semaphore (so
it doesn't outrun the upstream rate limit). Cached on
``(arm, qid)`` so re-running ``report`` doesn't re-judge.
Returned shape mirrors ``mmlongbench.grader.GradeResult`` to keep
report writers uniform across benchmarks.
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import string
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from ....core.providers.openrouter_chat import OpenRouterChatProvider
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Public types
# ---------------------------------------------------------------------------
@dataclass
class GradeResult:
"""Shape mirrors mmlongbench.grader.GradeResult for report uniformity."""
correct: bool
f1: float
method: str
normalised_pred: str = ""
normalised_gold: str = ""
judge_rationale: str = ""
def to_dict(self) -> dict[str, Any]:
return {
"correct": self.correct,
"f1": self.f1,
"method": self.method,
"normalised_pred": self.normalised_pred,
"normalised_gold": self.normalised_gold,
"judge_rationale": self.judge_rationale,
}
# ---------------------------------------------------------------------------
# Normalisation
# ---------------------------------------------------------------------------
_PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.IGNORECASE)
_WS = re.compile(r"\s+")
def _normalise(s: str) -> str:
s = (s or "").lower()
s = s.translate(_PUNCT_TABLE)
s = _ARTICLES.sub(" ", s)
s = _WS.sub(" ", s).strip()
return s
_WORD_NUMBERS = {
"zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11,
"twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15, "sixteen": 16,
"seventeen": 17, "eighteen": 18, "nineteen": 19, "twenty": 20,
}
_NUMERIC_RE = re.compile(r"-?\d+(?:[.,]\d+)?")
def _maybe_number(s: str) -> float | None:
"""Extract a single numeric value, recognising digit and word forms.
Operates on the lowercased *raw* text (rather than the
punctuation-stripped normalisation) so that thousands separators
like ``1,234`` are preserved through the regex and parsed
correctly. We only fall back to ``_normalise`` for the word-number
pass, which doesn't care about punctuation.
"""
raw = (s or "").strip().lower()
if not raw:
return None
match = _NUMERIC_RE.search(raw)
if match:
try:
return float(match.group(0).replace(",", ""))
except ValueError:
pass
for tok in _normalise(s).split():
if tok in _WORD_NUMBERS:
return float(_WORD_NUMBERS[tok])
return None
def _whole_word_substring(haystack: str, needle: str) -> bool:
"""Is ``needle`` present as a whole-word phrase in ``haystack``?"""
if not needle:
return False
pad_h = f" {haystack} "
pad_n = f" {needle} "
return pad_n in pad_h
# ---------------------------------------------------------------------------
# Deterministic shortcut
# ---------------------------------------------------------------------------
def grade_deterministic(*, pred: str, gold: str) -> GradeResult:
"""Try to grade without the LLM judge. Returns a final-result object.
A ``False`` result with ``method == "lexical_miss"`` is the signal
to the caller that the LLM judge should be consulted (if available).
"""
if not (pred or "").strip():
return GradeResult(False, 0.0, "empty_pred", "", _normalise(gold))
p = _normalise(pred)
g = _normalise(gold)
if not g:
# Defensively: gold should never be empty; if it is, we can't grade.
return GradeResult(False, 0.0, "empty_gold", p, g)
if p == g:
return GradeResult(True, 1.0, "exact", p, g)
p_num = _maybe_number(pred)
g_num = _maybe_number(gold)
if p_num is not None and g_num is not None:
# 1% relative tolerance, 0.5 absolute floor (handles year-ish answers).
tol = max(abs(g_num) * 0.01, 0.5)
if abs(p_num - g_num) <= tol:
return GradeResult(True, 1.0, "numeric", p, g)
return GradeResult(False, 0.0, "numeric_miss", p, g)
if _whole_word_substring(p, g):
return GradeResult(True, 1.0, "substring", p, g)
if _whole_word_substring(g, p) and len(p) >= 3:
# Be conservative the other direction — only credit if pred is
# at least 3 normalised chars (avoids "John" matching gold
# "John F. Kennedy" as correct).
return GradeResult(True, 1.0, "substring_reverse", p, g)
return GradeResult(False, 0.0, "lexical_miss", p, g)
# ---------------------------------------------------------------------------
# LLM-as-judge
# ---------------------------------------------------------------------------
_JUDGE_SYSTEM = (
"You are an impartial grader for short-answer factual questions. "
"Given a question, the gold answer, and a model's prediction, "
"decide whether the prediction is correct. The prediction is "
"correct if it expresses the same factual content as the gold "
"answer, allowing for paraphrasing, surface-level differences "
"(numbers as words, names with/without titles), and additional "
"non-contradictory detail. The prediction is incorrect if it "
"expresses a different fact, omits the central answer, or hedges "
"without committing.\n\n"
"Respond with ONLY a JSON object on a single line:\n"
'{\"correct\": true|false, \"rationale\": \"<one short sentence>\"}'
)
_JUDGE_TEMPLATE = """\
Question: {question}
Gold answer: {gold}
Model prediction: {pred}
Decide whether the prediction is correct.
"""
@dataclass
class JudgeConfig:
"""Configuration handed to ``LlmJudge`` at construction time."""
api_key: str
model: str = "anthropic/claude-sonnet-4.5"
base_url: str = "https://openrouter.ai/api/v1"
max_tokens: int = 200
concurrency: int = 4
class LlmJudge:
"""Async LLM judge over OpenRouter chat completions."""
def __init__(self, *, config: JudgeConfig) -> None:
self._config = config
self._provider = OpenRouterChatProvider(
api_key=config.api_key,
base_url=config.base_url,
model=config.model,
)
self._sem = asyncio.Semaphore(max(1, config.concurrency))
@property
def model(self) -> str:
return self._config.model
async def judge(
self,
*,
question: str,
gold: str,
pred: str,
) -> tuple[bool, str]:
"""Return ``(is_correct, rationale)``. Errors return False + reason."""
prompt = _JUDGE_TEMPLATE.format(question=question, gold=gold, pred=pred)
try:
async with self._sem:
response = await self._provider.complete(
prompt=prompt,
system_prompt=_JUDGE_SYSTEM,
max_tokens=self._config.max_tokens,
)
except Exception as exc: # noqa: BLE001
return False, f"judge_error: {type(exc).__name__}: {exc}"
return _parse_judge_response(response.text)
def _parse_judge_response(text: str) -> tuple[bool, str]:
"""Pull ``correct`` + ``rationale`` out of the judge's reply."""
if not text or not text.strip():
return False, "judge_returned_empty"
# Accept JSON anywhere in the message; some models prepend prose.
match = re.search(r"\{[^{}]*\}", text, flags=re.DOTALL)
candidate = match.group(0) if match else text
try:
data = json.loads(candidate)
except (json.JSONDecodeError, ValueError):
# Fallback: yes/no parsing.
lowered = text.strip().lower()
if lowered.startswith("yes") or "correct: yes" in lowered or '"correct": true' in lowered:
return True, "yes (parser_fallback)"
if lowered.startswith("no") or "correct: no" in lowered or '"correct": false' in lowered:
return False, "no (parser_fallback)"
return False, f"unparseable_judge_response: {text[:200]}"
correct = bool(data.get("correct"))
rationale = str(data.get("rationale", "")).strip()[:280]
return correct, rationale
# ---------------------------------------------------------------------------
# Combined grader
# ---------------------------------------------------------------------------
async def grade_with_judge(
*,
pred: str,
gold: str,
question: str,
judge: LlmJudge | None,
) -> GradeResult:
"""Grade one row: deterministic shortcut → optional LLM judge fallback."""
det = grade_deterministic(pred=pred, gold=gold)
if det.correct or det.method != "lexical_miss":
return det
if judge is None:
return det
is_correct, rationale = await judge.judge(question=question, gold=gold, pred=pred)
return GradeResult(
correct=is_correct,
f1=1.0 if is_correct else 0.0,
method="llm_judge",
normalised_pred=det.normalised_pred,
normalised_gold=det.normalised_gold,
judge_rationale=rationale,
)
async def grade_many(
*,
rows: Sequence[tuple[str, str, str, str]],
judge: LlmJudge | None,
) -> list[GradeResult]:
"""Grade ``[(qid, question, gold, pred), ...]`` concurrently.
The judge already enforces its own concurrency cap; this just
schedules everything via ``asyncio.gather``. ``qid`` is unused
inside the grader but threaded through so callers can correlate
results back to their rows.
"""
if not rows:
return []
coros = [
grade_with_judge(pred=p, gold=g, question=q, judge=judge)
for _qid, q, g, p in rows
]
return list(await asyncio.gather(*coros))
__all__ = [
"GradeResult",
"JudgeConfig",
"LlmJudge",
"grade_deterministic",
"grade_many",
"grade_with_judge",
]

View file

@ -0,0 +1,341 @@
"""FRAMES ingestion: download → fetch Wikipedia → upload markdown.
Steps:
1. Download ``test.tsv`` from ``hf://datasets/google/frames-benchmark``.
2. Parse rows into ``FramesQuestion`` objects.
3. Optionally cap to the first ``--max-questions N`` so a smoke run
doesn't trigger a 1k-article fetch.
4. Build the **deduplicated** set of Wikipedia URLs across the chosen
sample (questions share many articles Q1 and Q42 might both
reference ``James_A._Garfield``).
5. Fetch each unique article via ``WikiFetcher`` (polite 2 RPS) into
``<bench_dir>/wiki/<title>.md``.
6. Upload the resulting markdown files to SurfSense in batches with
``use_vision_llm=False, processing_mode="basic"`` (text-only no
reason to pay vision LLM costs on Wikipedia plaintext).
7. Persist a doc map at
``<suite_data>/maps/frames_doc_map.jsonl`` with one row per question
listing its ``document_ids`` (so the runner *could* scope retrieval
if requested, though by default we don't — see ``runner.py``).
The doc map row shape:
{"qid": "Q000",
"wiki_titles": ["President of the United States", "James Buchanan", ...],
"document_ids": [123, 124, ...],
"missing_titles": []}
We resolve titles SurfSense document_ids via the post-upload
``DocumentStatus.title`` field. SurfSense's title is the uploaded
filename (without extension), so we round-trip via
``cache_filename_for_title`` to match.
"""
from __future__ import annotations
import asyncio
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from ....core.clients.documents import (
DocumentProcessingFailed,
DocumentProcessingTimeout,
)
from ....core.config import set_suite_state
from ....core.ingest_settings import IngestSettings, settings_header_line
from ....core.registry import RunContext
from .dataset import (
download_test_tsv,
load_questions,
write_questions_jsonl,
)
from .wiki_fetch import (
WikiArticle,
WikiFetcher,
cache_filename_for_title,
title_from_url,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@dataclass
class _IngestStats:
n_questions: int
n_unique_urls: int
n_fetched: int
n_cached_hits: int
n_missing: int
n_uploaded: int
n_existing: int
bench_dir: Path
map_path: Path
async def _fetch_articles(
fetcher: WikiFetcher,
urls: list[str],
) -> tuple[dict[str, WikiArticle], list[str]]:
"""Fetch each URL serially (the WikiFetcher's rate-limiter serialises anyway).
Returns ``(url -> WikiArticle, missing_urls)``. Missing means
Wikipedia reported the title doesn't exist, the URL was non-wiki,
or the API returned an empty extract.
"""
fetched: dict[str, WikiArticle] = {}
missing: list[str] = []
n_total = len(urls)
for i, url in enumerate(urls, start=1):
try:
article = await fetcher.fetch(url)
except Exception as exc: # noqa: BLE001
logger.warning("FRAMES wiki fetch %s failed: %s", url, exc)
missing.append(url)
continue
if article is None:
missing.append(url)
continue
fetched[url] = article
if i % 25 == 0 or i == n_total:
logger.info(" ... fetched %d / %d Wikipedia articles", i, n_total)
return fetched, missing
async def _upload_markdowns(
ctx: RunContext,
articles: list[WikiArticle],
*,
batch_size: int,
settings: IngestSettings,
) -> dict[str, int]:
"""Upload deduplicated markdown files. Returns ``filename -> document_id``.
SurfSense dedupes uploads on ``(filename, search_space_id)``, so
re-running ingest after a crash is idempotent duplicates land in
``duplicate_document_ids`` and we still recover their ids via the
status endpoint.
"""
if not articles:
return {}
docs_client = ctx.documents_client()
name_to_id: dict[str, int] = {}
paths = [a.markdown_path for a in articles]
for batch_start in range(0, len(paths), batch_size):
batch = paths[batch_start : batch_start + batch_size]
result = await docs_client.upload(
files=batch,
search_space_id=ctx.search_space_id,
should_summarize=settings.should_summarize,
use_vision_llm=settings.use_vision_llm,
processing_mode=settings.processing_mode,
)
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
if result.document_ids:
try:
await docs_client.wait_until_ready(
search_space_id=ctx.search_space_id,
document_ids=result.document_ids,
timeout_s=900.0,
)
except (DocumentProcessingFailed, DocumentProcessingTimeout) as exc:
logger.warning("FRAMES batch processing issue: %s", exc)
if all_ids:
statuses = await docs_client.get_status(
search_space_id=ctx.search_space_id,
document_ids=all_ids,
)
for s in statuses:
# SurfSense stores the uploaded filename as ``title`` (no extension).
stem = Path(s.title).stem if s.title.endswith(".md") else s.title
name_to_id[stem] = s.document_id
name_to_id[s.title] = s.document_id
logger.info(
"FRAMES upload batch %d-%d: %d new, %d duplicate",
batch_start, batch_start + len(batch),
len(result.document_ids), len(result.duplicate_document_ids),
)
return name_to_id
def _resolve_question_doc_ids(
questions: list[Any],
fetched: dict[str, WikiArticle],
name_to_id: dict[str, int],
) -> list[dict[str, Any]]:
"""For each question, list the document_ids of its (fetched) wiki articles."""
rows: list[dict[str, Any]] = []
for q in questions:
doc_ids: list[int] = []
titles: list[str] = []
missing: list[str] = []
for url in q.wiki_urls:
article = fetched.get(url)
if article is None:
missing.append(url)
continue
titles.append(article.title)
stem = Path(cache_filename_for_title(article.title)).stem
doc_id = name_to_id.get(stem) or name_to_id.get(article.markdown_path.name)
if doc_id is not None and doc_id not in doc_ids:
doc_ids.append(doc_id)
rows.append({
"qid": q.qid,
"raw_index": q.raw_index,
"n_wiki_urls": len(q.wiki_urls),
"wiki_titles": titles,
"document_ids": doc_ids,
"missing_urls": missing,
})
return rows
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
async def run_ingest(
ctx: RunContext,
*,
max_questions: int | None = None,
upload_batch_size: int = 16,
skip_upload: bool = False,
fetch_rate_limit_rps: float = 2.0,
settings: IngestSettings | None = None,
) -> None:
"""Ingest the FRAMES benchmark into the research suite.
Parameters
----------
max_questions : int | None
Cap on the number of FRAMES questions to materialise. ``None`` =
all 824 (300+ unique articles). Smoke runs should pass 5-10.
upload_batch_size : int
Markdown files per ``/documents/fileupload`` call. Larger
batches reduce round-trip overhead; smaller batches recover
faster from individual processing failures.
skip_upload : bool
Fetch + cache Wikipedia articles locally but don't push to
SurfSense. Useful for debugging the fetcher in isolation.
fetch_rate_limit_rps : float
Maximum requests-per-second to the Wikipedia API. Default 2.0
is a polite ceiling; raise cautiously.
settings : IngestSettings | None
Override per-upload knobs. FRAMES defaults to text-only
(no vision LLM, basic mode) the corpus is plain wikitext.
"""
settings = settings or IngestSettings(
use_vision_llm=False,
processing_mode="basic",
should_summarize=False,
)
bench_dir = ctx.benchmark_data_dir()
wiki_cache = bench_dir / "wiki"
wiki_cache.mkdir(parents=True, exist_ok=True)
hf_cache = bench_dir / ".hf_cache"
hf_cache.mkdir(parents=True, exist_ok=True)
# 1. Download + parse questions.
tsv_path = download_test_tsv(hf_cache)
questions = load_questions(tsv_path)
if not questions:
raise RuntimeError(
"FRAMES test.tsv contained no parseable rows; upstream may "
"have changed schema."
)
logger.info("FRAMES: parsed %d questions from %s", len(questions), tsv_path.name)
if max_questions is not None and max_questions > 0:
questions = questions[:max_questions]
logger.info("FRAMES: capped to first %d questions", len(questions))
questions_jsonl = bench_dir / "questions.jsonl"
write_questions_jsonl(questions, questions_jsonl)
# 2. Build deduplicated URL set (preserving first-seen order).
seen_urls: dict[str, None] = {}
for q in questions:
for url in q.wiki_urls:
seen_urls.setdefault(url, None)
unique_urls = list(seen_urls.keys())
logger.info(
"FRAMES: %d unique Wikipedia URLs across %d questions",
len(unique_urls), len(questions),
)
# 3. Fetch (with cache).
fetcher = WikiFetcher(cache_dir=wiki_cache, rate_limit_rps=fetch_rate_limit_rps)
n_cached = sum(
1 for url in unique_urls
if (wiki_cache / cache_filename_for_title(_safe_title(url))).exists()
)
fetched, missing_urls = await _fetch_articles(fetcher, unique_urls)
logger.info(
"FRAMES: fetched=%d, cache_hits=%d, missing=%d",
len(fetched), n_cached, len(missing_urls),
)
# 4. Upload to SurfSense (deduped by filename).
name_to_id: dict[str, int] = {}
if skip_upload:
logger.info("FRAMES: --skip-upload; skipping SurfSense ingestion")
else:
unique_articles = list({a.markdown_path: a for a in fetched.values()}.values())
name_to_id = await _upload_markdowns(
ctx,
unique_articles,
batch_size=upload_batch_size,
settings=settings,
)
# 5. Persist per-question doc map.
doc_rows = _resolve_question_doc_ids(questions, fetched, name_to_id)
map_path = ctx.maps_dir() / "frames_doc_map.jsonl"
with map_path.open("w", encoding="utf-8") as fh:
fh.write(settings_header_line(settings) + "\n")
for row in doc_rows:
fh.write(json.dumps(row) + "\n")
logger.info("Wrote FRAMES doc map to %s (%d rows)", map_path, len(doc_rows))
# 6. Update suite state.
new_state = ctx.suite_state
new_state.ingestion_maps["frames"] = str(map_path)
set_suite_state(ctx.config, ctx.suite, new_state)
stats = _IngestStats(
n_questions=len(questions),
n_unique_urls=len(unique_urls),
n_fetched=len(fetched),
n_cached_hits=n_cached,
n_missing=len(missing_urls),
n_uploaded=len(name_to_id),
n_existing=0,
bench_dir=bench_dir,
map_path=map_path,
)
logger.info("FRAMES ingest done: %s", stats)
def _safe_title(url: str) -> str:
"""Pre-cache title resolution; returns ``""`` on bad URL."""
try:
return title_from_url(url)
except ValueError:
return ""
__all__ = ["run_ingest"]

View file

@ -0,0 +1,71 @@
"""FRAMES prompt templates.
Two templates: one for the bare-LLM arm (no retrieval), one for
SurfSense (the agent retrieves; we mostly just instruct it on
output format). Both arms must use byte-identical *content* for the
question itself so the head-to-head is fair the wrappers diverge
only in framing.
Format expectations (mirrors the FRAMES paper, section 4):
* Short factual answer names, dates, numbers, ordinals
* No extra explanation in the final line; we anchor on
``Answer: <text>`` for deterministic extraction
* Free-text reasoning is *allowed* before the final ``Answer:`` line
multi-hop questions often benefit from it. We just don't grade it.
"""
from __future__ import annotations
_BASE_INSTRUCTIONS = (
"You are a careful question-answering assistant. The question may "
"require combining facts from multiple sources, doing arithmetic, "
"or reasoning about dates. Think step by step if needed, then give "
"the final answer.\n\n"
"Format your final line EXACTLY as:\n"
"Answer: <short answer>\n\n"
"The answer should be as short as possible — a name, a number, a "
"date, a single phrase. Do not repeat the question. Do not include "
"punctuation at the end unless it is part of the answer."
)
_BARE_TEMPLATE = """\
{instructions}
Question: {question}
"""
_SURFSENSE_TEMPLATE = """\
{instructions}
You have access to a Wikipedia knowledge base via retrieval. Use it
to look up any facts you are not confident about. The corpus contains
the Wikipedia articles needed to answer this question, but you must
retrieve them yourself they are not pre-selected.
Question: {question}
"""
def build_bare_prompt(question: str) -> str:
"""Prompt for the no-retrieval baseline arm."""
return _BARE_TEMPLATE.format(
instructions=_BASE_INSTRUCTIONS,
question=question.strip(),
)
def build_surfsense_prompt(question: str) -> str:
"""Prompt for the SurfSense arm (retrieval-augmented)."""
return _SURFSENSE_TEMPLATE.format(
instructions=_BASE_INSTRUCTIONS,
question=question.strip(),
)
__all__ = ["build_bare_prompt", "build_surfsense_prompt"]

View file

@ -0,0 +1,686 @@
"""FRAMES runner — Bare LLM (no retrieval) vs SurfSense (multi-hop RAG).
Two arms run paired on every question in the sample:
1. ``BareLlmArm`` OpenRouter chat completion with the question only.
Reproduces the published "naive prompting" baseline (40.8% on
Gemini-Pro-1.5).
2. ``SurfSenseArm`` POST ``/api/v1/new_chat`` with **no**
``mentioned_document_ids`` so the agent retrieves over the entire
ingested Wikipedia corpus. This is the "multi-step retrieval &
reasoning" cell in the FRAMES paper.
Open-ended grading: deterministic shortcut + optional LLM-as-judge
(``--no-judge`` to disable). Cost / latency / token aggregates are
collected per arm. Paired stats (McNemar, bootstrap CI) for the
accuracy delta. Per-reasoning-type breakdown to surface where one
arm beats the other (numerical vs temporal vs multi-constraint, ...).
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import os
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from ....core.arms import ArmRequest, ArmResult, BareLlmArm, SurfSenseArm
from ....core.config import utc_iso_timestamp
from ....core.ingest_settings import (
IngestSettings,
add_ingest_settings_args,
format_ingest_settings_md,
is_settings_header,
)
from ....core.metrics.comparison import (
bootstrap_delta_ci,
mcnemar_test,
paired_aggregate,
)
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci
from ....core.parse.freeform_answer import extract_freeform_answer
from ....core.providers.openrouter_chat import OpenRouterChatProvider
from ....core.registry import ReportSection, RunArtifact, RunContext
from ....core.scenarios import format_scenario_md
from .grader import GradeResult, JudgeConfig, LlmJudge, grade_many
from .prompt import build_bare_prompt, build_surfsense_prompt
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Question shape
# ---------------------------------------------------------------------------
@dataclass
class FramesRunnerQuestion:
qid: str
raw_index: int
question: str
gold_answer: str
reasoning_types: list[str]
document_ids: list[int] # subset of corpus relevant to this Q (may be empty)
n_wiki_urls: int
missing_urls: list[str]
def _load_doc_map(map_path: Path) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
rows: dict[str, dict[str, Any]] = {}
settings: dict[str, Any] = {}
with map_path.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
row = json.loads(line)
if is_settings_header(row):
settings = dict(row["__settings__"])
continue
rows[str(row["qid"])] = row
return rows, settings
def _load_questions(
questions_jsonl: Path,
doc_map: dict[str, dict[str, Any]],
*,
sample_n: int | None,
reasoning_filter: str | None,
) -> list[FramesRunnerQuestion]:
out: list[FramesRunnerQuestion] = []
with questions_jsonl.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
row = json.loads(line)
qid = str(row.get("qid") or "").strip()
if not qid:
continue
map_row = doc_map.get(qid, {})
reasoning = list(row.get("reasoning_types") or [])
if reasoning_filter and reasoning_filter not in [r.lower() for r in reasoning]:
continue
out.append(FramesRunnerQuestion(
qid=qid,
raw_index=int(row.get("raw_index") or 0),
question=str(row.get("question") or "").strip(),
gold_answer=str(row.get("gold_answer") or "").strip(),
reasoning_types=reasoning,
document_ids=list(map_row.get("document_ids") or []),
n_wiki_urls=int(map_row.get("n_wiki_urls") or 0),
missing_urls=list(map_row.get("missing_urls") or []),
))
out.sort(key=lambda q: q.raw_index)
if sample_n is not None and sample_n > 0:
out = out[:sample_n]
return out
# ---------------------------------------------------------------------------
# Bounded concurrency helper
# ---------------------------------------------------------------------------
async def _gather_with_limit(coros: Iterable, *, concurrency: int) -> list[Any]:
sem = asyncio.Semaphore(max(1, concurrency))
async def _wrap(coro):
async with sem:
return await coro
return await asyncio.gather(*(_wrap(c) for c in coros))
# ---------------------------------------------------------------------------
# Benchmark
# ---------------------------------------------------------------------------
_DESCRIPTION = (
"FRAMES (824 multi-hop Wikipedia questions, 5 reasoning types) — "
"Bare LLM (no retrieval) vs SurfSense (multi-step RAG over the "
"Wikipedia corpus). Tests cross-document retrieval + reasoning."
)
_DEFAULT_INGEST_SETTINGS = IngestSettings(
use_vision_llm=False,
processing_mode="basic",
should_summarize=False,
)
class FramesBenchmark:
"""Multi-hop Wikipedia RAG vs naive prompting."""
suite: str = "research"
name: str = "frames"
headline: bool = True
description: str = _DESCRIPTION
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--n", dest="sample_n", type=int, default=None,
help="Run only the first N questions after filters (default: all 824).",
)
parser.add_argument(
"--reasoning",
dest="reasoning_filter",
default=None,
help=(
"Filter to questions tagged with this reasoning type "
"(e.g. 'numerical reasoning', 'temporal reasoning'). "
"Case-insensitive substring against the upstream tags."
),
)
parser.add_argument(
"--concurrency", type=int, default=4,
help="Parallel question workers per arm.",
)
parser.add_argument(
"--scope-mentions", dest="scope_mentions", action="store_true",
help=(
"SurfSense arm: scope retrieval to the per-question "
"document_ids (oracle-retrieval upper bound). Default "
"is full-corpus retrieval (the realistic FRAMES setting)."
),
)
parser.add_argument(
"--max-output-tokens", type=int, default=512,
help="Cap on completion length for both arms.",
)
parser.add_argument(
"--no-judge", dest="no_judge", action="store_true",
help=(
"Disable LLM-as-judge fallback grading; use only the "
"deterministic grader (faster but more pessimistic)."
),
)
parser.add_argument(
"--judge-model",
dest="judge_model",
default="anthropic/claude-sonnet-4.5",
help="OpenRouter slug for the LLM judge (default: claude-sonnet-4.5).",
)
parser.add_argument(
"--judge-concurrency",
dest="judge_concurrency",
type=int,
default=4,
help="Parallel judge calls (default: 4).",
)
# Ingest-only knobs.
parser.add_argument(
"--max-questions", dest="max_questions", type=int, default=None,
help="(ingest only) cap on number of questions to materialise + ingest.",
)
parser.add_argument(
"--upload-batch-size", dest="upload_batch_size", type=int, default=16,
help="(ingest only) markdown files per fileupload call.",
)
parser.add_argument(
"--skip-upload", dest="skip_upload", action="store_true",
help="(ingest only) cache wiki articles locally but don't push to SurfSense.",
)
parser.add_argument(
"--fetch-rps", dest="fetch_rate_limit_rps", type=float, default=2.0,
help="(ingest only) max requests/second to the Wikipedia API.",
)
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
from .ingest import run_ingest
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
await run_ingest(
ctx,
max_questions=opts.get("max_questions"),
upload_batch_size=int(opts.get("upload_batch_size") or 16),
skip_upload=bool(opts.get("skip_upload", False)),
fetch_rate_limit_rps=float(opts.get("fetch_rate_limit_rps") or 2.0),
settings=settings,
)
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
sample_n = opts.get("sample_n")
reasoning_filter = opts.get("reasoning_filter")
if reasoning_filter:
reasoning_filter = reasoning_filter.strip().lower() or None
concurrency = int(opts.get("concurrency") or 4)
scope_mentions = bool(opts.get("scope_mentions"))
max_output_tokens = int(opts.get("max_output_tokens") or 512)
no_judge = bool(opts.get("no_judge"))
judge_model = str(opts.get("judge_model") or "anthropic/claude-sonnet-4.5")
judge_concurrency = int(opts.get("judge_concurrency") or 4)
bench_dir = ctx.benchmark_data_dir()
questions_jsonl = bench_dir / "questions.jsonl"
map_path = ctx.maps_dir() / "frames_doc_map.jsonl"
if not questions_jsonl.exists() or not map_path.exists():
raise RuntimeError(
"FRAMES not ingested for this suite. Run "
"`python -m surfsense_evals ingest research frames` first."
)
doc_map, ingest_settings = _load_doc_map(map_path)
questions = _load_questions(
questions_jsonl, doc_map,
sample_n=sample_n,
reasoning_filter=reasoning_filter,
)
if not questions:
raise RuntimeError(
"No FRAMES questions matched the filters; broaden --reasoning/--n."
)
logger.info("FRAMES: scheduled %d questions", len(questions))
api_key = os.environ.get("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError(
"OPENROUTER_API_KEY env var is required for the bare-LLM arm."
)
bare_provider = OpenRouterChatProvider(
api_key=api_key,
base_url=ctx.config.openrouter_base_url,
model=ctx.native_arm_model,
)
bare_arm = BareLlmArm(
provider=bare_provider,
max_output_tokens=max_output_tokens,
)
surf_arm = SurfSenseArm(
client=ctx.new_chat_client(),
search_space_id=ctx.search_space_id,
ephemeral_threads=True,
)
judge: LlmJudge | None = None
if not no_judge:
judge = LlmJudge(config=JudgeConfig(
api_key=api_key,
model=judge_model,
base_url=ctx.config.openrouter_base_url,
concurrency=judge_concurrency,
))
run_timestamp = utc_iso_timestamp()
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
raw_path = run_dir / "raw.jsonl"
async def _bare_one(q: FramesRunnerQuestion) -> ArmResult:
return await bare_arm.answer(_make_bare_request(q, max_output_tokens))
async def _surf_one(q: FramesRunnerQuestion) -> ArmResult:
return await surf_arm.answer(
_make_surfsense_request(q, scope_mentions=scope_mentions)
)
bare_results, surf_results = await asyncio.gather(
_gather_with_limit((_bare_one(q) for q in questions), concurrency=concurrency),
_gather_with_limit((_surf_one(q) for q in questions), concurrency=concurrency),
)
bare_grades = await _grade_results(questions, bare_results, judge=judge)
surf_grades = await _grade_results(questions, surf_results, judge=judge)
with raw_path.open("w", encoding="utf-8") as fh:
for q, b_res, s_res, b_g, s_g in zip(
questions, bare_results, surf_results, bare_grades, surf_grades, strict=False
):
meta = {
"qid": q.qid,
"raw_index": q.raw_index,
"reasoning_types": q.reasoning_types,
"n_wiki_urls": q.n_wiki_urls,
"n_resolved_doc_ids": len(q.document_ids),
"n_missing_urls": len(q.missing_urls),
"gold": q.gold_answer,
}
fh.write(json.dumps({
**meta,
**b_res.to_jsonl(),
"graded": b_g.to_dict(),
}) + "\n")
fh.write(json.dumps({
**meta,
**s_res.to_jsonl(),
"graded": s_g.to_dict(),
}) + "\n")
metrics = _compute_metrics(questions, bare_results, surf_results, bare_grades, surf_grades)
artifact = RunArtifact(
suite=self.suite,
benchmark=self.name,
run_timestamp=run_timestamp,
raw_path=raw_path,
metrics=metrics,
extra={
"n_questions": len(questions),
"concurrency": concurrency,
"reasoning_filter": reasoning_filter,
"scope_mentions": scope_mentions,
"no_judge": no_judge,
"judge_model": judge_model if not no_judge else None,
"scenario": ctx.scenario,
"provider_model": ctx.provider_model,
"native_arm_model": ctx.native_arm_model,
"vision_provider_model": ctx.vision_provider_model,
"agent_llm_id": ctx.agent_llm_id,
"ingest_settings": ingest_settings,
"bare_arm_label": "bare_llm",
},
)
manifest_path = run_dir / "run_artifact.json"
manifest_path.write_text(
json.dumps({
"suite": self.suite,
"benchmark": self.name,
"raw_path": "raw.jsonl",
"metrics": metrics,
"extra": artifact.extra,
}, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
return artifact
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
if not artifacts:
return ReportSection(
title="FRAMES — Bare LLM vs SurfSense (multi-hop Wikipedia RAG)",
headline=True,
body_md="(no run artifacts found)",
body_json={},
)
latest = max(artifacts, key=lambda a: a.run_timestamp)
m = latest.metrics
bare = m.get("bare", {})
surf = m.get("surfsense", {})
delta = m.get("delta", {})
per_reasoning = m.get("per_reasoning", {})
extra = latest.extra
body_lines: list[str] = []
body_lines.append(
f"- Sample size: {extra.get('n_questions', '?')} questions "
f"(reasoning filter: `{extra.get('reasoning_filter') or 'none'}`, "
f"scope-mentions: `{extra.get('scope_mentions', False)}`, "
f"judge: `{extra.get('judge_model') or 'deterministic-only'}`)."
)
body_lines.append(format_scenario_md(extra))
body_lines.append(format_ingest_settings_md(extra.get("ingest_settings")))
body_lines.append(
"- Bare LLM arm (OpenRouter chat, no retrieval, "
f"`{extra.get('native_arm_model') or extra.get('provider_model', '?')}`):"
)
body_lines.append(_arm_summary_lines(bare, indent=" "))
body_lines.append(
"- SurfSense arm (`POST /api/v1/new_chat`, multi-step RAG, "
f"`{extra.get('provider_model', '?')}`):"
)
body_lines.append(_arm_summary_lines(surf, indent=" "))
body_lines.append("- Delta (paired):")
body_lines.append(
f" - Accuracy: SurfSense {_pp(delta.get('accuracy_pp'))} pp "
f"(McNemar p={_fmt(delta.get('mcnemar_p_value'), 4)}, "
f"method={delta.get('mcnemar_method')})"
)
body_lines.append(
f" - Bootstrap 95% CI on accuracy delta: "
f"[{_pp(delta.get('bootstrap_ci_low'))}pp, {_pp(delta.get('bootstrap_ci_high'))}pp]"
)
body_lines.append(
f" - Cost / question: bare ${_dollars(bare.get('cost_micros_mean'))}, "
f"surfsense ${_dollars(surf.get('cost_micros_mean'))} "
f"(SurfSense delta {_pct_change(delta.get('cost_micros_pct'))})"
)
body_lines.append(
f" - Latency p50: bare {_ms_to_s(bare.get('latency_ms_median'))}, "
f"surfsense {_ms_to_s(surf.get('latency_ms_median'))} "
f"(SurfSense delta {_pct_change(delta.get('latency_ms_pct'))})"
)
if per_reasoning:
body_lines.append("- Per-reasoning-type split (accuracy delta in pp):")
for tag, vals in sorted(per_reasoning.items()):
body_lines.append(
f" - {tag}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
f"(n={vals.get('n')}, bare acc={vals.get('bare_accuracy', 0)*100:.1f}%, "
f"surf acc={vals.get('surfsense_accuracy', 0)*100:.1f}%)"
)
return ReportSection(
title="FRAMES — Bare LLM vs SurfSense (multi-hop Wikipedia RAG)",
headline=True,
body_md="\n".join(body_lines),
body_json=m,
)
# ---------------------------------------------------------------------------
# Per-question helpers
# ---------------------------------------------------------------------------
def _make_bare_request(q: FramesRunnerQuestion, max_tokens: int) -> ArmRequest:
return ArmRequest(
question_id=q.qid,
prompt=build_bare_prompt(q.question),
options={"max_tokens": max_tokens},
)
def _make_surfsense_request(q: FramesRunnerQuestion, *, scope_mentions: bool) -> ArmRequest:
mentions: list[int] | None = None
if scope_mentions and q.document_ids:
mentions = list(q.document_ids)
return ArmRequest(
question_id=q.qid,
prompt=build_surfsense_prompt(q.question),
mentioned_document_ids=mentions,
)
async def _grade_results(
questions: list[FramesRunnerQuestion],
results: list[ArmResult],
*,
judge: LlmJudge | None,
) -> list[GradeResult]:
rows: list[tuple[str, str, str, str]] = []
for q, r in zip(questions, results, strict=False):
pred = extract_freeform_answer(r.raw_text or "")
rows.append((q.qid, q.question, q.gold_answer, pred))
return await grade_many(rows=rows, judge=judge)
# ---------------------------------------------------------------------------
# Metrics aggregation
# ---------------------------------------------------------------------------
def _compute_metrics(
questions: list[FramesRunnerQuestion],
bare_results: list[ArmResult],
surf_results: list[ArmResult],
bare_grades: list[GradeResult],
surf_grades: list[GradeResult],
) -> dict[str, Any]:
bare_correct = [g.correct for g in bare_grades]
surf_correct = [g.correct for g in surf_grades]
bare_costs = [float(r.cost_micros) for r in bare_results]
surf_costs = [float(r.cost_micros) for r in surf_results]
bare_latencies = [float(r.latency_ms) for r in bare_results]
surf_latencies = [float(r.latency_ms) for r in surf_results]
bare_in_tokens = [float(r.input_tokens) for r in bare_results]
bare_out_tokens = [float(r.output_tokens) for r in bare_results]
bare_acc = accuracy_with_wilson_ci(sum(bare_correct), len(bare_correct))
surf_acc = accuracy_with_wilson_ci(sum(surf_correct), len(surf_correct))
mc = mcnemar_test(bare_correct, surf_correct)
boot = bootstrap_delta_ci(bare_correct, surf_correct, n_resamples=2000)
bare_cost_agg = paired_aggregate(bare_costs)
surf_cost_agg = paired_aggregate(surf_costs)
bare_latency_agg = paired_aggregate(bare_latencies)
surf_latency_agg = paired_aggregate(surf_latencies)
cost_pct = _safe_pct(surf_cost_agg.mean, bare_cost_agg.mean)
latency_pct = _safe_pct(surf_latency_agg.median, bare_latency_agg.median)
# Per-reasoning-type breakdown. Each question may carry multiple
# reasoning tags; we count it under each tag (so totals don't
# equal len(questions) — the reader is expected to look at the
# per-tag ``n``).
per_reasoning_pairs: dict[str, list[tuple[bool, bool]]] = {}
for q, b_ok, s_ok in zip(questions, bare_correct, surf_correct, strict=False):
tags = q.reasoning_types or ["(untagged)"]
for tag in tags:
per_reasoning_pairs.setdefault(tag, []).append((b_ok, s_ok))
per_reasoning: dict[str, dict[str, Any]] = {}
for tag, pairs in per_reasoning_pairs.items():
b_correct = [a for a, _ in pairs]
s_correct = [b for _, b in pairs]
per_reasoning[tag] = {
"n": len(pairs),
"bare_accuracy": (sum(b_correct) / len(pairs)) if pairs else 0.0,
"surfsense_accuracy": (sum(s_correct) / len(pairs)) if pairs else 0.0,
"delta_accuracy_pp": (
100.0 * (sum(s_correct) - sum(b_correct)) / len(pairs)
if pairs else 0.0
),
}
grader_methods = {
"bare": _count_methods(bare_grades),
"surfsense": _count_methods(surf_grades),
}
return {
"bare": {
**bare_acc.to_dict(),
"cost_micros_mean": bare_cost_agg.mean,
"cost_micros_median": bare_cost_agg.median,
"latency_ms_mean": bare_latency_agg.mean,
"latency_ms_median": bare_latency_agg.median,
"latency_ms_p95": bare_latency_agg.p95,
"input_tokens_mean": (sum(bare_in_tokens) / len(bare_in_tokens)) if bare_in_tokens else 0.0,
"output_tokens_mean": (sum(bare_out_tokens) / len(bare_out_tokens)) if bare_out_tokens else 0.0,
},
"surfsense": {
**surf_acc.to_dict(),
"cost_micros_mean": surf_cost_agg.mean,
"cost_micros_median": surf_cost_agg.median,
"latency_ms_mean": surf_latency_agg.mean,
"latency_ms_median": surf_latency_agg.median,
"latency_ms_p95": surf_latency_agg.p95,
},
"delta": {
"accuracy_pp": 100.0 * (surf_acc.accuracy - bare_acc.accuracy),
"mcnemar_p_value": mc.p_value,
"mcnemar_method": mc.method,
"mcnemar_b_bare_only": mc.b,
"mcnemar_c_surfsense_only": mc.c,
"bootstrap_ci_low": 100.0 * boot.ci_low,
"bootstrap_ci_high": 100.0 * boot.ci_high,
"cost_micros_pct": cost_pct,
"latency_ms_pct": latency_pct,
},
"per_reasoning": per_reasoning,
"grader_methods": grader_methods,
}
def _count_methods(grades: list[GradeResult]) -> dict[str, int]:
out: dict[str, int] = {}
for g in grades:
out[g.method] = out.get(g.method, 0) + 1
return out
def _safe_pct(numerator: float, denominator: float) -> float | None:
if denominator == 0:
return None
return 100.0 * (numerator - denominator) / denominator
# ---------------------------------------------------------------------------
# Tiny formatting helpers used by report_section
# ---------------------------------------------------------------------------
def _arm_summary_lines(d: dict[str, Any], *, indent: str) -> str:
if not d:
return f"{indent}(no data)"
acc = d.get("accuracy", 0.0)
low = d.get("ci_low", 0.0)
high = d.get("ci_high", 0.0)
lines = [
f"{indent}- Accuracy: {acc * 100:.1f}% (Wilson 95% CI: {low * 100:.1f}% {high * 100:.1f}%)",
f"{indent}- Cost / question: ${_dollars(d.get('cost_micros_mean'))} (mean), "
f"${_dollars(d.get('cost_micros_median'))} (median)",
f"{indent}- Latency: p50 {_ms_to_s(d.get('latency_ms_median'))}, "
f"p95 {_ms_to_s(d.get('latency_ms_p95'))}",
]
if "input_tokens_mean" in d:
lines.append(
f"{indent}- Mean tokens / question: in {d.get('input_tokens_mean', 0):.0f}, "
f"out {d.get('output_tokens_mean', 0):.0f}"
)
return "\n".join(lines)
def _dollars(micros: Any) -> str:
if micros is None:
return "?"
try:
return f"{(float(micros) / 1_000_000):.4f}"
except (TypeError, ValueError):
return "?"
def _ms_to_s(ms: Any) -> str:
if ms is None:
return "?"
try:
return f"{float(ms) / 1000:.1f}s"
except (TypeError, ValueError):
return "?"
def _pp(value: Any) -> str:
if value is None:
return "?"
try:
return f"{float(value):+.1f}"
except (TypeError, ValueError):
return "?"
def _pct_change(value: Any) -> str:
if value is None:
return "?"
try:
return f"{float(value):+.0f}%"
except (TypeError, ValueError):
return "?"
def _fmt(value: Any, ndigits: int) -> str:
if value is None:
return "?"
try:
return f"{float(value):.{ndigits}f}"
except (TypeError, ValueError):
return "?"
__all__ = ["FramesBenchmark", "FramesRunnerQuestion"]

View file

@ -0,0 +1,241 @@
"""Wikipedia article fetcher → plain-text markdown, with disk cache.
We hit the MediaWiki action API for *plain text* extracts:
GET https://en.wikipedia.org/w/api.php
?action=query&prop=extracts&explaintext=true
&redirects=1&titles=<Title>&format=json&formatversion=2
This avoids HTMLmarkdown conversion (and its many edge cases). The
``explaintext=true`` mode strips infoboxes / templates / wikilinks
and returns clean section-headered prose, which is exactly what we
want SurfSense to chunk + embed. We prepend ``# <Title>\n\n`` so the
markdown has a visible H1 (helps SurfSense's chunker preserve doc
identity at the top of the first chunk).
Caching: every fetched article lands in
``<bench_dir>/wiki/<sanitised-title>.md`` and is reused on subsequent
runs. The cache key is the URL-decoded title (e.g.
``Charlotte_Brontë`` regardless of source URL casing or
percent-encoding).
Politeness: 2 RPS rate limit + a descriptive User-Agent (Wikimedia
asks for one). We don't parallelise above 2 RPS — this is a courtesy
to Wikipedia and only ~300 articles for the n=100 sample.
"""
from __future__ import annotations
import asyncio
import logging
import re
import urllib.parse
from dataclasses import dataclass
from pathlib import Path
import httpx
logger = logging.getLogger(__name__)
WIKI_API = "https://en.wikipedia.org/w/api.php"
USER_AGENT = (
"SurfSense-Evals/0.1 (https://github.com/MODSetter/SurfSense; "
"research-benchmark fetch; respects 2 RPS rate limit)"
)
@dataclass(frozen=True)
class WikiArticle:
"""One fetched article + metadata."""
title: str # canonical title returned by MW (post-redirect)
source_url: str # the URL we were asked to fetch
markdown_path: Path # where the cached body lives on disk
n_chars: int # length of the body (post-prepend H1)
redirected_from: str | None = None
# ---------------------------------------------------------------------------
# Title <-> URL helpers
# ---------------------------------------------------------------------------
_WIKI_PATH_RE = re.compile(r"^/wiki/(?P<title>[^?#]+)$")
def title_from_url(url: str) -> str:
"""Pull the page title out of a wiki URL.
``https://en.wikipedia.org/wiki/Charlotte_Bront%C3%AB`` ``Charlotte Brontë``.
Spaces are preserved (the API accepts spaces and underscores
interchangeably; we use spaces to keep cache filenames human-readable).
"""
parsed = urllib.parse.urlparse(url)
if parsed.netloc and "wikipedia.org" not in parsed.netloc:
raise ValueError(f"Not a Wikipedia URL: {url!r}")
match = _WIKI_PATH_RE.match(parsed.path)
if not match:
raise ValueError(f"Unrecognised wiki path: {parsed.path!r}")
raw_title = urllib.parse.unquote(match.group("title"))
# MW treats underscores and spaces as equivalent; spaces are friendlier.
return raw_title.replace("_", " ").strip()
_FILENAME_SAFE = re.compile(r"[^A-Za-z0-9._\- ]")
def cache_filename_for_title(title: str) -> str:
"""Map a title to a filesystem-safe filename.
Replaces every non-(alnum / ``._- `` / space) character with ``_``.
Title collisions are rare (FRAMES only has English Wikipedia titles)
and a final ``hash(title)[:8]`` would obscure the otherwise-readable
filenames; we accept the (vanishingly small) collision risk.
"""
safe = _FILENAME_SAFE.sub("_", title)
safe = safe.strip().replace(" ", "_")
return f"{safe}.md"
# ---------------------------------------------------------------------------
# Async fetcher with rate limiting + retry
# ---------------------------------------------------------------------------
class WikiFetcher:
"""Polite fetch + disk cache + redirect handling."""
def __init__(
self,
*,
cache_dir: Path,
rate_limit_rps: float = 2.0,
timeout_s: float = 30.0,
max_retries: int = 3,
) -> None:
self._cache_dir = Path(cache_dir)
self._cache_dir.mkdir(parents=True, exist_ok=True)
self._min_interval = 1.0 / max(rate_limit_rps, 0.1)
self._last_request_at = 0.0
self._rate_lock = asyncio.Lock()
self._timeout = httpx.Timeout(timeout_s, connect=10.0)
self._max_retries = max_retries
async def _throttle(self) -> None:
async with self._rate_lock:
now = asyncio.get_event_loop().time()
wait = self._last_request_at + self._min_interval - now
if wait > 0:
await asyncio.sleep(wait)
self._last_request_at = asyncio.get_event_loop().time()
async def fetch(
self,
url: str,
*,
http: httpx.AsyncClient | None = None,
) -> WikiArticle | None:
"""Fetch one article. Returns ``None`` only if MW reports the title is missing.
Raises on transport errors after retries. Caller decides
whether to abort the whole ingest or continue with the
successfully-fetched subset.
"""
try:
title = title_from_url(url)
except ValueError as exc:
logger.warning("Skipping non-wiki URL %s: %s", url, exc)
return None
cache_path = self._cache_dir / cache_filename_for_title(title)
if cache_path.exists() and cache_path.stat().st_size > 0:
return WikiArticle(
title=title,
source_url=url,
markdown_path=cache_path,
n_chars=cache_path.stat().st_size,
)
last_exc: Exception | None = None
for attempt in range(self._max_retries):
try:
await self._throttle()
payload = await self._fetch_extract(title, http=http)
break
except (httpx.HTTPError, RuntimeError) as exc:
last_exc = exc
wait = 1.0 * (2 ** attempt)
logger.warning(
"wiki fetch %r attempt %d failed: %s; retry in %.1fs",
title, attempt + 1, exc, wait,
)
await asyncio.sleep(wait)
else:
assert last_exc is not None
raise last_exc
page = payload.get("page") or {}
if not page or page.get("missing"):
logger.warning("Wikipedia reports missing page for %r (url=%s)", title, url)
return None
canonical_title = str(page.get("title") or title).strip()
body = str(page.get("extract") or "").strip()
if not body:
logger.warning("Wikipedia returned empty extract for %r", title)
return None
markdown = f"# {canonical_title}\n\n{body}\n"
cache_path.write_text(markdown, encoding="utf-8")
return WikiArticle(
title=canonical_title,
source_url=url,
markdown_path=cache_path,
n_chars=len(markdown),
redirected_from=title if canonical_title != title else None,
)
async def _fetch_extract(
self,
title: str,
*,
http: httpx.AsyncClient | None,
) -> dict:
"""One MW API call. Returns ``{'page': {...}}`` (formatversion=2)."""
params = {
"action": "query",
"prop": "extracts",
"explaintext": "true",
"redirects": "1",
"format": "json",
"formatversion": "2",
"titles": title,
}
headers = {"User-Agent": USER_AGENT, "Accept": "application/json"}
if http is not None:
response = await http.get(WIKI_API, params=params, headers=headers, timeout=self._timeout)
else:
async with httpx.AsyncClient(timeout=self._timeout) as client:
response = await client.get(WIKI_API, params=params, headers=headers, timeout=self._timeout)
response.raise_for_status()
data = response.json()
if "error" in data:
raise RuntimeError(f"MediaWiki API error: {data['error']!r}")
pages = (data.get("query") or {}).get("pages") or []
if not pages:
return {"page": {}}
return {"page": pages[0]}
__all__ = [
"WIKI_API",
"USER_AGENT",
"WikiArticle",
"WikiFetcher",
"cache_filename_for_title",
"title_from_url",
]

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,34 @@
"""Shared pytest fixtures for surfsense-evals."""
from __future__ import annotations
import os
from pathlib import Path
import pytest
from surfsense_evals.core.config import Config
@pytest.fixture
def tmp_env(monkeypatch, tmp_path: Path) -> Path:
"""Isolate env vars + filesystem state per test.
Wipes every ``SURFSENSE_*`` / ``OPENROUTER_*`` / ``EVAL_*`` var so a
test that wants a specific credential mode can ``monkeypatch.setenv``
just what it needs without leakage from the caller's shell.
"""
for key in list(os.environ):
if key.startswith(("SURFSENSE_", "OPENROUTER_", "EVAL_")):
monkeypatch.delenv(key, raising=False)
monkeypatch.setenv("EVAL_DATA_DIR", str(tmp_path / "data"))
monkeypatch.setenv("EVAL_REPORTS_DIR", str(tmp_path / "reports"))
return tmp_path
@pytest.fixture
def isolated_config(tmp_env: Path) -> Config: # noqa: ARG001
from surfsense_evals.core.config import load_config
return load_config()

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,95 @@
"""Auth credential resolution + 401 refresh hook."""
from __future__ import annotations
import httpx
import pytest
import respx
from surfsense_evals.core.auth import (
CredentialError,
acquire_token,
client_with_auth,
)
from surfsense_evals.core.config import Config
def _make_config(**overrides) -> Config:
base = {
"surfsense_api_base": "http://test",
"openrouter_api_key": None,
"openrouter_base_url": "https://openrouter.ai/api/v1",
"surfsense_jwt": None,
"surfsense_refresh_token": None,
"surfsense_user_email": None,
"surfsense_user_password": None,
"data_dir": None,
"reports_dir": None,
}
base.update(overrides)
# Path objects required by Config; tests don't touch the FS.
from pathlib import Path
base["data_dir"] = base["data_dir"] or Path("/tmp/eval_test_data")
base["reports_dir"] = base["reports_dir"] or Path("/tmp/eval_test_reports")
return Config(**base)
@pytest.mark.asyncio
async def test_acquire_token_jwt_mode_short_circuits():
config = _make_config(surfsense_jwt="abc", surfsense_refresh_token="ref")
bundle = await acquire_token(config)
assert bundle.access_token == "abc"
assert bundle.refresh_token == "ref"
assert bundle.mode == "jwt"
@pytest.mark.asyncio
@respx.mock
async def test_acquire_token_local_mode_posts_form():
respx.post("http://test/auth/jwt/login").mock(
return_value=httpx.Response(
200, json={"access_token": "T", "refresh_token": "R", "token_type": "bearer"}
)
)
config = _make_config(
surfsense_user_email="u@example.com", surfsense_user_password="pw"
)
bundle = await acquire_token(config)
assert bundle.access_token == "T"
assert bundle.refresh_token == "R"
assert bundle.mode == "local"
@pytest.mark.asyncio
async def test_acquire_token_no_credentials():
config = _make_config()
with pytest.raises(CredentialError) as exc:
await acquire_token(config)
assert "SURFSENSE_USER_EMAIL" in str(exc.value)
assert "SURFSENSE_JWT" in str(exc.value)
@pytest.mark.asyncio
@respx.mock
async def test_client_with_auth_refreshes_on_401():
config = _make_config(surfsense_jwt="old", surfsense_refresh_token="ref")
bundle = await acquire_token(config)
respx.post("http://test/auth/jwt/refresh").mock(
return_value=httpx.Response(200, json={"access_token": "new", "refresh_token": "ref2"})
)
# First call returns 401; the retry (post-refresh) returns 200.
respx.get("http://test/api/v1/searchspaces").mock(
side_effect=[
httpx.Response(401, json={"detail": "expired"}),
httpx.Response(200, json=[]),
]
)
async with client_with_auth(config, bundle) as client:
response = await client.get("http://test/api/v1/searchspaces")
assert response.status_code == 200
assert bundle.access_token == "new"
assert bundle.refresh_token == "ref2"

View file

@ -0,0 +1,262 @@
"""respx-mocked tests for the SurfSense HTTP clients."""
from __future__ import annotations
import json
from pathlib import Path
import httpx
import pytest
import respx
from surfsense_evals.core.clients import (
DocumentsClient,
NewChatClient,
SearchSpaceClient,
)
from surfsense_evals.core.clients.new_chat import ThreadBusyError
_BASE = "http://test"
@pytest.fixture
def http() -> httpx.AsyncClient:
return httpx.AsyncClient(base_url=_BASE)
# ---------------------------------------------------------------------------
# SearchSpaceClient
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_create_search_space_returns_row(respx_mock, http):
respx_mock.post("/api/v1/searchspaces").mock(
return_value=httpx.Response(
200,
json={
"id": 99,
"name": "eval-medical-2026",
"description": None,
"user_id": "user-x",
"citations_enabled": True,
"qna_custom_instructions": None,
},
)
)
client = SearchSpaceClient(http, _BASE)
row = await client.create("eval-medical-2026")
assert row.id == 99
assert row.name == "eval-medical-2026"
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_delete_search_space_idempotent_on_404(respx_mock, http):
respx_mock.delete("/api/v1/searchspaces/42").mock(
return_value=httpx.Response(404, json={"detail": "gone"})
)
client = SearchSpaceClient(http, _BASE)
await client.delete(42) # must not raise
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_set_llm_preferences_partial_update(respx_mock, http):
route = respx_mock.put("/api/v1/search-spaces/42/llm-preferences").mock(
return_value=httpx.Response(
200,
json={
"agent_llm_id": -10042,
"document_summary_llm_id": None,
"image_generation_config_id": None,
"vision_llm_config_id": None,
"agent_llm": {
"id": -10042,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-sonnet-4.5",
},
},
)
)
client = SearchSpaceClient(http, _BASE)
prefs = await client.set_llm_preferences(42, agent_llm_id=-10042)
assert prefs.agent_llm_id == -10042
assert prefs.agent_llm["provider"] == "OPENROUTER"
sent_body = json.loads(route.calls[-1].request.content)
assert sent_body == {"agent_llm_id": -10042}
# ---------------------------------------------------------------------------
# DocumentsClient
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_documents_status_parses_state(respx_mock, http):
respx_mock.get("/api/v1/documents/status").mock(
return_value=httpx.Response(
200,
json={
"items": [
{"id": 1, "title": "a.pdf", "document_type": "FILE",
"status": {"state": "ready", "reason": None}},
{"id": 2, "title": "b.pdf", "document_type": "FILE",
"status": {"state": "failed", "reason": "ETL boom"}},
]
},
)
)
client = DocumentsClient(http, _BASE)
statuses = await client.get_status(search_space_id=1, document_ids=[1, 2])
assert {s.document_id for s in statuses} == {1, 2}
assert {s.is_ready for s in statuses} == {True, False}
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_documents_upload_returns_payload(respx_mock, http, tmp_path: Path):
f1 = tmp_path / "a.pdf"
f1.write_bytes(b"%PDF-1.4 small")
respx_mock.post("/api/v1/documents/fileupload").mock(
return_value=httpx.Response(
200,
json={
"message": "Files uploaded",
"document_ids": [101],
"duplicate_document_ids": [],
"total_files": 1,
"pending_files": 1,
"skipped_duplicates": 0,
},
)
)
client = DocumentsClient(http, _BASE)
result = await client.upload(files=[f1], search_space_id=7)
assert result.document_ids == [101]
assert result.pending_files == 1
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_documents_list_chunks_paginated(respx_mock, http):
respx_mock.get("/api/v1/documents/5/chunks").mock(
side_effect=[
httpx.Response(200, json={
"items": [{"id": 1, "content": "a"}, {"id": 2, "content": "b"}],
"total": 3, "page": 0, "page_size": 2, "has_more": True,
}),
httpx.Response(200, json={
"items": [{"id": 3, "content": "c"}],
"total": 3, "page": 1, "page_size": 2, "has_more": False,
}),
]
)
client = DocumentsClient(http, _BASE)
rows = await client.list_chunks(5, page_size=2)
assert [r.id for r in rows] == [1, 2, 3]
# ---------------------------------------------------------------------------
# NewChatClient
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_create_thread_returns_id(respx_mock, http):
respx_mock.post("/api/v1/threads").mock(
return_value=httpx.Response(
200,
json={
"id": 555,
"title": "eval",
"archived": False,
"visibility": "PRIVATE",
"search_space_id": 1,
"messages": [],
"created_at": "2026-05-11T00:00:00Z",
"updated_at": "2026-05-11T00:00:00Z",
},
)
)
client = NewChatClient(http, _BASE)
tid = await client.create_thread(search_space_id=1)
assert tid == 555
def _sse_body(events: list[dict]) -> bytes:
parts = []
for ev in events:
parts.append(f"data: {json.dumps(ev)}\n\n")
parts.append("data: [DONE]\n\n")
return "".join(parts).encode("utf-8")
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_ask_accumulates_text_deltas(respx_mock, http):
body = _sse_body([
{"type": "start", "messageId": "m1"},
{"type": "text-start", "id": "t1"},
{"type": "text-delta", "id": "t1", "delta": "Answer "},
{"type": "text-delta", "id": "t1", "delta": "is "},
{"type": "text-delta", "id": "t1", "delta": "B [citation:42]."},
{"type": "text-end", "id": "t1"},
{"type": "finish"},
])
respx_mock.post("/api/v1/new_chat").mock(
return_value=httpx.Response(
200,
content=body,
headers={"Content-Type": "text/event-stream"},
)
)
client = NewChatClient(http, _BASE)
answer = await client.ask(
thread_id=1, search_space_id=2, user_query="What is the answer?"
)
assert answer.text == "Answer is B [citation:42]."
assert answer.finished_normally is True
assert any(c["chunk_id"] == 42 for c in answer.citations)
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_ask_409_thread_busy_retries(respx_mock, http):
body = _sse_body([
{"type": "text-delta", "id": "t1", "delta": "ok"},
{"type": "finish"},
])
busy = httpx.Response(
409,
json={"detail": {"errorCode": "THREAD_BUSY", "message": "busy"}},
headers={"Retry-After": "1"},
)
success = httpx.Response(
200, content=body, headers={"Content-Type": "text/event-stream"}
)
respx_mock.post("/api/v1/new_chat").mock(side_effect=[busy, success])
client = NewChatClient(http, _BASE)
answer = await client.ask(
thread_id=1, search_space_id=2, user_query="hi", max_busy_retries=2
)
assert answer.text == "ok"
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_ask_409_exhausts_retries(respx_mock, http):
busy = httpx.Response(
409,
json={"detail": {"errorCode": "TURN_CANCELLING", "message": "wait"}},
headers={"Retry-After": "1"},
)
respx_mock.post("/api/v1/new_chat").mock(return_value=busy)
client = NewChatClient(http, _BASE)
with pytest.raises(ThreadBusyError):
await client.ask(
thread_id=1, search_space_id=2, user_query="hi", max_busy_retries=1
)

View file

@ -0,0 +1,160 @@
"""Tests for env loading + state.json read/write."""
from __future__ import annotations
import json
from surfsense_evals.core.config import (
DEFAULT_SCENARIO,
SCENARIOS,
SuiteState,
clear_suite_state,
get_suite_state,
load_config,
set_suite_state,
)
def test_load_config_defaults_to_localhost(tmp_env): # noqa: ARG001
config = load_config()
assert config.surfsense_api_base == "http://localhost:8000"
assert config.has_jwt_mode() is False
assert config.has_local_mode() is False
assert config.credential_mode() == "none"
def test_load_config_picks_up_jwt_env(tmp_env, monkeypatch): # noqa: ARG001
monkeypatch.setenv("SURFSENSE_JWT", "tok")
config = load_config()
assert config.credential_mode() == "jwt"
def test_load_config_picks_up_local_env(tmp_env, monkeypatch): # noqa: ARG001
monkeypatch.setenv("SURFSENSE_USER_EMAIL", "u@x.com")
monkeypatch.setenv("SURFSENSE_USER_PASSWORD", "pw")
config = load_config()
assert config.credential_mode() == "local"
def test_state_roundtrip_per_suite(tmp_env): # noqa: ARG001
config = load_config()
assert get_suite_state(config, "medical") is None
state = SuiteState(
search_space_id=1,
agent_llm_id=-10042,
provider_model="anthropic/claude-sonnet-4.5",
created_at="2026-05-11T20-30-00Z",
)
set_suite_state(config, "medical", state)
legal = SuiteState(
search_space_id=2,
agent_llm_id=-1,
provider_model="openai/gpt-5",
created_at="2026-05-11T21-00-00Z",
)
set_suite_state(config, "legal", legal)
fetched = get_suite_state(config, "medical")
assert fetched.search_space_id == 1
assert fetched.provider_model == "anthropic/claude-sonnet-4.5"
# Other suite untouched after teardown.
cleared = clear_suite_state(config, "medical")
assert cleared is True
assert get_suite_state(config, "medical") is None
assert get_suite_state(config, "legal").search_space_id == 2
raw = json.loads(config.state_path.read_text(encoding="utf-8"))
assert "medical" not in raw["suites"]
assert "legal" in raw["suites"]
def test_paths_are_per_suite(tmp_env): # noqa: ARG001
config = load_config()
a = config.suite_data_dir("medical")
b = config.suite_data_dir("legal")
assert a != b
assert config.suite_reports_dir("medical").parent == config.reports_dir
assert config.suite_runs_dir("medical").name == "runs"
assert config.suite_maps_dir("medical").name == "maps"
# ---------------------------------------------------------------------------
# Scenario state — back-compat + new fields
# ---------------------------------------------------------------------------
def test_legacy_state_back_compat_defaults_to_head_to_head():
"""state.json files written before scenarios shipped must still load.
Missing ``scenario`` / ``vision_*`` / ``native_arm_model`` keys all
default to ``head-to-head`` / ``None`` so old setups keep working
after upgrade the runner's behaviour exactly mirrors the legacy
one (both arms answer with ``provider_model``).
"""
legacy = {
"search_space_id": 7,
"agent_llm_id": -123,
"provider_model": "anthropic/claude-sonnet-4.5",
"created_at": "2026-05-11T20-30-00Z",
"ingestion_maps": {},
}
state = SuiteState.from_dict(legacy)
assert state.scenario == DEFAULT_SCENARIO == "head-to-head"
assert state.vision_llm_config_id is None
assert state.vision_provider_model is None
assert state.native_arm_model is None
# The native arm should still answer with the same slug as SurfSense.
assert state.effective_native_arm_model == state.provider_model
def test_unknown_scenario_falls_back_to_default():
"""Garbage scenario in state.json → default, not crash.
Defensive: we'd rather a stale state file render with the safe
head-to-head behaviour than break the whole run with a KeyError.
"""
payload = {
"search_space_id": 1,
"agent_llm_id": -1,
"provider_model": "openai/gpt-5",
"scenario": "unknown-scenario-name",
}
state = SuiteState.from_dict(payload)
assert state.scenario == DEFAULT_SCENARIO
def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG001
config = load_config()
state = SuiteState(
search_space_id=42,
agent_llm_id=-1,
provider_model="openai/gpt-5.4-mini",
created_at="2026-05-11T20-30-00Z",
scenario="cost-arbitrage",
vision_llm_config_id=-101,
vision_provider_model="anthropic/claude-sonnet-4.5",
native_arm_model="anthropic/claude-sonnet-4.5",
)
set_suite_state(config, "medical", state)
fetched = get_suite_state(config, "medical")
assert fetched.scenario == "cost-arbitrage"
assert fetched.vision_llm_config_id == -101
assert fetched.vision_provider_model == "anthropic/claude-sonnet-4.5"
assert fetched.native_arm_model == "anthropic/claude-sonnet-4.5"
# Cost arbitrage's whole point: native arm slug != surfsense slug.
assert fetched.effective_native_arm_model != fetched.provider_model
assert fetched.effective_native_arm_model == "anthropic/claude-sonnet-4.5"
raw = json.loads(config.state_path.read_text(encoding="utf-8"))
assert raw["suites"]["medical"]["scenario"] == "cost-arbitrage"
def test_scenario_constants_are_stable():
"""Pin the public scenario list; runners + tests key off these strings."""
assert SCENARIOS == ("head-to-head", "symmetric-cheap", "cost-arbitrage")
assert DEFAULT_SCENARIO == "head-to-head"

View file

@ -0,0 +1,269 @@
"""Unit tests for ``surfsense_evals.core.ingest_settings``.
Covers:
* ``IngestSettings.merge`` honours operator overrides and falls back
to per-benchmark defaults when the operator is silent.
* ``add_ingest_settings_args`` exposes the three flag pairs and
argparse defaults of ``None`` correctly distinguish "not passed"
from "explicitly false".
* ``settings_header_line`` / ``read_settings_header`` round-trip
through a JSONL file.
* ``read_settings_header`` is fault-tolerant: missing files, missing
header, malformed JSON.
* ``format_ingest_settings_md`` produces a stable Markdown bullet.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import pytest
from surfsense_evals.core.ingest_settings import (
PROCESSING_MODES,
SETTINGS_HEADER_KEY,
IngestSettings,
add_ingest_settings_args,
format_ingest_settings_md,
is_settings_header,
read_settings_header,
settings_header_line,
)
# ---------------------------------------------------------------------------
# IngestSettings.merge
# ---------------------------------------------------------------------------
class TestMerge:
def test_silent_operator_uses_defaults(self) -> None:
defaults = IngestSettings(use_vision_llm=True, processing_mode="basic", should_summarize=True)
merged = IngestSettings.merge(defaults, {})
assert merged == defaults
def test_explicit_false_overrides_default_true(self) -> None:
defaults = IngestSettings(use_vision_llm=True)
merged = IngestSettings.merge(
defaults, {"use_vision_llm": False}
)
assert merged.use_vision_llm is False
def test_explicit_true_overrides_default_false(self) -> None:
defaults = IngestSettings(use_vision_llm=False)
merged = IngestSettings.merge(
defaults, {"use_vision_llm": True}
)
assert merged.use_vision_llm is True
def test_none_means_silent(self) -> None:
# Argparse with BooleanOptionalAction yields None when the
# operator passed neither --use-vision-llm nor --no-vision-llm.
defaults = IngestSettings(use_vision_llm=True)
merged = IngestSettings.merge(
defaults, {"use_vision_llm": None}
)
assert merged.use_vision_llm is True
def test_processing_mode_override(self) -> None:
defaults = IngestSettings(processing_mode="basic")
merged = IngestSettings.merge(
defaults, {"processing_mode": "premium"}
)
assert merged.processing_mode == "premium"
def test_processing_mode_invalid_raises(self) -> None:
defaults = IngestSettings(processing_mode="basic")
with pytest.raises(ValueError, match="Invalid processing_mode"):
IngestSettings.merge(defaults, {"processing_mode": "exotic"})
def test_processing_mode_blank_falls_back(self) -> None:
defaults = IngestSettings(processing_mode="basic")
merged = IngestSettings.merge(defaults, {"processing_mode": ""})
assert merged.processing_mode == "basic"
def test_string_truthy_coerced(self) -> None:
defaults = IngestSettings(use_vision_llm=False)
merged = IngestSettings.merge(defaults, {"use_vision_llm": "yes"})
assert merged.use_vision_llm is True
def test_string_falsy_coerced(self) -> None:
defaults = IngestSettings(use_vision_llm=True)
merged = IngestSettings.merge(defaults, {"use_vision_llm": "false"})
assert merged.use_vision_llm is False
def test_other_keys_ignored(self) -> None:
# Benchmarks pass the whole opts dict; merge must tolerate
# unrelated keys without crashing.
defaults = IngestSettings(use_vision_llm=True, processing_mode="basic")
merged = IngestSettings.merge(
defaults,
{
"use_vision_llm": False,
"concurrency": 4,
"task_filter": "all",
"no_mentions": True,
},
)
assert merged.use_vision_llm is False
assert merged.processing_mode == "basic"
def test_to_dict_round_trips(self) -> None:
s = IngestSettings(use_vision_llm=True, processing_mode="premium", should_summarize=False)
d = s.to_dict()
assert d == {
"use_vision_llm": True,
"processing_mode": "premium",
"should_summarize": False,
}
def test_render_label_format(self) -> None:
s = IngestSettings(use_vision_llm=True, processing_mode="premium", should_summarize=True)
assert s.render_label() == "vision=on, mode=premium, summarize=on"
# ---------------------------------------------------------------------------
# add_ingest_settings_args
# ---------------------------------------------------------------------------
class TestAddArgs:
@pytest.fixture
def parser(self) -> argparse.ArgumentParser:
p = argparse.ArgumentParser()
add_ingest_settings_args(
p,
defaults=IngestSettings(
use_vision_llm=False, processing_mode="basic", should_summarize=False
),
)
return p
def test_silent_invocation_yields_none(self, parser: argparse.ArgumentParser) -> None:
args = parser.parse_args([])
assert args.use_vision_llm is None
assert args.processing_mode is None
assert args.should_summarize is None
def test_use_vision_llm_flag(self, parser: argparse.ArgumentParser) -> None:
args = parser.parse_args(["--use-vision-llm"])
assert args.use_vision_llm is True
def test_no_vision_llm_flag(self, parser: argparse.ArgumentParser) -> None:
args = parser.parse_args(["--no-vision-llm"])
assert args.use_vision_llm is False
def test_processing_mode_choices(self, parser: argparse.ArgumentParser) -> None:
for mode in PROCESSING_MODES:
args = parser.parse_args(["--processing-mode", mode])
assert args.processing_mode == mode
def test_processing_mode_rejects_unknown(
self, parser: argparse.ArgumentParser
) -> None:
with pytest.raises(SystemExit):
parser.parse_args(["--processing-mode", "exotic"])
def test_summarize_flag_pair(self, parser: argparse.ArgumentParser) -> None:
on = parser.parse_args(["--should-summarize"])
assert on.should_summarize is True
off = parser.parse_args(["--no-summarize"])
assert off.should_summarize is False
def test_vision_flags_mutually_exclusive(
self, parser: argparse.ArgumentParser
) -> None:
with pytest.raises(SystemExit):
parser.parse_args(["--use-vision-llm", "--no-vision-llm"])
def test_full_pipeline(self, parser: argparse.ArgumentParser) -> None:
# Operator passes flags + defaults are reasonable. Merge
# should yield exactly what they asked for.
args = parser.parse_args(
["--use-vision-llm", "--processing-mode", "premium"]
)
defaults = IngestSettings(
use_vision_llm=False, processing_mode="basic", should_summarize=False
)
merged = IngestSettings.merge(defaults, vars(args))
assert merged == IngestSettings(
use_vision_llm=True, processing_mode="premium", should_summarize=False
)
# ---------------------------------------------------------------------------
# Header round-trip + read_settings_header fault tolerance
# ---------------------------------------------------------------------------
class TestHeader:
def test_header_line_round_trip(self, tmp_path: Path) -> None:
s = IngestSettings(use_vision_llm=True, processing_mode="premium")
path = tmp_path / "map.jsonl"
with path.open("w", encoding="utf-8") as fh:
fh.write(settings_header_line(s) + "\n")
fh.write(json.dumps({"case_id": "x", "document_id": 1}) + "\n")
loaded = read_settings_header(path)
assert loaded == s.to_dict()
def test_is_settings_header_recognises(self) -> None:
assert is_settings_header({SETTINGS_HEADER_KEY: {}})
assert not is_settings_header({"case_id": "x"})
def test_missing_file_returns_empty(self, tmp_path: Path) -> None:
assert read_settings_header(tmp_path / "does_not_exist.jsonl") == {}
def test_empty_file_returns_empty(self, tmp_path: Path) -> None:
path = tmp_path / "empty.jsonl"
path.write_text("", encoding="utf-8")
assert read_settings_header(path) == {}
def test_no_header_returns_empty(self, tmp_path: Path) -> None:
path = tmp_path / "legacy.jsonl"
with path.open("w", encoding="utf-8") as fh:
fh.write(json.dumps({"case_id": "x", "document_id": 1}) + "\n")
fh.write(json.dumps({"case_id": "y", "document_id": 2}) + "\n")
assert read_settings_header(path) == {}
def test_malformed_json_returns_empty(self, tmp_path: Path) -> None:
path = tmp_path / "broken.jsonl"
path.write_text("not json\n", encoding="utf-8")
assert read_settings_header(path) == {}
def test_skips_blank_first_lines(self, tmp_path: Path) -> None:
s = IngestSettings(use_vision_llm=True)
path = tmp_path / "padded.jsonl"
with path.open("w", encoding="utf-8") as fh:
fh.write("\n\n")
fh.write(settings_header_line(s) + "\n")
assert read_settings_header(path) == s.to_dict()
# ---------------------------------------------------------------------------
# format_ingest_settings_md
# ---------------------------------------------------------------------------
class TestFormatMd:
def test_full_settings(self) -> None:
out = format_ingest_settings_md(
{"use_vision_llm": True, "processing_mode": "premium", "should_summarize": True}
)
assert "vision_llm=`on`" in out
assert "processing_mode=`premium`" in out
assert "summarize=`on`" in out
def test_default_off(self) -> None:
out = format_ingest_settings_md(
{"use_vision_llm": False, "processing_mode": "basic", "should_summarize": False}
)
assert "vision_llm=`off`" in out
assert "processing_mode=`basic`" in out
assert "summarize=`off`" in out
def test_missing_returns_re_ingest_hint(self) -> None:
# Empty dict + None + non-mapping should all degrade gracefully.
for raw in [None, {}, "not-a-mapping"]:
assert "(not recorded" in format_ingest_settings_md(raw)

View file

@ -0,0 +1,153 @@
"""Metric correctness — Wilson, McNemar, retrieval scores."""
from __future__ import annotations
import math
import pytest
from surfsense_evals.core.metrics import (
accuracy_with_wilson_ci,
bootstrap_delta_ci,
mcnemar_test,
mrr,
ndcg_at_k,
recall_at_k,
score_run,
wilson_ci,
)
# ---------------------------------------------------------------------------
# Wilson
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"k,n,low,high",
[
(80, 100, 0.7111, 0.8666), # cross-checked vs statsmodels.proportion_confint(method='wilson')
(50, 100, 0.4038, 0.5962),
(0, 0, 0.0, 1.0),
(0, 10, 0.0, 0.2775),
(10, 10, 0.7225, 1.0),
],
)
def test_wilson_ci_known_values(k, n, low, high):
result_low, result_high = wilson_ci(k, n)
assert math.isclose(result_low, low, abs_tol=5e-4), (k, n, result_low, low)
assert math.isclose(result_high, high, abs_tol=5e-4), (k, n, result_high, high)
def test_accuracy_with_wilson_ci_object():
res = accuracy_with_wilson_ci(70, 100)
assert res.accuracy == 0.7
assert 0.0 < res.ci_low < res.ci_high < 1.0
def test_invalid_inputs_raise():
with pytest.raises(ValueError):
accuracy_with_wilson_ci(-1, 10)
with pytest.raises(ValueError):
accuracy_with_wilson_ci(11, 10)
# ---------------------------------------------------------------------------
# McNemar
# ---------------------------------------------------------------------------
def test_mcnemar_degenerate_returns_p_value_one():
a = [True, True, False, False]
b = [True, True, False, False]
res = mcnemar_test(a, b)
assert res.b == 0 and res.c == 0
assert res.p_value == 1.0
assert res.method == "degenerate"
def test_mcnemar_exact_branch_strong_signal():
"""B = 0, C = 10 → exact two-sided binomial p == 2 * (1/2)**10."""
a = [True] * 10 + [False] * 10
b = [True] * 10 + [True] * 10 # surfsense beats native on the 10 native-wrong
res = mcnemar_test(a, b)
assert res.b == 0
assert res.c == 10
assert res.method == "exact"
expected = 2 * (0.5 ** 10)
assert math.isclose(res.p_value, expected, rel_tol=1e-9)
def test_mcnemar_chi_square_approx_for_large_discordant():
# Construct b=15, c=5 with continuity-corrected chi^2 = (|10|-1)^2/20 = 4.05.
a = [True] * 15 + [False] * 5 + [True] * 30 + [False] * 30
b = [False] * 15 + [True] * 5 + [True] * 30 + [False] * 30
res = mcnemar_test(a, b)
assert res.method == "chi2_cc"
assert res.b == 15 and res.c == 5
assert math.isclose(res.statistic, ((abs(15 - 5) - 1) ** 2) / 20.0, rel_tol=1e-9)
# p ≈ chi2.sf(4.05, df=1) ≈ 0.04417
assert 0.04 < res.p_value < 0.05
def test_mcnemar_length_mismatch():
with pytest.raises(ValueError):
mcnemar_test([True], [True, False])
# ---------------------------------------------------------------------------
# Bootstrap
# ---------------------------------------------------------------------------
def test_bootstrap_delta_ci_shape_and_determinism():
a = [True, True, False, True, False, False, True, True]
b = [True, True, True, True, True, False, True, False]
res1 = bootstrap_delta_ci(a, b, n_resamples=500, random_state=42)
res2 = bootstrap_delta_ci(a, b, n_resamples=500, random_state=42)
assert res1.delta == res2.delta
assert res1.ci_low == res2.ci_low
assert res1.ci_high == res2.ci_high
assert res1.ci_low <= res1.delta <= res1.ci_high
assert res1.n_resamples == 500
# ---------------------------------------------------------------------------
# Retrieval
# ---------------------------------------------------------------------------
def test_recall_at_k():
retrieved = ["a", "b", "c", "d"]
relevant = ["b", "d", "z"]
assert recall_at_k(retrieved, relevant, k=2) == pytest.approx(1 / 3)
assert recall_at_k(retrieved, relevant, k=4) == pytest.approx(2 / 3)
def test_mrr():
assert mrr(["a", "b", "c"], ["c"]) == pytest.approx(1 / 3)
assert mrr(["x", "y"], ["z"]) == 0.0
def test_ndcg_at_k_perfect_order():
qrels = {"a": 2, "b": 1}
assert ndcg_at_k(["a", "b"], qrels, k=2) == pytest.approx(1.0)
def test_ndcg_at_k_irrelevant_first():
qrels = {"a": 2, "b": 1}
# Wrong order should still be > 0 but < 1
val = ndcg_at_k(["c", "a", "b"], qrels, k=3)
assert 0 < val < 1
def test_score_run_aggregates_across_queries():
scores = score_run(
per_query_retrieved={"q1": ["a", "b"], "q2": ["x", "y", "z"]},
per_query_qrels={"q1": {"a": 1}, "q2": {"z": 2}},
ks=(1, 5),
ndcg_k=5,
)
assert scores.n_queries == 2
assert scores.recall_at_k[1] == pytest.approx((1 + 0) / 2) # q1 hits @1, q2 doesn't
assert scores.mrr == pytest.approx((1.0 + 1 / 3) / 2)

View file

@ -0,0 +1,27 @@
"""Tests for the MCQ answer-letter extractor."""
from __future__ import annotations
import pytest
from surfsense_evals.core.parse import extract_answer_letter
from surfsense_evals.core.parse.answer_letter import AnswerLetterResult
@pytest.mark.parametrize(
"text,expected_letter,expected_strategy",
[
('```json\n{"step_by_step_thinking": "...", "answer_choice": "B"}\n```', "B", "json_envelope"),
('Reasoning... {"step_by_step_thinking": "x", "answer_choice": "C"}', "C", "json_envelope"),
("Long reasoning.\nAnswer: D", "D", "answer_line"),
("The correct answer is (A).", "A", "answer_line"),
("Final answer: e", "E", "answer_line"),
("Long reasoning.\n\nB", "B", "bare_letter"),
("Long reasoning.\n\n(C).", "C", "bare_letter"),
("", None, "none"),
("Just narrative without an answer.", None, "none"),
],
)
def test_extract_answer_letter(text, expected_letter, expected_strategy):
result = extract_answer_letter(text)
assert result == AnswerLetterResult(expected_letter, expected_strategy)

View file

@ -0,0 +1,108 @@
"""Parity tests for the citation regex.
Each row mirrors a case from the canonical TS reference at
``surfsense_web/lib/citations/citation-parser.ts``. If a future PR
loosens or tightens the TS regex, these tests will start failing;
that's the explicit signal to re-port the change.
"""
from __future__ import annotations
import pytest
from surfsense_evals.core.parse import (
CITATION_REGEX,
ChunkCitation,
UrlCitation,
parse_citations,
)
PARITY_TABLE = [
# (input, expected number of matches, expected first-token kind/value)
("Plain text with no citation.", 0, None),
(
"The patient has fever [citation:42] and cough.",
1,
ChunkCitation(chunk_id=42, is_docs_chunk=False),
),
(
"Negative chunk ids work [citation:-7].",
1,
ChunkCitation(chunk_id=-7, is_docs_chunk=False),
),
(
"doc-prefix [citation:doc-12].",
1,
ChunkCitation(chunk_id=12, is_docs_chunk=True),
),
(
"Multi id [citation:1, doc-2, -3].",
3,
ChunkCitation(chunk_id=1, is_docs_chunk=False),
),
(
"URL form [citation:https://x.com/a].",
1,
UrlCitation(url="https://x.com/a"),
),
(
"Chinese brackets【citation:5】.",
1,
ChunkCitation(chunk_id=5, is_docs_chunk=False),
),
(
"ZWSP-decorated [\u200bcitation:9\u200b].",
1,
ChunkCitation(chunk_id=9, is_docs_chunk=False),
),
(
"Whitespace [citation: doc-100 ] tolerated.",
1,
ChunkCitation(chunk_id=100, is_docs_chunk=True),
),
(
# The TS regex's URL char class excludes ']', so a trailing
# bracket isn't swallowed.
"Two URLs [citation:https://a.io] and [citation:https://b.io].",
2,
UrlCitation(url="https://a.io"),
),
(
# Garbled form should match nothing.
"Citation-like but wrong [citation:].",
0,
None,
),
]
@pytest.mark.parametrize("text,n_expected,first", PARITY_TABLE)
def test_citation_regex_parity(text: str, n_expected: int, first):
tokens = parse_citations(text)
assert len(tokens) == n_expected, (text, tokens)
if first is not None:
assert tokens[0] == first, (text, tokens)
def test_regex_pattern_matches_ts_source():
"""Sanity: the compiled pattern carries the exact alternatives the TS source does."""
pattern = CITATION_REGEX.pattern
assert "https?://" in pattern
assert "urlcite" in pattern
assert "doc-" in pattern
assert "\u200B" in pattern
assert "" in pattern and "" in pattern
def test_url_map_resolution():
text = "Inline placeholder [citation:urlcite0]."
tokens = parse_citations(text, url_map={"urlcite0": "https://resolved.example/x"})
assert tokens == [UrlCitation(url="https://resolved.example/x")]
def test_url_map_missing_key_drops_token():
"""Missing urlcite resolution returns no token (TS behaviour)."""
text = "[citation:urlcite99]"
assert parse_citations(text, url_map={}) == []

View file

@ -0,0 +1,73 @@
"""Tests for ``surfsense_evals.core.parse.freeform_answer``."""
from __future__ import annotations
import pytest
from surfsense_evals.core.parse.freeform_answer import extract_freeform_answer
class TestExtractFreeformAnswer:
def test_empty_string_returns_empty(self) -> None:
assert extract_freeform_answer("") == ""
assert extract_freeform_answer(" \n\n ") == ""
def test_simple_answer_marker(self) -> None:
assert extract_freeform_answer("Answer: 42") == "42"
def test_final_answer_marker(self) -> None:
assert extract_freeform_answer("Final answer: Paris") == "Paris"
def test_the_answer_is_marker(self) -> None:
assert extract_freeform_answer("The answer is: not answerable") == "not answerable"
def test_multiline_picks_last_answer_marker(self) -> None:
text = "Let me think...\nAnswer: 5\nAnswer: 7\n"
assert extract_freeform_answer(text) == "7"
def test_falls_back_to_last_nonempty_line(self) -> None:
text = "Some thinking here.\n\n42"
assert extract_freeform_answer(text) == "42"
def test_strips_quotes(self) -> None:
assert extract_freeform_answer('Answer: "Paris"') == "Paris"
assert extract_freeform_answer("Answer: 'Paris'") == "Paris"
def test_strips_backticks(self) -> None:
assert extract_freeform_answer("Answer: `42`") == "42"
def test_uses_fenced_block_when_no_marker(self) -> None:
text = "Here's my response:\n```\nfinal value\n```\n"
assert extract_freeform_answer(text) == "final value"
def test_case_insensitive_markers(self) -> None:
assert extract_freeform_answer("ANSWER: yes") == "yes"
assert extract_freeform_answer("answer: no") == "no"
@pytest.mark.parametrize("text,expected", [
("Answer: 1, 2, 3", "1, 2, 3"),
("Answer: 3.14", "3.14"),
("Answer: spaced ", "spaced"),
])
def test_various_payloads(self, text: str, expected: str) -> None:
assert extract_freeform_answer(text) == expected
def test_inline_answer_after_thinking_trace(self) -> None:
# Agent replies sometimes glue their thinking onto the same
# line as the final "Answer: ..." marker (no newline before it).
# The line-anchored regex misses this; the inline fallback
# should still extract the right value.
text = (
"Need the Charlotte Bronte book title/year and the rank "
"for a 128-foot NYC building.Answer: 128th"
)
assert extract_freeform_answer(text) == "128th"
def test_inline_picks_last_inline_answer(self) -> None:
text = "Thought: maybe Answer: 5 is right? Actually Answer: 7."
assert extract_freeform_answer(text) == "7."
def test_inline_does_not_override_proper_marker(self) -> None:
# When a clean line-anchored "Answer: ..." exists, that wins.
text = "Some preamble.Answer: 99\nAnswer: 42"
assert extract_freeform_answer(text) == "42"

View file

@ -0,0 +1,84 @@
"""Tests for the SSE consumer."""
from __future__ import annotations
import pytest
from surfsense_evals.core.parse import iter_sse_events
async def _alist(it):
out = []
async for x in it:
out.append(x)
return out
async def _astream(lines):
for line in lines:
yield line
@pytest.mark.asyncio
async def test_basic_data_frame():
events = await _alist(
iter_sse_events(_astream([
'data: {"type": "text-delta", "delta": "hi"}',
"",
'data: {"type": "finish"}',
"",
]))
)
assert [e.data for e in events] == [
'{"type": "text-delta", "delta": "hi"}',
'{"type": "finish"}',
]
@pytest.mark.asyncio
async def test_done_sentinel_passes_through():
events = await _alist(
iter_sse_events(_astream([
"data: [DONE]",
"",
]))
)
assert [e.data for e in events] == ["[DONE]"]
@pytest.mark.asyncio
async def test_multiline_data_joins_with_newline():
events = await _alist(
iter_sse_events(_astream([
"data: line1",
"data: line2",
"",
]))
)
assert events[0].data == "line1\nline2"
@pytest.mark.asyncio
async def test_comments_and_other_fields_ignored():
events = await _alist(
iter_sse_events(_astream([
": heartbeat",
"event: foo",
"id: 123",
"data: payload",
"",
]))
)
assert [e.data for e in events] == ["payload"]
@pytest.mark.asyncio
async def test_handles_missing_trailing_blank():
"""Some servers omit the final blank line; the consumer should still emit."""
events = await _alist(
iter_sse_events(_astream([
"data: only-one",
]))
)
assert [e.data for e in events] == ["only-one"]

View file

@ -0,0 +1,51 @@
"""Smoke tests for PDF rendering.
We don't pull a full PDF parser into the test deps; the assertions
are bytes-level (``%PDF`` magic, deterministic CreationDate scrub).
"""
from __future__ import annotations
from pathlib import Path
from surfsense_evals.core.pdf import render_pdf, render_text_files_to_pdf
def test_render_pdf_writes_pdf_with_magic(tmp_path: Path):
out = tmp_path / "out.pdf"
rendered = render_pdf(
title="Test",
sections=[("intro", "Hello world."), ("body", "Line one.\nLine two.")],
output_path=out,
)
assert rendered.path == out
assert out.exists()
assert out.read_bytes().startswith(b"%PDF-")
def test_render_pdf_deterministic_dates(tmp_path: Path):
out_a = tmp_path / "a.pdf"
out_b = tmp_path / "b.pdf"
sections = [("only", "deterministic body content")]
render_pdf(title="Det", sections=sections, output_path=out_a)
render_pdf(title="Det", sections=sections, output_path=out_b)
# CreationDate / ModDate are scrubbed to a fixed value, so the two
# files should compare equal (modulo any other internal randomness
# — reportlab's basic outputs are deterministic given fixed inputs).
assert out_a.read_bytes() == out_b.read_bytes()
def test_render_text_files_uses_filename_as_section(tmp_path: Path):
files_dir = tmp_path / "src"
files_dir.mkdir()
(files_dir / "admission_note.txt").write_text("history of present illness", encoding="utf-8")
(files_dir / "labs.txt").write_text("Na 138, K 4.0", encoding="utf-8")
out = tmp_path / "case.pdf"
rendered = render_text_files_to_pdf(
title="Case 1",
files=[files_dir / "admission_note.txt", files_dir / "labs.txt"],
output_path=out,
)
assert out.exists()
# We don't decode the PDF; the n_chars estimate should reflect both inputs.
assert rendered.n_chars >= len("history of present illness") + len("Na 138, K 4.0")

View file

@ -0,0 +1,73 @@
"""Tests for ``render_pdf_with_images`` — covers image embedding +
deterministic byte output, mirroring ``test_pdf_render.py`` for the
text-only path.
"""
from __future__ import annotations
from pathlib import Path
import pytest
from surfsense_evals.core.pdf import PdfImage, render_pdf_with_images
@pytest.fixture
def tiny_png(tmp_path: Path) -> Path:
"""Generate a real 4x4 PNG via Pillow — embeds cleanly in reportlab.
Hand-crafted PNG headers tend to fail PIL's strict decoder, so we
delegate to Pillow which is already a transitive dep of reportlab.
"""
from PIL import Image as PILImage
p = tmp_path / "pixel.png"
PILImage.new("RGB", (4, 4), color=(128, 128, 128)).save(p, format="PNG")
return p
class TestRenderPdfWithImages:
def test_renders_pdf_with_no_images(self, tmp_path: Path) -> None:
out = tmp_path / "out.pdf"
rendered = render_pdf_with_images(
title="Test",
sections=[("Heading", "Body text here.", None)],
output_path=out,
)
assert rendered.path == out
assert out.exists()
assert out.read_bytes().startswith(b"%PDF-")
def test_renders_pdf_with_one_image(self, tmp_path: Path, tiny_png: Path) -> None:
out = tmp_path / "out.pdf"
render_pdf_with_images(
title="Test",
sections=[("Case", "Body text.", [PdfImage(path=tiny_png, caption="A pixel")])],
output_path=out,
)
assert out.exists()
assert out.stat().st_size > 200 # not empty
def test_deterministic_bytes(self, tmp_path: Path, tiny_png: Path) -> None:
out_a = tmp_path / "a.pdf"
out_b = tmp_path / "b.pdf"
sections = [
("Case", "Some text.", [PdfImage(path=tiny_png, caption="cap")]),
("Options", "A) one\nB) two", None),
]
render_pdf_with_images(title="Test", sections=sections, output_path=out_a)
render_pdf_with_images(title="Test", sections=sections, output_path=out_b)
assert out_a.read_bytes() == out_b.read_bytes()
def test_skips_invalid_image_silently(self, tmp_path: Path) -> None:
"""A bad image path should not abort the whole PDF render."""
out = tmp_path / "out.pdf"
render_pdf_with_images(
title="Test",
sections=[("Case", "Text", [PdfImage(path=tmp_path / "nope.jpg", caption="x")])],
output_path=out,
)
assert out.exists()
assert out.read_bytes().startswith(b"%PDF-")

View file

@ -0,0 +1,121 @@
"""respx-mocked tests for the OpenRouter PDF provider."""
from __future__ import annotations
import base64
import json
from pathlib import Path
import httpx
import pytest
import respx
from surfsense_evals.core.providers.openrouter_pdf import (
OpenRouterPdfProvider,
PdfEngine,
)
_BASE = "https://openrouter.test"
@pytest.fixture
def tiny_pdf(tmp_path: Path) -> Path:
p = tmp_path / "case.pdf"
p.write_bytes(b"%PDF-1.4 minimal content")
return p
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_payload_shape_matches_openrouter_docs(respx_mock, tiny_pdf: Path):
captured = {}
def _capture(request):
captured["body"] = json.loads(request.content)
captured["headers"] = dict(request.headers)
return httpx.Response(
200,
json={
"choices": [{
"message": {"content": "Answer: B"},
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15, "cost": 0.0001},
},
)
respx_mock.post("/chat/completions").mock(side_effect=_capture)
provider = OpenRouterPdfProvider(
api_key="sk-or-test",
base_url=_BASE,
model="anthropic/claude-sonnet-4.5",
engine=PdfEngine.NATIVE,
)
response = await provider.complete(prompt="What is the diagnosis?", pdf_path=tiny_pdf)
body = captured["body"]
assert body["model"] == "anthropic/claude-sonnet-4.5"
assert body["plugins"] == [{"id": "file-parser", "pdf": {"engine": "native"}}]
user = body["messages"][-1]
assert user["role"] == "user"
file_part = user["content"][0]
assert file_part["type"] == "file"
assert file_part["file"]["filename"] == tiny_pdf.name
assert file_part["file"]["file_data"].startswith("data:application/pdf;base64,")
assert (
base64.b64decode(file_part["file"]["file_data"].split(",", 1)[1])
== tiny_pdf.read_bytes() # noqa: ASYNC240 — test fixture, sync read is fine
)
assert user["content"][1] == {"type": "text", "text": "What is the diagnosis?"}
assert captured["headers"]["authorization"] == "Bearer sk-or-test"
assert captured["headers"].get("x-title") == "SurfSense-evals"
assert response.text == "Answer: B"
assert response.input_tokens == 10
assert response.output_tokens == 5
assert response.total_tokens == 15
# cost 0.0001 USD == 100 micros
assert response.cost_micros == 100
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_chat_array_content_concatenates(respx_mock, tiny_pdf: Path):
respx_mock.post("/chat/completions").mock(
return_value=httpx.Response(
200,
json={
"choices": [{
"message": {
"content": [
{"type": "text", "text": "Hello "},
{"type": "text", "text": "world"},
{"type": "image_url", "image_url": "ignored"},
]
}
}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
},
)
)
provider = OpenRouterPdfProvider(
api_key="sk-or-test", base_url=_BASE, model="x/y"
)
response = await provider.complete(prompt="hi", pdf_path=tiny_pdf)
assert response.text == "Hello world"
@pytest.mark.asyncio
@respx.mock(base_url=_BASE)
async def test_provider_raises_on_4xx(respx_mock, tiny_pdf: Path):
respx_mock.post("/chat/completions").mock(
return_value=httpx.Response(429, json={"error": {"message": "rate limited"}})
)
provider = OpenRouterPdfProvider(api_key="sk-or-test", base_url=_BASE, model="x/y")
with pytest.raises(httpx.HTTPStatusError):
await provider.complete(prompt="hi", pdf_path=tiny_pdf)
def test_missing_api_key_raises():
with pytest.raises(ValueError):
OpenRouterPdfProvider(api_key="", base_url=_BASE, model="x/y")

View file

@ -0,0 +1,58 @@
"""Registry + auto-discovery tests.
* Auto-discovery skips packages starting with ``_`` (so test fixtures
don't leak into the production catalogue).
* Manually importing a ``_demo`` benchmark fires its ``register(...)``
call and the CLI sees it.
"""
from __future__ import annotations
import importlib
from surfsense_evals.core import registry
def _force_register_demo() -> None:
"""Import (or reload) the demo module so its ``register(...)`` runs.
On a fresh interpreter, ``import_module`` triggers package
initialization. After the first call though, the module is cached
in ``sys.modules`` and a second ``import_module`` is a no-op so
if a previous test already unregistered the entry, we have to
``reload`` to re-execute the module body.
"""
module = importlib.import_module("surfsense_evals.suites._demo.hello")
if ("_demo", "hello") not in registry.snapshot():
importlib.reload(module)
def test_auto_discovery_skips_underscore_prefixed_subpackages():
from surfsense_evals.suites import discover_suites
discovered = discover_suites()
assert all(not part.startswith("_") for full in discovered for part in full.split("."))
# The medical suite's headline benchmark must always discover.
assert any(name.endswith(".medical.medxpertqa") for name in discovered)
def test_demo_benchmark_registers_on_explicit_import():
_force_register_demo()
bench = registry.get("_demo", "hello")
assert bench is not None
assert bench.name == "hello"
assert bench.headline is False
# Cleanup so the test is idempotent under repeated runs.
registry.unregister("_demo", "hello")
def test_register_unregister_roundtrip():
# Make sure no stale entry from a prior test in the session.
if ("_demo", "hello") in registry.snapshot():
registry.unregister("_demo", "hello")
snapshot_before = dict(registry.snapshot())
_force_register_demo()
assert ("_demo", "hello") in registry.snapshot()
registry.unregister("_demo", "hello")
assert dict(registry.snapshot()) == snapshot_before

Some files were not shown because too many files have changed in this diff Show more