mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
chore: evals
This commit is contained in:
parent
2402b730fa
commit
3737118050
122 changed files with 22598 additions and 13 deletions
65
surfsense_evals/.env.example
Normal file
65
surfsense_evals/.env.example
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
# surfsense_evals — environment template.
|
||||
#
|
||||
# Copy this file to `.env` (in the surfsense_evals/ project root or your
|
||||
# CWD) and fill in the values. `python-dotenv` loads it automatically
|
||||
# the first time `core.config` is imported, so every CLI subcommand
|
||||
# (`setup`, `ingest`, `run`, `report`, `teardown`, `models list`, …)
|
||||
# will pick the values up.
|
||||
#
|
||||
# cp .env.example .env
|
||||
# # then edit .env with your values
|
||||
#
|
||||
# `.env` is gitignored — never commit real secrets.
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Backend target — REQUIRED (default works for a local dev backend)
|
||||
# ---------------------------------------------------------------------------
|
||||
SURFSENSE_API_BASE=http://localhost:8000
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. OpenRouter — REQUIRED for any `run` invocation
|
||||
# ---------------------------------------------------------------------------
|
||||
# The `native_pdf` arm calls OpenRouter directly; the `surfsense` arm
|
||||
# routes through SurfSense which uses the same key under the hood.
|
||||
OPENROUTER_API_KEY=sk-or-...
|
||||
|
||||
# Override only if you proxy OpenRouter through a private gateway:
|
||||
# OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
|
||||
|
||||
# Multimodal benchmarks (medxpertqa, mmlongbench) require a vision-capable
|
||||
# slug. Recommended (verify in your catalog with `models list --grep ...`):
|
||||
# anthropic/claude-sonnet-4.5 (default recommendation)
|
||||
# anthropic/claude-opus-4.7 (strongest)
|
||||
# openai/gpt-5 (top-tier vision)
|
||||
# google/gemini-2.5-pro (1M-token context, best for long PDFs)
|
||||
# DO NOT use openai/gpt-5.4-mini for image-bearing benchmarks — it's
|
||||
# text-only on PDF content and the runner emits a warning if pinned.
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Auth — pick EXACTLY ONE of the two modes below
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# --- Mode A: LOCAL (backend started with AUTH_TYPE=LOCAL)
|
||||
# The harness POSTs these to /auth/jwt/login automatically.
|
||||
# SURFSENSE_USER_EMAIL=you@example.com
|
||||
# SURFSENSE_USER_PASSWORD=...
|
||||
|
||||
# --- Mode B: GOOGLE OAuth (or any pre-issued JWT)
|
||||
# Open the SurfSense web UI in your browser, log in via Google, then in
|
||||
# DevTools → Application → Local Storage copy:
|
||||
# surfsense_bearer_token → SURFSENSE_JWT
|
||||
# surfsense_refresh_token → SURFSENSE_REFRESH_TOKEN (optional, enables
|
||||
# auto-refresh on 401)
|
||||
# SURFSENSE_JWT=eyJhbGciOi...
|
||||
# SURFSENSE_REFRESH_TOKEN=eyJhbGciOi...
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Filesystem paths — OPTIONAL (defaults below)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Where datasets, rendered PDFs, ingestion id maps, run outputs, and
|
||||
# state.json live. Default: <surfsense_evals>/data/
|
||||
# EVAL_DATA_DIR=./data
|
||||
|
||||
# Where generated reports (summary.md / summary.json) get written.
|
||||
# Default: <surfsense_evals>/reports/
|
||||
# EVAL_REPORTS_DIR=./reports
|
||||
29
surfsense_evals/.gitignore
vendored
Normal file
29
surfsense_evals/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Python bytecode + caches
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
|
||||
# Editable-install / build artifacts
|
||||
*.egg-info/
|
||||
build/
|
||||
dist/
|
||||
.eggs/
|
||||
|
||||
# Virtual envs (uv venv default + common alternates)
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
|
||||
# Tooling caches
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
.mypy_cache/
|
||||
.coverage
|
||||
.coverage.*
|
||||
htmlcov/
|
||||
|
||||
# Local secrets — keep `.env.example` tracked, never the real `.env`.
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
!.env.example
|
||||
228
surfsense_evals/README.md
Normal file
228
surfsense_evals/README.md
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
# SurfSense Evals
|
||||
|
||||
Domain-agnostic eval harness for SurfSense. Each benchmark is a Python subpackage under `suites/<domain>/<benchmark>/` that self-registers with the CLI; `core/` is the shared infrastructure (HTTP clients, arms, parsers, metrics, report writer, registry). The harness talks to SurfSense over HTTP only — it does **not** import any backend Python module — so it ships in its own venv and never bloats the FastAPI runtime image.
|
||||
|
||||
## Benchmarks
|
||||
|
||||
| Benchmark | Shape | Vision required? | Default ingest |
|
||||
|---------------------------------|--------------------------------------------------|------------------|----------------------------|
|
||||
| `medical/medxpertqa` (headline) | Native PDF vs SurfSense head-to-head, MCQ | yes | `vision=on, mode=basic` |
|
||||
| `medical/mirage` | SurfSense single-arm, MCQ | no | `vision=off, mode=basic` |
|
||||
| `medical/cure` | SurfSense single-arm retrieval (Recall/MRR/nDCG) | no | `vision=off, mode=basic` |
|
||||
| `multimodal_doc/mmlongbench` | Native PDF vs SurfSense head-to-head, open-ended | yes | `vision=on, mode=basic` |
|
||||
|
||||
Future domains (`legal/`, `finance/`, `code/`, `scientific/`) drop into `suites/` without touching `core/` or the CLI.
|
||||
|
||||
## Install + auth
|
||||
|
||||
```bash
|
||||
uv pip install -e ./surfsense_evals
|
||||
cp surfsense_evals/.env.example surfsense_evals/.env
|
||||
# Edit .env: SURFSENSE_API_BASE, OPENROUTER_API_KEY, and ONE of:
|
||||
# LOCAL → SURFSENSE_USER_EMAIL + SURFSENSE_USER_PASSWORD
|
||||
# GOOGLE → SURFSENSE_JWT (+ optional SURFSENSE_REFRESH_TOKEN)
|
||||
# (lift both from browser localStorage after a normal Google login)
|
||||
```
|
||||
|
||||
## Step-by-step: run all four benchmarks
|
||||
|
||||
The medical and multimodal_doc suites each get their own SearchSpace and pinned model, so they're independent — run them in any order. Both head-to-head benchmarks (`medxpertqa`, `mmlongbench`) require a **vision-capable** OpenRouter slug; pinning a text-only one (e.g. `openai/gpt-5.4-mini`) silently drops images and the runner emits a warning.
|
||||
|
||||
Recommended vision slugs (use `models list --grep <name>` to confirm one): `anthropic/claude-sonnet-4.5` (balanced cost), `anthropic/claude-opus-4.7` (strongest reasoning), `openai/gpt-5` (top-tier vision), `google/gemini-2.5-pro` (best for long PDFs, 1M-token context).
|
||||
|
||||
```bash
|
||||
# 0. (optional) discover what's registered
|
||||
python -m surfsense_evals suites list
|
||||
python -m surfsense_evals benchmarks list
|
||||
|
||||
# 1. MEDICAL SUITE — one SearchSpace, three benchmarks
|
||||
python -m surfsense_evals setup --suite medical --provider-model anthropic/claude-sonnet-4.5
|
||||
|
||||
# 1a. headline head-to-head: Native PDF (vision) vs SurfSense (vision RAG)
|
||||
# Downloads dev+test JSONL + images.zip, renders one PDF per question
|
||||
# (case + table + images + 5 options), uploads with use_vision_llm=True.
|
||||
python -m surfsense_evals ingest medical medxpertqa --split test
|
||||
python -m surfsense_evals run medical medxpertqa --concurrency 4
|
||||
|
||||
# 1b. MIRAGE — single-arm SurfSense MCQ accuracy
|
||||
# (MMLU-Med / MedQA-US / MedMCQA / PubMedQA / BioASQ)
|
||||
python -m surfsense_evals ingest medical mirage
|
||||
python -m surfsense_evals run medical mirage
|
||||
|
||||
# 1c. CUREv1 — single-arm SurfSense retrieval (Recall@k / MRR / nDCG@10)
|
||||
python -m surfsense_evals ingest medical cure --lang en
|
||||
python -m surfsense_evals run medical cure --lang en
|
||||
|
||||
# 1d. write reports/medical/<UTC-ts>/summary.{md,json}
|
||||
python -m surfsense_evals report --suite medical
|
||||
|
||||
# 2. MULTIMODAL_DOC SUITE — long PDFs with embedded images, charts, tables
|
||||
python -m surfsense_evals setup --suite multimodal_doc --provider-model google/gemini-2.5-pro
|
||||
python -m surfsense_evals ingest multimodal_doc mmlongbench # ~660MB, resumable
|
||||
python -m surfsense_evals run multimodal_doc mmlongbench --concurrency 4
|
||||
python -m surfsense_evals report --suite multimodal_doc
|
||||
|
||||
# 3. CLEANUP — soft-deletes the SearchSpaces; rendered PDFs stay cached
|
||||
python -m surfsense_evals teardown --suite medical
|
||||
python -m surfsense_evals teardown --suite multimodal_doc
|
||||
```
|
||||
|
||||
## Asymmetric scenarios — the "vision-extract once, answer cheap" play
|
||||
|
||||
The walkthrough above is `--scenario head-to-head` (default): both arms answer with the same vision-capable slug. SurfSense's actual architectural value-prop is that the **ingestion-time vision LLM and the runtime LLM are completely independent** — you can pay a vision LLM *once*, at ingest, to convert every embedded image into text (per-image OCR **and** semantic description, inlined where the image actually appears in the document — see [What `--use-vision-llm` produces](#what---use-vision-llm-produces) below). Then every query is served by a cheap text-only model that sees that extracted text natively. Two extra scenarios make this explicit:
|
||||
|
||||
| `--scenario` | Native arm answers with | SurfSense arm answers with | Question being measured |
|
||||
|--------------------|----------------------------------------|--------------------------------|------------------------------------------------------------------------------------------|
|
||||
| `head-to-head` | `--provider-model` (vision) | `--provider-model` (vision) | Pure RAG quality at parity. (Default.) |
|
||||
| `symmetric-cheap` | `--provider-model` (cheap, text-only) | `--provider-model` (same) | Does pre-extracted image context let a non-vision LLM reason over image-heavy docs? |
|
||||
| `cost-arbitrage` | `--native-arm-model` (vision) | `--provider-model` (cheap) | How close does SurfSense get to a vision-native baseline at a fraction of per-query cost?|
|
||||
|
||||
In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_llm_config_id` (auto-picked from the strongest registered global OpenRouter vision config — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm <slug>`). What changes is which slug the *answering* models hit per arm.
|
||||
|
||||
### Ingest with vision, evaluate with a non-vision LLM (`symmetric-cheap`)
|
||||
|
||||
This is the answer to *"does SurfSense give a non-vision LLM enough context to reason over image-heavy docs?"*. Both arms hit the same cheap text-only slug. The native arm is structurally blind to images (text-only LLM + raw PDFs). The SurfSense arm reads chunks that already contain the per-image OCR and visual descriptions, written there by the vision LLM at ingest time.
|
||||
|
||||
```bash
|
||||
python -m surfsense_evals setup --suite medical \
|
||||
--scenario symmetric-cheap \
|
||||
--provider-model openai/gpt-5.4-mini
|
||||
# vision LLM at ingest = auto-picked (claude-sonnet-4.5 by default)
|
||||
# answer LLM for BOTH arms = openai/gpt-5.4-mini (text-only)
|
||||
|
||||
python -m surfsense_evals ingest medical medxpertqa --split test # vision=on by default
|
||||
python -m surfsense_evals run medical medxpertqa --concurrency 4
|
||||
python -m surfsense_evals report --suite medical
|
||||
# Δ accuracy on image-required MCQs is the headline number; native arm
|
||||
# baseline is "what a text-only LLM gets without seeing the images".
|
||||
```
|
||||
|
||||
### Cheap SurfSense vs vision-native baseline (`cost-arbitrage`)
|
||||
|
||||
```bash
|
||||
python -m surfsense_evals setup --suite medical \
|
||||
--scenario cost-arbitrage \
|
||||
--provider-model openai/gpt-5.4-mini \
|
||||
--native-arm-model anthropic/claude-sonnet-4.5
|
||||
# vision LLM at ingest = auto-picked claude-sonnet-4.5
|
||||
# native arm = sonnet (vision); SurfSense arm = gpt-5.4-mini (text-only)
|
||||
|
||||
python -m surfsense_evals ingest medical medxpertqa --split test
|
||||
python -m surfsense_evals run medical medxpertqa --concurrency 4
|
||||
python -m surfsense_evals report --suite medical
|
||||
# Report header reads:
|
||||
# Scenario: cost-arbitrage — native arm answers with `anthropic/claude-sonnet-4.5`
|
||||
# (vision); SurfSense answers with `openai/gpt-5.4-mini` over chunks vision-extracted
|
||||
# at ingest by `anthropic/claude-sonnet-4.5`.
|
||||
```
|
||||
|
||||
Notes:
|
||||
- `cost-arbitrage` requires both `--provider-model` (the cheap SurfSense slug) AND `--native-arm-model <vision slug>`.
|
||||
- `--vision-llm <slug>` is optional; if omitted the harness queries `GET /api/v1/global-vision-llm-configs` and auto-picks the strongest registered one. Pass `--no-vision-llm-setup` if you want to keep whatever vision config is already attached to the SearchSpace.
|
||||
- The runner's "looks text-only" warning is suppressed (or relabelled as informational) for `symmetric-cheap` so intentional asymmetry doesn't read as a misconfiguration.
|
||||
- All three scenario fields (`scenario`, `provider_model`, `native_arm_model`, `vision_provider_model`) are persisted to `state.json` and recorded in `run_artifact.extra` + the report header — no need to retrace what was set.
|
||||
|
||||
## Per-benchmark useful flags
|
||||
|
||||
`medical/medxpertqa` (`run`):
|
||||
- `--split {test,dev,all}` — pick a subset (default `test`)
|
||||
- `--task "Diagnosis"` / `--body-system "Cardiovascular"` — slice the report
|
||||
- `--require-images` — drop rare rows where every image filename failed to resolve
|
||||
- `--n 100` — quick smoke run
|
||||
- `--no-mentions` — let SurfSense retrieve unscoped ("did the @-mention matter?")
|
||||
|
||||
`multimodal_doc/mmlongbench`:
|
||||
- `--max-docs N` (ingest) — cap downloads at the first N unique PDFs
|
||||
- `--format {str,int,float,list,none}` (run) — slice by answer format; `none` = the ~22% intentionally unanswerable hallucination probes
|
||||
- `--skip-unanswerable` (run) — drop unanswerable questions
|
||||
- `--docs <a.pdf>,<b.pdf>` (run) — scope to specific docs
|
||||
|
||||
## Ingestion knobs (vision LLM, processing mode, summarize)
|
||||
|
||||
The harness exposes `POST /api/v1/documents/fileupload`'s three knobs on every `ingest` subcommand:
|
||||
|
||||
| Flag pair | Effect |
|
||||
|--------------------------------------------|-----------------------------------------------------------------------------------------|
|
||||
| `--use-vision-llm` / `--no-vision-llm` | Walk every embedded image in the PDF and inline image-derived text at the image's position (see below). |
|
||||
| `--processing-mode {basic,premium}` | `premium` carries a 10× page multiplier and routes to a stronger ETL (e.g. LlamaCloud). |
|
||||
| `--should-summarize` / `--no-summarize` | Generate a per-document summary at ingest. |
|
||||
|
||||
The "Default ingest" column in the benchmarks table is what runs if you don't pass any flag. Whatever was actually used is recorded as a `__settings__` header in the doc map (`data/<suite>/maps/<benchmark>_*_map.jsonl`) and as `extra.ingest_settings` in `run_artifact.json`, then surfaced in the report — no need to hunt through CLI history.
|
||||
|
||||
> The backend's `ETL_SERVICE` env var (`DOCLING` | `UNSTRUCTURED` | `LLAMACLOUD`) is **not** per-upload. Restart the backend with a different `ETL_SERVICE` and re-ingest to compare ETLs (route through `--processing-mode premium` if your backend uses that mode for the stronger ETL).
|
||||
|
||||
### What `--use-vision-llm` produces
|
||||
|
||||
When vision is on, the backend's ETL pipeline (`app/etl_pipeline/picture_describer.py`) does, **per embedded image** in the PDF:
|
||||
|
||||
1. Extract the raw image bytes via `pypdf` (deduped by sha256, size-capped to match the vision LLM's per-image limit).
|
||||
2. **Per-image OCR** — re-feed the image as a standalone upload through the configured ETL service (Docling / Azure DI / LlamaCloud) with `vision_llm=None`, so the ETL's OCR engine extracts the literal text-in-image.
|
||||
3. **Visual description** — call the vision LLM on the image with a description-only prompt (it's explicitly told *not* to transcribe text — that's OCR's job). Steps 2 and 3 run in parallel per image.
|
||||
4. Splice a horizontal-rule-delimited section **at the image's original position** in the parser markdown (replacing Docling's `<!-- image -->` placeholder + caption, or the bare `Image: <name>` caption a stripped-image parser leaves behind):
|
||||
|
||||
```markdown
|
||||
---
|
||||
|
||||
**Embedded image:** `MM-130-a.jpeg`
|
||||
|
||||
**OCR text:**
|
||||
Slice 24 / 60
|
||||
L R
|
||||
|
||||
**Visual description:**
|
||||
|
||||
- Axial contrast-enhanced CT showing a large cystic mass in the left upper quadrant.
|
||||
- Mass effect on the adjacent stomach; left kidney displaced inferiorly.
|
||||
|
||||
---
|
||||
```
|
||||
|
||||
This is what makes `--scenario symmetric-cheap` and `--scenario cost-arbitrage` work: a non-vision LLM reading SurfSense's chunks sees the image's text and semantic content as plain markdown, alongside the surrounding case text, in the same retrieved chunk. Without it the cheap LLM would have nothing extra to read.
|
||||
|
||||
### A/B testing the same corpus with different settings
|
||||
|
||||
SurfSense dedupes uploads by `(filename, search_space_id)` — **not** by content hash and **not** by ingestion settings. Re-uploading the same filename to the same SearchSpace with a different `--use-vision-llm` flag silently skips re-processing. Give each variant its own SearchSpace:
|
||||
|
||||
```bash
|
||||
# Baseline arm (vision off)
|
||||
python -m surfsense_evals setup --suite medical --provider-model anthropic/claude-sonnet-4.5
|
||||
python -m surfsense_evals ingest medical medxpertqa --no-vision-llm
|
||||
python -m surfsense_evals run medical medxpertqa --n 100
|
||||
python -m surfsense_evals teardown --suite medical
|
||||
|
||||
# Vision arm (the benchmark default)
|
||||
python -m surfsense_evals setup --suite medical --provider-model anthropic/claude-sonnet-4.5
|
||||
python -m surfsense_evals ingest medical medxpertqa
|
||||
python -m surfsense_evals run medical medxpertqa --n 100
|
||||
python -m surfsense_evals report --suite medical
|
||||
```
|
||||
|
||||
Both runs land in `data/medical/runs/<ts>/medxpertqa/` with their settings recorded; rendered PDFs stay cached under `data/medical/medxpertqa/pdfs/` so the second `ingest` is upload-only.
|
||||
|
||||
## Environment variables
|
||||
|
||||
- `SURFSENSE_API_BASE` (default `http://localhost:8000`)
|
||||
- `OPENROUTER_API_KEY` — required for the `native_pdf` arm and for `models list`
|
||||
- One of `SURFSENSE_USER_EMAIL` + `SURFSENSE_USER_PASSWORD` (LOCAL), **or** `SURFSENSE_JWT` (+ optional `SURFSENSE_REFRESH_TOKEN`) for GOOGLE/pre-issued JWT
|
||||
- `EVAL_DATA_DIR` (default `<project>/data`) — datasets, rendered PDFs, ingestion id maps, run outputs, `state.json`
|
||||
- `EVAL_REPORTS_DIR` (default `<project>/reports`)
|
||||
- `OPENROUTER_BASE_URL` (default `https://openrouter.ai/api/v1`) — only if you proxy OpenRouter
|
||||
|
||||
## Adding a new domain suite
|
||||
|
||||
1. Create `surfsense_evals/src/surfsense_evals/suites/<domain>/<benchmark>/` with `__init__.py`, `ingest.py`, `runner.py`, optional `prompt.py`.
|
||||
2. Implement a `Benchmark` subclass (see `core/registry.py`); compose with `core.clients.*`, `core.arms.*`, `core.parse.*`, `core.metrics.*`.
|
||||
3. Call `register(MyBenchmark())` at the bottom of `<benchmark>/__init__.py`. Auto-discovery picks it up; `setup --suite <domain>` and `ingest/run <domain> <benchmark>` work immediately.
|
||||
|
||||
Each suite gets its own SearchSpace (`eval-<suite>-<UTC-ts>`), `state.json` slot, data dir, reports dir, and pinned LLM. Suites never share a SearchSpace.
|
||||
|
||||
## Out of scope (follow-up PRs)
|
||||
|
||||
- Docker service for `docker compose run evals run medical medxpertqa`.
|
||||
- Multi-model sweeps (one slug per `setup` for now; aggregate reports come later).
|
||||
- A long-context-stuffing arm (give the model the same retrieved chunks SurfSense saw).
|
||||
- LLM-judge grader for MMLongBench-Doc (paper uses GPT-4 as judge; we ship a deterministic rule-based grader).
|
||||
- MedXpertQA-MM accuracy by image modality — dataset doesn't tag modality directly; we slice by `medical_task` and `body_system`.
|
||||
- A `--slot <name>` flag that decouples the state-slot key from the benchmark registry's `suite` attribute, so parallel SearchSpaces with different ingestion settings can coexist on the same benchmark without `teardown` between A/B arms.
|
||||
|
||||
See `c:/Users/91882/.cursor/plans/medical_rag_evals_(mirage_+_curev1)_e797a324.plan.md` for the full design rationale.
|
||||
2
surfsense_evals/data/.gitignore
vendored
Normal file
2
surfsense_evals/data/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
*
|
||||
!.gitignore
|
||||
60
surfsense_evals/pyproject.toml
Normal file
60
surfsense_evals/pyproject.toml
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
[project]
|
||||
name = "surfsense-evals"
|
||||
version = "0.1.0"
|
||||
description = "Domain-agnostic evaluation harness for SurfSense (medical RAG suite ships first; legal/finance/code suites slot in under suites/)."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
license = { text = "Apache-2.0" }
|
||||
authors = [{ name = "SurfSense" }]
|
||||
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
"httpx-sse>=0.4.0",
|
||||
"datasets>=2.21.0",
|
||||
"huggingface_hub>=0.24.0",
|
||||
"reportlab>=4.0.0",
|
||||
"Pillow>=10.0.0",
|
||||
"pyarrow>=15.0.0",
|
||||
"pydantic>=2.6.0",
|
||||
"tqdm>=4.66.0",
|
||||
"numpy>=1.26.0",
|
||||
"scikit-learn>=1.4.0",
|
||||
"scipy>=1.12.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"rich>=13.7.0",
|
||||
"trafilatura>=1.12.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"respx>=0.21.0",
|
||||
"ruff>=0.5.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
surfsense-evals = "surfsense_evals.core.cli:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
include = ["surfsense_evals*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
markers = [
|
||||
"integration: opt-in tests that hit a live SurfSense instance (run with `-m integration`)",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "B", "UP", "SIM", "ASYNC"]
|
||||
ignore = ["E501"]
|
||||
4
surfsense_evals/reports/.gitignore
vendored
Normal file
4
surfsense_evals/reports/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
*
|
||||
!.gitignore
|
||||
!medical/
|
||||
!medical/sample_summary.md
|
||||
97
surfsense_evals/scripts/download_crag_task3.py
Normal file
97
surfsense_evals/scripts/download_crag_task3.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Download CRAG Task 3's 4 .tar.bz2 parts in parallel.
|
||||
|
||||
Run once before ``ingest research crag_t3`` to avoid the ingest
|
||||
synchronously blocking on a 7 GB download. Skips parts already
|
||||
present and complete on disk.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
log = logging.getLogger("download_task3")
|
||||
|
||||
|
||||
_BASE = (
|
||||
"https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/"
|
||||
"crag_task_3_dev_v4.tar.bz2.part"
|
||||
)
|
||||
_USER_AGENT = "SurfSense-Evals/0.1 (CRAG Task 3 fetch)"
|
||||
|
||||
|
||||
def _expected_size(url: str) -> int:
|
||||
req = urllib.request.Request(url, method="HEAD", headers={"User-Agent": _USER_AGENT})
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
return int(resp.headers.get("content-length", 0))
|
||||
|
||||
|
||||
def download_one(part: int, dest_dir: Path) -> Path:
|
||||
url = f"{_BASE}{part}"
|
||||
dest = dest_dir / f"crag_task_3_dev_v4.tar.bz2.part{part}"
|
||||
expected = _expected_size(url)
|
||||
if dest.exists() and dest.stat().st_size == expected:
|
||||
log.info("part%d: cached (%d bytes)", part, expected)
|
||||
return dest
|
||||
log.info("part%d: downloading %d bytes ...", part, expected)
|
||||
tmp = dest.with_suffix(dest.suffix + ".part_dl")
|
||||
started = time.monotonic()
|
||||
last_log = started
|
||||
with urllib.request.urlopen(
|
||||
urllib.request.Request(url, headers={"User-Agent": _USER_AGENT}),
|
||||
timeout=900,
|
||||
) as resp, tmp.open("wb") as fh:
|
||||
downloaded = 0
|
||||
chunk = resp.read(1 << 20)
|
||||
while chunk:
|
||||
fh.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
now = time.monotonic()
|
||||
if now - last_log > 5.0:
|
||||
pct = 100 * downloaded / expected if expected else 0
|
||||
rate_mb = (downloaded / (now - started)) / (1 << 20)
|
||||
log.info(
|
||||
"part%d: %5.1f%% (%.1f / %.1f MiB at %.1f MiB/s)",
|
||||
part, pct, downloaded / (1 << 20), expected / (1 << 20), rate_mb,
|
||||
)
|
||||
last_log = now
|
||||
chunk = resp.read(1 << 20)
|
||||
tmp.replace(dest)
|
||||
elapsed = time.monotonic() - started
|
||||
log.info(
|
||||
"part%d: done in %.1fs (%.1f MiB/s avg)",
|
||||
part, elapsed, (expected / (1 << 20)) / max(elapsed, 0.001),
|
||||
)
|
||||
return dest
|
||||
|
||||
|
||||
def main() -> int:
|
||||
dest_dir = Path("data/research/crag_t3/.raw_cache")
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 4 parts in parallel — typical residential connection saturates around
|
||||
# 2 streams; GitHub raw serves these fine in parallel.
|
||||
started = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=4) as ex:
|
||||
futures = {ex.submit(download_one, i, dest_dir): i for i in range(1, 5)}
|
||||
for fut in as_completed(futures):
|
||||
part = futures[fut]
|
||||
try:
|
||||
fut.result()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.error("part%d failed: %s", part, exc)
|
||||
return 1
|
||||
log.info("All 4 parts downloaded in %.1fs", time.monotonic() - started)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
37
surfsense_evals/scripts/peek_crag_run.py
Normal file
37
surfsense_evals/scripts/peek_crag_run.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""Tiny helper to inspect the latest CRAG run's per-question outputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def main() -> None:
|
||||
raw_path = sorted(glob.glob("data/research/runs/*/crag/raw.jsonl"))[-1]
|
||||
print(f"Reading: {raw_path}")
|
||||
rows = [json.loads(line) for line in open(raw_path, encoding="utf-8") if line.strip()]
|
||||
by_q: dict[str, dict[str, dict]] = defaultdict(dict)
|
||||
for r in rows:
|
||||
by_q[r["qid"]][r["arm"]] = r
|
||||
|
||||
for qid, arms in list(by_q.items()):
|
||||
b = arms.get("bare_llm", {})
|
||||
l = arms.get("long_context", {})
|
||||
s = arms.get("surfsense", {})
|
||||
print(f"\n=== {qid} ({b.get('domain')}/{b.get('question_type')}) ===")
|
||||
print(f" question: {b.get('extra', {}).get('question', '?')!r}")
|
||||
print(f" gold: {b.get('gold')!r}")
|
||||
for arm_name, a in (("bare_llm", b), ("long_context", l), ("surfsense", s)):
|
||||
grade = a.get("graded", {})
|
||||
text = (a.get("raw_text") or "").strip()
|
||||
tail = text[-200:] if text else ""
|
||||
print(
|
||||
f" [{arm_name}] grade={grade.get('grade')} "
|
||||
f"method={grade.get('method')}"
|
||||
)
|
||||
print(f" -> {tail!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
64
surfsense_evals/scripts/peek_disagreements.py
Normal file
64
surfsense_evals/scripts/peek_disagreements.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""Show questions where SurfSense was wrong but long-context was right (and vice versa)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def main() -> None:
|
||||
raw_path = sorted(glob.glob("data/research/runs/*/crag/raw.jsonl"))[-1]
|
||||
print(f"Reading: {raw_path}")
|
||||
rows = [json.loads(line) for line in open(raw_path, encoding="utf-8") if line.strip()]
|
||||
by_q: dict[str, dict[str, dict]] = defaultdict(dict)
|
||||
for r in rows:
|
||||
by_q[r["qid"]][r["arm"]] = r
|
||||
|
||||
surf_wrong_lc_right = []
|
||||
lc_wrong_surf_right = []
|
||||
surf_wrong_bare_right = []
|
||||
for qid, arms in by_q.items():
|
||||
b = arms.get("bare_llm", {}).get("graded", {}).get("grade")
|
||||
lc = arms.get("long_context", {}).get("graded", {}).get("grade")
|
||||
s = arms.get("surfsense", {}).get("graded", {}).get("grade")
|
||||
if s == "incorrect" and lc == "correct":
|
||||
surf_wrong_lc_right.append(qid)
|
||||
if lc == "incorrect" and s == "correct":
|
||||
lc_wrong_surf_right.append(qid)
|
||||
if s == "incorrect" and b == "correct":
|
||||
surf_wrong_bare_right.append(qid)
|
||||
|
||||
print(f"\nSurfSense INCORRECT but Long-Context CORRECT: {len(surf_wrong_lc_right)}")
|
||||
print(f"Long-Context INCORRECT but SurfSense CORRECT: {len(lc_wrong_surf_right)}")
|
||||
print(f"SurfSense INCORRECT but Bare CORRECT: {len(surf_wrong_bare_right)}")
|
||||
|
||||
print("\n=== Where SurfSense is wrong but long-context is right (top 5) ===")
|
||||
for qid in surf_wrong_lc_right[:5]:
|
||||
arms = by_q[qid]
|
||||
b = arms.get("bare_llm", {})
|
||||
print(f"\n[{qid}] domain={b.get('domain')} qtype={b.get('question_type')}")
|
||||
print(f" GOLD: {b.get('gold')!r}")
|
||||
for arm_name in ("bare_llm", "long_context", "surfsense"):
|
||||
a = arms.get(arm_name, {})
|
||||
t = (a.get("raw_text") or "").strip()
|
||||
tail = t[-180:] if t else ""
|
||||
grade = a.get("graded", {})
|
||||
print(f" [{arm_name}] {grade.get('grade')} ({grade.get('method')}): {tail!r}")
|
||||
|
||||
print("\n=== Where Long-Context is wrong but SurfSense is right (top 5) ===")
|
||||
for qid in lc_wrong_surf_right[:5]:
|
||||
arms = by_q[qid]
|
||||
b = arms.get("bare_llm", {})
|
||||
print(f"\n[{qid}] domain={b.get('domain')} qtype={b.get('question_type')}")
|
||||
print(f" GOLD: {b.get('gold')!r}")
|
||||
for arm_name in ("bare_llm", "long_context", "surfsense"):
|
||||
a = arms.get(arm_name, {})
|
||||
t = (a.get("raw_text") or "").strip()
|
||||
tail = t[-180:] if t else ""
|
||||
grade = a.get("graded", {})
|
||||
print(f" [{arm_name}] {grade.get('grade')} ({grade.get('method')}): {tail!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
40
surfsense_evals/scripts/peek_t3_doc_map.py
Normal file
40
surfsense_evals/scripts/peek_t3_doc_map.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Quick sanity-check for the CRAG Task 3 doc map after ingest."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main() -> int:
|
||||
p = Path("data/research/maps/crag_t3_doc_map.jsonl")
|
||||
if not p.exists():
|
||||
print(f"Doc map missing: {p}")
|
||||
return 1
|
||||
rows = []
|
||||
settings = {}
|
||||
for line in p.read_text(encoding="utf-8").splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if "__settings__" in row:
|
||||
settings = row
|
||||
continue
|
||||
rows.append(row)
|
||||
print(f"Settings header: {settings}")
|
||||
print(f"Doc map rows: {len(rows)}")
|
||||
for r in rows:
|
||||
print(f" qid={r['qid']:<10} domain={r['domain']:<8} qtype={r['question_type']}")
|
||||
print(f" question: {r['question'][:90]}")
|
||||
print(f" gold: {r['gold_answer'][:90]}")
|
||||
print(
|
||||
f" pages: {len(r['page_filenames'])} extracted, "
|
||||
f"{len(r['document_ids'])} doc_ids, "
|
||||
f"{len(r['missing_pages'])} missing"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
65
surfsense_evals/scripts/summarise_crag_run.py
Normal file
65
surfsense_evals/scripts/summarise_crag_run.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
"""Render a quick textual summary of the latest CRAG run."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
|
||||
|
||||
def main() -> None:
|
||||
runs = sorted(glob.glob("data/research/runs/*/crag/run_artifact.json"))
|
||||
if not runs:
|
||||
print("(no CRAG runs found)")
|
||||
return
|
||||
m = json.load(open(runs[-1], encoding="utf-8"))
|
||||
metrics = m["metrics"]
|
||||
|
||||
print(f"Reading: {runs[-1]}")
|
||||
print(f"n_questions: {m['extra']['n_questions']}")
|
||||
print()
|
||||
print("=== ARMS ===")
|
||||
for arm in ("bare_llm", "long_context", "surfsense"):
|
||||
d = metrics[arm]
|
||||
print(
|
||||
f"{arm:14s}: "
|
||||
f"acc={d['accuracy']*100:5.1f}% (Wilson 95% CI "
|
||||
f"{d['ci_low']*100:.1f}-{d['ci_high']*100:.1f}) | "
|
||||
f"correct={d['correct_rate']*100:5.1f}% "
|
||||
f"missing={d['missing_rate']*100:5.1f}% "
|
||||
f"incorrect={d['incorrect_rate']*100:5.1f}% | "
|
||||
f"truth={d['truthfulness_score']*100:+5.1f}%"
|
||||
)
|
||||
|
||||
print()
|
||||
print("=== DELTAS ===")
|
||||
for key, d in metrics["deltas"].items():
|
||||
print(
|
||||
f"{key:30s}: acc={d['accuracy_pp']:+5.1f}pp "
|
||||
f"truth={d['truthfulness_score_pp']:+5.1f}pp "
|
||||
f"McNemar p={d['mcnemar_p_value']:.4f} ({d['mcnemar_method']}) "
|
||||
f"bootstrap CI [{d['bootstrap_ci_low']:+.1f}, {d['bootstrap_ci_high']:+.1f}]"
|
||||
)
|
||||
|
||||
print()
|
||||
print("=== PER-QUESTION-TYPE TRUTHFULNESS ===")
|
||||
for qt, row in sorted(metrics["per_question_type"].items()):
|
||||
n = row["n"]
|
||||
pieces = [f"{qt:20s} (n={n:3d}):"]
|
||||
for arm in ("bare_llm", "long_context", "surfsense"):
|
||||
if arm in row:
|
||||
pieces.append(f"{arm}={row[arm]['truthfulness_score']*100:+7.1f}%")
|
||||
print(" ".join(pieces))
|
||||
|
||||
print()
|
||||
print("=== PER-DOMAIN TRUTHFULNESS ===")
|
||||
for dom, row in sorted(metrics["per_domain"].items()):
|
||||
n = row["n"]
|
||||
pieces = [f"{dom:10s} (n={n:3d}):"]
|
||||
for arm in ("bare_llm", "long_context", "surfsense"):
|
||||
if arm in row:
|
||||
pieces.append(f"{arm}={row[arm]['truthfulness_score']*100:+7.1f}%")
|
||||
print(" ".join(pieces))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
10
surfsense_evals/src/surfsense_evals/__init__.py
Normal file
10
surfsense_evals/src/surfsense_evals/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""SurfSense Evals — domain-agnostic eval harness.
|
||||
|
||||
Public entry-point is the ``surfsense_evals`` CLI (``python -m surfsense_evals``).
|
||||
Programmatic embedding is a non-goal for now; everything goes through the CLI
|
||||
+ filesystem outputs (state.json, raw run JSONL, summary.md/json reports).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__version__ = "0.1.0"
|
||||
13
surfsense_evals/src/surfsense_evals/__main__.py
Normal file
13
surfsense_evals/src/surfsense_evals/__main__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Module entry point: ``python -m surfsense_evals ...``.
|
||||
|
||||
Delegates to ``core.cli.main``. ``core.cli`` lazily imports
|
||||
``surfsense_evals.suites`` so every benchmark gets a chance to register
|
||||
before argparse builds its subcommand groups.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from surfsense_evals.core.cli import main
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
raise SystemExit(main())
|
||||
8
surfsense_evals/src/surfsense_evals/core/__init__.py
Normal file
8
surfsense_evals/src/surfsense_evals/core/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""Domain-agnostic infrastructure shared by every suite.
|
||||
|
||||
Nothing under ``core/`` knows or cares about a specific evaluation domain.
|
||||
Suites live under ``surfsense_evals.suites.<domain>.<benchmark>`` and
|
||||
register themselves with ``core.registry`` on import.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
44
surfsense_evals/src/surfsense_evals/core/arms/__init__.py
Normal file
44
surfsense_evals/src/surfsense_evals/core/arms/__init__.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Arm protocol + concrete arms shared across suites.
|
||||
|
||||
Concrete arms (``NativePdfArm``, ``SurfSenseArm``, ``BareLlmArm``) are
|
||||
imported lazily via ``__getattr__`` so consumers that only need the
|
||||
protocol — e.g. the registry's ``Arm`` re-export — don't transitively
|
||||
pull in ``httpx`` providers or the SurfSense client unless they
|
||||
actually use those arms.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import Arm, ArmRequest, ArmResult
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .bare_llm import BareLlmArm
|
||||
from .native_pdf import NativePdfArm
|
||||
from .surfsense import SurfSenseArm
|
||||
|
||||
__all__ = [
|
||||
"Arm",
|
||||
"ArmRequest",
|
||||
"ArmResult",
|
||||
"BareLlmArm",
|
||||
"NativePdfArm",
|
||||
"SurfSenseArm",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str): # PEP 562
|
||||
if name == "NativePdfArm":
|
||||
from .native_pdf import NativePdfArm
|
||||
|
||||
return NativePdfArm
|
||||
if name == "SurfSenseArm":
|
||||
from .surfsense import SurfSenseArm
|
||||
|
||||
return SurfSenseArm
|
||||
if name == "BareLlmArm":
|
||||
from .bare_llm import BareLlmArm
|
||||
|
||||
return BareLlmArm
|
||||
raise AttributeError(f"module 'surfsense_evals.core.arms' has no attribute {name!r}")
|
||||
100
surfsense_evals/src/surfsense_evals/core/arms/bare_llm.py
Normal file
100
surfsense_evals/src/surfsense_evals/core/arms/bare_llm.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""Bare-LLM arm: chat completion with prompt-only input, no retrieval.
|
||||
|
||||
Pairs with ``SurfSenseArm`` for any benchmark that wants to measure
|
||||
"how much does the model already know without RAG?". For factuality /
|
||||
multi-hop benchmarks (FRAMES, MuSiQue, …) this produces the published
|
||||
"naive prompting" baseline — e.g. FRAMES's 40.8% on Gemini-Pro-1.5.
|
||||
|
||||
Symmetric with ``NativePdfArm`` in shape, but the request carries no
|
||||
``pdf_paths``: the prompt itself is the only input the model gets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..providers.openrouter_chat import OpenRouterChatProvider
|
||||
from .base import Arm, ArmRequest, ArmResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BareLlmArm(Arm):
|
||||
"""``Arm`` implementation backed by ``OpenRouterChatProvider``.
|
||||
|
||||
``name`` defaults to ``"bare_llm"`` but is overridable per-instance.
|
||||
Suites that want two distinct OpenRouter chat arms (e.g. CRAG's
|
||||
``bare_llm`` vs ``long_context`` — both backed by chat-completions
|
||||
but exercising different prompt strategies) instantiate twice with
|
||||
different names so the metrics aggregator can keep them separate.
|
||||
"""
|
||||
|
||||
name: str = "bare_llm"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: OpenRouterChatProvider,
|
||||
max_output_tokens: int | None = 1024,
|
||||
system_prompt: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._max_output = max_output_tokens
|
||||
self._system_prompt = system_prompt
|
||||
if name:
|
||||
self.name = name
|
||||
|
||||
@classmethod
|
||||
def from_env(
|
||||
cls,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_output_tokens: int | None = 1024,
|
||||
system_prompt: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> BareLlmArm:
|
||||
provider = OpenRouterChatProvider(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
)
|
||||
return cls(
|
||||
provider=provider,
|
||||
max_output_tokens=max_output_tokens,
|
||||
system_prompt=system_prompt,
|
||||
name=name,
|
||||
)
|
||||
|
||||
async def answer(self, request: ArmRequest) -> ArmResult:
|
||||
try:
|
||||
response = await self._provider.complete(
|
||||
prompt=request.prompt,
|
||||
system_prompt=self._system_prompt,
|
||||
max_tokens=self._max_output,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text="",
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text=response.text,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cost_micros=response.cost_micros,
|
||||
latency_ms=response.latency_ms,
|
||||
extra={
|
||||
"model": self._provider.model,
|
||||
"finish_reason": response.finish_reason,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["BareLlmArm"]
|
||||
93
surfsense_evals/src/surfsense_evals/core/arms/base.py
Normal file
93
surfsense_evals/src/surfsense_evals/core/arms/base.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Arm protocol + the value types every arm exchanges with a runner.
|
||||
|
||||
An ``Arm`` is "one way to answer one question". Two ship in this PR:
|
||||
|
||||
* ``NativePdfArm`` — drop the PDF straight into an OpenRouter
|
||||
chat-completions request with ``plugins=[{file-parser, engine:
|
||||
native}]``. Used for the head-to-head "is the model good enough on
|
||||
its own?" measurement.
|
||||
* ``SurfSenseArm`` — POST ``/api/v1/new_chat`` with the question
|
||||
scoped to the relevant ``mentioned_document_ids``; consume the SSE
|
||||
stream and parse citations.
|
||||
|
||||
Both implement the same protocol so a benchmark runner only sees
|
||||
``Arm.answer(request) -> ArmResult``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArmRequest:
|
||||
"""One arm-call worth of input.
|
||||
|
||||
* ``question_id`` is opaque — used for logging and joining results.
|
||||
* ``prompt`` is the fully-formatted text the arm should send. The
|
||||
runner is responsible for prompt construction so head-to-head
|
||||
comparisons use byte-identical text.
|
||||
* ``pdf_paths`` is the per-question source PDFs (used by
|
||||
``NativePdfArm``). Empty for retrieval-only / corpus-wide
|
||||
benchmarks.
|
||||
* ``mentioned_document_ids`` is the SurfSense document scoping list
|
||||
(used by ``SurfSenseArm``). When ``None`` SurfSense retrieves
|
||||
across the whole search space.
|
||||
* ``options`` is a free-form bag of arm-specific overrides
|
||||
(e.g. SurfSense's ``disabled_tools``).
|
||||
"""
|
||||
|
||||
question_id: str
|
||||
prompt: str
|
||||
pdf_paths: list[Path] = field(default_factory=list)
|
||||
mentioned_document_ids: list[int] | None = None
|
||||
options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArmResult:
|
||||
"""Outcome of one ``Arm.answer`` invocation."""
|
||||
|
||||
arm: str
|
||||
question_id: str
|
||||
raw_text: str
|
||||
answer_letter: str | None = None
|
||||
citations: list[dict[str, Any]] = field(default_factory=list)
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cost_micros: int = 0
|
||||
latency_ms: int = 0
|
||||
error: str | None = None
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
def to_jsonl(self) -> dict[str, Any]:
|
||||
"""Stable dict shape for ``data/<suite>/runs/<ts>/<bench>_raw.jsonl``."""
|
||||
|
||||
return {
|
||||
"arm": self.arm,
|
||||
"question_id": self.question_id,
|
||||
"answer_letter": self.answer_letter,
|
||||
"raw_text": self.raw_text,
|
||||
"citations": self.citations,
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cost_micros": self.cost_micros,
|
||||
"latency_ms": self.latency_ms,
|
||||
"error": self.error,
|
||||
"extra": self.extra,
|
||||
}
|
||||
|
||||
|
||||
class Arm(Protocol):
|
||||
"""One concrete way to answer questions for a given run."""
|
||||
|
||||
name: str
|
||||
|
||||
async def answer(self, request: ArmRequest) -> ArmResult: # pragma: no cover - protocol
|
||||
...
|
||||
104
surfsense_evals/src/surfsense_evals/core/arms/native_pdf.py
Normal file
104
surfsense_evals/src/surfsense_evals/core/arms/native_pdf.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Native-PDF arm: drop the PDF straight into OpenRouter chat-completions.
|
||||
|
||||
Generic across suites — a benchmark just supplies the prompt and the
|
||||
single PDF path. Multi-PDF questions concatenate in the runner before
|
||||
calling this arm so each ``answer`` invocation feeds the model exactly
|
||||
one ``data:application/pdf;base64,...`` block (matches the human
|
||||
"drag-and-drop one PDF into Claude" intent).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..parse.answer_letter import extract_answer_letter
|
||||
from ..providers.openrouter_pdf import OpenRouterPdfProvider, PdfEngine
|
||||
from .base import Arm, ArmRequest, ArmResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NativePdfArm(Arm):
|
||||
"""``Arm`` implementation backed by ``OpenRouterPdfProvider``."""
|
||||
|
||||
name: str = "native_pdf"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: OpenRouterPdfProvider,
|
||||
max_output_tokens: int | None = 1024,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._max_output = max_output_tokens
|
||||
|
||||
@classmethod
|
||||
def from_env(
|
||||
cls,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str,
|
||||
engine: PdfEngine = PdfEngine.NATIVE,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_output_tokens: int | None = 1024,
|
||||
) -> NativePdfArm:
|
||||
provider = OpenRouterPdfProvider(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
engine=engine,
|
||||
)
|
||||
return cls(provider=provider, max_output_tokens=max_output_tokens)
|
||||
|
||||
async def answer(self, request: ArmRequest) -> ArmResult:
|
||||
if not request.pdf_paths:
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text="",
|
||||
error="native_pdf arm requires at least one pdf_path",
|
||||
)
|
||||
if len(request.pdf_paths) > 1:
|
||||
# The plan calls out one-PDF-per-question so the head-to-head
|
||||
# is fair; runners are responsible for upstream concatenation.
|
||||
logger.debug(
|
||||
"qid=%s native_pdf got %d pdfs; using first only",
|
||||
request.question_id,
|
||||
len(request.pdf_paths),
|
||||
)
|
||||
pdf = request.pdf_paths[0]
|
||||
try:
|
||||
response = await self._provider.complete(
|
||||
prompt=request.prompt,
|
||||
pdf_path=pdf,
|
||||
max_tokens=self._max_output,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text="",
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
|
||||
letter = extract_answer_letter(response.text)
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text=response.text,
|
||||
answer_letter=letter.letter,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cost_micros=response.cost_micros,
|
||||
latency_ms=response.latency_ms,
|
||||
extra={
|
||||
"model": self._provider.model,
|
||||
"engine": self._provider.engine.value,
|
||||
"answer_letter_strategy": letter.strategy,
|
||||
"finish_reason": response.finish_reason,
|
||||
"pdf_filename": pdf.name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["NativePdfArm"]
|
||||
104
surfsense_evals/src/surfsense_evals/core/arms/surfsense.py
Normal file
104
surfsense_evals/src/surfsense_evals/core/arms/surfsense.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""SurfSense arm: per-question fresh thread + ``/api/v1/new_chat`` stream.
|
||||
|
||||
For every question:
|
||||
|
||||
* Create a fresh ``NewChatThread`` on the suite's pinned SearchSpace.
|
||||
This sidesteps the per-thread ``THREAD_BUSY`` 409 (a single thread
|
||||
serialises turns, see ``surfsense_backend/app/routes/new_chat_routes.py:191-220``).
|
||||
* POST ``/api/v1/new_chat`` with the prompt and the per-question
|
||||
``mentioned_document_ids`` (``surfsense_backend/app/schemas/new_chat.py:241-243``).
|
||||
* Consume the SSE stream via ``NewChatClient.ask`` which accumulates
|
||||
text deltas and returns ``StreamedAnswer``.
|
||||
* Optionally delete the thread (default ON for ephemeral runs).
|
||||
|
||||
Citations are parsed from the streamed assistant text via the
|
||||
canonical regex port; chunk ids are returned in ``ArmResult.citations``
|
||||
for the runner to map back to corpus ids.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..clients import NewChatClient
|
||||
from ..parse.answer_letter import extract_answer_letter
|
||||
from .base import Arm, ArmRequest, ArmResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SurfSenseArm(Arm):
|
||||
"""``Arm`` implementation backed by ``NewChatClient``."""
|
||||
|
||||
name: str = "surfsense"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: NewChatClient,
|
||||
search_space_id: int,
|
||||
ephemeral_threads: bool = True,
|
||||
thread_title_prefix: str = "eval",
|
||||
) -> None:
|
||||
self._client = client
|
||||
self._search_space_id = search_space_id
|
||||
self._ephemeral = ephemeral_threads
|
||||
self._title_prefix = thread_title_prefix
|
||||
|
||||
async def answer(self, request: ArmRequest) -> ArmResult:
|
||||
thread_id: int | None = None
|
||||
try:
|
||||
thread_id = await self._client.create_thread(
|
||||
search_space_id=self._search_space_id,
|
||||
title=f"{self._title_prefix}:{request.question_id}",
|
||||
)
|
||||
answer = await self._client.ask(
|
||||
thread_id=thread_id,
|
||||
search_space_id=self._search_space_id,
|
||||
user_query=request.prompt,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
disabled_tools=request.options.get("disabled_tools"),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text="",
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
extra={"thread_id": thread_id},
|
||||
)
|
||||
finally:
|
||||
if self._ephemeral and thread_id is not None:
|
||||
try:
|
||||
await self._client.delete_thread(thread_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug(
|
||||
"Failed to delete thread %s: %s", thread_id, exc
|
||||
)
|
||||
|
||||
letter = extract_answer_letter(answer.text)
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text=answer.text,
|
||||
answer_letter=letter.letter,
|
||||
citations=answer.citations,
|
||||
latency_ms=answer.latency_ms,
|
||||
# SurfSense doesn't surface input/output token counts in the
|
||||
# SSE stream today; leaving the cost / token fields at 0
|
||||
# documents that gap. Estimating from the raw text would
|
||||
# bias the comparison against the SurfSense arm.
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"search_space_id": self._search_space_id,
|
||||
"answer_letter_strategy": letter.strategy,
|
||||
"user_message_id": answer.user_message_id,
|
||||
"assistant_message_id": answer.assistant_message_id,
|
||||
"finished_normally": answer.finished_normally,
|
||||
"n_raw_events": len(answer.raw_events),
|
||||
"n_mentioned_documents": len(request.mentioned_document_ids or []),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["SurfSenseArm"]
|
||||
273
surfsense_evals/src/surfsense_evals/core/auth.py
Normal file
273
surfsense_evals/src/surfsense_evals/core/auth.py
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
"""Dual-mode credential resolver + httpx client factory with 401 auto-refresh.
|
||||
|
||||
SurfSense supports ``AUTH_TYPE=LOCAL`` (email + password) and
|
||||
``AUTH_TYPE=GOOGLE`` (Google OAuth → frontend stores JWT in ``localStorage``).
|
||||
There is no headless equivalent of the Google flow, so the harness handles
|
||||
both modes by treating the JWT as the universal credential:
|
||||
|
||||
* **LOCAL**: harness POSTs form-encoded ``username`` + ``password`` to
|
||||
``/auth/jwt/login``, reads ``{access_token, refresh_token}``.
|
||||
* **GOOGLE / pre-issued JWT**: operator pastes their existing JWT (and
|
||||
optionally refresh token) into ``SURFSENSE_JWT`` /
|
||||
``SURFSENSE_REFRESH_TOKEN``; harness skips login.
|
||||
|
||||
Either way ``client_with_auth`` returns one shared
|
||||
``httpx.AsyncClient`` with ``Authorization: Bearer <jwt>`` set and an
|
||||
event hook that, on a 401 with a refresh token in scope, calls
|
||||
``POST /auth/jwt/refresh`` and retries the original request once. JWT
|
||||
lifetime defaults to one day backend-side, so this matters for long
|
||||
MIRAGE runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialError(RuntimeError):
|
||||
"""Raised when no credential mode is configured."""
|
||||
|
||||
|
||||
_NO_CREDENTIALS_MESSAGE = (
|
||||
"No SurfSense credentials configured. Set ONE of:\n"
|
||||
" (LOCAL) SURFSENSE_USER_EMAIL + SURFSENSE_USER_PASSWORD\n"
|
||||
" (GOOGLE) SURFSENSE_JWT (and optionally SURFSENSE_REFRESH_TOKEN)\n"
|
||||
"For GOOGLE: log in to SurfSense in your browser, open DevTools → "
|
||||
"Application → Local Storage → copy `surfsense_bearer_token` and "
|
||||
"`surfsense_refresh_token` into those env vars."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBundle:
|
||||
"""Mutable token state — refresh hook updates ``access_token`` in place."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str | None = None
|
||||
# ``mode`` is informational only ("local" or "jwt"); used in error messages.
|
||||
mode: str = "jwt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token acquisition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def acquire_token(config: Config, *, http: httpx.AsyncClient | None = None) -> TokenBundle:
|
||||
"""Resolve credentials → ``TokenBundle``.
|
||||
|
||||
Precedence:
|
||||
|
||||
1. ``SURFSENSE_JWT`` set → use it directly. Refresh token captured if
|
||||
supplied.
|
||||
2. ``SURFSENSE_USER_EMAIL`` + ``SURFSENSE_USER_PASSWORD`` set →
|
||||
form-encoded POST to ``/auth/jwt/login``.
|
||||
3. Neither → raise ``CredentialError``.
|
||||
|
||||
The optional ``http`` argument lets tests inject a mocked client; if
|
||||
omitted a one-shot client is created for the login call only.
|
||||
"""
|
||||
|
||||
if config.has_jwt_mode():
|
||||
return TokenBundle(
|
||||
access_token=config.surfsense_jwt or "",
|
||||
refresh_token=config.surfsense_refresh_token,
|
||||
mode="jwt",
|
||||
)
|
||||
|
||||
if config.has_local_mode():
|
||||
async def _login(client: httpx.AsyncClient) -> TokenBundle:
|
||||
response = await client.post(
|
||||
f"{config.surfsense_api_base}/auth/jwt/login",
|
||||
data={
|
||||
"username": config.surfsense_user_email,
|
||||
"password": config.surfsense_user_password,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise CredentialError(
|
||||
f"LOCAL login failed (HTTP {response.status_code}): "
|
||||
f"{_safe_text(response)}"
|
||||
)
|
||||
payload = response.json()
|
||||
access = payload.get("access_token")
|
||||
if not access:
|
||||
raise CredentialError(
|
||||
f"LOCAL login response missing access_token: {payload!r}"
|
||||
)
|
||||
return TokenBundle(
|
||||
access_token=access,
|
||||
refresh_token=payload.get("refresh_token") or None,
|
||||
mode="local",
|
||||
)
|
||||
|
||||
if http is not None:
|
||||
return await _login(http)
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, connect=10.0)) as client:
|
||||
return await _login(client)
|
||||
|
||||
raise CredentialError(_NO_CREDENTIALS_MESSAGE)
|
||||
|
||||
|
||||
def _safe_text(response: httpx.Response, *, limit: int = 200) -> str:
|
||||
body = response.text or ""
|
||||
if len(body) > limit:
|
||||
return body[:limit] + "…"
|
||||
return body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# httpx client + 401 auto-refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AuthState:
|
||||
"""Shared mutable holder closed over by the auth event hook.
|
||||
|
||||
Kept private so callers can't accidentally mutate the access token
|
||||
out-of-band; ``client_with_auth`` returns the client directly.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, tokens: TokenBundle) -> None:
|
||||
self.config = config
|
||||
self.tokens = tokens
|
||||
self._refresh_in_flight: bool = False
|
||||
|
||||
|
||||
def _build_auth_request(state: _AuthState, request: httpx.Request) -> None:
|
||||
"""Stamp the current bearer onto ``request`` (request-event hook)."""
|
||||
|
||||
request.headers["Authorization"] = f"Bearer {state.tokens.access_token}"
|
||||
|
||||
|
||||
async def _refresh_access_token(
|
||||
state: _AuthState, transport: httpx.AsyncBaseTransport | None = None
|
||||
) -> bool:
|
||||
"""POST ``/auth/jwt/refresh`` with the current refresh token.
|
||||
|
||||
Returns ``True`` on success and updates ``state.tokens`` in place.
|
||||
Returns ``False`` if no refresh token is configured or the call fails.
|
||||
Recursive 401s are avoided by using a *new* client without the auth
|
||||
hook.
|
||||
"""
|
||||
|
||||
refresh = state.tokens.refresh_token
|
||||
if not refresh:
|
||||
return False
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(15.0, connect=5.0),
|
||||
transport=transport,
|
||||
) as inner:
|
||||
response = await inner.post(
|
||||
f"{state.config.surfsense_api_base}/auth/jwt/refresh",
|
||||
json={"refresh_token": refresh},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
logger.warning("Token refresh transport error: %s", exc)
|
||||
return False
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"Token refresh rejected (HTTP %s): %s",
|
||||
response.status_code,
|
||||
_safe_text(response),
|
||||
)
|
||||
return False
|
||||
payload = response.json()
|
||||
new_access = payload.get("access_token")
|
||||
if not new_access:
|
||||
logger.warning("Refresh response missing access_token: %r", payload)
|
||||
return False
|
||||
state.tokens.access_token = new_access
|
||||
new_refresh = payload.get("refresh_token")
|
||||
if new_refresh:
|
||||
state.tokens.refresh_token = new_refresh
|
||||
return True
|
||||
|
||||
|
||||
def client_with_auth(
|
||||
config: Config,
|
||||
tokens: TokenBundle,
|
||||
*,
|
||||
timeout: float = 60.0,
|
||||
transport: httpx.AsyncBaseTransport | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Build a single shared ``httpx.AsyncClient`` for the SurfSense API.
|
||||
|
||||
* Stamps ``Authorization: Bearer <jwt>`` on every outgoing request.
|
||||
* On any 401 response, attempts a single refresh (if a refresh token
|
||||
is configured) and retries the original request once. The retry
|
||||
uses a fresh stamping of the bearer header, so a successful
|
||||
refresh transparently unblocks long runs.
|
||||
* The retry is best-effort — repeated 401s after a refresh attempt
|
||||
are surfaced to the caller so they can re-auth manually.
|
||||
|
||||
Pass ``base_url`` to scope a sub-client (e.g. tests). The default
|
||||
keeps full URLs in calling code, which makes route-spec citations in
|
||||
the codebase easier to grep.
|
||||
"""
|
||||
|
||||
state = _AuthState(config, tokens)
|
||||
|
||||
async def _request_hook(request: httpx.Request) -> None:
|
||||
_build_auth_request(state, request)
|
||||
|
||||
# ``send`` is overridden in ``_AuthAwareClient`` to retry once on 401
|
||||
# after refreshing the bearer. httpx's response event-hook can't
|
||||
# *replace* a response, so we need a subclass to do the replay.
|
||||
client = _AuthAwareClient(
|
||||
state=state,
|
||||
transport=transport,
|
||||
timeout=httpx.Timeout(timeout, connect=10.0),
|
||||
base_url=base_url or "",
|
||||
event_hooks={"request": [_request_hook]},
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class _AuthAwareClient(httpx.AsyncClient):
|
||||
"""``AsyncClient`` that retries once on 401 after refreshing the token."""
|
||||
|
||||
def __init__(self, *, state: _AuthState, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._auth_state = state
|
||||
|
||||
async def send( # type: ignore[override]
|
||||
self, request: httpx.Request, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
response = await super().send(request, **kwargs)
|
||||
if response.status_code != 401:
|
||||
return response
|
||||
# Don't refresh while a refresh is itself in flight.
|
||||
if self._auth_state._refresh_in_flight:
|
||||
return response
|
||||
self._auth_state._refresh_in_flight = True
|
||||
try:
|
||||
refreshed = await _refresh_access_token(self._auth_state)
|
||||
finally:
|
||||
self._auth_state._refresh_in_flight = False
|
||||
if not refreshed:
|
||||
return response
|
||||
# Re-stamp and replay once. ``request`` is reusable.
|
||||
await response.aclose()
|
||||
request.headers["Authorization"] = f"Bearer {self._auth_state.tokens.access_token}"
|
||||
return await super().send(request, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CredentialError",
|
||||
"TokenBundle",
|
||||
"acquire_token",
|
||||
"client_with_auth",
|
||||
]
|
||||
790
surfsense_evals/src/surfsense_evals/core/cli.py
Normal file
790
surfsense_evals/src/surfsense_evals/core/cli.py
Normal file
|
|
@ -0,0 +1,790 @@
|
|||
"""Argparse CLI for ``python -m surfsense_evals``.
|
||||
|
||||
Subcommands:
|
||||
|
||||
* ``setup --suite <name> --provider-model <slug> [--agent-llm-id <int>]``
|
||||
* ``teardown --suite <name>``
|
||||
* ``models list [--provider openrouter] [--grep <s>]``
|
||||
* ``suites list``
|
||||
* ``benchmarks list [--suite <name>]``
|
||||
* ``ingest <suite> <benchmark> [benchmark flags]``
|
||||
* ``run <suite> <benchmark> [benchmark flags]``
|
||||
* ``report --suite <name> [--benchmark <name>]``
|
||||
|
||||
The ``ingest`` / ``run`` subparsers are built dynamically from the
|
||||
registry — adding a new benchmark only requires registering it; the
|
||||
CLI surface comes for free. ``add_run_args`` lets each benchmark
|
||||
publish its own flags.
|
||||
|
||||
Design choices worth flagging:
|
||||
|
||||
* ``setup`` rejects ``agent_llm_id == 0`` (Auto / LiteLLM router) so
|
||||
per-question accuracy is reproducible.
|
||||
* ``setup`` validates that the picked LLM config has
|
||||
``provider == "OPENROUTER"`` and ``model_name == --provider-model``
|
||||
before declaring success — both arms of the head-to-head must hit
|
||||
the same OpenRouter slug.
|
||||
* Lifecycle state is keyed by suite, so ``setup --suite legal`` does
|
||||
not touch ``medical``'s SearchSpace, and vice versa.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import sys
|
||||
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
# Windows' legacy console (cp1252) crashes when Rich tries to write characters
|
||||
# outside the active codepage (e.g. '->', em-dashes, box-drawing). Force UTF-8
|
||||
# on stdout/stderr and disable Rich's legacy_windows render path so the file
|
||||
# stream is used directly. Modern Windows (>=10, VS Code terminal, Windows
|
||||
# Terminal, PowerShell, cmd) all interpret ANSI escapes natively.
|
||||
if sys.platform == "win32":
|
||||
for _stream in (sys.stdout, sys.stderr):
|
||||
try:
|
||||
_stream.reconfigure(encoding="utf-8", errors="replace")
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
|
||||
from . import registry
|
||||
from .auth import CredentialError, acquire_token, client_with_auth
|
||||
from .clients import SearchSpaceClient
|
||||
from .clients.search_space import LlmPreferences
|
||||
from .config import (
|
||||
DEFAULT_SCENARIO,
|
||||
SCENARIOS,
|
||||
Config,
|
||||
SuiteState,
|
||||
clear_suite_state,
|
||||
get_suite_state,
|
||||
load_config,
|
||||
set_suite_state,
|
||||
utc_iso_timestamp,
|
||||
)
|
||||
from .vision_llm import VisionConfigError, resolve_vision_llm
|
||||
|
||||
logger = logging.getLogger("surfsense_evals")
|
||||
console = Console(legacy_windows=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _discover_suites() -> list[str]:
|
||||
"""Trigger ``register(...)`` for every benchmark.
|
||||
|
||||
Imported lazily so ``models list`` (which doesn't need any
|
||||
benchmark) still runs fast.
|
||||
"""
|
||||
|
||||
from surfsense_evals.suites import discover_suites
|
||||
|
||||
return discover_suites()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global LLM config fetcher (used by setup + models list)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlmConfigEntry:
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
model_name: str
|
||||
raw: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> LlmConfigEntry:
|
||||
return cls(
|
||||
id=int(payload["id"]),
|
||||
name=str(payload.get("name", "")),
|
||||
provider=str(payload.get("provider", "")).upper(),
|
||||
model_name=str(payload.get("model_name", "")),
|
||||
raw=payload,
|
||||
)
|
||||
|
||||
|
||||
async def _list_global_llm_configs(http: httpx.AsyncClient, base: str) -> list[LlmConfigEntry]:
|
||||
response = await http.get(
|
||||
f"{base}/api/v1/global-new-llm-configs",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if not isinstance(payload, list):
|
||||
raise RuntimeError(f"Unexpected /global-new-llm-configs payload: {payload!r}")
|
||||
return [LlmConfigEntry.from_payload(item) for item in payload]
|
||||
|
||||
|
||||
def _resolve_openrouter_id(
|
||||
candidates: list[LlmConfigEntry],
|
||||
provider_model: str,
|
||||
*,
|
||||
explicit_id: int | None,
|
||||
) -> int:
|
||||
"""Resolve the SurfSense LLM id for ``provider_model``.
|
||||
|
||||
Behaviour:
|
||||
|
||||
* If ``explicit_id`` is given: return it directly. The caller is
|
||||
then expected to GET-validate that the row's
|
||||
``provider == "OPENROUTER"`` and ``model_name`` matches the slug.
|
||||
That branch supports positive BYOK ``NewLLMConfig`` rows whose
|
||||
slugs may overlap with global OpenRouter virtuals.
|
||||
* Otherwise: filter to ``provider == "OPENROUTER"`` and
|
||||
``model_name == provider_model``. Expect exactly one match —
|
||||
raise with a friendly message otherwise.
|
||||
"""
|
||||
|
||||
if explicit_id is not None:
|
||||
return explicit_id
|
||||
|
||||
matches = [
|
||||
c for c in candidates if c.provider == "OPENROUTER" and c.model_name == provider_model
|
||||
]
|
||||
if not matches:
|
||||
sample = ", ".join(
|
||||
f"{c.model_name} (id={c.id})" for c in candidates if c.provider == "OPENROUTER"
|
||||
)[:600]
|
||||
raise RuntimeError(
|
||||
f"No OpenRouter config found for slug '{provider_model}'. "
|
||||
"Make sure `openrouter_integration.enabled: true` in "
|
||||
"global_llm_config.yaml and that the Celery worker has "
|
||||
"finished its first refresh (the catalogue is fetched at "
|
||||
"Celery startup per `app/celery_app.py`). "
|
||||
f"Available OpenRouter slugs (sample): {sample or '<none>'}.\n"
|
||||
"Browse with: python -m surfsense_evals models list --grep <substring>"
|
||||
)
|
||||
if len(matches) > 1:
|
||||
listing = "\n".join(f" id={c.id} name={c.name!r}" for c in matches)
|
||||
raise RuntimeError(
|
||||
f"Multiple OpenRouter configs for slug '{provider_model}':\n{listing}\n"
|
||||
"Pass --agent-llm-id <id> to disambiguate."
|
||||
)
|
||||
return matches[0].id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subcommand implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _cmd_setup(args: argparse.Namespace) -> int:
|
||||
suite = args.suite
|
||||
provider_model: str = args.provider_model
|
||||
explicit_id: int | None = args.agent_llm_id
|
||||
scenario: str = args.scenario
|
||||
vision_llm_slug: str | None = args.vision_llm
|
||||
native_arm_model: str | None = args.native_arm_model
|
||||
skip_vision_setup: bool = args.no_vision_llm_setup
|
||||
|
||||
if explicit_id == 0:
|
||||
console.print(
|
||||
"[red]agent_llm_id == 0 (Auto / LiteLLM router) is not allowed — "
|
||||
"results would not be reproducible.[/red]"
|
||||
)
|
||||
return 2
|
||||
|
||||
if scenario not in SCENARIOS:
|
||||
console.print(
|
||||
f"[red]Unknown scenario {scenario!r}. Pick one of: "
|
||||
f"{', '.join(SCENARIOS)}[/red]"
|
||||
)
|
||||
return 2
|
||||
|
||||
# Scenario-specific validation. Each branch documents WHY the rule
|
||||
# exists so the operator's mental model matches what the runner does.
|
||||
if scenario == "cost-arbitrage":
|
||||
if not native_arm_model:
|
||||
console.print(
|
||||
"[red]--scenario cost-arbitrage requires --native-arm-model "
|
||||
"<vision-capable slug>.[/red] The native arm needs a vision "
|
||||
"model to fairly answer image-bearing questions; SurfSense "
|
||||
"answers from already-extracted text via --provider-model."
|
||||
)
|
||||
return 2
|
||||
if native_arm_model == provider_model:
|
||||
console.print(
|
||||
"[yellow]--native-arm-model equals --provider-model in "
|
||||
"cost-arbitrage; that's degenerate (same as head-to-head). "
|
||||
"Pick a different slug or switch to --scenario head-to-head.[/yellow]"
|
||||
)
|
||||
elif scenario in ("head-to-head", "symmetric-cheap"):
|
||||
if native_arm_model:
|
||||
console.print(
|
||||
f"[yellow]--native-arm-model is ignored for --scenario {scenario} "
|
||||
f"(both arms answer with --provider-model={provider_model!r}).[/yellow]"
|
||||
)
|
||||
native_arm_model = None # don't persist a stale value
|
||||
|
||||
config = load_config()
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
|
||||
async with client_with_auth(config, token) as http:
|
||||
candidates = await _list_global_llm_configs(http, config.surfsense_api_base)
|
||||
|
||||
try:
|
||||
agent_llm_id = _resolve_openrouter_id(
|
||||
candidates, provider_model, explicit_id=explicit_id
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
|
||||
ss_client = SearchSpaceClient(http, config.surfsense_api_base)
|
||||
existing = get_suite_state(config, suite)
|
||||
if existing is not None:
|
||||
try:
|
||||
row = await ss_client.get(existing.search_space_id)
|
||||
console.print(
|
||||
f"Reusing existing SearchSpace [cyan]{row.name}[/cyan] "
|
||||
f"(id={row.id}) for suite [bold]{suite}[/bold]."
|
||||
)
|
||||
search_space_id = row.id
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 404:
|
||||
console.print(
|
||||
f"[yellow]state.json pointed at SearchSpace id={existing.search_space_id} "
|
||||
f"but backend returned 404; creating a fresh one.[/yellow]"
|
||||
)
|
||||
existing = None
|
||||
else:
|
||||
raise
|
||||
if existing is None:
|
||||
ss_name = f"eval-{suite}-{utc_iso_timestamp()}"
|
||||
row = await ss_client.create(
|
||||
ss_name, description=f"surfsense-evals lifecycle ({suite})"
|
||||
)
|
||||
console.print(
|
||||
f"Created SearchSpace [cyan]{row.name}[/cyan] (id={row.id}) "
|
||||
f"for suite [bold]{suite}[/bold]."
|
||||
)
|
||||
search_space_id = row.id
|
||||
|
||||
# Resolve + attach the vision LLM config (unless explicitly skipped).
|
||||
# Asymmetric scenarios make the vision LLM at ingest a hard
|
||||
# requirement — without it, SurfSense's chunks have no image
|
||||
# content and the entire framing collapses.
|
||||
vision_required = scenario in ("symmetric-cheap", "cost-arbitrage")
|
||||
vision_config_id: int | None = None
|
||||
vision_provider_model: str | None = None
|
||||
if not skip_vision_setup and (vision_required or vision_llm_slug is not None):
|
||||
try:
|
||||
vision_candidates = await ss_client.list_global_vision_llm_configs()
|
||||
resolved = resolve_vision_llm(
|
||||
vision_candidates, explicit_slug=vision_llm_slug
|
||||
)
|
||||
except VisionConfigError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
vision_config_id = resolved.config_id
|
||||
vision_provider_model = resolved.provider_model
|
||||
console.print(
|
||||
f"Vision LLM at ingest: [cyan]{vision_provider_model}[/cyan] "
|
||||
f"(id={vision_config_id}, selected_via={resolved.selected_via})."
|
||||
)
|
||||
|
||||
pref_kwargs: dict[str, Any] = {"agent_llm_id": agent_llm_id}
|
||||
if vision_config_id is not None:
|
||||
pref_kwargs["vision_llm_config_id"] = vision_config_id
|
||||
|
||||
await ss_client.set_llm_preferences(search_space_id, **pref_kwargs)
|
||||
prefs = await ss_client.get_llm_preferences(search_space_id)
|
||||
if not _validate_pin(prefs, provider_model):
|
||||
agent = prefs.agent_llm or {}
|
||||
console.print(
|
||||
f"[red]LLM pin validation FAILED.[/red] After PUT, "
|
||||
f"agent_llm.provider={agent.get('provider')!r}, "
|
||||
f"model_name={agent.get('model_name')!r}; expected "
|
||||
f"provider=OPENROUTER, model_name={provider_model!r}."
|
||||
)
|
||||
return 2
|
||||
if vision_config_id is not None and prefs.vision_llm_config_id != vision_config_id:
|
||||
console.print(
|
||||
f"[red]Vision LLM pin validation FAILED.[/red] After PUT, "
|
||||
f"vision_llm_config_id={prefs.vision_llm_config_id!r}; "
|
||||
f"expected {vision_config_id!r}."
|
||||
)
|
||||
return 2
|
||||
|
||||
suite_state = SuiteState(
|
||||
search_space_id=search_space_id,
|
||||
agent_llm_id=agent_llm_id,
|
||||
provider_model=provider_model,
|
||||
created_at=utc_iso_timestamp(),
|
||||
ingestion_maps=existing.ingestion_maps if existing else {},
|
||||
scenario=scenario,
|
||||
vision_llm_config_id=vision_config_id,
|
||||
vision_provider_model=vision_provider_model,
|
||||
native_arm_model=native_arm_model,
|
||||
)
|
||||
set_suite_state(config, suite, suite_state)
|
||||
|
||||
summary_bits = [
|
||||
f"suite={suite!r}",
|
||||
f"scenario={scenario!r}",
|
||||
f"search_space_id={suite_state.search_space_id}",
|
||||
f"agent_llm_id={suite_state.agent_llm_id}",
|
||||
f"provider_model={suite_state.provider_model!r}",
|
||||
]
|
||||
if suite_state.vision_provider_model:
|
||||
summary_bits.append(f"vision_provider_model={suite_state.vision_provider_model!r}")
|
||||
if suite_state.native_arm_model:
|
||||
summary_bits.append(f"native_arm_model={suite_state.native_arm_model!r}")
|
||||
console.print(f"[green]setup OK[/green] {' '.join(summary_bits)}")
|
||||
return 0
|
||||
|
||||
|
||||
def _validate_pin(prefs: LlmPreferences, provider_model: str) -> bool:
|
||||
agent = prefs.agent_llm or {}
|
||||
return (
|
||||
str(agent.get("provider", "")).upper() == "OPENROUTER"
|
||||
and str(agent.get("model_name", "")) == provider_model
|
||||
)
|
||||
|
||||
|
||||
async def _cmd_teardown(args: argparse.Namespace) -> int:
|
||||
suite = args.suite
|
||||
config = load_config()
|
||||
state = get_suite_state(config, suite)
|
||||
if state is None:
|
||||
console.print(f"[yellow]No state for suite {suite!r}; nothing to tear down.[/yellow]")
|
||||
return 0
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
async with client_with_auth(config, token) as http:
|
||||
ss_client = SearchSpaceClient(http, config.surfsense_api_base)
|
||||
try:
|
||||
await ss_client.delete(state.search_space_id)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
console.print(
|
||||
f"[yellow]DELETE failed (HTTP {exc.response.status_code}); "
|
||||
"clearing state.json anyway.[/yellow]"
|
||||
)
|
||||
clear_suite_state(config, suite)
|
||||
console.print(
|
||||
f"[green]teardown OK[/green] suite={suite!r} "
|
||||
f"(SearchSpace soft-deleted, state.json slot cleared)."
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
async def _cmd_models_list(args: argparse.Namespace) -> int:
|
||||
config = load_config()
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
async with client_with_auth(config, token) as http:
|
||||
entries = await _list_global_llm_configs(http, config.surfsense_api_base)
|
||||
grep = (args.grep or "").lower()
|
||||
provider_filter = (args.provider or "").upper()
|
||||
rows: list[LlmConfigEntry] = []
|
||||
for e in entries:
|
||||
if provider_filter and e.provider != provider_filter:
|
||||
continue
|
||||
if grep and grep not in e.model_name.lower() and grep not in e.name.lower():
|
||||
continue
|
||||
rows.append(e)
|
||||
table = Table(
|
||||
title=f"Global LLM configs ({len(rows)} of {len(entries)})",
|
||||
show_lines=False,
|
||||
)
|
||||
table.add_column("id", justify="right", style="cyan")
|
||||
table.add_column("provider", style="magenta")
|
||||
table.add_column("model_name", style="green")
|
||||
table.add_column("name")
|
||||
for e in sorted(rows, key=lambda x: (x.provider, x.model_name)):
|
||||
table.add_row(str(e.id), e.provider, e.model_name, e.name)
|
||||
console.print(table)
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_suites_list(_args: argparse.Namespace) -> int:
|
||||
_discover_suites()
|
||||
suites = registry.list_suites()
|
||||
if not suites:
|
||||
console.print(
|
||||
"[yellow]No suites registered. Drop a benchmark under "
|
||||
"src/surfsense_evals/suites/<domain>/<benchmark>/.[/yellow]"
|
||||
)
|
||||
return 0
|
||||
table = Table(title=f"Registered suites ({len(suites)})")
|
||||
table.add_column("suite", style="bold")
|
||||
table.add_column("benchmarks", style="green")
|
||||
for suite in suites:
|
||||
names = [b.name for b in registry.list_benchmarks(suite)]
|
||||
table.add_row(suite, ", ".join(names) or "<none>")
|
||||
console.print(table)
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_benchmarks_list(args: argparse.Namespace) -> int:
|
||||
_discover_suites()
|
||||
benchmarks = registry.list_benchmarks(args.suite)
|
||||
if not benchmarks:
|
||||
console.print("[yellow]No benchmarks registered.[/yellow]")
|
||||
return 0
|
||||
table = Table(title=f"Benchmarks ({len(benchmarks)})")
|
||||
table.add_column("suite", style="bold")
|
||||
table.add_column("name", style="cyan")
|
||||
table.add_column("headline", justify="center")
|
||||
table.add_column("description")
|
||||
for b in benchmarks:
|
||||
table.add_row(
|
||||
b.suite,
|
||||
b.name,
|
||||
"yes" if b.headline else "no",
|
||||
getattr(b, "description", ""),
|
||||
)
|
||||
console.print(table)
|
||||
return 0
|
||||
|
||||
|
||||
async def _cmd_ingest(args: argparse.Namespace) -> int:
|
||||
benchmark = registry.get(args.suite, args.benchmark)
|
||||
config = load_config()
|
||||
state = get_suite_state(config, args.suite)
|
||||
if state is None:
|
||||
console.print(
|
||||
f"[red]No setup for suite {args.suite!r}. Run "
|
||||
f"`python -m surfsense_evals setup --suite {args.suite} "
|
||||
f"--provider-model <slug>` first.[/red]"
|
||||
)
|
||||
return 2
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
|
||||
# Forward parsed CLI flags into ingest() so a benchmark can honour
|
||||
# its own flags (e.g. MIRAGE's --skip-snippet-filter / --corpus).
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if k not in {"_func", "_async", "command", "subcommand", "suite", "benchmark", "log_level"}
|
||||
}
|
||||
async with client_with_auth(config, token) as http:
|
||||
ctx = registry.RunContext(
|
||||
suite=args.suite,
|
||||
benchmark=args.benchmark,
|
||||
config=config,
|
||||
suite_state=state,
|
||||
http=http,
|
||||
)
|
||||
await benchmark.ingest(ctx, **extra_kwargs)
|
||||
console.print(f"[green]ingest OK[/green] {args.suite}/{args.benchmark}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _cmd_run(args: argparse.Namespace) -> int:
|
||||
benchmark = registry.get(args.suite, args.benchmark)
|
||||
config = load_config()
|
||||
state = get_suite_state(config, args.suite)
|
||||
if state is None:
|
||||
console.print(
|
||||
f"[red]No setup for suite {args.suite!r}. Run "
|
||||
f"`python -m surfsense_evals setup --suite {args.suite} "
|
||||
f"--provider-model <slug>` first.[/red]"
|
||||
)
|
||||
return 2
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if k not in {"_func", "_async", "command", "subcommand", "suite", "benchmark", "log_level"}
|
||||
}
|
||||
async with client_with_auth(config, token) as http:
|
||||
ctx = registry.RunContext(
|
||||
suite=args.suite,
|
||||
benchmark=args.benchmark,
|
||||
config=config,
|
||||
suite_state=state,
|
||||
http=http,
|
||||
)
|
||||
artifact = await benchmark.run(ctx, **extra_kwargs)
|
||||
|
||||
console.print(
|
||||
f"[green]run OK[/green] {args.suite}/{args.benchmark} → "
|
||||
f"{artifact.raw_path}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
async def _cmd_report(args: argparse.Namespace) -> int:
|
||||
from .report import write_report
|
||||
|
||||
benchmark_filter = args.benchmark
|
||||
config = load_config()
|
||||
state = get_suite_state(config, args.suite)
|
||||
if state is None:
|
||||
console.print(f"[red]No setup for suite {args.suite!r}.[/red]")
|
||||
return 2
|
||||
benchmarks = registry.list_benchmarks(args.suite)
|
||||
if benchmark_filter:
|
||||
benchmarks = [b for b in benchmarks if b.name == benchmark_filter]
|
||||
if not benchmarks:
|
||||
console.print(
|
||||
f"[red]No registered benchmark named {benchmark_filter!r} in suite {args.suite!r}.[/red]"
|
||||
)
|
||||
return 2
|
||||
|
||||
artifacts = _collect_artifacts(config, args.suite, [b.name for b in benchmarks])
|
||||
if not artifacts:
|
||||
console.print(
|
||||
"[yellow]No run artifacts found under "
|
||||
f"{config.suite_runs_dir(args.suite)}. Run a benchmark first.[/yellow]"
|
||||
)
|
||||
return 1
|
||||
|
||||
grouped: dict[str, list[registry.RunArtifact]] = {}
|
||||
for art in artifacts:
|
||||
grouped.setdefault(art.benchmark, []).append(art)
|
||||
sections: list[registry.ReportSection] = []
|
||||
for benchmark in benchmarks:
|
||||
if benchmark.name not in grouped:
|
||||
continue
|
||||
sections.append(benchmark.report_section(grouped[benchmark.name]))
|
||||
|
||||
summary_path = write_report(
|
||||
config=config,
|
||||
suite=args.suite,
|
||||
sections=sections,
|
||||
run_timestamp=utc_iso_timestamp(),
|
||||
)
|
||||
console.print(f"[green]report OK[/green] → {summary_path}")
|
||||
return 0
|
||||
|
||||
|
||||
def _collect_artifacts(
|
||||
config: Config, suite: str, benchmark_names: list[str]
|
||||
) -> list[registry.RunArtifact]:
|
||||
"""Walk ``data/<suite>/runs/*/<benchmark>/`` for the latest artifacts.
|
||||
|
||||
Reads any ``run_artifact.json`` written by a benchmark runner. The
|
||||
runner is responsible for writing this manifest alongside its raw
|
||||
JSONL so the report writer doesn't have to know benchmark-specific
|
||||
metric shapes.
|
||||
"""
|
||||
|
||||
runs_dir = config.suite_runs_dir(suite)
|
||||
if not runs_dir.exists():
|
||||
return []
|
||||
artifacts: list[registry.RunArtifact] = []
|
||||
by_bench: dict[str, registry.RunArtifact] = {}
|
||||
for ts_dir in sorted(runs_dir.iterdir()):
|
||||
if not ts_dir.is_dir():
|
||||
continue
|
||||
for bench_name in benchmark_names:
|
||||
bench_dir = ts_dir / bench_name
|
||||
manifest = bench_dir / "run_artifact.json"
|
||||
if not manifest.exists():
|
||||
continue
|
||||
try:
|
||||
with manifest.open("r", encoding="utf-8") as fh:
|
||||
payload = json.load(fh)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
continue
|
||||
artifact = registry.RunArtifact(
|
||||
suite=suite,
|
||||
benchmark=bench_name,
|
||||
run_timestamp=ts_dir.name,
|
||||
raw_path=bench_dir / payload.get("raw_path", "raw.jsonl"),
|
||||
metrics=payload.get("metrics", {}),
|
||||
extra=payload.get("extra", {}),
|
||||
)
|
||||
# Latest run wins per benchmark.
|
||||
by_bench[bench_name] = artifact
|
||||
artifacts = list(by_bench.values())
|
||||
return artifacts
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Argparse wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="surfsense-evals",
|
||||
description="SurfSense evaluation harness — domain-agnostic core + pluggable suites.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
p_setup = sub.add_parser("setup", help="Create per-suite SearchSpace + pin LLM.")
|
||||
p_setup.add_argument("--suite", required=True)
|
||||
p_setup.add_argument(
|
||||
"--provider-model",
|
||||
required=True,
|
||||
help=(
|
||||
"OpenRouter slug for the SurfSense answer LLM (and the native arm "
|
||||
"too unless --native-arm-model is set), e.g. "
|
||||
"'anthropic/claude-sonnet-4.5'."
|
||||
),
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--agent-llm-id",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Optional override for BYOK NewLLMConfig rows.",
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--scenario",
|
||||
choices=SCENARIOS,
|
||||
default=DEFAULT_SCENARIO,
|
||||
help=(
|
||||
"head-to-head (default): both arms answer with --provider-model; "
|
||||
"symmetric-cheap: both arms use the same cheap text-only slug, "
|
||||
"SurfSense pre-extracted images at ingest with a vision LLM; "
|
||||
"cost-arbitrage: native arm uses --native-arm-model (vision), "
|
||||
"SurfSense uses --provider-model (cheap, text-only) over chunks "
|
||||
"the vision LLM already extracted at ingest."
|
||||
),
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--vision-llm",
|
||||
default=None,
|
||||
metavar="SLUG",
|
||||
help=(
|
||||
"OpenRouter slug for the vision LLM SurfSense uses at ingest "
|
||||
"when --use-vision-llm is on. If omitted in symmetric-cheap / "
|
||||
"cost-arbitrage, the strongest registered vision config is "
|
||||
"auto-picked (priority: claude-sonnet-4.5 > claude-opus-4.7 > "
|
||||
"gpt-5 > gemini-2.5-pro)."
|
||||
),
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--native-arm-model",
|
||||
default=None,
|
||||
metavar="SLUG",
|
||||
help=(
|
||||
"Required for --scenario cost-arbitrage. OpenRouter slug used "
|
||||
"by the native_pdf arm only; SurfSense answers with "
|
||||
"--provider-model. Ignored for head-to-head / symmetric-cheap."
|
||||
),
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--no-vision-llm-setup",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Skip attaching a vision LLM config to the SearchSpace even if "
|
||||
"the scenario would normally require one. Use when you want to "
|
||||
"keep whatever is already attached (e.g. a per-user config)."
|
||||
),
|
||||
)
|
||||
p_setup.set_defaults(_func=_cmd_setup, _async=True)
|
||||
|
||||
p_teardown = sub.add_parser("teardown", help="Soft-delete the suite SearchSpace + clear state slot.")
|
||||
p_teardown.add_argument("--suite", required=True)
|
||||
p_teardown.set_defaults(_func=_cmd_teardown, _async=True)
|
||||
|
||||
p_models = sub.add_parser("models", help="LLM-config discovery helpers.")
|
||||
models_sub = p_models.add_subparsers(dest="subcommand", required=True)
|
||||
p_models_list = models_sub.add_parser("list", help="List global LLM configs.")
|
||||
p_models_list.add_argument("--provider", default=None, help="Filter by provider, e.g. openrouter")
|
||||
p_models_list.add_argument("--grep", default=None, help="Substring filter on name / model_name.")
|
||||
p_models_list.set_defaults(_func=_cmd_models_list, _async=True)
|
||||
|
||||
p_suites = sub.add_parser("suites", help="List registered suites.")
|
||||
suites_sub = p_suites.add_subparsers(dest="subcommand", required=True)
|
||||
p_suites_list = suites_sub.add_parser("list", help="List suites.")
|
||||
p_suites_list.set_defaults(_func=_cmd_suites_list, _async=False)
|
||||
|
||||
p_benchmarks = sub.add_parser("benchmarks", help="List registered benchmarks.")
|
||||
bench_sub = p_benchmarks.add_subparsers(dest="subcommand", required=True)
|
||||
p_bench_list = bench_sub.add_parser("list", help="List benchmarks.")
|
||||
p_bench_list.add_argument("--suite", default=None)
|
||||
p_bench_list.set_defaults(_func=_cmd_benchmarks_list, _async=False)
|
||||
|
||||
# Dynamic ingest / run subcommands need the registry populated, so
|
||||
# discover up-front (cheap on import — modules just register).
|
||||
_discover_suites()
|
||||
|
||||
p_ingest = sub.add_parser("ingest", help="Ingest a benchmark's corpus.")
|
||||
ingest_sub = p_ingest.add_subparsers(dest="suite", required=True)
|
||||
for suite in registry.list_suites():
|
||||
suite_parser = ingest_sub.add_parser(suite, help=f"Ingest a {suite} benchmark.")
|
||||
suite_bench = suite_parser.add_subparsers(dest="benchmark", required=True)
|
||||
for benchmark in registry.list_benchmarks(suite):
|
||||
bp = suite_bench.add_parser(benchmark.name, help=getattr(benchmark, "description", benchmark.name))
|
||||
if hasattr(benchmark, "add_run_args"):
|
||||
benchmark.add_run_args(bp)
|
||||
bp.set_defaults(_func=_cmd_ingest, _async=True)
|
||||
|
||||
p_run = sub.add_parser("run", help="Run a benchmark.")
|
||||
run_sub = p_run.add_subparsers(dest="suite", required=True)
|
||||
for suite in registry.list_suites():
|
||||
suite_parser = run_sub.add_parser(suite, help=f"Run a {suite} benchmark.")
|
||||
suite_bench = suite_parser.add_subparsers(dest="benchmark", required=True)
|
||||
for benchmark in registry.list_benchmarks(suite):
|
||||
bp = suite_bench.add_parser(benchmark.name, help=getattr(benchmark, "description", benchmark.name))
|
||||
if hasattr(benchmark, "add_run_args"):
|
||||
benchmark.add_run_args(bp)
|
||||
bp.set_defaults(_func=_cmd_run, _async=True)
|
||||
|
||||
p_report = sub.add_parser("report", help="Aggregate latest run artifacts into a summary.")
|
||||
p_report.add_argument("--suite", required=True)
|
||||
p_report.add_argument("--benchmark", default=None, help="Optional: report only this benchmark.")
|
||||
p_report.set_defaults(_func=_cmd_report, _async=True)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
func = getattr(args, "_func", None)
|
||||
if func is None:
|
||||
parser.print_help()
|
||||
return 2
|
||||
is_async = getattr(args, "_async", False)
|
||||
try:
|
||||
if is_async:
|
||||
return asyncio.run(func(args))
|
||||
return func(args)
|
||||
except KeyboardInterrupt:
|
||||
console.print("[yellow]Interrupted.[/yellow]")
|
||||
return 130
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("CLI command failed")
|
||||
console.print(f"[red]Command failed: {exc}[/red]")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
sys.exit(main())
|
||||
14
surfsense_evals/src/surfsense_evals/core/clients/__init__.py
Normal file
14
surfsense_evals/src/surfsense_evals/core/clients/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""HTTP clients for the SurfSense API. All share one ``httpx.AsyncClient``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .documents import DocumentsClient
|
||||
from .new_chat import NewChatClient, StreamedAnswer
|
||||
from .search_space import SearchSpaceClient
|
||||
|
||||
__all__ = [
|
||||
"DocumentsClient",
|
||||
"NewChatClient",
|
||||
"SearchSpaceClient",
|
||||
"StreamedAnswer",
|
||||
]
|
||||
277
surfsense_evals/src/surfsense_evals/core/clients/documents.py
Normal file
277
surfsense_evals/src/surfsense_evals/core/clients/documents.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
"""Client for ``/api/v1/documents/{fileupload,status,{id}/chunks}``.
|
||||
|
||||
Verified against:
|
||||
|
||||
* ``surfsense_backend/app/routes/documents_routes.py:122-292`` (POST fileupload)
|
||||
* ``surfsense_backend/app/routes/documents_routes.py:806-871`` (GET status batch)
|
||||
* ``surfsense_backend/app/routes/documents_routes.py:1062-1128`` (GET {id}/chunks paginated)
|
||||
|
||||
Document processing is asynchronous:
|
||||
* ``POST /documents/fileupload`` returns immediately with
|
||||
``document_ids`` in ``pending``;
|
||||
* a Celery worker moves each through ``processing → ready/failed``;
|
||||
* the harness polls ``GET /documents/status?document_ids=...`` until
|
||||
every doc is ``ready`` (otherwise the retriever sees an empty corpus
|
||||
and accuracy numbers are meaningless).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileUploadResult:
|
||||
"""Mirrors the JSON returned by ``POST /documents/fileupload``."""
|
||||
|
||||
document_ids: list[int]
|
||||
duplicate_document_ids: list[int]
|
||||
total_files: int
|
||||
pending_files: int
|
||||
skipped_duplicates: int
|
||||
message: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> FileUploadResult:
|
||||
return cls(
|
||||
document_ids=[int(x) for x in payload.get("document_ids", [])],
|
||||
duplicate_document_ids=[int(x) for x in payload.get("duplicate_document_ids", [])],
|
||||
total_files=int(payload.get("total_files", 0)),
|
||||
pending_files=int(payload.get("pending_files", 0)),
|
||||
skipped_duplicates=int(payload.get("skipped_duplicates", 0)),
|
||||
message=str(payload.get("message", "")),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentStatus:
|
||||
document_id: int
|
||||
title: str
|
||||
document_type: str
|
||||
state: str
|
||||
reason: str | None = None
|
||||
|
||||
@property
|
||||
def is_ready(self) -> bool:
|
||||
return self.state == "ready"
|
||||
|
||||
@property
|
||||
def is_failed(self) -> bool:
|
||||
return self.state == "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkRow:
|
||||
id: int
|
||||
document_id: int
|
||||
content: str = ""
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DocumentProcessingFailed(RuntimeError):
|
||||
"""Raised when a polled document lands in ``failed``."""
|
||||
|
||||
def __init__(self, statuses: Sequence[DocumentStatus]) -> None:
|
||||
details = ", ".join(
|
||||
f"id={s.document_id} ({s.title!r}): {s.reason or 'unknown'}"
|
||||
for s in statuses
|
||||
)
|
||||
super().__init__(f"Document(s) failed to process: {details}")
|
||||
self.statuses = list(statuses)
|
||||
|
||||
|
||||
class DocumentProcessingTimeout(RuntimeError):
|
||||
"""Raised when polling exceeds the per-doc timeout budget."""
|
||||
|
||||
|
||||
class DocumentsClient:
|
||||
"""Document upload + status polling + chunk listing."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient, base_url: str) -> None:
|
||||
self._http = http
|
||||
self._base = base_url.rstrip("/")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# upload
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
files: Iterable[Path],
|
||||
*,
|
||||
search_space_id: int,
|
||||
should_summarize: bool = False,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
) -> FileUploadResult:
|
||||
"""Upload files to ``/api/v1/documents/fileupload``.
|
||||
|
||||
``files`` is materialised to a list because we may need to
|
||||
re-read on retry. Caller is responsible for ensuring each path
|
||||
exists and respects the per-file size cap (50 MB backend default).
|
||||
"""
|
||||
|
||||
materialised = [Path(p) for p in files]
|
||||
if not materialised:
|
||||
return FileUploadResult(
|
||||
document_ids=[],
|
||||
duplicate_document_ids=[],
|
||||
total_files=0,
|
||||
pending_files=0,
|
||||
skipped_duplicates=0,
|
||||
message="No files supplied",
|
||||
)
|
||||
|
||||
opened: list[tuple[str, Any]] = []
|
||||
try:
|
||||
for path in materialised:
|
||||
# ``open`` directly — httpx wraps it in MultipartStream.
|
||||
file_obj = path.open("rb")
|
||||
mime, _ = mimetypes.guess_type(path.name)
|
||||
opened.append(
|
||||
(
|
||||
"files",
|
||||
(path.name, file_obj, mime or "application/octet-stream"),
|
||||
)
|
||||
)
|
||||
|
||||
response = await self._http.post(
|
||||
f"{self._base}/api/v1/documents/fileupload",
|
||||
data={
|
||||
"search_space_id": str(search_space_id),
|
||||
"should_summarize": "true" if should_summarize else "false",
|
||||
"use_vision_llm": "true" if use_vision_llm else "false",
|
||||
"processing_mode": processing_mode,
|
||||
},
|
||||
files=opened,
|
||||
# Multipart uploads can be slow for big PDFs; bump per-call.
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
finally:
|
||||
for _, (_, file_obj, _) in opened:
|
||||
try:
|
||||
file_obj.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
response.raise_for_status()
|
||||
return FileUploadResult.from_payload(response.json())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# status polling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_status(
|
||||
self, *, search_space_id: int, document_ids: Sequence[int]
|
||||
) -> list[DocumentStatus]:
|
||||
if not document_ids:
|
||||
return []
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/documents/status",
|
||||
params={
|
||||
"search_space_id": search_space_id,
|
||||
"document_ids": ",".join(str(d) for d in document_ids),
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
return [
|
||||
DocumentStatus(
|
||||
document_id=int(item["id"]),
|
||||
title=str(item.get("title", "")),
|
||||
document_type=str(item.get("document_type", "")),
|
||||
state=str((item.get("status") or {}).get("state", "ready")),
|
||||
reason=(item.get("status") or {}).get("reason"),
|
||||
)
|
||||
for item in payload.get("items", [])
|
||||
]
|
||||
|
||||
async def wait_until_ready(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
document_ids: Sequence[int],
|
||||
timeout_s: float = 300.0,
|
||||
initial_poll_s: float = 1.0,
|
||||
max_poll_s: float = 10.0,
|
||||
) -> list[DocumentStatus]:
|
||||
"""Poll ``GET /documents/status`` until every doc is ``ready``.
|
||||
|
||||
Exponential backoff from ``initial_poll_s`` up to ``max_poll_s``.
|
||||
Raises ``DocumentProcessingFailed`` if any doc lands in
|
||||
``failed`` (with the offending document ids), or
|
||||
``DocumentProcessingTimeout`` if the budget is exhausted.
|
||||
"""
|
||||
|
||||
if not document_ids:
|
||||
return []
|
||||
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||
poll = initial_poll_s
|
||||
while True:
|
||||
statuses = await self.get_status(
|
||||
search_space_id=search_space_id, document_ids=document_ids
|
||||
)
|
||||
failed = [s for s in statuses if s.is_failed]
|
||||
if failed:
|
||||
raise DocumentProcessingFailed(failed)
|
||||
ready = [s for s in statuses if s.is_ready]
|
||||
if len(ready) == len(document_ids):
|
||||
return statuses
|
||||
now = asyncio.get_event_loop().time()
|
||||
if now >= deadline:
|
||||
pending = [s for s in statuses if not s.is_ready and not s.is_failed]
|
||||
pending_ids = [s.document_id for s in pending]
|
||||
raise DocumentProcessingTimeout(
|
||||
f"Timed out after {timeout_s:.0f}s waiting for documents "
|
||||
f"(still pending/processing: {pending_ids})"
|
||||
)
|
||||
await asyncio.sleep(min(poll, max(0.1, deadline - now)))
|
||||
poll = min(poll * 1.5, max_poll_s)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# chunks (chunk_id -> document_id map)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def list_chunks(
|
||||
self, document_id: int, *, page_size: int = 100
|
||||
) -> list[ChunkRow]:
|
||||
"""Walk ``GET /documents/{id}/chunks`` until ``has_more=False``.
|
||||
|
||||
Used by ingestion to materialise the ``chunk_id -> document_id``
|
||||
map needed for retrieval scoring (CUREv1).
|
||||
"""
|
||||
|
||||
rows: list[ChunkRow] = []
|
||||
page = 0
|
||||
while True:
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/documents/{document_id}/chunks",
|
||||
params={"page": page, "page_size": page_size},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
for item in payload.get("items", []):
|
||||
rows.append(
|
||||
ChunkRow(
|
||||
id=int(item["id"]),
|
||||
document_id=document_id,
|
||||
content=str(item.get("content", "")),
|
||||
raw=item,
|
||||
)
|
||||
)
|
||||
if not payload.get("has_more"):
|
||||
break
|
||||
page += 1
|
||||
return rows
|
||||
280
surfsense_evals/src/surfsense_evals/core/clients/new_chat.py
Normal file
280
surfsense_evals/src/surfsense_evals/core/clients/new_chat.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""Client for ``/api/v1/threads`` and ``/api/v1/new_chat`` (SSE).
|
||||
|
||||
Verified against:
|
||||
|
||||
* ``surfsense_backend/app/routes/new_chat_routes.py:793-848`` (POST /threads)
|
||||
* ``surfsense_backend/app/routes/new_chat_routes.py:1073-1142`` (DELETE /threads/{id})
|
||||
* ``surfsense_backend/app/routes/new_chat_routes.py:1689-1800`` (POST /new_chat SSE)
|
||||
* ``surfsense_backend/app/routes/new_chat_routes.py:191-220`` (THREAD_BUSY / TURN_CANCELLING 409)
|
||||
* ``surfsense_backend/app/services/streaming/envelope/sse.py`` (wire framing)
|
||||
* ``surfsense_backend/app/services/streaming/events/text.py`` (text-delta events)
|
||||
* ``surfsense_backend/app/schemas/new_chat.py:234-288`` (NewChatRequest body)
|
||||
|
||||
The wire format is "Vercel AI SDK"-flavoured SSE with one event per
|
||||
``data: <json>\n\n`` block (or the literal ``data: [DONE]\n\n``
|
||||
terminator). Text deltas arrive as ``{"type":"text-delta","id":...,"delta":...}``
|
||||
events; we accumulate them per ``id`` and emit the final concatenated
|
||||
text plus parsed citations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from ..parse import iter_sse_events, parse_citations
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamedAnswer:
|
||||
"""Result of a single ``/new_chat`` turn."""
|
||||
|
||||
text: str
|
||||
raw_events: list[dict[str, Any]] = field(default_factory=list)
|
||||
latency_ms: int = 0
|
||||
user_message_id: str | None = None
|
||||
assistant_message_id: str | None = None
|
||||
finished_normally: bool = False
|
||||
|
||||
@property
|
||||
def citations(self) -> list[dict[str, Any]]:
|
||||
"""Parsed citation tokens (lazy; small enough to recompute)."""
|
||||
|
||||
return [token.to_dict() for token in parse_citations(self.text)]
|
||||
|
||||
|
||||
class ThreadBusyError(RuntimeError):
|
||||
"""Raised after exhausting retries on a 409 ``THREAD_BUSY`` / ``TURN_CANCELLING``."""
|
||||
|
||||
def __init__(self, error_code: str, message: str) -> None:
|
||||
super().__init__(f"{error_code}: {message}")
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class NewChatClient:
|
||||
"""Thread create / delete / SSE ask."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient, base_url: str) -> None:
|
||||
self._http = http
|
||||
self._base = base_url.rstrip("/")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# threads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def create_thread(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
title: str = "eval",
|
||||
archived: bool = False,
|
||||
visibility: str = "PRIVATE",
|
||||
) -> int:
|
||||
response = await self._http.post(
|
||||
f"{self._base}/api/v1/threads",
|
||||
json={
|
||||
"search_space_id": search_space_id,
|
||||
"title": title,
|
||||
"archived": archived,
|
||||
"visibility": visibility,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
return int(payload["id"])
|
||||
|
||||
async def delete_thread(self, thread_id: int) -> None:
|
||||
response = await self._http.delete(
|
||||
f"{self._base}/api/v1/threads/{thread_id}",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return # idempotent
|
||||
response.raise_for_status()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /new_chat SSE
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def ask(
|
||||
self,
|
||||
*,
|
||||
thread_id: int,
|
||||
search_space_id: int,
|
||||
user_query: str,
|
||||
mentioned_document_ids: Sequence[int] | None = None,
|
||||
disabled_tools: Sequence[str] | None = None,
|
||||
max_busy_retries: int = 4,
|
||||
timeout_s: float = 600.0,
|
||||
) -> StreamedAnswer:
|
||||
"""Stream a single turn and return the accumulated answer.
|
||||
|
||||
Honours backend ``THREAD_BUSY`` / ``TURN_CANCELLING`` 409
|
||||
responses by sleeping for the ``Retry-After`` header (or the
|
||||
``retry-after-ms`` header if present) and replaying. Bounded
|
||||
by ``max_busy_retries`` so a stuck thread never blocks the
|
||||
whole run.
|
||||
"""
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"chat_id": thread_id,
|
||||
"search_space_id": search_space_id,
|
||||
"user_query": user_query,
|
||||
}
|
||||
if mentioned_document_ids:
|
||||
body["mentioned_document_ids"] = list(mentioned_document_ids)
|
||||
if disabled_tools:
|
||||
body["disabled_tools"] = list(disabled_tools)
|
||||
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return await self._stream_once(body=body, timeout_s=timeout_s)
|
||||
except ThreadBusyError as exc:
|
||||
attempt += 1
|
||||
if attempt > max_busy_retries:
|
||||
raise
|
||||
# Cap wait at 30s; backend retry hint is exponential anyway.
|
||||
wait = min(30.0, 0.5 * (2 ** attempt))
|
||||
logger.info(
|
||||
"thread_id=%s busy (%s); retry %d/%d after %.1fs",
|
||||
thread_id,
|
||||
exc.error_code,
|
||||
attempt,
|
||||
max_busy_retries,
|
||||
wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
async def _stream_once(
|
||||
self,
|
||||
*,
|
||||
body: dict[str, Any],
|
||||
timeout_s: float,
|
||||
) -> StreamedAnswer:
|
||||
# Per-call timeout — the connect should be quick, the read needs
|
||||
# to outlive the longest LLM completion.
|
||||
timeout = httpx.Timeout(timeout_s, connect=10.0)
|
||||
started = time.monotonic()
|
||||
async with self._http.stream(
|
||||
"POST",
|
||||
f"{self._base}/api/v1/new_chat",
|
||||
json=body,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status_code == 409:
|
||||
detail = await self._extract_busy_detail(response)
|
||||
raise ThreadBusyError(
|
||||
error_code=detail.get("errorCode", "THREAD_BUSY"),
|
||||
message=detail.get("message", "Thread is busy"),
|
||||
)
|
||||
response.raise_for_status()
|
||||
answer = await self._consume_sse(response)
|
||||
answer.latency_ms = int((time.monotonic() - started) * 1000)
|
||||
return answer
|
||||
|
||||
@staticmethod
|
||||
async def _extract_busy_detail(response: httpx.Response) -> dict[str, Any]:
|
||||
try:
|
||||
payload = json.loads(await response.aread())
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return {"errorCode": "THREAD_BUSY", "message": response.text}
|
||||
if isinstance(payload, dict) and isinstance(payload.get("detail"), dict):
|
||||
return payload["detail"]
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
async def _consume_sse(response: httpx.Response) -> StreamedAnswer:
|
||||
"""Walk SSE events, accumulate text-delta payloads.
|
||||
|
||||
Backend events of interest:
|
||||
|
||||
* ``{"type": "text-start", "id": ...}``
|
||||
* ``{"type": "text-delta", "id": ..., "delta": ...}``
|
||||
* ``{"type": "text-end", "id": ...}``
|
||||
* ``{"type": "start", "messageId": ...}`` (top-level message id)
|
||||
* ``{"type": "finish"}``
|
||||
* literal ``[DONE]`` sentinel
|
||||
|
||||
Multiple ``text-start`` blocks can interleave — each gets its
|
||||
own ``id`` and we concatenate them in arrival order. That
|
||||
mirrors the AI SDK client behaviour: one continuous assistant
|
||||
message visible to the user.
|
||||
"""
|
||||
|
||||
ordered_text_ids: list[str] = []
|
||||
text_buffers: dict[str, list[str]] = {}
|
||||
raw_events: list[dict[str, Any]] = []
|
||||
user_message_id: str | None = None
|
||||
assistant_message_id: str | None = None
|
||||
finished = False
|
||||
|
||||
async for event in iter_sse_events(_aiter_lines(response)):
|
||||
data = event.data
|
||||
if data == "[DONE]":
|
||||
finished = True
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(data)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("Skipping non-JSON SSE payload: %r", data[:120])
|
||||
continue
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
raw_events.append(payload)
|
||||
ev_type = payload.get("type")
|
||||
if ev_type == "text-delta":
|
||||
tid = str(payload.get("id", ""))
|
||||
delta = payload.get("delta", "")
|
||||
if not isinstance(delta, str):
|
||||
continue
|
||||
if tid not in text_buffers:
|
||||
text_buffers[tid] = []
|
||||
ordered_text_ids.append(tid)
|
||||
text_buffers[tid].append(delta)
|
||||
elif ev_type == "text-start":
|
||||
tid = str(payload.get("id", ""))
|
||||
if tid and tid not in text_buffers:
|
||||
text_buffers[tid] = []
|
||||
ordered_text_ids.append(tid)
|
||||
elif ev_type == "start":
|
||||
msg_id = payload.get("messageId")
|
||||
if isinstance(msg_id, str):
|
||||
user_message_id = user_message_id or msg_id
|
||||
elif ev_type == "data-user-message-id":
|
||||
msg_id = (payload.get("data") or {}).get("id") or payload.get("id")
|
||||
if isinstance(msg_id, str):
|
||||
user_message_id = msg_id
|
||||
elif ev_type == "data-assistant-message-id":
|
||||
msg_id = (payload.get("data") or {}).get("id") or payload.get("id")
|
||||
if isinstance(msg_id, str):
|
||||
assistant_message_id = msg_id
|
||||
elif ev_type == "finish":
|
||||
finished = True
|
||||
|
||||
text = "".join("".join(text_buffers.get(tid, [])) for tid in ordered_text_ids)
|
||||
return StreamedAnswer(
|
||||
text=text,
|
||||
raw_events=raw_events,
|
||||
user_message_id=user_message_id,
|
||||
assistant_message_id=assistant_message_id,
|
||||
finished_normally=finished,
|
||||
)
|
||||
|
||||
|
||||
async def _aiter_lines(response: httpx.Response) -> AsyncIterator[str]:
|
||||
"""Adapter so the parser can consume any line iterator (mockable in tests)."""
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
207
surfsense_evals/src/surfsense_evals/core/clients/search_space.py
Normal file
207
surfsense_evals/src/surfsense_evals/core/clients/search_space.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
"""Client for ``/api/v1/searchspaces`` and ``/api/v1/search-spaces/{id}/llm-preferences``.
|
||||
|
||||
Verified against:
|
||||
|
||||
* ``surfsense_backend/app/routes/search_spaces_routes.py:116`` (POST create)
|
||||
* ``surfsense_backend/app/routes/search_spaces_routes.py:234`` (GET by id)
|
||||
* ``surfsense_backend/app/routes/search_spaces_routes.py:422`` (DELETE soft-delete)
|
||||
* ``surfsense_backend/app/routes/search_spaces_routes.py:698-849`` (GET/PUT llm-preferences)
|
||||
* ``surfsense_backend/app/schemas/search_space.py:14`` (SearchSpaceCreate body)
|
||||
* ``surfsense_backend/app/routes/vision_llm_routes.py:60`` (GET global vision configs)
|
||||
|
||||
Note the inconsistent pluralisation in the backend: ``/searchspaces``
|
||||
(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for the
|
||||
``llm-preferences`` sub-resource. Both are mirrored verbatim here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchSpaceRow:
|
||||
"""Subset of the SearchSpace row we care about."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
description: str | None
|
||||
user_id: str
|
||||
citations_enabled: bool
|
||||
qna_custom_instructions: str | None
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> SearchSpaceRow:
|
||||
return cls(
|
||||
id=int(payload["id"]),
|
||||
name=str(payload["name"]),
|
||||
description=payload.get("description"),
|
||||
user_id=str(payload.get("user_id", "")),
|
||||
citations_enabled=bool(payload.get("citations_enabled", True)),
|
||||
qna_custom_instructions=payload.get("qna_custom_instructions"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionLlmConfigEntry:
|
||||
"""Subset of one ``GET /global-vision-llm-configs`` row.
|
||||
|
||||
The backend returns negative ids for global / OpenRouter-derived
|
||||
vision configs and positive ids for per-user BYOK rows. Either is
|
||||
accepted by ``set_llm_preferences(vision_llm_config_id=...)``.
|
||||
"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
model_name: str
|
||||
is_auto_mode: bool
|
||||
raw: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> VisionLlmConfigEntry:
|
||||
return cls(
|
||||
id=int(payload.get("id", 0)),
|
||||
name=str(payload.get("name", "")),
|
||||
provider=str(payload.get("provider", "")).upper(),
|
||||
model_name=str(payload.get("model_name", "")),
|
||||
is_auto_mode=bool(payload.get("is_auto_mode", False)),
|
||||
raw=payload,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlmPreferences:
|
||||
"""Resolved LLM preferences with the embedded full config row.
|
||||
|
||||
Mirrors ``LLMPreferencesRead`` from the backend so the lifecycle
|
||||
command can introspect ``provider`` / ``model_name`` to validate the
|
||||
OpenRouter pin.
|
||||
"""
|
||||
|
||||
agent_llm_id: int | None
|
||||
document_summary_llm_id: int | None
|
||||
image_generation_config_id: int | None
|
||||
vision_llm_config_id: int | None
|
||||
agent_llm: dict[str, Any] | None
|
||||
raw: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> LlmPreferences:
|
||||
return cls(
|
||||
agent_llm_id=payload.get("agent_llm_id"),
|
||||
document_summary_llm_id=payload.get("document_summary_llm_id"),
|
||||
image_generation_config_id=payload.get("image_generation_config_id"),
|
||||
vision_llm_config_id=payload.get("vision_llm_config_id"),
|
||||
agent_llm=payload.get("agent_llm"),
|
||||
raw=payload,
|
||||
)
|
||||
|
||||
|
||||
class SearchSpaceClient:
|
||||
"""Thin wrapper around the SearchSpace + LLM preferences endpoints."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient, base_url: str) -> None:
|
||||
self._http = http
|
||||
self._base = base_url.rstrip("/")
|
||||
|
||||
async def create(self, name: str, *, description: str | None = None) -> SearchSpaceRow:
|
||||
body: dict[str, Any] = {"name": name}
|
||||
if description is not None:
|
||||
body["description"] = description
|
||||
# citations_enabled defaults to True backend-side; keep that default.
|
||||
response = await self._http.post(
|
||||
f"{self._base}/api/v1/searchspaces",
|
||||
json=body,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return SearchSpaceRow.from_payload(response.json())
|
||||
|
||||
async def get(self, search_space_id: int) -> SearchSpaceRow:
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/searchspaces/{search_space_id}",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return SearchSpaceRow.from_payload(response.json())
|
||||
|
||||
async def delete(self, search_space_id: int) -> None:
|
||||
"""Soft-delete: backend prefixes name with ``[DELETING]`` and dispatches a Celery cascade."""
|
||||
|
||||
response = await self._http.delete(
|
||||
f"{self._base}/api/v1/searchspaces/{search_space_id}",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
# 404 means it's already gone — treat as success (idempotent teardown).
|
||||
if response.status_code == 404:
|
||||
return
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_llm_preferences(self, search_space_id: int) -> LlmPreferences:
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return LlmPreferences.from_payload(response.json())
|
||||
|
||||
async def set_llm_preferences(
|
||||
self,
|
||||
search_space_id: int,
|
||||
*,
|
||||
agent_llm_id: int | None = None,
|
||||
document_summary_llm_id: int | None = None,
|
||||
image_generation_config_id: int | None = None,
|
||||
vision_llm_config_id: int | None = None,
|
||||
) -> LlmPreferences:
|
||||
"""PUT a partial update to ``/search-spaces/{id}/llm-preferences``.
|
||||
|
||||
Backend uses ``model_dump(exclude_unset=True)`` so omitted fields
|
||||
are left unchanged.
|
||||
"""
|
||||
|
||||
body: dict[str, Any] = {}
|
||||
if agent_llm_id is not None:
|
||||
body["agent_llm_id"] = agent_llm_id
|
||||
if document_summary_llm_id is not None:
|
||||
body["document_summary_llm_id"] = document_summary_llm_id
|
||||
if image_generation_config_id is not None:
|
||||
body["image_generation_config_id"] = image_generation_config_id
|
||||
if vision_llm_config_id is not None:
|
||||
body["vision_llm_config_id"] = vision_llm_config_id
|
||||
response = await self._http.put(
|
||||
f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences",
|
||||
json=body,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return LlmPreferences.from_payload(response.json())
|
||||
|
||||
async def list_global_vision_llm_configs(self) -> list[VisionLlmConfigEntry]:
|
||||
"""List the registered global vision LLM configs.
|
||||
|
||||
Used by ``setup`` to (a) resolve an explicit ``--vision-llm <slug>``
|
||||
to a config id and (b) auto-pick the strongest registered vision
|
||||
config when the operator doesn't pass one. The ``Auto (Fastest)``
|
||||
entry (``id=0``) is filtered out — accuracy must be reproducible.
|
||||
"""
|
||||
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/global-vision-llm-configs",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if not isinstance(payload, list):
|
||||
raise RuntimeError(
|
||||
f"Unexpected /global-vision-llm-configs payload: {payload!r}"
|
||||
)
|
||||
return [
|
||||
VisionLlmConfigEntry.from_payload(item)
|
||||
for item in payload
|
||||
if not bool(item.get("is_auto_mode", False))
|
||||
]
|
||||
279
surfsense_evals/src/surfsense_evals/core/config.py
Normal file
279
surfsense_evals/src/surfsense_evals/core/config.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""Environment + filesystem configuration for the harness.
|
||||
|
||||
Two responsibilities:
|
||||
|
||||
1. Load env vars (with sensible defaults) into a single immutable ``Config``
|
||||
so that every other module reads it from one place.
|
||||
2. Read / write ``data/state.json``. State is keyed by suite name so multiple
|
||||
suites can be set up in parallel and torn down independently.
|
||||
|
||||
The pinned ``search_space_id`` lives in ``state.json`` (not env) so re-runs
|
||||
are idempotent without forcing the operator to remember an integer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Resolve once at import time. ``find_dotenv`` walks up; an explicit ``.env``
|
||||
# at the package root or in CWD wins. Silent-no-op if neither exists.
|
||||
load_dotenv()
|
||||
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
"""Resolves to ``surfsense_evals/`` (the package root, not ``src/``)."""
|
||||
|
||||
|
||||
def _project_root() -> Path:
|
||||
"""Return the ``surfsense_evals/`` project root.
|
||||
|
||||
Computed from this file's path: ``src/surfsense_evals/core/config.py`` →
|
||||
walk up four levels. Kept as a function so tests can monkeypatch.
|
||||
"""
|
||||
|
||||
return _PROJECT_ROOT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config:
|
||||
"""Immutable runtime configuration."""
|
||||
|
||||
surfsense_api_base: str
|
||||
openrouter_api_key: str | None
|
||||
openrouter_base_url: str
|
||||
|
||||
# Credentials — exactly ONE mode must be supplied.
|
||||
surfsense_jwt: str | None
|
||||
surfsense_refresh_token: str | None
|
||||
surfsense_user_email: str | None
|
||||
surfsense_user_password: str | None
|
||||
|
||||
# Filesystem paths.
|
||||
data_dir: Path
|
||||
reports_dir: Path
|
||||
|
||||
@property
|
||||
def state_path(self) -> Path:
|
||||
return self.data_dir / "state.json"
|
||||
|
||||
def has_jwt_mode(self) -> bool:
|
||||
return bool(self.surfsense_jwt)
|
||||
|
||||
def has_local_mode(self) -> bool:
|
||||
return bool(self.surfsense_user_email and self.surfsense_user_password)
|
||||
|
||||
def credential_mode(self) -> str:
|
||||
"""Return ``"jwt"``, ``"local"``, or ``"none"`` (no credentials supplied)."""
|
||||
|
||||
if self.has_jwt_mode():
|
||||
return "jwt"
|
||||
if self.has_local_mode():
|
||||
return "local"
|
||||
return "none"
|
||||
|
||||
def suite_data_dir(self, suite: str) -> Path:
|
||||
return self.data_dir / suite
|
||||
|
||||
def suite_reports_dir(self, suite: str) -> Path:
|
||||
return self.reports_dir / suite
|
||||
|
||||
def suite_runs_dir(self, suite: str) -> Path:
|
||||
return self.suite_data_dir(suite) / "runs"
|
||||
|
||||
def suite_maps_dir(self, suite: str) -> Path:
|
||||
return self.suite_data_dir(suite) / "maps"
|
||||
|
||||
|
||||
def load_config() -> Config:
|
||||
"""Read the current process env into a ``Config``.
|
||||
|
||||
No validation is performed here; callers (e.g. ``auth.acquire_token``,
|
||||
``cli`` subcommands) decide which fields they require. This keeps
|
||||
``models list`` and ``suites list`` runnable without OpenRouter creds.
|
||||
"""
|
||||
|
||||
project_root = _project_root()
|
||||
data_dir = Path(os.environ.get("EVAL_DATA_DIR") or (project_root / "data")).resolve()
|
||||
reports_dir = Path(os.environ.get("EVAL_REPORTS_DIR") or (project_root / "reports")).resolve()
|
||||
return Config(
|
||||
surfsense_api_base=os.environ.get("SURFSENSE_API_BASE", "http://localhost:8000").rstrip("/"),
|
||||
openrouter_api_key=os.environ.get("OPENROUTER_API_KEY") or None,
|
||||
openrouter_base_url=os.environ.get(
|
||||
"OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"
|
||||
).rstrip("/"),
|
||||
surfsense_jwt=os.environ.get("SURFSENSE_JWT") or None,
|
||||
surfsense_refresh_token=os.environ.get("SURFSENSE_REFRESH_TOKEN") or None,
|
||||
surfsense_user_email=os.environ.get("SURFSENSE_USER_EMAIL") or None,
|
||||
surfsense_user_password=os.environ.get("SURFSENSE_USER_PASSWORD") or None,
|
||||
data_dir=data_dir,
|
||||
reports_dir=reports_dir,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# state.json — per-suite slots
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Scenario names — chosen at ``setup`` time, persisted in ``state.json``.
|
||||
#
|
||||
# * ``head-to-head`` (default, current behaviour): both arms answer with the
|
||||
# SAME slug pinned via ``--provider-model``. Vision LLM at ingest is
|
||||
# optional but recommended for image-bearing benchmarks.
|
||||
# * ``symmetric-cheap``: both arms answer with the SAME (cheap, text-only)
|
||||
# slug; SurfSense pre-extracted images at ingest with a vision LLM.
|
||||
# Measures whether vision-RAG ingestion lets a cheap downstream model
|
||||
# match a vision one. Native arm structurally loses on image questions —
|
||||
# that's the point, and the report labels it accordingly.
|
||||
# * ``cost-arbitrage``: native arm answers with an EXPENSIVE vision slug
|
||||
# (``--native-arm-model``), SurfSense answers with a CHEAP text-only slug
|
||||
# (``--provider-model``) over chunks the vision LLM already extracted at
|
||||
# ingest. Measures how close SurfSense gets to native at a fraction of
|
||||
# the per-query cost. The most compelling "shines" framing.
|
||||
SCENARIOS: tuple[str, ...] = ("head-to-head", "symmetric-cheap", "cost-arbitrage")
|
||||
DEFAULT_SCENARIO: str = "head-to-head"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuiteState:
|
||||
"""Per-suite persisted state.
|
||||
|
||||
``provider_model`` is the slug pinned to the SearchSpace's
|
||||
``agent_llm`` — what answers SurfSense queries (and what the native
|
||||
arm uses too, unless ``native_arm_model`` is set for cost-arbitrage).
|
||||
|
||||
``vision_provider_model`` is the slug of the OpenRouter vision LLM
|
||||
config attached to the SearchSpace's ``vision_llm_config_id`` — what
|
||||
SurfSense uses to extract image content at ingest time when
|
||||
``use_vision_llm=True``. ``None`` means no vision config was attached
|
||||
at setup (legacy or text-only suite).
|
||||
"""
|
||||
|
||||
search_space_id: int
|
||||
agent_llm_id: int
|
||||
provider_model: str
|
||||
created_at: str
|
||||
ingestion_maps: dict[str, str] = field(default_factory=dict)
|
||||
scenario: str = DEFAULT_SCENARIO
|
||||
vision_llm_config_id: int | None = None
|
||||
vision_provider_model: str | None = None
|
||||
native_arm_model: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"search_space_id": self.search_space_id,
|
||||
"agent_llm_id": self.agent_llm_id,
|
||||
"provider_model": self.provider_model,
|
||||
"created_at": self.created_at,
|
||||
"ingestion_maps": dict(self.ingestion_maps),
|
||||
"scenario": self.scenario,
|
||||
"vision_llm_config_id": self.vision_llm_config_id,
|
||||
"vision_provider_model": self.vision_provider_model,
|
||||
"native_arm_model": self.native_arm_model,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: Mapping[str, Any]) -> SuiteState:
|
||||
# ``scenario`` / vision / native fields default for back-compat with
|
||||
# ``state.json`` written before scenarios shipped.
|
||||
scenario = str(payload.get("scenario") or DEFAULT_SCENARIO)
|
||||
if scenario not in SCENARIOS:
|
||||
scenario = DEFAULT_SCENARIO
|
||||
raw_vision_id = payload.get("vision_llm_config_id")
|
||||
return cls(
|
||||
search_space_id=int(payload["search_space_id"]),
|
||||
agent_llm_id=int(payload["agent_llm_id"]),
|
||||
provider_model=str(payload["provider_model"]),
|
||||
created_at=str(payload.get("created_at") or ""),
|
||||
ingestion_maps=dict(payload.get("ingestion_maps") or {}),
|
||||
scenario=scenario,
|
||||
vision_llm_config_id=int(raw_vision_id) if raw_vision_id is not None else None,
|
||||
vision_provider_model=(
|
||||
str(payload["vision_provider_model"])
|
||||
if payload.get("vision_provider_model")
|
||||
else None
|
||||
),
|
||||
native_arm_model=(
|
||||
str(payload["native_arm_model"])
|
||||
if payload.get("native_arm_model")
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def effective_native_arm_model(self) -> str:
|
||||
"""Slug the native arm should use; falls back to ``provider_model``."""
|
||||
|
||||
return self.native_arm_model or self.provider_model
|
||||
|
||||
|
||||
def _load_state(config: Config) -> dict[str, Any]:
|
||||
if not config.state_path.exists():
|
||||
return {"suites": {}}
|
||||
try:
|
||||
with config.state_path.open("r", encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to read state file {config.state_path}: {exc!s}. "
|
||||
"Delete it if you want to start fresh."
|
||||
) from exc
|
||||
if not isinstance(data, dict) or "suites" not in data:
|
||||
return {"suites": {}}
|
||||
return data
|
||||
|
||||
|
||||
def _write_state(config: Config, payload: Mapping[str, Any]) -> None:
|
||||
config.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
tmp = config.state_path.with_suffix(".json.tmp")
|
||||
with tmp.open("w", encoding="utf-8") as fh:
|
||||
json.dump(dict(payload), fh, indent=2, sort_keys=True)
|
||||
fh.write("\n")
|
||||
tmp.replace(config.state_path)
|
||||
|
||||
|
||||
def get_suite_state(config: Config, suite: str) -> SuiteState | None:
|
||||
"""Return ``SuiteState`` for ``suite`` or ``None`` if not set up."""
|
||||
|
||||
state = _load_state(config)
|
||||
raw = (state.get("suites") or {}).get(suite)
|
||||
if not raw:
|
||||
return None
|
||||
return SuiteState.from_dict(raw)
|
||||
|
||||
|
||||
def set_suite_state(config: Config, suite: str, suite_state: SuiteState) -> None:
|
||||
"""Persist ``suite_state`` under the suite slot. Other suites are untouched."""
|
||||
|
||||
state = _load_state(config)
|
||||
suites = dict(state.get("suites") or {})
|
||||
suites[suite] = suite_state.to_dict()
|
||||
state["suites"] = suites
|
||||
_write_state(config, state)
|
||||
|
||||
|
||||
def clear_suite_state(config: Config, suite: str) -> bool:
|
||||
"""Remove the slot for ``suite``. Returns ``True`` if removal happened."""
|
||||
|
||||
state = _load_state(config)
|
||||
suites = dict(state.get("suites") or {})
|
||||
if suite not in suites:
|
||||
return False
|
||||
del suites[suite]
|
||||
state["suites"] = suites
|
||||
_write_state(config, state)
|
||||
return True
|
||||
|
||||
|
||||
def utc_iso_timestamp() -> str:
|
||||
"""Filesystem-safe UTC ISO timestamp, e.g. ``2026-05-11T20-30-00Z``."""
|
||||
|
||||
return datetime.now(UTC).strftime("%Y-%m-%dT%H-%M-%SZ")
|
||||
311
surfsense_evals/src/surfsense_evals/core/ingest_settings.py
Normal file
311
surfsense_evals/src/surfsense_evals/core/ingest_settings.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
"""Per-upload ingestion settings shared across every benchmark.
|
||||
|
||||
The SurfSense ``POST /api/v1/documents/fileupload`` endpoint exposes
|
||||
exactly three knobs (verified at
|
||||
``surfsense_backend/app/routes/documents_routes.py`` and
|
||||
``surfsense_backend/app/etl_pipeline/etl_document.py``):
|
||||
|
||||
* ``processing_mode`` — ``"basic"`` (default) | ``"premium"``
|
||||
* ``use_vision_llm`` — ``bool`` (run vision LLM during ingest to
|
||||
extract image content / captions / tables)
|
||||
* ``should_summarize`` — ``bool`` (generate document summary)
|
||||
|
||||
This module gives every benchmark a uniform way to:
|
||||
|
||||
1. Receive sensible per-benchmark defaults (text-only benchmarks
|
||||
default vision off; image-bearing benchmarks default vision on).
|
||||
2. Accept CLI overrides (``--use-vision-llm`` / ``--no-vision-llm``,
|
||||
``--processing-mode {basic,premium}``,
|
||||
``--should-summarize`` / ``--no-summarize``).
|
||||
3. Persist the *actual* settings used into the doc-map manifest and
|
||||
the run artifact so reports can show "vision=ON, mode=premium →
|
||||
65% accuracy" head-to-head with "vision=OFF, mode=basic → 52%".
|
||||
|
||||
A/B testing on the same corpus
|
||||
------------------------------
|
||||
|
||||
SurfSense dedupes uploads by ``(filename, search_space_id)`` — NOT by
|
||||
content hash and NOT by ingestion settings. Re-uploading the same
|
||||
filename to the same SearchSpace with a different ``use_vision_llm``
|
||||
flag will hit the duplicate branch and *not* re-process. To compare
|
||||
two settings combos head-to-head on the same corpus you must give
|
||||
each combo its own SearchSpace, which today means:
|
||||
|
||||
teardown --suite <s>
|
||||
setup --suite <s> ...
|
||||
ingest <s> <bench> --no-vision-llm # baseline run
|
||||
run <s> <bench>
|
||||
teardown --suite <s>
|
||||
setup --suite <s> ...
|
||||
ingest <s> <bench> --use-vision-llm # vision arm
|
||||
run <s> <bench>
|
||||
|
||||
The runs land in different timestamped subdirectories under
|
||||
``data/<suite>/runs/`` and ``report --suite <s>`` aggregates whichever
|
||||
manifest is currently latest per benchmark.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Keep the constant list of valid processing modes here so benchmarks
|
||||
# don't have to re-import from the backend (they don't have access to
|
||||
# the backend package anyway).
|
||||
PROCESSING_MODES: tuple[str, ...] = ("basic", "premium")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IngestSettings:
|
||||
"""Resolved per-upload knobs handed to ``DocumentsClient.upload``.
|
||||
|
||||
Use ``IngestSettings(...)`` directly to define benchmark defaults,
|
||||
or ``IngestSettings.merge(defaults, opts)`` to apply CLI overrides
|
||||
on top of those defaults.
|
||||
"""
|
||||
|
||||
use_vision_llm: bool = False
|
||||
processing_mode: str = "basic"
|
||||
should_summarize: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_vision_llm": self.use_vision_llm,
|
||||
"processing_mode": self.processing_mode,
|
||||
"should_summarize": self.should_summarize,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def merge(cls, defaults: IngestSettings, opts: Mapping[str, Any]) -> IngestSettings:
|
||||
"""Apply CLI overrides on top of ``defaults``.
|
||||
|
||||
``opts`` is the kwargs dict built by ``core.cli`` from the
|
||||
argparse namespace (see ``_cmd_ingest`` / ``_cmd_run``). Keys
|
||||
we look for: ``use_vision_llm`` (bool or None), ``processing_mode``
|
||||
(str or None), ``should_summarize`` (bool or None). Anything
|
||||
else is ignored so benchmarks can pass through their own opts.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
use_vision_llm=_coerce_bool(opts.get("use_vision_llm"), defaults.use_vision_llm),
|
||||
processing_mode=_coerce_mode(opts.get("processing_mode"), defaults.processing_mode),
|
||||
should_summarize=_coerce_bool(opts.get("should_summarize"), defaults.should_summarize),
|
||||
)
|
||||
|
||||
def render_label(self) -> str:
|
||||
"""Human-readable single-line label for reports / log lines."""
|
||||
|
||||
return (
|
||||
f"vision={'on' if self.use_vision_llm else 'off'}, "
|
||||
f"mode={self.processing_mode}, "
|
||||
f"summarize={'on' if self.should_summarize else 'off'}"
|
||||
)
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool) -> bool:
|
||||
"""Argparse with ``BooleanOptionalAction`` yields True/False/None.
|
||||
|
||||
``None`` means the operator didn't pass the flag → fall back to
|
||||
the benchmark default.
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _coerce_mode(value: Any, default: str) -> str:
|
||||
if value is None or value == "":
|
||||
return default
|
||||
val = str(value).strip().lower()
|
||||
if val not in PROCESSING_MODES:
|
||||
raise ValueError(
|
||||
f"Invalid processing_mode {val!r}; must be one of {PROCESSING_MODES}"
|
||||
)
|
||||
return val
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Argparse helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _add_bool_pair(
|
||||
parser: argparse.ArgumentParser,
|
||||
*,
|
||||
dest: str,
|
||||
on_flag: str,
|
||||
off_flag: str,
|
||||
on_help: str,
|
||||
off_help: str,
|
||||
) -> None:
|
||||
"""Add a mutually exclusive ``--foo`` / ``--no-foo`` pair.
|
||||
|
||||
We don't use ``argparse.BooleanOptionalAction`` because it would
|
||||
auto-generate ``--no-use-vision-llm`` rather than the friendlier
|
||||
``--no-vision-llm`` that operators reach for. Default is ``None``
|
||||
so ``IngestSettings.merge`` can distinguish "silent" from
|
||||
"explicit false".
|
||||
"""
|
||||
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
on_flag,
|
||||
dest=dest,
|
||||
action="store_true",
|
||||
default=None,
|
||||
help=on_help,
|
||||
)
|
||||
group.add_argument(
|
||||
off_flag,
|
||||
dest=dest,
|
||||
action="store_false",
|
||||
default=None,
|
||||
help=off_help,
|
||||
)
|
||||
|
||||
|
||||
def add_ingest_settings_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
*,
|
||||
defaults: IngestSettings,
|
||||
) -> None:
|
||||
"""Attach the three ingest-settings flag pairs to ``parser``.
|
||||
|
||||
Each bool exposes a mutually exclusive ``--foo`` / ``--no-foo``
|
||||
pair so an operator can flip either direction without restating
|
||||
every flag. Default is ``None`` so that "operator didn't pass the
|
||||
flag" is distinguishable from "operator explicitly passed false"
|
||||
— ``IngestSettings.merge`` then folds in the benchmark default
|
||||
only when the operator was silent.
|
||||
"""
|
||||
|
||||
settings_group = parser.add_argument_group(
|
||||
"ingest settings",
|
||||
f"Per-upload knobs (forwarded to /documents/fileupload). "
|
||||
f"Defaults for this benchmark: {defaults.render_label()}.",
|
||||
)
|
||||
_add_bool_pair(
|
||||
settings_group,
|
||||
dest="use_vision_llm",
|
||||
on_flag="--use-vision-llm",
|
||||
off_flag="--no-vision-llm",
|
||||
on_help=(
|
||||
"Run vision LLM during ingest to extract image content "
|
||||
f"(default for this benchmark: "
|
||||
f"{'on' if defaults.use_vision_llm else 'off'})."
|
||||
),
|
||||
off_help="Skip vision LLM during ingest (text-only ETL).",
|
||||
)
|
||||
settings_group.add_argument(
|
||||
"--processing-mode",
|
||||
dest="processing_mode",
|
||||
choices=PROCESSING_MODES,
|
||||
default=None,
|
||||
help=(
|
||||
"SurfSense ETL processing mode (premium uses a 10x page "
|
||||
f"multiplier and typically routes to a stronger ETL). "
|
||||
f"Default for this benchmark: {defaults.processing_mode!r}."
|
||||
),
|
||||
)
|
||||
_add_bool_pair(
|
||||
settings_group,
|
||||
dest="should_summarize",
|
||||
on_flag="--should-summarize",
|
||||
off_flag="--no-summarize",
|
||||
on_help=(
|
||||
"Have SurfSense generate a document summary at ingest "
|
||||
f"(default for this benchmark: "
|
||||
f"{'on' if defaults.should_summarize else 'off'})."
|
||||
),
|
||||
off_help="Skip per-document summary generation.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Doc-map manifest helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Every benchmark writes a doc-map JSONL under ``data/<suite>/maps/`` that
|
||||
# pairs source identifiers (case_id, snippet_id, doc_path, …) to the
|
||||
# SurfSense document_ids returned by the upload. To make the report
|
||||
# self-describing we also write a header line:
|
||||
#
|
||||
# {"__settings__": {"use_vision_llm": ..., "processing_mode": ..., ...}}
|
||||
#
|
||||
# These two helpers centralise that protocol so each benchmark only has to
|
||||
# call ``write_settings_header`` and ``read_settings_header``.
|
||||
|
||||
SETTINGS_HEADER_KEY = "__settings__"
|
||||
|
||||
|
||||
def settings_header_line(settings: IngestSettings) -> str:
|
||||
"""Return the JSON-serialised header line (no trailing newline)."""
|
||||
|
||||
return json.dumps({SETTINGS_HEADER_KEY: settings.to_dict()})
|
||||
|
||||
|
||||
def is_settings_header(row: Mapping[str, Any]) -> bool:
|
||||
return SETTINGS_HEADER_KEY in row
|
||||
|
||||
|
||||
def read_settings_header(map_path: Path) -> dict[str, Any]:
|
||||
"""Read the ``__settings__`` header out of a doc-map JSONL.
|
||||
|
||||
Returns ``{}`` on a missing file, an empty file, an unreadable
|
||||
file, or a file whose first non-blank line is not a settings
|
||||
header (e.g. a corpus ingested before this feature existed).
|
||||
Callers use this purely to surface settings in the report; it
|
||||
must never fail the run.
|
||||
"""
|
||||
|
||||
if not map_path.exists():
|
||||
return {}
|
||||
try:
|
||||
with map_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if isinstance(row, dict) and SETTINGS_HEADER_KEY in row:
|
||||
return dict(row[SETTINGS_HEADER_KEY])
|
||||
return {}
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
def format_ingest_settings_md(settings: Any) -> str:
|
||||
"""Render the resolved settings as a single Markdown bullet line."""
|
||||
|
||||
if not isinstance(settings, Mapping) or not settings:
|
||||
return "- SurfSense ingest settings: (not recorded — re-ingest to capture)"
|
||||
vision = "on" if settings.get("use_vision_llm") else "off"
|
||||
mode = settings.get("processing_mode") or "basic"
|
||||
summarize = "on" if settings.get("should_summarize") else "off"
|
||||
return (
|
||||
f"- SurfSense ingest settings: vision_llm=`{vision}`, "
|
||||
f"processing_mode=`{mode}`, summarize=`{summarize}`"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROCESSING_MODES",
|
||||
"SETTINGS_HEADER_KEY",
|
||||
"IngestSettings",
|
||||
"add_ingest_settings_args",
|
||||
"format_ingest_settings_md",
|
||||
"is_settings_header",
|
||||
"read_settings_header",
|
||||
"settings_header_line",
|
||||
]
|
||||
50
surfsense_evals/src/surfsense_evals/core/metrics/__init__.py
Normal file
50
surfsense_evals/src/surfsense_evals/core/metrics/__init__.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""Pure-function metric primitives. Lazy imports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .comparison import McnemarResult, bootstrap_delta_ci, mcnemar_test, paired_aggregate
|
||||
from .mc_accuracy import AccuracyResult, accuracy_with_wilson_ci, wilson_ci
|
||||
from .retrieval import RetrievalScores, mrr, ndcg_at_k, recall_at_k, score_run
|
||||
|
||||
__all__ = [
|
||||
"AccuracyResult",
|
||||
"McnemarResult",
|
||||
"RetrievalScores",
|
||||
"accuracy_with_wilson_ci",
|
||||
"bootstrap_delta_ci",
|
||||
"mcnemar_test",
|
||||
"mrr",
|
||||
"ndcg_at_k",
|
||||
"paired_aggregate",
|
||||
"recall_at_k",
|
||||
"score_run",
|
||||
"wilson_ci",
|
||||
]
|
||||
|
||||
|
||||
_MODULE_FOR = {
|
||||
"AccuracyResult": "mc_accuracy",
|
||||
"accuracy_with_wilson_ci": "mc_accuracy",
|
||||
"wilson_ci": "mc_accuracy",
|
||||
"RetrievalScores": "retrieval",
|
||||
"mrr": "retrieval",
|
||||
"ndcg_at_k": "retrieval",
|
||||
"recall_at_k": "retrieval",
|
||||
"score_run": "retrieval",
|
||||
"McnemarResult": "comparison",
|
||||
"bootstrap_delta_ci": "comparison",
|
||||
"mcnemar_test": "comparison",
|
||||
"paired_aggregate": "comparison",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in _MODULE_FOR:
|
||||
from importlib import import_module
|
||||
|
||||
mod = import_module(f".{_MODULE_FOR[name]}", __name__)
|
||||
return getattr(mod, name)
|
||||
raise AttributeError(f"module 'surfsense_evals.core.metrics' has no attribute {name!r}")
|
||||
258
surfsense_evals/src/surfsense_evals/core/metrics/comparison.py
Normal file
258
surfsense_evals/src/surfsense_evals/core/metrics/comparison.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
"""Paired comparison statistics for head-to-head benchmarks.
|
||||
|
||||
In every head-to-head benchmark (currently MedXpertQA-MM and
|
||||
MMLongBench-Doc) each question is answered by both arms (Native PDF
|
||||
and SurfSense). That makes per-question outcomes paired, so
|
||||
``McNemar's test`` on the discordant pairs is the right significance
|
||||
test for "are the two arms different?". We also expose a bootstrap
|
||||
delta CI for visualising effect size.
|
||||
|
||||
Aggregate cost / latency / token deltas are mean-based; the runner
|
||||
slices them by arm before passing them in.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import statistics
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class McnemarResult:
|
||||
"""Discordant pair counts + the test statistics."""
|
||||
|
||||
n_total: int
|
||||
b: int # native correct, surfsense wrong
|
||||
c: int # native wrong, surfsense correct
|
||||
statistic: float
|
||||
p_value: float
|
||||
method: str
|
||||
|
||||
def to_dict(self) -> dict[str, float | int | str]:
|
||||
return {
|
||||
"n_total": self.n_total,
|
||||
"b_native_correct_only": self.b,
|
||||
"c_surfsense_correct_only": self.c,
|
||||
"statistic": self.statistic,
|
||||
"p_value": self.p_value,
|
||||
"method": self.method,
|
||||
}
|
||||
|
||||
|
||||
def mcnemar_test(
|
||||
arm_a_correct: Sequence[bool],
|
||||
arm_b_correct: Sequence[bool],
|
||||
*,
|
||||
use_exact_below: int = 11,
|
||||
) -> McnemarResult:
|
||||
"""Paired McNemar's test on per-question correctness.
|
||||
|
||||
``arm_a_correct`` is treated as the reference arm (typically the
|
||||
"native" arm); ``arm_b_correct`` is the challenger (typically
|
||||
"surfsense"). The test statistic only depends on discordant pairs.
|
||||
|
||||
Default switch-over (``b + c < 11``): for very small discordant
|
||||
samples the exact binomial test is preferred; above that the
|
||||
continuity-corrected chi-square is well-behaved (Edwards 1948).
|
||||
Callers can raise ``use_exact_below`` if they prefer the more
|
||||
conservative ``b + c < 25`` rule.
|
||||
|
||||
No external statistical package is required: scipy is a heavy dep
|
||||
and we only need binomial CDFs / chi-square sf, both implementable
|
||||
in stdlib + numpy without surprises.
|
||||
"""
|
||||
|
||||
if len(arm_a_correct) != len(arm_b_correct):
|
||||
raise ValueError(
|
||||
f"Length mismatch: arm_a={len(arm_a_correct)}, arm_b={len(arm_b_correct)}"
|
||||
)
|
||||
n = len(arm_a_correct)
|
||||
b = sum(1 for a, c in zip(arm_a_correct, arm_b_correct) if a and not c)
|
||||
c = sum(1 for a, cc in zip(arm_a_correct, arm_b_correct) if (not a) and cc)
|
||||
discordant = b + c
|
||||
if discordant == 0:
|
||||
return McnemarResult(
|
||||
n_total=n, b=b, c=c, statistic=0.0, p_value=1.0, method="degenerate"
|
||||
)
|
||||
|
||||
if discordant < use_exact_below:
|
||||
# Exact binomial: under H0 each discordant pair is a Bernoulli(0.5).
|
||||
# p-value = 2 * P(X <= min(b,c) | n=discordant, p=0.5), capped at 1.
|
||||
k = min(b, c)
|
||||
cdf = sum(_binom_pmf(discordant, i) for i in range(k + 1))
|
||||
p_value = min(1.0, 2.0 * cdf)
|
||||
return McnemarResult(
|
||||
n_total=n, b=b, c=c, statistic=float(k), p_value=p_value, method="exact"
|
||||
)
|
||||
|
||||
# Chi-square with continuity correction (McNemar-Edwards).
|
||||
chi = ((abs(b - c) - 1) ** 2) / discordant
|
||||
p_value = _chi2_sf(chi, df=1)
|
||||
return McnemarResult(
|
||||
n_total=n, b=b, c=c, statistic=chi, p_value=p_value, method="chi2_cc"
|
||||
)
|
||||
|
||||
|
||||
def _binom_pmf(n: int, k: int) -> float:
|
||||
return math.comb(n, k) * (0.5 ** n)
|
||||
|
||||
|
||||
def _chi2_sf(x: float, *, df: int) -> float:
|
||||
"""Survival function (1 - CDF) of chi-square; df=1 closed form."""
|
||||
|
||||
if x <= 0:
|
||||
return 1.0
|
||||
if df == 1:
|
||||
# Chi^2(1) = N(0,1)^2; sf(x) = 2 * Phi_complement(sqrt(x))
|
||||
return math.erfc(math.sqrt(x / 2.0))
|
||||
# General fallback via regularized upper incomplete gamma.
|
||||
a = df / 2.0
|
||||
z = x / 2.0
|
||||
return _gammaincc(a, z)
|
||||
|
||||
|
||||
def _gammaincc(a: float, x: float, *, max_iter: int = 200, tol: float = 1e-12) -> float:
|
||||
"""Regularised upper incomplete gamma Q(a, x). Series + continued fraction."""
|
||||
|
||||
if x < 0 or a <= 0:
|
||||
return float("nan")
|
||||
if x == 0:
|
||||
return 1.0
|
||||
if x < a + 1.0:
|
||||
# Series for P(a, x); subtract from 1.
|
||||
p_series = _gammainc_series(a, x, max_iter=max_iter, tol=tol)
|
||||
return 1.0 - p_series
|
||||
return _gammaincc_cf(a, x, max_iter=max_iter, tol=tol)
|
||||
|
||||
|
||||
def _gammainc_series(a: float, x: float, *, max_iter: int, tol: float) -> float:
|
||||
term = 1.0 / a
|
||||
summation = term
|
||||
for n in range(1, max_iter):
|
||||
term *= x / (a + n)
|
||||
summation += term
|
||||
if abs(term) < abs(summation) * tol:
|
||||
break
|
||||
log_pre = -x + a * math.log(x) - math.lgamma(a)
|
||||
return summation * math.exp(log_pre)
|
||||
|
||||
|
||||
def _gammaincc_cf(a: float, x: float, *, max_iter: int, tol: float) -> float:
|
||||
b = x + 1.0 - a
|
||||
c_val = 1.0 / 1e-300
|
||||
d = 1.0 / b
|
||||
h = d
|
||||
for i in range(1, max_iter):
|
||||
an = -i * (i - a)
|
||||
b += 2.0
|
||||
d = an * d + b
|
||||
if abs(d) < 1e-300:
|
||||
d = 1e-300
|
||||
c_val = b + an / c_val
|
||||
if abs(c_val) < 1e-300:
|
||||
c_val = 1e-300
|
||||
d = 1.0 / d
|
||||
delta = d * c_val
|
||||
h *= delta
|
||||
if abs(delta - 1.0) < tol:
|
||||
break
|
||||
log_pre = -x + a * math.log(x) - math.lgamma(a)
|
||||
return h * math.exp(log_pre)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bootstrap delta CI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BootstrapDelta:
|
||||
delta: float
|
||||
ci_low: float
|
||||
ci_high: float
|
||||
n_resamples: int
|
||||
|
||||
def to_dict(self) -> dict[str, float | int]:
|
||||
return {
|
||||
"delta": self.delta,
|
||||
"ci_low": self.ci_low,
|
||||
"ci_high": self.ci_high,
|
||||
"n_resamples": self.n_resamples,
|
||||
}
|
||||
|
||||
|
||||
def bootstrap_delta_ci(
|
||||
arm_a_correct: Sequence[bool],
|
||||
arm_b_correct: Sequence[bool],
|
||||
*,
|
||||
n_resamples: int = 5000,
|
||||
level: float = 0.95,
|
||||
random_state: int | None = 0,
|
||||
) -> BootstrapDelta:
|
||||
"""Paired-sample bootstrap CI for ``mean(arm_b) - mean(arm_a)``.
|
||||
|
||||
Resamples *paired indices* with replacement so the dependency
|
||||
between arms is preserved.
|
||||
"""
|
||||
|
||||
if len(arm_a_correct) != len(arm_b_correct):
|
||||
raise ValueError("paired arms must have the same length")
|
||||
n = len(arm_a_correct)
|
||||
if n == 0:
|
||||
return BootstrapDelta(0.0, 0.0, 0.0, 0)
|
||||
a = np.asarray(arm_a_correct, dtype=np.int8)
|
||||
b = np.asarray(arm_b_correct, dtype=np.int8)
|
||||
delta = float(b.mean() - a.mean())
|
||||
|
||||
rng = np.random.default_rng(random_state)
|
||||
deltas = np.empty(n_resamples, dtype=np.float64)
|
||||
for i in range(n_resamples):
|
||||
idx = rng.integers(0, n, size=n)
|
||||
deltas[i] = b[idx].mean() - a[idx].mean()
|
||||
alpha = (1.0 - level) / 2.0
|
||||
ci_low, ci_high = float(np.quantile(deltas, alpha)), float(np.quantile(deltas, 1 - alpha))
|
||||
return BootstrapDelta(delta=delta, ci_low=ci_low, ci_high=ci_high, n_resamples=n_resamples)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simple aggregate helpers (cost / latency / tokens)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Aggregate:
|
||||
mean: float
|
||||
median: float
|
||||
p95: float
|
||||
n: int
|
||||
|
||||
def to_dict(self) -> dict[str, float | int]:
|
||||
return {"mean": self.mean, "median": self.median, "p95": self.p95, "n": self.n}
|
||||
|
||||
|
||||
def paired_aggregate(values: Sequence[float]) -> Aggregate:
|
||||
"""Mean / median / p95 of a list of numbers (e.g. cost-per-question)."""
|
||||
|
||||
if not values:
|
||||
return Aggregate(0.0, 0.0, 0.0, 0)
|
||||
arr = np.asarray(values, dtype=np.float64)
|
||||
return Aggregate(
|
||||
mean=float(arr.mean()),
|
||||
median=float(statistics.median(values)),
|
||||
p95=float(np.quantile(arr, 0.95)),
|
||||
n=len(values),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Aggregate",
|
||||
"BootstrapDelta",
|
||||
"McnemarResult",
|
||||
"bootstrap_delta_ci",
|
||||
"mcnemar_test",
|
||||
"paired_aggregate",
|
||||
]
|
||||
130
surfsense_evals/src/surfsense_evals/core/metrics/mc_accuracy.py
Normal file
130
surfsense_evals/src/surfsense_evals/core/metrics/mc_accuracy.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""Multiple-choice accuracy + Wilson 95% confidence intervals.
|
||||
|
||||
Wilson CI is preferred over normal-approximation because MIRAGE's
|
||||
per-task subsets can be small (PubMedQA* and BioASQ-Y/N have a few
|
||||
hundred questions each) and Wilson handles n→0 / p→{0,1} edges
|
||||
gracefully.
|
||||
|
||||
Reference for the closed form: Wilson (1927); identical to the
|
||||
``statsmodels.stats.proportion.proportion_confint(method='wilson')``
|
||||
output and what scikit-learn implements internally for its bounded
|
||||
estimators.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AccuracyResult:
|
||||
"""Per-task accuracy with Wilson CI."""
|
||||
|
||||
n_correct: int
|
||||
n_total: int
|
||||
accuracy: float
|
||||
ci_low: float
|
||||
ci_high: float
|
||||
|
||||
def to_dict(self) -> dict[str, float | int]:
|
||||
return {
|
||||
"n_correct": self.n_correct,
|
||||
"n_total": self.n_total,
|
||||
"accuracy": self.accuracy,
|
||||
"ci_low": self.ci_low,
|
||||
"ci_high": self.ci_high,
|
||||
}
|
||||
|
||||
|
||||
# Two-sided Wilson z values. 1.959964 ≈ z_{0.975}.
|
||||
_Z_FOR_LEVEL: dict[float, float] = {
|
||||
0.90: 1.6448536269514722,
|
||||
0.95: 1.959963984540054,
|
||||
0.99: 2.5758293035489004,
|
||||
}
|
||||
|
||||
|
||||
def wilson_ci(
|
||||
n_correct: int, n_total: int, *, level: float = 0.95
|
||||
) -> tuple[float, float]:
|
||||
"""Two-sided Wilson score confidence interval for a proportion.
|
||||
|
||||
Returns ``(low, high)``. ``n_total == 0`` returns ``(0.0, 1.0)`` —
|
||||
the maximally uncertain interval.
|
||||
"""
|
||||
|
||||
if n_total <= 0:
|
||||
return 0.0, 1.0
|
||||
if level not in _Z_FOR_LEVEL:
|
||||
raise ValueError(f"Unsupported confidence level {level!r}")
|
||||
z = _Z_FOR_LEVEL[level]
|
||||
p = n_correct / n_total
|
||||
n = n_total
|
||||
denom = 1.0 + (z * z) / n
|
||||
centre = (p + (z * z) / (2 * n)) / denom
|
||||
half = (z / denom) * math.sqrt((p * (1 - p) / n) + (z * z) / (4 * n * n))
|
||||
low = max(0.0, centre - half)
|
||||
high = min(1.0, centre + half)
|
||||
return low, high
|
||||
|
||||
|
||||
def accuracy_with_wilson_ci(
|
||||
n_correct: int, n_total: int, *, level: float = 0.95
|
||||
) -> AccuracyResult:
|
||||
if n_total < 0:
|
||||
raise ValueError(f"n_total must be >= 0, got {n_total}")
|
||||
if n_correct < 0 or n_correct > n_total:
|
||||
raise ValueError(
|
||||
f"n_correct must be in [0, n_total]; got n_correct={n_correct}, n_total={n_total}"
|
||||
)
|
||||
accuracy = (n_correct / n_total) if n_total > 0 else 0.0
|
||||
low, high = wilson_ci(n_correct, n_total, level=level)
|
||||
return AccuracyResult(
|
||||
n_correct=n_correct,
|
||||
n_total=n_total,
|
||||
accuracy=accuracy,
|
||||
ci_low=low,
|
||||
ci_high=high,
|
||||
)
|
||||
|
||||
|
||||
def per_task_accuracy(
|
||||
rows: Sequence[Mapping[str, object]],
|
||||
*,
|
||||
task_key: str = "task",
|
||||
correct_key: str = "is_correct",
|
||||
level: float = 0.95,
|
||||
) -> dict[str, AccuracyResult]:
|
||||
"""Group ``rows`` by ``task_key`` and compute per-task ``AccuracyResult``.
|
||||
|
||||
``rows[i][correct_key]`` must be truthy iff the answer was correct.
|
||||
"""
|
||||
|
||||
counts: dict[str, list[int]] = {}
|
||||
for row in rows:
|
||||
task = str(row.get(task_key, ""))
|
||||
bucket = counts.setdefault(task, [0, 0])
|
||||
bucket[1] += 1
|
||||
if row.get(correct_key):
|
||||
bucket[0] += 1
|
||||
return {
|
||||
task: accuracy_with_wilson_ci(c[0], c[1], level=level)
|
||||
for task, c in counts.items()
|
||||
}
|
||||
|
||||
|
||||
def macro_accuracy(per_task: Mapping[str, AccuracyResult]) -> float:
|
||||
if not per_task:
|
||||
return 0.0
|
||||
return sum(r.accuracy for r in per_task.values()) / len(per_task)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AccuracyResult",
|
||||
"accuracy_with_wilson_ci",
|
||||
"macro_accuracy",
|
||||
"per_task_accuracy",
|
||||
"wilson_ci",
|
||||
]
|
||||
132
surfsense_evals/src/surfsense_evals/core/metrics/retrieval.py
Normal file
132
surfsense_evals/src/surfsense_evals/core/metrics/retrieval.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""Retrieval metrics: Recall@k, MRR, nDCG@k.
|
||||
|
||||
Used by CUREv1's runner to score the SurfSense arm against the
|
||||
benchmark's qrels. ``corpus_id`` is the canonical CUREv1 passage id
|
||||
(string); the runner maps SurfSense ``chunk_id`` → ``document_id`` →
|
||||
``corpus_id`` before calling these.
|
||||
|
||||
Graded relevance (CUREv1 uses 0/1/2 grades) is honoured by ``ndcg_at_k``;
|
||||
``recall_at_k`` and ``mrr`` flatten anything > 0 to "relevant".
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RetrievalScores:
|
||||
"""Aggregated retrieval scores."""
|
||||
|
||||
recall_at_k: dict[int, float]
|
||||
mrr: float
|
||||
ndcg_at_10: float
|
||||
n_queries: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"recall_at_k": dict(self.recall_at_k),
|
||||
"mrr": self.mrr,
|
||||
"ndcg_at_10": self.ndcg_at_10,
|
||||
"n_queries": self.n_queries,
|
||||
}
|
||||
|
||||
|
||||
def recall_at_k(retrieved: Sequence[str], relevant: Iterable[str], k: int) -> float:
|
||||
"""Fraction of ``relevant`` documents found in ``retrieved[:k]``."""
|
||||
|
||||
if not relevant:
|
||||
return 0.0
|
||||
relevant_set = set(relevant)
|
||||
if not relevant_set:
|
||||
return 0.0
|
||||
top_k = list(retrieved)[:k]
|
||||
hits = sum(1 for doc in top_k if doc in relevant_set)
|
||||
return hits / len(relevant_set)
|
||||
|
||||
|
||||
def mrr(retrieved: Sequence[str], relevant: Iterable[str]) -> float:
|
||||
"""Reciprocal rank of the first relevant doc, 0 if none found."""
|
||||
|
||||
relevant_set = set(relevant)
|
||||
for rank, doc in enumerate(retrieved, start=1):
|
||||
if doc in relevant_set:
|
||||
return 1.0 / rank
|
||||
return 0.0
|
||||
|
||||
|
||||
def _dcg_at_k(grades: Sequence[float], k: int) -> float:
|
||||
s = 0.0
|
||||
for i, grade in enumerate(grades[:k], start=1):
|
||||
# Standard log-base-2 discount; gain = 2^grade - 1 for graded relevance.
|
||||
s += (2.0 ** grade - 1.0) / math.log2(i + 1)
|
||||
return s
|
||||
|
||||
|
||||
def ndcg_at_k(
|
||||
retrieved: Sequence[str],
|
||||
qrels: Mapping[str, float],
|
||||
k: int,
|
||||
) -> float:
|
||||
"""nDCG@k against graded ``qrels`` (``{doc_id: grade}``).
|
||||
|
||||
Unjudged documents in ``retrieved`` contribute zero gain. The
|
||||
ideal ordering is ``qrels`` sorted by grade descending.
|
||||
"""
|
||||
|
||||
if not qrels:
|
||||
return 0.0
|
||||
grades = [float(qrels.get(doc, 0.0)) for doc in retrieved]
|
||||
dcg = _dcg_at_k(grades, k)
|
||||
ideal = sorted(qrels.values(), reverse=True)
|
||||
idcg = _dcg_at_k([float(g) for g in ideal], k)
|
||||
if idcg == 0.0:
|
||||
return 0.0
|
||||
return dcg / idcg
|
||||
|
||||
|
||||
def score_run(
|
||||
*,
|
||||
per_query_retrieved: Mapping[str, Sequence[str]],
|
||||
per_query_qrels: Mapping[str, Mapping[str, float]],
|
||||
ks: Sequence[int] = (1, 5, 10, 32),
|
||||
ndcg_k: int = 10,
|
||||
) -> RetrievalScores:
|
||||
"""Aggregate Recall@k, MRR, nDCG@k across a run.
|
||||
|
||||
``per_query_retrieved`` maps ``query_id -> ordered list of doc ids``.
|
||||
``per_query_qrels`` maps ``query_id -> {doc_id: grade}`` (grade > 0
|
||||
is relevant).
|
||||
|
||||
Queries present in retrieved but not in qrels are skipped. Queries
|
||||
in qrels but missing from retrieved contribute zeros.
|
||||
"""
|
||||
|
||||
qids = set(per_query_qrels.keys()) & set(per_query_retrieved.keys())
|
||||
if not qids:
|
||||
return RetrievalScores(recall_at_k={k: 0.0 for k in ks}, mrr=0.0, ndcg_at_10=0.0, n_queries=0)
|
||||
|
||||
recall_totals = {k: 0.0 for k in ks}
|
||||
mrr_total = 0.0
|
||||
ndcg_total = 0.0
|
||||
for qid in qids:
|
||||
retrieved = list(per_query_retrieved[qid])
|
||||
qrels = per_query_qrels[qid]
|
||||
relevant_docs = [d for d, g in qrels.items() if g > 0]
|
||||
for k in ks:
|
||||
recall_totals[k] += recall_at_k(retrieved, relevant_docs, k)
|
||||
mrr_total += mrr(retrieved, relevant_docs)
|
||||
ndcg_total += ndcg_at_k(retrieved, qrels, ndcg_k)
|
||||
|
||||
n = len(qids)
|
||||
return RetrievalScores(
|
||||
recall_at_k={k: v / n for k, v in recall_totals.items()},
|
||||
mrr=mrr_total / n,
|
||||
ndcg_at_10=ndcg_total / n,
|
||||
n_queries=n,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["RetrievalScores", "mrr", "ndcg_at_k", "recall_at_k", "score_run"]
|
||||
21
surfsense_evals/src/surfsense_evals/core/parse/__init__.py
Normal file
21
surfsense_evals/src/surfsense_evals/core/parse/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""Parsers shared across suites: citations, MCQ envelopes, AI-SDK SSE."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .answer_letter import AnswerLetterResult, extract_answer_letter
|
||||
from .citations import CITATION_REGEX, CitationToken, ChunkCitation, UrlCitation, parse_citations
|
||||
from .freeform_answer import extract_freeform_answer
|
||||
from .sse import SseEvent, iter_sse_events
|
||||
|
||||
__all__ = [
|
||||
"CITATION_REGEX",
|
||||
"CitationToken",
|
||||
"ChunkCitation",
|
||||
"UrlCitation",
|
||||
"parse_citations",
|
||||
"AnswerLetterResult",
|
||||
"extract_answer_letter",
|
||||
"extract_freeform_answer",
|
||||
"SseEvent",
|
||||
"iter_sse_events",
|
||||
]
|
||||
122
surfsense_evals/src/surfsense_evals/core/parse/answer_letter.py
Normal file
122
surfsense_evals/src/surfsense_evals/core/parse/answer_letter.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Robust extractor for MCQ answer letters.
|
||||
|
||||
Handles three answer shapes seen in the wild:
|
||||
|
||||
1. **MedRAG envelope** — ``{"step_by_step_thinking": "...", "answer_choice": "A"}``
|
||||
embedded somewhere in the assistant message (often inside ```` ```json ```` /
|
||||
``` ``` ``` fences). The regex grabs the JSON object and reads the
|
||||
``answer_choice`` field.
|
||||
|
||||
2. **Final-line letter** — e.g. ``Answer: B`` or ``The correct answer is (C).``.
|
||||
Falls back to a permissive regex over the last few lines.
|
||||
|
||||
3. **Bare letter** — single uppercase letter at the end of the message.
|
||||
|
||||
The function returns the parsed letter (uppercased) plus a discriminator
|
||||
of which strategy fired so the runner / report can flag suspicious
|
||||
parses (typically zero-confidence parses indicate the model didn't
|
||||
follow the prompt).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
ParserStrategy = Literal["json_envelope", "answer_line", "bare_letter", "none"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnswerLetterResult:
|
||||
letter: str | None
|
||||
strategy: ParserStrategy
|
||||
|
||||
@property
|
||||
def found(self) -> bool:
|
||||
return self.letter is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_JSON_BLOCK = re.compile(r"\{[^{}]*\"answer_choice\"\s*:\s*\"([A-Za-z])\"[^{}]*\}", re.DOTALL)
|
||||
_FENCED_JSON = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL | re.IGNORECASE)
|
||||
_ANSWER_LINE = re.compile(
|
||||
r"(?:final\s*answer|answer\s*choice|the\s+correct\s+answer\s+is|answer)\s*[:=\-]?\s*"
|
||||
r"\(?\s*([A-Za-z])\s*[\)\.]*\s*$",
|
||||
re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
_BARE_LETTER = re.compile(r"^\s*\(?\s*([A-Za-z])\s*[\)\.]*\s*$", re.MULTILINE)
|
||||
|
||||
|
||||
def _from_json_envelope(text: str) -> str | None:
|
||||
# Try fenced code blocks first (most likely to contain the JSON).
|
||||
for fence in _FENCED_JSON.finditer(text):
|
||||
try:
|
||||
obj = json.loads(fence.group(1))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
if isinstance(obj, dict):
|
||||
choice = obj.get("answer_choice")
|
||||
if isinstance(choice, str) and choice.strip():
|
||||
return choice.strip()[:1].upper()
|
||||
|
||||
# Fall back to a tolerant regex over the whole text (handles
|
||||
# responses that drop the fences).
|
||||
match = _JSON_BLOCK.search(text)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
return None
|
||||
|
||||
|
||||
def _from_answer_line(text: str) -> str | None:
|
||||
# Walk lines bottom-up; the answer is almost always near the end.
|
||||
for match in reversed(list(_ANSWER_LINE.finditer(text))):
|
||||
letter = match.group(1).upper()
|
||||
if letter.isalpha():
|
||||
return letter
|
||||
return None
|
||||
|
||||
|
||||
def _from_bare_letter(text: str) -> str | None:
|
||||
# Inspect only the final non-empty lines (avoid grabbing in-prose
|
||||
# mentions of "A" or "I").
|
||||
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
|
||||
for ln in reversed(lines[-3:]):
|
||||
match = _BARE_LETTER.match(ln)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
return None
|
||||
|
||||
|
||||
def extract_answer_letter(text: str) -> AnswerLetterResult:
|
||||
"""Run strategies in order and return the first hit.
|
||||
|
||||
Order: JSON envelope → final-answer-line regex → bare-letter
|
||||
fallback. Empty / whitespace-only text returns
|
||||
``AnswerLetterResult(None, "none")``.
|
||||
"""
|
||||
|
||||
if not text or not text.strip():
|
||||
return AnswerLetterResult(None, "none")
|
||||
|
||||
letter = _from_json_envelope(text)
|
||||
if letter:
|
||||
return AnswerLetterResult(letter, "json_envelope")
|
||||
|
||||
letter = _from_answer_line(text)
|
||||
if letter:
|
||||
return AnswerLetterResult(letter, "answer_line")
|
||||
|
||||
letter = _from_bare_letter(text)
|
||||
if letter:
|
||||
return AnswerLetterResult(letter, "bare_letter")
|
||||
|
||||
return AnswerLetterResult(None, "none")
|
||||
|
||||
|
||||
__all__ = ["AnswerLetterResult", "ParserStrategy", "extract_answer_letter"]
|
||||
110
surfsense_evals/src/surfsense_evals/core/parse/citations.py
Normal file
110
surfsense_evals/src/surfsense_evals/core/parse/citations.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Python port of the canonical citation parser.
|
||||
|
||||
Source of truth: ``surfsense_web/lib/citations/citation-parser.ts:20-21``.
|
||||
The pattern is byte-for-byte identical to the TS export ``CITATION_REGEX``
|
||||
so a SurfSense user reading the web client and a CUREv1 retrieval scorer
|
||||
running here see the same chunk_ids extracted from the same answer.
|
||||
|
||||
The TS reference also handles a ``urlcite{N}`` placeholder produced by
|
||||
``preprocessCitationMarkdown`` — that pre-processing step is web-only
|
||||
(GFM autolink workaround), so the harness sees raw ``[citation:URL]``
|
||||
tokens and ``parse_citations`` returns them as ``UrlCitation`` directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
|
||||
# Pattern preserves the TS source verbatim:
|
||||
# /[\[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g
|
||||
#
|
||||
# Notes:
|
||||
# * Matches both ASCII ``[]`` and Chinese fullwidth ``【】`` brackets.
|
||||
# * Allows an optional ZWSP (``\u200B``) just inside each bracket.
|
||||
# * ``citation:`` then EITHER a URL (anything not ``]``, ``】``, or ZWSP),
|
||||
# OR a ``urlcite\d+`` placeholder, OR one or more comma-separated
|
||||
# chunk ids (each optionally prefixed with ``doc-`` and optionally
|
||||
# negative).
|
||||
# * URL char class deliberately excludes the closing brackets so a
|
||||
# ``[citation:https://x.com]`` doesn't swallow the ``]``.
|
||||
# The ZWSP must be the actual code-point — the original TS source uses
|
||||
# the regex literal ``\u200B`` which the JS engine interprets as the
|
||||
# character. Python's ``re`` doesn't process the ``\u`` escape inside
|
||||
# the pattern source, so we splice the literal character in via an
|
||||
# f-string. This keeps our pattern functionally identical to the TS
|
||||
# reference and lets ``"\u200B" in CITATION_REGEX.pattern`` succeed.
|
||||
_ZWSP = "\u200B"
|
||||
CITATION_REGEX = re.compile(
|
||||
rf"[\[【]{_ZWSP}?citation:\s*("
|
||||
rf"https?://[^\]】{_ZWSP}]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*"
|
||||
rf")\s*{_ZWSP}?[\]】]"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChunkCitation:
|
||||
chunk_id: int
|
||||
is_docs_chunk: bool
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"kind": "chunk",
|
||||
"chunk_id": self.chunk_id,
|
||||
"is_docs_chunk": self.is_docs_chunk,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UrlCitation:
|
||||
url: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"kind": "url", "url": self.url}
|
||||
|
||||
|
||||
CitationToken = Union[ChunkCitation, UrlCitation]
|
||||
|
||||
|
||||
def parse_citations(text: str, *, url_map: dict[str, str] | None = None) -> list[CitationToken]:
|
||||
"""Return the citation tokens found in ``text`` in document order.
|
||||
|
||||
``url_map`` is the optional ``urlciteN -> URL`` lookup that the web
|
||||
client builds in its preprocessing step. The harness ordinarily
|
||||
doesn't preprocess (we don't render the markdown, we score it), so
|
||||
the default empty map means ``urlciteN`` placeholders are dropped
|
||||
rather than mis-resolved to a missing URL.
|
||||
|
||||
Multi-id payloads like ``[citation:1, doc-2, -3]`` are flattened
|
||||
into separate ``ChunkCitation`` entries — same as the TS reference.
|
||||
"""
|
||||
|
||||
out: list[CitationToken] = []
|
||||
for match in CITATION_REGEX.finditer(text):
|
||||
captured = match.group(1)
|
||||
if captured.startswith("http://") or captured.startswith("https://"):
|
||||
out.append(UrlCitation(url=captured.strip()))
|
||||
continue
|
||||
if captured.startswith("urlcite"):
|
||||
if url_map and captured in url_map:
|
||||
out.append(UrlCitation(url=url_map[captured]))
|
||||
continue
|
||||
for raw_id in (s.strip() for s in captured.split(",")):
|
||||
is_docs_chunk = raw_id.startswith("doc-")
|
||||
number_part = raw_id[4:] if is_docs_chunk else raw_id
|
||||
try:
|
||||
chunk_id = int(number_part)
|
||||
except ValueError:
|
||||
continue
|
||||
out.append(ChunkCitation(chunk_id=chunk_id, is_docs_chunk=is_docs_chunk))
|
||||
return out
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CITATION_REGEX",
|
||||
"ChunkCitation",
|
||||
"UrlCitation",
|
||||
"CitationToken",
|
||||
"parse_citations",
|
||||
]
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Extract free-form answers from open-ended LLM responses.
|
||||
|
||||
Used by benchmarks that don't have a fixed letter set (MMLongBench-Doc,
|
||||
DocVQA-style benchmarks, future legal/finance suites). The contract:
|
||||
|
||||
* Strip leading "Answer:" / "Final answer:" markers if present.
|
||||
* Drop fenced code blocks if the model wrapped its answer in one.
|
||||
* Trim leading/trailing whitespace.
|
||||
* Return the *last* meaningful chunk — models often think out loud
|
||||
before stating the answer.
|
||||
|
||||
If the message is empty or only contains a fence, return ``""``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
_ANSWER_PREFIX = re.compile(
|
||||
r"^\s*(?:final\s*answer|the\s+answer\s+is|answer)\s*[:=\-]\s*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Marker-only regex (no capture group) used to find every "Answer:"
|
||||
# token position. We then slice from the LAST marker's end to the
|
||||
# next newline ourselves — robust to multiple inline answers because
|
||||
# we never let the engine greedy-capture across markers.
|
||||
_ANSWER_MARKER = re.compile(
|
||||
r"(?:final\s*answer|the\s+answer\s+is|answer)\s*[:=\-]\s*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_FENCED_BLOCK = re.compile(r"```[a-zA-Z0-9]*\s*([\s\S]*?)\s*```")
|
||||
|
||||
|
||||
def extract_freeform_answer(text: str) -> str:
|
||||
"""Pull the model's final answer out of a possibly-verbose response."""
|
||||
|
||||
if not text or not text.strip():
|
||||
return ""
|
||||
|
||||
# 1. Find the last line that starts with an Answer: marker. If
|
||||
# nothing matches, walk back to the last non-empty line.
|
||||
lines = [ln.rstrip() for ln in text.strip().splitlines()]
|
||||
candidate = ""
|
||||
for ln in reversed(lines):
|
||||
if not ln.strip():
|
||||
continue
|
||||
if _ANSWER_PREFIX.search(ln):
|
||||
candidate = _ANSWER_PREFIX.sub("", ln, count=1).strip()
|
||||
break
|
||||
|
||||
if not candidate:
|
||||
# 2. Inline match: find every "Answer:" marker position and
|
||||
# slice from the LAST marker's end to the next newline. Robust
|
||||
# to "preamble.Answer: 42" one-liners and multiple inline
|
||||
# markers (we always pick the final, freshest one).
|
||||
marker_matches = list(_ANSWER_MARKER.finditer(text))
|
||||
if marker_matches:
|
||||
last = marker_matches[-1]
|
||||
tail = text[last.end():]
|
||||
nl = tail.find("\n")
|
||||
if nl >= 0:
|
||||
tail = tail[:nl]
|
||||
candidate = tail.strip()
|
||||
|
||||
if not candidate:
|
||||
# 3. No "Answer:" marker — try fenced blocks.
|
||||
fences = _FENCED_BLOCK.findall(text)
|
||||
if fences:
|
||||
candidate = fences[-1].strip()
|
||||
else:
|
||||
# Last non-empty line as a fallback.
|
||||
for ln in reversed(lines):
|
||||
if ln.strip():
|
||||
candidate = ln.strip()
|
||||
break
|
||||
|
||||
# 2. Strip wrapping quotes / parens / trailing punctuation that
|
||||
# confuse the grader without changing meaning.
|
||||
candidate = candidate.strip().strip("`").strip()
|
||||
if candidate.startswith(("\"", "'")) and candidate.endswith(("\"", "'")):
|
||||
candidate = candidate[1:-1].strip()
|
||||
return candidate
|
||||
|
||||
|
||||
__all__ = ["extract_freeform_answer"]
|
||||
72
surfsense_evals/src/surfsense_evals/core/parse/sse.py
Normal file
72
surfsense_evals/src/surfsense_evals/core/parse/sse.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Minimal SSE consumer compatible with SurfSense's wire format.
|
||||
|
||||
SurfSense uses ``app/services/streaming/envelope/sse.py`` to frame events:
|
||||
|
||||
* ``data: <single-line-string>\\n\\n``
|
||||
* ``data: <json-string>\\n\\n`` (most events)
|
||||
* ``data: [DONE]\\n\\n`` (terminator)
|
||||
|
||||
There is no ``event:``, ``id:``, or ``retry:`` framing in production —
|
||||
``format_sse(payload)`` only emits the ``data:`` line. This implementation
|
||||
is therefore intentionally smaller than ``httpx-sse`` (which we still
|
||||
list as a dep so callers who want richer parsing can opt in): one event
|
||||
per ``data:`` line, separated by blank lines.
|
||||
|
||||
We accept any line iterator (an ``httpx.Response.aiter_lines`` adapter
|
||||
in production, a list in tests) so this is unit-testable without a
|
||||
network mock.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SseEvent:
|
||||
"""A parsed SSE event. Only the ``data`` field is populated.
|
||||
|
||||
Multi-line payloads (``data: a\\ndata: b``) are joined with ``\\n``
|
||||
per the SSE spec, even though SurfSense doesn't currently emit them.
|
||||
"""
|
||||
|
||||
data: str
|
||||
|
||||
|
||||
async def iter_sse_events(lines: AsyncIterator[str]) -> AsyncIterator[SseEvent]:
|
||||
"""Yield one ``SseEvent`` per blank-line-terminated frame.
|
||||
|
||||
Lines that are empty or whitespace flush the buffer. ``data:`` lines
|
||||
are accumulated into the buffer; everything else is ignored
|
||||
(matches the lenient browser EventSource behaviour).
|
||||
"""
|
||||
|
||||
buffer: list[str] = []
|
||||
async for raw in lines:
|
||||
if raw is None:
|
||||
continue
|
||||
line = raw.rstrip("\r")
|
||||
if line == "":
|
||||
if buffer:
|
||||
yield SseEvent(data="\n".join(buffer))
|
||||
buffer.clear()
|
||||
continue
|
||||
if line.startswith(":"):
|
||||
# comment / heartbeat
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
# spec: optional single space after the colon.
|
||||
payload = line[5:]
|
||||
if payload.startswith(" "):
|
||||
payload = payload[1:]
|
||||
buffer.append(payload)
|
||||
continue
|
||||
# Any other field (event:, id:, retry:) is currently unused.
|
||||
continue
|
||||
|
||||
if buffer:
|
||||
yield SseEvent(data="\n".join(buffer))
|
||||
|
||||
|
||||
__all__ = ["SseEvent", "iter_sse_events"]
|
||||
31
surfsense_evals/src/surfsense_evals/core/pdf/__init__.py
Normal file
31
surfsense_evals/src/surfsense_evals/core/pdf/__init__.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""Domain-agnostic PDF rendering helper. Lazy import."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .render import (
|
||||
PdfImage,
|
||||
render_pdf,
|
||||
render_pdf_with_images,
|
||||
render_text_files_to_pdf,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PdfImage",
|
||||
"render_pdf",
|
||||
"render_pdf_with_images",
|
||||
"render_text_files_to_pdf",
|
||||
]
|
||||
|
||||
|
||||
_LAZY = {"PdfImage", "render_pdf", "render_pdf_with_images", "render_text_files_to_pdf"}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in _LAZY:
|
||||
from . import render as _mod
|
||||
|
||||
return getattr(_mod, name)
|
||||
raise AttributeError(f"module 'surfsense_evals.core.pdf' has no attribute {name!r}")
|
||||
351
surfsense_evals/src/surfsense_evals/core/pdf/render.py
Normal file
351
surfsense_evals/src/surfsense_evals/core/pdf/render.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
"""Deterministic ``.txt`` / ``.md`` → single PDF via reportlab.
|
||||
|
||||
Used wherever a benchmark needs the same source bytes fed to both the
|
||||
native-PDF arm and the SurfSense ingestion arm. The head-to-head
|
||||
comparison is fair only if the *same* PDF is the input to both arms,
|
||||
which is why we go to lengths to make the rendering deterministic.
|
||||
|
||||
Determinism notes:
|
||||
|
||||
* We pin the PDF metadata to a fixed creation date and producer
|
||||
(``reportlab`` accepts neither directly, but ``Canvas.setAuthor`` and
|
||||
the absence of an ``info`` mutator means the bytes only differ by
|
||||
``CreationDate`` / ``ModDate``). We post-process the PDF to scrub
|
||||
those if ``deterministic=True`` is passed.
|
||||
* Page size, font, margins, and tab handling are fixed in code so the
|
||||
same input yields the same byte output across machines.
|
||||
* PDF/A is overkill for our use; basic PDF 1.4 is what every model
|
||||
expects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import re
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from reportlab.lib.pagesizes import LETTER
|
||||
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.lib.utils import ImageReader
|
||||
from reportlab.platypus import (
|
||||
Image,
|
||||
KeepTogether,
|
||||
PageBreak,
|
||||
Paragraph,
|
||||
SimpleDocTemplate,
|
||||
Spacer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderedPdf:
|
||||
path: Path
|
||||
n_pages_estimate: int
|
||||
n_chars: int
|
||||
|
||||
|
||||
_PDF_DATE_KEY = re.compile(rb"/(?:CreationDate|ModDate)\s*\(D:[^)]*\)")
|
||||
# reportlab also writes a `/ID [<hex1><hex2>]` trailer entry that
|
||||
# embeds a per-run hash. Scrub it so two renders of the same input
|
||||
# produce the same bytes.
|
||||
_PDF_ID_ARRAY = re.compile(rb"/ID\s*\[\s*<[^>]*>\s*<[^>]*>\s*\]")
|
||||
|
||||
|
||||
def _scrub_dates(pdf_bytes: bytes) -> bytes:
|
||||
"""Remove ``CreationDate`` / ``ModDate`` / trailer ``/ID`` so the
|
||||
file is byte-deterministic across runs."""
|
||||
|
||||
pdf_bytes = _PDF_DATE_KEY.sub(b"/CreationDate (D:19700101000000Z)", pdf_bytes)
|
||||
pdf_bytes = _PDF_ID_ARRAY.sub(b"/ID [<00><00>]", pdf_bytes)
|
||||
return pdf_bytes
|
||||
|
||||
|
||||
_DEFAULT_STYLES = getSampleStyleSheet()
|
||||
|
||||
|
||||
def _build_body_style() -> ParagraphStyle:
|
||||
base = _DEFAULT_STYLES["BodyText"]
|
||||
style = ParagraphStyle(
|
||||
"EvalBody",
|
||||
parent=base,
|
||||
fontName="Helvetica",
|
||||
fontSize=10.5,
|
||||
leading=14,
|
||||
spaceAfter=6,
|
||||
spaceBefore=0,
|
||||
)
|
||||
return style
|
||||
|
||||
|
||||
def _build_heading_style() -> ParagraphStyle:
|
||||
base = _DEFAULT_STYLES["Heading2"]
|
||||
style = ParagraphStyle(
|
||||
"EvalHeading",
|
||||
parent=base,
|
||||
fontName="Helvetica-Bold",
|
||||
fontSize=14,
|
||||
leading=18,
|
||||
spaceAfter=10,
|
||||
spaceBefore=8,
|
||||
)
|
||||
return style
|
||||
|
||||
|
||||
def _normalise_paragraphs(text: str) -> list[str]:
|
||||
"""Split a text blob into paragraphs while preserving blank-line structure."""
|
||||
|
||||
blocks: list[list[str]] = [[]]
|
||||
for line in text.splitlines():
|
||||
stripped = line.rstrip()
|
||||
if stripped == "":
|
||||
if blocks[-1]:
|
||||
blocks.append([])
|
||||
continue
|
||||
blocks[-1].append(stripped)
|
||||
paragraphs: list[str] = []
|
||||
for block in blocks:
|
||||
if not block:
|
||||
continue
|
||||
# Join lines within a paragraph with spaces (text-from-PDF style).
|
||||
paragraphs.append(" ".join(block))
|
||||
return paragraphs
|
||||
|
||||
|
||||
def _escape_html(text: str) -> str:
|
||||
return (
|
||||
text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
|
||||
|
||||
def render_pdf(
|
||||
*,
|
||||
title: str,
|
||||
sections: Sequence[tuple[str | None, str]],
|
||||
output_path: Path,
|
||||
deterministic: bool = True,
|
||||
) -> RenderedPdf:
|
||||
"""Render one PDF from a list of ``(section_heading, section_text)`` tuples.
|
||||
|
||||
``section_heading`` may be ``None`` for an unnamed section. Each
|
||||
section is followed by a page break so the model's PDF parser sees
|
||||
a clean structural boundary between source files.
|
||||
"""
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=LETTER,
|
||||
leftMargin=0.75 * inch,
|
||||
rightMargin=0.75 * inch,
|
||||
topMargin=0.75 * inch,
|
||||
bottomMargin=0.75 * inch,
|
||||
title=title,
|
||||
author="surfsense-evals",
|
||||
subject="Eval input",
|
||||
creator="surfsense-evals",
|
||||
)
|
||||
|
||||
body_style = _build_body_style()
|
||||
heading_style = _build_heading_style()
|
||||
title_style = ParagraphStyle(
|
||||
"EvalTitle",
|
||||
parent=_DEFAULT_STYLES["Title"],
|
||||
fontName="Helvetica-Bold",
|
||||
fontSize=18,
|
||||
leading=22,
|
||||
spaceAfter=14,
|
||||
)
|
||||
|
||||
flow: list = [Paragraph(_escape_html(title), title_style)]
|
||||
total_chars = 0
|
||||
for index, (heading, text) in enumerate(sections):
|
||||
if index > 0:
|
||||
flow.append(PageBreak())
|
||||
if heading:
|
||||
flow.append(Paragraph(_escape_html(heading), heading_style))
|
||||
for paragraph in _normalise_paragraphs(text):
|
||||
total_chars += len(paragraph)
|
||||
flow.append(Paragraph(_escape_html(paragraph), body_style))
|
||||
flow.append(Spacer(1, 4))
|
||||
|
||||
doc.build(flow)
|
||||
pdf_bytes = buffer.getvalue()
|
||||
if deterministic:
|
||||
pdf_bytes = _scrub_dates(pdf_bytes)
|
||||
output_path.write_bytes(pdf_bytes)
|
||||
|
||||
# Conservative page estimate: ~3000 chars per LETTER page at 10.5pt.
|
||||
n_pages = max(1, total_chars // 3000 + len(sections))
|
||||
return RenderedPdf(path=output_path, n_pages_estimate=n_pages, n_chars=total_chars)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PdfImage:
|
||||
"""One image to embed inside a section.
|
||||
|
||||
``caption`` is rendered below the image (italic). ``max_width_in``
|
||||
caps the rendered width in inches; height auto-scales to preserve
|
||||
aspect ratio (read with PIL).
|
||||
"""
|
||||
|
||||
path: Path
|
||||
caption: str = ""
|
||||
max_width_in: float = 5.5 # default leaves margin for LETTER 8.5"
|
||||
|
||||
|
||||
def _make_image_flowable(image: PdfImage) -> Image:
|
||||
"""Build a reportlab Image flowable scaled to fit page width."""
|
||||
|
||||
reader = ImageReader(str(image.path))
|
||||
iw, ih = reader.getSize()
|
||||
if iw <= 0 or ih <= 0:
|
||||
raise ValueError(f"Invalid image dimensions for {image.path}: {iw}x{ih}")
|
||||
target_w = image.max_width_in * inch
|
||||
target_h = target_w * (ih / iw)
|
||||
# Cap height too — some medical images are extreme portrait.
|
||||
max_h = 7.0 * inch
|
||||
if target_h > max_h:
|
||||
target_h = max_h
|
||||
target_w = target_h * (iw / ih)
|
||||
return Image(str(image.path), width=target_w, height=target_h)
|
||||
|
||||
|
||||
def render_pdf_with_images(
|
||||
*,
|
||||
title: str,
|
||||
sections: Sequence[tuple[str | None, str, Sequence[PdfImage] | None]],
|
||||
output_path: Path,
|
||||
deterministic: bool = True,
|
||||
page_break_between_sections: bool = False,
|
||||
) -> RenderedPdf:
|
||||
"""Render a PDF that mixes text and embedded images.
|
||||
|
||||
Each section is ``(heading, body_text, images)``. Images render
|
||||
inline after the body text, each followed by an italic caption.
|
||||
Set ``page_break_between_sections=True`` if you want explicit
|
||||
structural boundaries (mostly useful for multi-case PDFs); the
|
||||
default keeps everything on one page when possible (so a single
|
||||
MedXpertQA case is one PDF page with case + images + options).
|
||||
"""
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=LETTER,
|
||||
leftMargin=0.75 * inch,
|
||||
rightMargin=0.75 * inch,
|
||||
topMargin=0.75 * inch,
|
||||
bottomMargin=0.75 * inch,
|
||||
title=title,
|
||||
author="surfsense-evals",
|
||||
subject="Eval input",
|
||||
creator="surfsense-evals",
|
||||
)
|
||||
|
||||
body_style = _build_body_style()
|
||||
heading_style = _build_heading_style()
|
||||
caption_style = ParagraphStyle(
|
||||
"EvalCaption",
|
||||
parent=body_style,
|
||||
fontSize=9,
|
||||
leading=11,
|
||||
textColor="#444",
|
||||
spaceBefore=2,
|
||||
spaceAfter=10,
|
||||
)
|
||||
title_style = ParagraphStyle(
|
||||
"EvalTitle",
|
||||
parent=_DEFAULT_STYLES["Title"],
|
||||
fontName="Helvetica-Bold",
|
||||
fontSize=18,
|
||||
leading=22,
|
||||
spaceAfter=14,
|
||||
)
|
||||
|
||||
flow: list = [Paragraph(_escape_html(title), title_style)]
|
||||
total_chars = 0
|
||||
for index, (heading, text, images) in enumerate(sections):
|
||||
if index > 0 and page_break_between_sections:
|
||||
flow.append(PageBreak())
|
||||
if heading:
|
||||
flow.append(Paragraph(_escape_html(heading), heading_style))
|
||||
for paragraph in _normalise_paragraphs(text):
|
||||
total_chars += len(paragraph)
|
||||
flow.append(Paragraph(_escape_html(paragraph), body_style))
|
||||
flow.append(Spacer(1, 4))
|
||||
for image in images or []:
|
||||
try:
|
||||
img_flow = _make_image_flowable(image)
|
||||
except Exception: # noqa: BLE001 — bad image shouldn't kill PDF
|
||||
continue
|
||||
grouped = [img_flow]
|
||||
if image.caption:
|
||||
grouped.append(Paragraph(_escape_html(image.caption), caption_style))
|
||||
else:
|
||||
grouped.append(Spacer(1, 8))
|
||||
flow.append(KeepTogether(grouped))
|
||||
|
||||
doc.build(flow)
|
||||
pdf_bytes = buffer.getvalue()
|
||||
if deterministic:
|
||||
pdf_bytes = _scrub_dates(pdf_bytes)
|
||||
output_path.write_bytes(pdf_bytes)
|
||||
|
||||
n_pages = max(1, total_chars // 3000 + len(sections))
|
||||
return RenderedPdf(path=output_path, n_pages_estimate=n_pages, n_chars=total_chars)
|
||||
|
||||
|
||||
def render_text_files_to_pdf(
|
||||
*,
|
||||
title: str,
|
||||
files: Iterable[Path],
|
||||
output_path: Path,
|
||||
deterministic: bool = True,
|
||||
) -> RenderedPdf:
|
||||
"""Convenience wrapper: read a list of text files, render to one PDF.
|
||||
|
||||
The heading of each section is the file's name (no extension), so
|
||||
e.g. ``admission_note.txt`` becomes a section header ``admission_note``
|
||||
in the rendered PDF. Useful for any text-only benchmark that ships
|
||||
a corpus as separate ``.txt`` / ``.md`` shards per logical document.
|
||||
"""
|
||||
|
||||
sections: list[tuple[str | None, str]] = []
|
||||
for path in files:
|
||||
path = Path(path)
|
||||
text = path.read_text(encoding="utf-8")
|
||||
sections.append((path.stem, text))
|
||||
return render_pdf(
|
||||
title=title,
|
||||
sections=sections,
|
||||
output_path=output_path,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
|
||||
# Tiny self-check — handy when debugging.
|
||||
def _self_test() -> None: # pragma: no cover
|
||||
out = Path("./_render_self_test.pdf")
|
||||
sections = [
|
||||
("intro", "Hello world.\n\nThis is a test."),
|
||||
("body", "Line one.\nLine two."),
|
||||
]
|
||||
rendered = render_pdf(title="Self test", sections=sections, output_path=out)
|
||||
print(f"wrote {rendered.path} ({rendered.n_chars} chars)")
|
||||
|
||||
|
||||
# Importing ``datetime`` keeps the timezone helper handy if a future
|
||||
# benchmark wants to embed a real timestamp without losing determinism.
|
||||
_NOW_FROZEN = datetime(2026, 5, 11, tzinfo=UTC)
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
"""External LLM providers (used by the native arm).
|
||||
|
||||
Lazy imports so the SurfSense-only path doesn't transitively load the
|
||||
OpenRouter client until something actually constructs ``OpenRouterPdfProvider``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .openrouter_pdf import OpenRouterPdfProvider, OpenRouterResponse
|
||||
|
||||
__all__ = ["OpenRouterPdfProvider", "OpenRouterResponse"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in {"OpenRouterPdfProvider", "OpenRouterResponse"}:
|
||||
from . import openrouter_pdf as _mod
|
||||
|
||||
return getattr(_mod, name)
|
||||
raise AttributeError(f"module 'surfsense_evals.core.providers' has no attribute {name!r}")
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
"""Bare OpenRouter ``chat/completions`` provider — no PDF, no plugins.
|
||||
|
||||
Used by ``BareLlmArm`` to measure "what does the model answer with
|
||||
zero retrieval context?". Same wire shape as ``OpenRouterPdfProvider``
|
||||
minus the file-parser plugin and the ``file`` content part:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "openai/gpt-5.4-mini",
|
||||
"messages": [
|
||||
{"role": "system", "content": "<optional>"},
|
||||
{"role": "user", "content": "<prompt>"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The response shape is identical to the PDF provider's, so we re-use
|
||||
``_parse_chat_completion`` from ``openrouter_pdf`` and only specialise
|
||||
the request builder. That keeps cost-extraction, token-counting, and
|
||||
content-array handling in one place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .openrouter_pdf import (
|
||||
OpenRouterResponse,
|
||||
_DEFAULT_HEADERS,
|
||||
_parse_chat_completion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenRouterChatProvider:
|
||||
"""Stateless bare-chat client. No PDF, no file-parser plugin."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
model: str,
|
||||
timeout_s: float = 600.0,
|
||||
) -> None:
|
||||
if not api_key:
|
||||
raise ValueError("OPENROUTER_API_KEY is required for the bare-LLM arm.")
|
||||
self._api_key = api_key
|
||||
self._base = base_url.rstrip("/")
|
||||
self._model = model
|
||||
self._timeout = httpx.Timeout(timeout_s, connect=15.0)
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
system_prompt: str | None,
|
||||
max_tokens: int | None,
|
||||
) -> dict[str, Any]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
body: dict[str, Any] = {"model": self._model, "messages": messages}
|
||||
if max_tokens:
|
||||
body["max_tokens"] = max_tokens
|
||||
return body
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
http: httpx.AsyncClient | None = None,
|
||||
) -> OpenRouterResponse:
|
||||
"""Single chat completion. Errors are raised verbatim — caller decides retries."""
|
||||
|
||||
payload = self._build_payload(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
**_DEFAULT_HEADERS,
|
||||
}
|
||||
url = f"{self._base}/chat/completions"
|
||||
started = time.monotonic()
|
||||
if http is not None:
|
||||
response = await http.post(url, json=payload, headers=headers, timeout=self._timeout)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.post(
|
||||
url, json=payload, headers=headers, timeout=self._timeout
|
||||
)
|
||||
latency_ms = int((time.monotonic() - started) * 1000)
|
||||
if response.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"OpenRouter HTTP {response.status_code}: {response.text[:300]}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
return _parse_chat_completion(response.json(), latency_ms=latency_ms)
|
||||
|
||||
|
||||
__all__ = ["OpenRouterChatProvider"]
|
||||
|
|
@ -0,0 +1,231 @@
|
|||
"""Native-PDF arm provider: OpenRouter ``chat/completions`` with PDF input.
|
||||
|
||||
Per `<https://openrouter.ai/docs/features/multimodal/pdfs>`__ the wire
|
||||
shape is OpenAI-compatible with one PDF-specific extra:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "anthropic/claude-sonnet-4.5",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "file", "file": {"filename": "case.pdf",
|
||||
"file_data": "data:application/pdf;base64,<b64>"}},
|
||||
{"type": "text", "text": "<prompt>"}
|
||||
]
|
||||
}],
|
||||
"plugins": [{"id": "file-parser", "pdf": {"engine": "native"}}]
|
||||
}
|
||||
```
|
||||
|
||||
``engine: "native"`` is the only engine that doesn't pre-OCR the
|
||||
PDF — it forwards raw bytes to PDF-native models (Claude, Gemini),
|
||||
matching what a human user does when "dropping the PDF into Claude".
|
||||
``mistral-ocr`` and ``cloudflare-ai`` are exposed as enum options for
|
||||
non-native models.
|
||||
|
||||
Headers ``HTTP-Referer`` and ``X-Title`` make spend show up cleanly on
|
||||
the OpenRouter dashboard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfEngine(str, Enum):
|
||||
NATIVE = "native"
|
||||
MISTRAL_OCR = "mistral-ocr"
|
||||
CLOUDFLARE_AI = "cloudflare-ai"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenRouterResponse:
|
||||
"""Subset of the OpenRouter response we care about for scoring."""
|
||||
|
||||
text: str
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
cost_micros: int
|
||||
latency_ms: int
|
||||
finish_reason: str | None
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
_DEFAULT_HEADERS = {
|
||||
"HTTP-Referer": "https://github.com/MODSetter/SurfSense",
|
||||
"X-Title": "SurfSense-evals",
|
||||
}
|
||||
|
||||
|
||||
class OpenRouterPdfProvider:
|
||||
"""Thin httpx-based client. Stateless; safe to reuse per arm instance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
model: str,
|
||||
engine: PdfEngine = PdfEngine.NATIVE,
|
||||
timeout_s: float = 600.0,
|
||||
) -> None:
|
||||
if not api_key:
|
||||
raise ValueError("OPENROUTER_API_KEY is required for the native arm.")
|
||||
self._api_key = api_key
|
||||
self._base = base_url.rstrip("/")
|
||||
self._model = model
|
||||
self._engine = engine
|
||||
self._timeout = httpx.Timeout(timeout_s, connect=15.0)
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def engine(self) -> PdfEngine:
|
||||
return self._engine
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
pdf_path: Path,
|
||||
max_tokens: int | None,
|
||||
extra_messages: list[dict[str, Any]] | None,
|
||||
) -> dict[str, Any]:
|
||||
b64 = base64.b64encode(pdf_path.read_bytes()).decode("ascii")
|
||||
user_content: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"filename": pdf_path.name,
|
||||
"file_data": f"data:application/pdf;base64,{b64}",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
]
|
||||
messages: list[dict[str, Any]] = list(extra_messages or [])
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
body: dict[str, Any] = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"plugins": [
|
||||
{"id": "file-parser", "pdf": {"engine": self._engine.value}}
|
||||
],
|
||||
}
|
||||
if max_tokens:
|
||||
body["max_tokens"] = max_tokens
|
||||
return body
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
pdf_path: Path,
|
||||
max_tokens: int | None = None,
|
||||
extra_messages: list[dict[str, Any]] | None = None,
|
||||
http: httpx.AsyncClient | None = None,
|
||||
) -> OpenRouterResponse:
|
||||
"""Single chat completion. Errors are raised verbatim — runner decides retries."""
|
||||
|
||||
payload = self._build_payload(
|
||||
prompt=prompt,
|
||||
pdf_path=pdf_path,
|
||||
max_tokens=max_tokens,
|
||||
extra_messages=extra_messages,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
**_DEFAULT_HEADERS,
|
||||
}
|
||||
url = f"{self._base}/chat/completions"
|
||||
started = time.monotonic()
|
||||
if http is not None:
|
||||
response = await http.post(url, json=payload, headers=headers, timeout=self._timeout)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.post(
|
||||
url, json=payload, headers=headers, timeout=self._timeout
|
||||
)
|
||||
latency_ms = int((time.monotonic() - started) * 1000)
|
||||
if response.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"OpenRouter HTTP {response.status_code}: {response.text[:300]}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
data = response.json()
|
||||
return _parse_chat_completion(data, latency_ms=latency_ms)
|
||||
|
||||
|
||||
def _parse_chat_completion(payload: dict[str, Any], *, latency_ms: int) -> OpenRouterResponse:
|
||||
"""Tolerant parser for OpenRouter / OpenAI chat-completions JSON.
|
||||
|
||||
OpenRouter passes through any provider-specific extras, but the
|
||||
canonical shape is ``choices[0].message.content`` (string OR array
|
||||
of content parts) and ``usage.prompt_tokens / completion_tokens / total_tokens``.
|
||||
Cost lives at the top level (``payload["usage"]["cost"]`` or
|
||||
``payload["x-or-cost"]``) depending on routing.
|
||||
"""
|
||||
|
||||
text = ""
|
||||
finish_reason: str | None = None
|
||||
choices = payload.get("choices") or []
|
||||
if choices:
|
||||
message = (choices[0] or {}).get("message") or {}
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text = content
|
||||
elif isinstance(content, list):
|
||||
chunks: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") in {"text", "output_text"}:
|
||||
chunks.append(str(part.get("text", "")))
|
||||
text = "".join(chunks)
|
||||
finish_reason = (choices[0] or {}).get("finish_reason") or None
|
||||
|
||||
usage = payload.get("usage") or {}
|
||||
input_tokens = int(usage.get("prompt_tokens") or 0)
|
||||
output_tokens = int(usage.get("completion_tokens") or 0)
|
||||
total_tokens = int(usage.get("total_tokens") or (input_tokens + output_tokens))
|
||||
|
||||
# OpenRouter exposes cost in dollars on `usage.cost` or `cost`. We
|
||||
# convert to integer micros to avoid float-summing surprises across
|
||||
# 7,663 MIRAGE questions.
|
||||
raw_cost = usage.get("cost")
|
||||
if raw_cost is None:
|
||||
raw_cost = payload.get("cost")
|
||||
cost_micros = 0
|
||||
if raw_cost is not None:
|
||||
try:
|
||||
cost_micros = int(round(float(raw_cost) * 1_000_000))
|
||||
except (TypeError, ValueError):
|
||||
cost_micros = 0
|
||||
|
||||
return OpenRouterResponse(
|
||||
text=text,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
latency_ms=latency_ms,
|
||||
finish_reason=finish_reason,
|
||||
raw=payload,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["OpenRouterPdfProvider", "OpenRouterResponse", "PdfEngine"]
|
||||
265
surfsense_evals/src/surfsense_evals/core/registry.py
Normal file
265
surfsense_evals/src/surfsense_evals/core/registry.py
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
"""Suite + Benchmark protocols and the global registry.
|
||||
|
||||
The extensibility seam: ``core.cli`` walks ``surfsense_evals.suites`` on
|
||||
import, which auto-imports every benchmark subpackage, which calls
|
||||
``register(<benchmark>)`` at module bottom. The CLI then iterates the
|
||||
populated registry to build subcommand groups dynamically.
|
||||
|
||||
Adding a new domain = drop a folder under ``suites/<domain>/<bench>/``
|
||||
that ends in ``register(MyBenchmark())``. No edits anywhere in
|
||||
``core/`` are required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
import httpx
|
||||
|
||||
from .clients import DocumentsClient, NewChatClient, SearchSpaceClient
|
||||
from .config import Config, SuiteState
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run context — what every benchmark.ingest/run receives
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunContext:
|
||||
"""Per-invocation environment threaded into ``ingest`` and ``run``.
|
||||
|
||||
A benchmark uses this to read pinned suite state, build new HTTP
|
||||
clients on the shared ``http`` session, find the right data /
|
||||
reports paths, and discover the active OpenRouter model + key.
|
||||
|
||||
``http`` is the authenticated SurfSense client (auth event hook
|
||||
attached). It is **not** an OpenRouter client — providers create
|
||||
their own short-lived clients because OpenRouter doesn't share the
|
||||
SurfSense bearer.
|
||||
"""
|
||||
|
||||
suite: str
|
||||
benchmark: str
|
||||
config: Config
|
||||
suite_state: SuiteState
|
||||
http: httpx.AsyncClient
|
||||
|
||||
@property
|
||||
def search_space_id(self) -> int:
|
||||
return self.suite_state.search_space_id
|
||||
|
||||
@property
|
||||
def agent_llm_id(self) -> int:
|
||||
return self.suite_state.agent_llm_id
|
||||
|
||||
@property
|
||||
def provider_model(self) -> str:
|
||||
"""Slug used by the SurfSense agent (and the native arm by default).
|
||||
|
||||
For ``cost-arbitrage`` scenarios this is the *cheap, text-only*
|
||||
slug — SurfSense answers from the chunks the vision LLM already
|
||||
extracted at ingest. The native arm should use
|
||||
``native_arm_model`` instead in that scenario.
|
||||
"""
|
||||
|
||||
return self.suite_state.provider_model
|
||||
|
||||
@property
|
||||
def native_arm_model(self) -> str:
|
||||
"""Slug the native_pdf arm should use.
|
||||
|
||||
Defaults to ``provider_model`` (head-to-head / symmetric-cheap);
|
||||
for ``cost-arbitrage`` it returns the explicit
|
||||
``--native-arm-model`` so the native arm can fairly answer
|
||||
image-bearing questions.
|
||||
"""
|
||||
|
||||
return self.suite_state.effective_native_arm_model
|
||||
|
||||
@property
|
||||
def vision_provider_model(self) -> str | None:
|
||||
"""Slug of the OpenRouter vision LLM SurfSense used at ingest.
|
||||
|
||||
``None`` if no vision config was attached at setup (legacy or
|
||||
text-only suite). Used by runners purely to record what was
|
||||
actually used in ``RunArtifact.extra`` and to label reports.
|
||||
"""
|
||||
|
||||
return self.suite_state.vision_provider_model
|
||||
|
||||
@property
|
||||
def scenario(self) -> str:
|
||||
"""Scenario name pinned at setup time (see ``config.SCENARIOS``)."""
|
||||
|
||||
return self.suite_state.scenario
|
||||
|
||||
def search_space_client(self) -> SearchSpaceClient:
|
||||
return SearchSpaceClient(self.http, self.config.surfsense_api_base)
|
||||
|
||||
def documents_client(self) -> DocumentsClient:
|
||||
return DocumentsClient(self.http, self.config.surfsense_api_base)
|
||||
|
||||
def new_chat_client(self) -> NewChatClient:
|
||||
return NewChatClient(self.http, self.config.surfsense_api_base)
|
||||
|
||||
def maps_dir(self) -> Path:
|
||||
path = self.config.suite_maps_dir(self.suite)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
def runs_dir(self, *, run_timestamp: str) -> Path:
|
||||
path = self.config.suite_runs_dir(self.suite) / run_timestamp / self.benchmark
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
def benchmark_data_dir(self) -> Path:
|
||||
path = self.config.suite_data_dir(self.suite) / self.benchmark
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run artifact + report section
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunArtifact:
|
||||
"""Everything a runner persists for the report writer to consume.
|
||||
|
||||
``raw_path`` points at the JSONL of per-question ``ArmResult``
|
||||
rows. ``metrics`` is a free-form dict the benchmark fills in (e.g.
|
||||
``{"native": {...}, "surfsense": {...}, "delta": {...}}``).
|
||||
"""
|
||||
|
||||
suite: str
|
||||
benchmark: str
|
||||
run_timestamp: str
|
||||
raw_path: Path
|
||||
metrics: dict[str, Any] = field(default_factory=dict)
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReportSection:
|
||||
"""One benchmark's slice of the final summary."""
|
||||
|
||||
title: str
|
||||
headline: bool
|
||||
body_md: str
|
||||
body_json: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark protocol + registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Benchmark(Protocol):
|
||||
"""The contract every benchmark module ends with ``register(<x>)``."""
|
||||
|
||||
suite: str
|
||||
name: str
|
||||
headline: bool
|
||||
description: str
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None: # pragma: no cover - protocol
|
||||
"""Add benchmark-specific flags to ``run <suite> <benchmark>``."""
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry storage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_REGISTRY: dict[tuple[str, str], Benchmark] = {}
|
||||
|
||||
|
||||
def register(benchmark: Benchmark) -> None:
|
||||
"""Add ``benchmark`` to the registry. Last-wins on duplicate keys.
|
||||
|
||||
Duplicate registrations log a warning rather than raising so a
|
||||
benchmark module imported twice (once via auto-discovery, once via
|
||||
a test directly importing it) doesn't blow up the CLI.
|
||||
"""
|
||||
|
||||
key = (benchmark.suite, benchmark.name)
|
||||
if key in _REGISTRY:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"Benchmark %s/%s re-registered (overwriting prior)", *key
|
||||
)
|
||||
_REGISTRY[key] = benchmark
|
||||
|
||||
|
||||
def unregister(suite: str, name: str) -> None:
|
||||
"""Test helper: drop a single benchmark from the registry."""
|
||||
|
||||
_REGISTRY.pop((suite, name), None)
|
||||
|
||||
|
||||
def reset() -> None:
|
||||
"""Test helper: wipe the registry (use with monkeypatched discovery)."""
|
||||
|
||||
_REGISTRY.clear()
|
||||
|
||||
|
||||
def get(suite: str, name: str) -> Benchmark:
|
||||
try:
|
||||
return _REGISTRY[(suite, name)]
|
||||
except KeyError as exc:
|
||||
available = ", ".join(f"{s}/{n}" for s, n in sorted(_REGISTRY)) or "<none>"
|
||||
raise KeyError(
|
||||
f"Unknown benchmark '{suite}/{name}'. Registered: {available}"
|
||||
) from exc
|
||||
|
||||
|
||||
def list_suites() -> list[str]:
|
||||
return sorted({s for s, _ in _REGISTRY})
|
||||
|
||||
|
||||
def list_benchmarks(suite: str | None = None) -> list[Benchmark]:
|
||||
if suite is None:
|
||||
return [_REGISTRY[k] for k in sorted(_REGISTRY)]
|
||||
return [_REGISTRY[k] for k in sorted(_REGISTRY) if k[0] == suite]
|
||||
|
||||
|
||||
def snapshot() -> Mapping[tuple[str, str], Benchmark]:
|
||||
"""Read-only view for diagnostics (e.g. ``benchmarks list`` rendering)."""
|
||||
|
||||
return dict(_REGISTRY)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Arm",
|
||||
"Benchmark",
|
||||
"ReportSection",
|
||||
"RunArtifact",
|
||||
"RunContext",
|
||||
"get",
|
||||
"list_benchmarks",
|
||||
"list_suites",
|
||||
"register",
|
||||
"reset",
|
||||
"snapshot",
|
||||
"unregister",
|
||||
]
|
||||
|
||||
|
||||
# Re-export Arm from arms.base so suites can `from core.registry import Arm`.
|
||||
from .arms.base import Arm # noqa: E402, F401 (deliberate re-export at bottom)
|
||||
18
surfsense_evals/src/surfsense_evals/core/report/__init__.py
Normal file
18
surfsense_evals/src/surfsense_evals/core/report/__init__.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Report writer + section composition primitives. Lazy import."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .writer import write_report
|
||||
|
||||
__all__ = ["write_report"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "write_report":
|
||||
from .writer import write_report
|
||||
|
||||
return write_report
|
||||
raise AttributeError(f"module 'surfsense_evals.core.report' has no attribute {name!r}")
|
||||
89
surfsense_evals/src/surfsense_evals/core/report/writer.py
Normal file
89
surfsense_evals/src/surfsense_evals/core/report/writer.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Report writer — composes per-benchmark sections into one summary.
|
||||
|
||||
Output:
|
||||
|
||||
* ``reports/<suite>/<run-timestamp>/summary.md`` — human-readable.
|
||||
Bullet lists only (no tables) per project's coding-standards.
|
||||
* ``reports/<suite>/<run-timestamp>/summary.json`` — same content as
|
||||
structured JSON for downstream tooling (CI dashboards, regressions).
|
||||
|
||||
Headline benchmarks come first in both outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
|
||||
from ..config import Config
|
||||
from ..registry import ReportSection
|
||||
|
||||
|
||||
def write_report(
|
||||
*,
|
||||
config: Config,
|
||||
suite: str,
|
||||
sections: Iterable[ReportSection],
|
||||
run_timestamp: str,
|
||||
) -> Path:
|
||||
"""Write ``summary.md`` + ``summary.json``. Returns the path of the .md file."""
|
||||
|
||||
sections_list = list(sections)
|
||||
sections_list.sort(key=lambda s: (not s.headline, s.title.lower()))
|
||||
|
||||
out_dir = config.suite_reports_dir(suite) / run_timestamp
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
md_path = out_dir / "summary.md"
|
||||
json_path = out_dir / "summary.json"
|
||||
|
||||
md_lines: list[str] = [
|
||||
f"# SurfSense evals — suite `{suite}`",
|
||||
"",
|
||||
f"- Run timestamp: `{run_timestamp}`",
|
||||
f"- Sections: {len(sections_list)}",
|
||||
"",
|
||||
]
|
||||
headline = [s for s in sections_list if s.headline]
|
||||
secondary = [s for s in sections_list if not s.headline]
|
||||
if headline:
|
||||
md_lines.append("## Headline")
|
||||
md_lines.append("")
|
||||
for section in headline:
|
||||
md_lines.append(f"### {section.title}")
|
||||
md_lines.append("")
|
||||
md_lines.append(section.body_md.rstrip())
|
||||
md_lines.append("")
|
||||
if secondary:
|
||||
md_lines.append("## Secondary measurements")
|
||||
md_lines.append("")
|
||||
for section in secondary:
|
||||
md_lines.append(f"### {section.title}")
|
||||
md_lines.append("")
|
||||
md_lines.append(section.body_md.rstrip())
|
||||
md_lines.append("")
|
||||
|
||||
md_path.write_text("\n".join(md_lines).rstrip() + "\n", encoding="utf-8")
|
||||
|
||||
json_payload = {
|
||||
"suite": suite,
|
||||
"run_timestamp": run_timestamp,
|
||||
"sections": [
|
||||
{
|
||||
"title": s.title,
|
||||
"headline": s.headline,
|
||||
"body_md": s.body_md,
|
||||
"body_json": s.body_json,
|
||||
}
|
||||
for s in sections_list
|
||||
],
|
||||
}
|
||||
json_path.write_text(
|
||||
json.dumps(json_payload, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return md_path
|
||||
|
||||
|
||||
__all__ = ["ReportSection", "write_report"]
|
||||
58
surfsense_evals/src/surfsense_evals/core/scenarios.py
Normal file
58
surfsense_evals/src/surfsense_evals/core/scenarios.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Shared scenario formatting helpers for head-to-head benchmark reports.
|
||||
|
||||
The scenario chosen at ``setup`` time (``head-to-head``, ``symmetric-cheap``,
|
||||
``cost-arbitrage``) materially changes how a head-to-head report should be
|
||||
read. This module produces the one-bullet summary every head-to-head
|
||||
runner stamps near the top of its ``report_section`` body so reviewers
|
||||
immediately see the framing — no need to dig into ``run_artifact.json``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
|
||||
def format_scenario_md(extra: Mapping[str, Any] | None) -> str:
|
||||
"""Render a scenario-aware bullet for a benchmark report.
|
||||
|
||||
Reads ``extra["scenario"]`` plus the runtime LLM slugs the runner
|
||||
recorded. Falls back to a sensible "head-to-head" line if the artifact
|
||||
pre-dates scenarios so old runs still render cleanly.
|
||||
"""
|
||||
|
||||
extra = dict(extra or {})
|
||||
scenario = str(extra.get("scenario") or "head-to-head")
|
||||
surf_slug = str(extra.get("provider_model") or "?")
|
||||
native_slug = str(extra.get("native_arm_model") or surf_slug)
|
||||
vision_slug = extra.get("vision_provider_model")
|
||||
|
||||
if scenario == "cost-arbitrage":
|
||||
body = (
|
||||
f"- Scenario: **cost-arbitrage** — native arm answers with "
|
||||
f"`{native_slug}` (vision); SurfSense answers with `{surf_slug}` "
|
||||
f"over chunks vision-extracted at ingest"
|
||||
f"{f' by `{vision_slug}`' if vision_slug else ''}. "
|
||||
"Measures how close SurfSense gets to native at a fraction of "
|
||||
"the per-query cost."
|
||||
)
|
||||
elif scenario == "symmetric-cheap":
|
||||
body = (
|
||||
f"- Scenario: **symmetric-cheap** — both arms answer with "
|
||||
f"`{surf_slug}`; SurfSense pre-extracted images at ingest"
|
||||
f"{f' via `{vision_slug}`' if vision_slug else ''}. "
|
||||
"Native arm structurally loses on image-bearing questions "
|
||||
"(text-only model can't see images) — that's the point."
|
||||
)
|
||||
else:
|
||||
body = (
|
||||
f"- Scenario: head-to-head — both arms answer with `{surf_slug}` "
|
||||
"via OpenRouter."
|
||||
)
|
||||
if vision_slug:
|
||||
body += f" SurfSense ingest VLM: `{vision_slug}`."
|
||||
|
||||
return body
|
||||
|
||||
|
||||
__all__ = ["format_scenario_md"]
|
||||
127
surfsense_evals/src/surfsense_evals/core/vision_llm.py
Normal file
127
surfsense_evals/src/surfsense_evals/core/vision_llm.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Vision LLM resolution + auto-pick logic for the harness's ``setup`` command.
|
||||
|
||||
Two responsibilities:
|
||||
|
||||
1. Resolve an explicit ``--vision-llm <slug>`` to a global OpenRouter
|
||||
vision LLM config id that ``set_llm_preferences(vision_llm_config_id=...)``
|
||||
can accept.
|
||||
2. Auto-pick the strongest registered vision config when the operator
|
||||
doesn't pass ``--vision-llm`` but the scenario / benchmark needs one.
|
||||
|
||||
The priority list mirrors the recommended slugs in the README so the
|
||||
auto-pick is deterministic and reviewable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .clients.search_space import VisionLlmConfigEntry
|
||||
|
||||
# Order matters — first match wins when auto-picking. Keep these in sync
|
||||
# with the "Recommended vision slugs" table in the README so the
|
||||
# auto-pick story is the same one users read about.
|
||||
RECOMMENDED_VISION_PRIORITY: tuple[str, ...] = (
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"anthropic/claude-opus-4.7",
|
||||
"openai/gpt-5",
|
||||
"google/gemini-2.5-pro",
|
||||
)
|
||||
|
||||
|
||||
class VisionConfigError(RuntimeError):
|
||||
"""Raised when no vision config can be resolved (explicit or auto)."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedVisionConfig:
|
||||
"""Result of ``resolve_vision_llm`` — what to attach + a label for logs."""
|
||||
|
||||
config_id: int
|
||||
provider_model: str
|
||||
selected_via: str # "explicit" | "auto-priority" | "auto-fallback"
|
||||
|
||||
|
||||
def _openrouter_only(entries: Iterable[VisionLlmConfigEntry]) -> list[VisionLlmConfigEntry]:
|
||||
return [e for e in entries if e.provider == "OPENROUTER" and not e.is_auto_mode]
|
||||
|
||||
|
||||
def resolve_vision_llm(
|
||||
candidates: list[VisionLlmConfigEntry],
|
||||
*,
|
||||
explicit_slug: str | None,
|
||||
) -> ResolvedVisionConfig:
|
||||
"""Resolve a vision LLM config id from a slug or by auto-picking.
|
||||
|
||||
* If ``explicit_slug`` is given: must match exactly one OpenRouter
|
||||
vision config's ``model_name``. Raises ``VisionConfigError`` with a
|
||||
friendly listing if zero / many match.
|
||||
* Otherwise: walk ``RECOMMENDED_VISION_PRIORITY`` in order and return
|
||||
the first registered one. If none of the recommended slugs are
|
||||
registered, fall back to the first OpenRouter vision config in the
|
||||
list (deterministic by listing order). Raises ``VisionConfigError``
|
||||
if zero are registered at all.
|
||||
"""
|
||||
|
||||
or_vision = _openrouter_only(candidates)
|
||||
|
||||
if explicit_slug is not None:
|
||||
matches = [e for e in or_vision if e.model_name == explicit_slug]
|
||||
if not matches:
|
||||
sample = ", ".join(e.model_name for e in or_vision[:8]) or "<none>"
|
||||
raise VisionConfigError(
|
||||
f"No OpenRouter vision config found for slug '{explicit_slug}'. "
|
||||
"Make sure `openrouter_integration.vision_enabled: true` in "
|
||||
"global_llm_config.yaml and that the Celery worker has finished "
|
||||
"its first refresh. "
|
||||
f"Available OpenRouter vision slugs (sample): {sample}."
|
||||
)
|
||||
if len(matches) > 1:
|
||||
listing = "\n".join(f" id={e.id} name={e.name!r}" for e in matches)
|
||||
raise VisionConfigError(
|
||||
f"Multiple OpenRouter vision configs match '{explicit_slug}':\n{listing}"
|
||||
)
|
||||
only = matches[0]
|
||||
return ResolvedVisionConfig(
|
||||
config_id=only.id,
|
||||
provider_model=only.model_name,
|
||||
selected_via="explicit",
|
||||
)
|
||||
|
||||
if not or_vision:
|
||||
raise VisionConfigError(
|
||||
"No OpenRouter vision LLM configs are registered with this "
|
||||
"SurfSense backend. Either pass `--no-vision-llm` to the ingest "
|
||||
"step (text-only ingestion), or enable "
|
||||
"`openrouter_integration.vision_enabled: true` in "
|
||||
"global_llm_config.yaml so the Celery worker syncs vision-capable "
|
||||
"OpenRouter models on next refresh."
|
||||
)
|
||||
|
||||
by_slug = {e.model_name: e for e in or_vision}
|
||||
for preferred in RECOMMENDED_VISION_PRIORITY:
|
||||
match = by_slug.get(preferred)
|
||||
if match is not None:
|
||||
return ResolvedVisionConfig(
|
||||
config_id=match.id,
|
||||
provider_model=match.model_name,
|
||||
selected_via="auto-priority",
|
||||
)
|
||||
|
||||
# Fallback: first registered OpenRouter vision config. Deterministic
|
||||
# because the backend returns them in a stable order.
|
||||
fallback = or_vision[0]
|
||||
return ResolvedVisionConfig(
|
||||
config_id=fallback.id,
|
||||
provider_model=fallback.model_name,
|
||||
selected_via="auto-fallback",
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RECOMMENDED_VISION_PRIORITY",
|
||||
"ResolvedVisionConfig",
|
||||
"VisionConfigError",
|
||||
"resolve_vision_llm",
|
||||
]
|
||||
66
surfsense_evals/src/surfsense_evals/suites/__init__.py
Normal file
66
surfsense_evals/src/surfsense_evals/suites/__init__.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Suite registry auto-discovery.
|
||||
|
||||
Importing ``surfsense_evals.suites`` walks every subpackage one level deep
|
||||
(domain like ``medical``) AND its benchmark subpackages
|
||||
(``medical/medxpertqa``, ``medical/mirage``, ``medical/cure``). Each
|
||||
benchmark's ``__init__.py`` is expected to call
|
||||
``core.registry.register(<Benchmark>)`` at module bottom; merely importing
|
||||
the module is enough to populate the registry.
|
||||
|
||||
Adding a new domain is therefore: drop a folder under ``suites/`` with the
|
||||
right structure. No edits anywhere else.
|
||||
|
||||
Subpackages whose name starts with ``_`` are skipped — that's reserved for
|
||||
test fixtures (e.g. ``suites/_demo/``) so they don't accidentally show up
|
||||
in ``benchmarks list``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
from typing import Iterable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _iter_subpackages(package) -> Iterable[str]:
|
||||
"""Yield fully-qualified subpackage names one level deep, skipping ``_*``."""
|
||||
|
||||
for module_info in pkgutil.iter_modules(package.__path__, prefix=f"{package.__name__}."):
|
||||
if not module_info.ispkg:
|
||||
continue
|
||||
leaf = module_info.name.rsplit(".", 1)[-1]
|
||||
if leaf.startswith("_"):
|
||||
continue
|
||||
yield module_info.name
|
||||
|
||||
|
||||
def discover_suites() -> list[str]:
|
||||
"""Import every domain + benchmark subpackage so registrations fire.
|
||||
|
||||
Returns the list of fully-qualified benchmark module names that were
|
||||
successfully imported. Failures are logged (not raised) so a single
|
||||
broken benchmark doesn't take down the whole CLI — the operator still
|
||||
sees the working benchmarks via ``benchmarks list``.
|
||||
"""
|
||||
|
||||
import surfsense_evals.suites as _suites # self-import for __path__
|
||||
|
||||
imported: list[str] = []
|
||||
for domain_name in _iter_subpackages(_suites):
|
||||
try:
|
||||
domain_pkg = importlib.import_module(domain_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to import suite domain %s: %s", domain_name, exc)
|
||||
continue
|
||||
for benchmark_name in _iter_subpackages(domain_pkg):
|
||||
try:
|
||||
importlib.import_module(benchmark_name)
|
||||
imported.append(benchmark_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to import benchmark %s: %s", benchmark_name, exc
|
||||
)
|
||||
return imported
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
"""Test fixture suite — skipped by the auto-discovery walker (name starts with ``_``).
|
||||
|
||||
Imported explicitly by ``tests/core/test_registry.py`` to prove the
|
||||
register-on-import contract works without polluting the production
|
||||
benchmark list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
"""Demo benchmark — registers on import, used only by the registry tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from typing import Any
|
||||
|
||||
from ....core.registry import (
|
||||
Benchmark,
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
register,
|
||||
)
|
||||
|
||||
|
||||
class HelloBenchmark:
|
||||
suite: str = "_demo"
|
||||
name: str = "hello"
|
||||
headline: bool = False
|
||||
description: str = "Demo benchmark used by the registry test."
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("--echo", default="hi")
|
||||
|
||||
async def ingest(self, ctx: RunContext, **_opts: Any) -> None: # pragma: no cover
|
||||
return None
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact: # pragma: no cover
|
||||
return RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp="0",
|
||||
raw_path=ctx.benchmark_data_dir() / "raw.jsonl",
|
||||
metrics={"echo": opts.get("echo")},
|
||||
)
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
return ReportSection(
|
||||
title="Hello demo",
|
||||
headline=False,
|
||||
body_md="- runs: " + str(len(artifacts)),
|
||||
)
|
||||
|
||||
|
||||
register(HelloBenchmark())
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""Medical RAG benchmarks (MedXpertQA-MM headline + MIRAGE/CUREv1 secondary).
|
||||
|
||||
Subpackages register themselves with ``core.registry`` on import. The
|
||||
``suites/__init__.py`` discovery walker imports them automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""CUREv1 — secondary single-arm SurfSense retrieval measurement.
|
||||
|
||||
Source: https://huggingface.co/datasets/clinia/CUREv1
|
||||
Paper: https://arxiv.org/html/2412.06954v4
|
||||
|
||||
Pure retrieval benchmark — 10 medical disciplines, English/French/Spanish
|
||||
queries, expert-curated qrels (graded 0/1/2). The harness ingests the
|
||||
corpus, runs each query via SurfSense's ``/api/v1/new_chat``, parses
|
||||
chunk citations, maps them back to CUREv1 ``corpus-id``, and scores
|
||||
Recall@k / MRR / nDCG@10 against qrels.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .runner import CureBenchmark
|
||||
from ....core import registry as _registry
|
||||
|
||||
_registry.register(CureBenchmark())
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
"""CUREv1 ingestion.
|
||||
|
||||
For each (lang, discipline) requested, downloads the corpus split via
|
||||
``datasets.load_dataset(path="clinia/CUREv1", name="corpus", split=<discipline>)``,
|
||||
batches passages into ~5 MB markdown bundles, uploads them to
|
||||
SurfSense, polls until ``ready``, and persists the
|
||||
``corpus_id -> document_id`` map under
|
||||
``data/medical/maps/cure_corpus_map_<discipline>.jsonl``. A union map
|
||||
``cure_corpus_map.jsonl`` is also written so the runner can resolve
|
||||
citations across disciplines without juggling per-file paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_BATCH_SIZE_BYTES = 5 * 1024 * 1024
|
||||
|
||||
# 10 disciplines covered by the dataset card. We exhaustively list
|
||||
# them so a smoke test can default to one.
|
||||
DISCIPLINES = (
|
||||
"anesthesiology",
|
||||
"cardiology",
|
||||
"dermatology",
|
||||
"endocrinology",
|
||||
"gastroenterology",
|
||||
"hematology",
|
||||
"nephrology",
|
||||
"neurology",
|
||||
"obstetrics_gynecology",
|
||||
"psychiatry",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorpusPassage:
|
||||
corpus_id: str
|
||||
title: str
|
||||
text: str
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
title = (self.title or "").strip() or "Untitled"
|
||||
body = (self.text or "").strip()
|
||||
return f"# {title}\n\n_id: `{self.corpus_id}`_\n\n{body}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PassageBatch:
|
||||
path: Path
|
||||
corpus_ids: list[str]
|
||||
|
||||
|
||||
def _stream_corpus(discipline: str) -> Iterable[CorpusPassage]:
|
||||
"""Stream corpus rows for one discipline via the ``datasets`` library."""
|
||||
|
||||
from datasets import load_dataset # noqa: PLC0415
|
||||
|
||||
logger.info("Loading CUREv1 corpus for discipline=%s", discipline)
|
||||
ds = load_dataset(path="clinia/CUREv1", name="corpus", split=discipline)
|
||||
for row in ds:
|
||||
cid = str(row.get("_id") or "")
|
||||
if not cid:
|
||||
continue
|
||||
yield CorpusPassage(
|
||||
corpus_id=cid,
|
||||
title=str(row.get("title") or ""),
|
||||
text=str(row.get("text") or ""),
|
||||
)
|
||||
|
||||
|
||||
def _write_batches(
|
||||
passages: Iterable[CorpusPassage],
|
||||
*,
|
||||
out_dir: Path,
|
||||
discipline: str,
|
||||
batch_bytes: int = _BATCH_SIZE_BYTES,
|
||||
) -> list[PassageBatch]:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
batches: list[PassageBatch] = []
|
||||
current_buffer = io.StringIO()
|
||||
current_ids: list[str] = []
|
||||
current_bytes = 0
|
||||
batch_idx = 0
|
||||
|
||||
def _flush() -> None:
|
||||
nonlocal current_buffer, current_ids, current_bytes, batch_idx
|
||||
if not current_ids:
|
||||
return
|
||||
path = out_dir / f"cure_{discipline}_{batch_idx:04d}.md"
|
||||
path.write_text(current_buffer.getvalue(), encoding="utf-8")
|
||||
batches.append(PassageBatch(path=path, corpus_ids=current_ids))
|
||||
batch_idx += 1
|
||||
current_buffer = io.StringIO()
|
||||
current_ids = []
|
||||
current_bytes = 0
|
||||
|
||||
for passage in passages:
|
||||
chunk = passage.to_markdown() + "\n---\n\n"
|
||||
chunk_bytes = len(chunk.encode("utf-8"))
|
||||
if current_bytes + chunk_bytes > batch_bytes and current_ids:
|
||||
_flush()
|
||||
current_buffer.write(chunk)
|
||||
current_ids.append(passage.corpus_id)
|
||||
current_bytes += chunk_bytes
|
||||
_flush()
|
||||
return batches
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
disciplines: list[str] | None = None,
|
||||
max_per_discipline: int | None = None,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
disciplines = disciplines or list(DISCIPLINES)
|
||||
settings = settings or IngestSettings(use_vision_llm=False, processing_mode="basic")
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
batches_root = bench_dir / "batches"
|
||||
batches_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
docs_client = ctx.documents_client()
|
||||
union_map_path = ctx.maps_dir() / "cure_corpus_map.jsonl"
|
||||
union_map_fh = union_map_path.open("w", encoding="utf-8")
|
||||
# Header row records the ingest-time settings so the runner can
|
||||
# surface them in the report (see core/ingest_settings.py).
|
||||
union_map_fh.write(settings_header_line(settings) + "\n")
|
||||
try:
|
||||
for discipline in disciplines:
|
||||
try:
|
||||
passages_iter = _stream_corpus(discipline)
|
||||
if max_per_discipline is not None:
|
||||
passages_iter = _take(passages_iter, max_per_discipline)
|
||||
batches = _write_batches(
|
||||
passages_iter,
|
||||
out_dir=batches_root / discipline,
|
||||
discipline=discipline,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Skipping discipline %s: %s", discipline, exc)
|
||||
continue
|
||||
if not batches:
|
||||
logger.warning("Discipline %s produced 0 batches; skipping upload", discipline)
|
||||
continue
|
||||
logger.info(
|
||||
"Uploading %d batches for discipline %s", len(batches), discipline
|
||||
)
|
||||
upload_result = await docs_client.upload(
|
||||
files=[b.path for b in batches],
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
new_doc_ids = list(upload_result.document_ids)
|
||||
if new_doc_ids:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=new_doc_ids,
|
||||
timeout_s=3600.0,
|
||||
max_poll_s=15.0,
|
||||
)
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=new_doc_ids + upload_result.duplicate_document_ids,
|
||||
)
|
||||
title_to_doc = {s.title: s.document_id for s in statuses}
|
||||
|
||||
per_discipline_path = (
|
||||
ctx.maps_dir() / f"cure_corpus_map_{discipline}.jsonl"
|
||||
)
|
||||
with per_discipline_path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for batch in batches:
|
||||
doc_id = title_to_doc.get(batch.path.name)
|
||||
if doc_id is None:
|
||||
logger.warning("No document_id for batch %s", batch.path.name)
|
||||
continue
|
||||
for cid in batch.corpus_ids:
|
||||
record = {
|
||||
"corpus_id": cid,
|
||||
"document_id": doc_id,
|
||||
"discipline": discipline,
|
||||
}
|
||||
fh.write(json.dumps(record) + "\n")
|
||||
union_map_fh.write(json.dumps(record) + "\n")
|
||||
|
||||
chunks_map_path = ctx.maps_dir() / f"cure_chunk_map_{discipline}.jsonl"
|
||||
with chunks_map_path.open("w", encoding="utf-8") as fh:
|
||||
for doc_id in {title_to_doc.get(b.path.name) for b in batches} - {None}:
|
||||
try:
|
||||
chunks = await docs_client.list_chunks(int(doc_id))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to list chunks for doc_id=%s: %s", doc_id, exc
|
||||
)
|
||||
continue
|
||||
for chunk in chunks:
|
||||
fh.write(
|
||||
json.dumps(
|
||||
{
|
||||
"chunk_id": chunk.id,
|
||||
"document_id": doc_id,
|
||||
"discipline": discipline,
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
finally:
|
||||
union_map_fh.close()
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["cure"] = str(union_map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
logger.info("CUREv1 ingestion complete; union map at %s", union_map_path)
|
||||
|
||||
|
||||
def _take(it: Iterable, n: int) -> Iterable:
|
||||
yielded = 0
|
||||
for x in it:
|
||||
if yielded >= n:
|
||||
return
|
||||
yield x
|
||||
yielded += 1
|
||||
|
||||
|
||||
__all__ = ["DISCIPLINES", "CorpusPassage", "PassageBatch", "run_ingest"]
|
||||
|
|
@ -0,0 +1,397 @@
|
|||
"""CUREv1 runner — single-arm SurfSense retrieval scoring.
|
||||
|
||||
For each query we ask SurfSense via ``/api/v1/new_chat`` (no
|
||||
``mentioned_document_ids``) and parse chunk citations from the
|
||||
streamed answer. Cited ``chunk_id`` → ``document_id`` (chunk map) →
|
||||
``corpus_id`` (corpus map). The resulting ranked list is scored
|
||||
against the dataset's qrels.
|
||||
|
||||
The prompt nudges the model to surface its supporting passages via
|
||||
SurfSense's standard ``[citation:CHUNK_ID]`` format (already required
|
||||
by the agent system prompt), so we recover retrieval ordering from
|
||||
the answer text without needing a separate retrieval API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
is_settings_header,
|
||||
read_settings_header,
|
||||
)
|
||||
from ....core.metrics.retrieval import score_run
|
||||
from ....core.registry import (
|
||||
Benchmark,
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_PROMPT = """\
|
||||
You are a medical literature retrieval assistant for the question
|
||||
below. Identify the top passages from the knowledge base that best
|
||||
answer it and cite each one in the standard format
|
||||
[citation:CHUNK_ID]. List as many citations as are useful, ordered
|
||||
from most to least relevant. Provide a one-sentence justification
|
||||
for each citation.
|
||||
|
||||
Query: {query}
|
||||
"""
|
||||
|
||||
|
||||
_DESCRIPTION = "CUREv1 retrieval (single-arm SurfSense): Recall@k / MRR / nDCG@10."
|
||||
|
||||
# CUREv1 corpus is text-only markdown bundles; vision LLM at ingest
|
||||
# is wasted by default but the operator can flip it via CLI for an
|
||||
# A/B comparison.
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CureQuery:
|
||||
qid: str
|
||||
text: str
|
||||
discipline: str
|
||||
|
||||
|
||||
def _load_chunk_map(maps_dir: Path) -> dict[int, int]:
|
||||
"""Union all ``cure_chunk_map_<discipline>.jsonl`` into one dict."""
|
||||
|
||||
out: dict[int, int] = {}
|
||||
for path in sorted(maps_dir.glob("cure_chunk_map_*.jsonl")):
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
continue
|
||||
try:
|
||||
out[int(row["chunk_id"])] = int(row["document_id"])
|
||||
except (KeyError, TypeError, ValueError):
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
def _load_doc_to_corpus(maps_dir: Path) -> dict[int, list[str]]:
|
||||
"""Map ``document_id -> [corpus_id, ...]`` from the union map.
|
||||
|
||||
Multiple corpus passages may live in one batched markdown
|
||||
document, so each doc_id maps to a list. Citation ordering of the
|
||||
first occurrence is preserved.
|
||||
"""
|
||||
|
||||
out: dict[int, list[str]] = defaultdict(list)
|
||||
union_path = maps_dir / "cure_corpus_map.jsonl"
|
||||
if not union_path.exists():
|
||||
return out
|
||||
with union_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
continue
|
||||
try:
|
||||
out[int(row["document_id"])].append(str(row["corpus_id"]))
|
||||
except (KeyError, TypeError, ValueError):
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
def _load_queries(*, lang: str, disciplines: list[str], sample_n: int | None) -> list[CureQuery]:
|
||||
from datasets import load_dataset # noqa: PLC0415
|
||||
|
||||
out: list[CureQuery] = []
|
||||
for discipline in disciplines:
|
||||
try:
|
||||
ds = load_dataset(path="clinia/CUREv1", name=f"queries-{lang}", split=discipline)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Skipping queries for %s/%s: %s", lang, discipline, exc)
|
||||
continue
|
||||
for row in ds:
|
||||
qid = str(row.get("_id") or "")
|
||||
text = str(row.get("text") or "")
|
||||
if not qid or not text:
|
||||
continue
|
||||
out.append(CureQuery(qid=qid, text=text, discipline=discipline))
|
||||
out.sort(key=lambda q: (q.discipline, q.qid))
|
||||
if sample_n is not None and sample_n > 0:
|
||||
# Stratified-by-discipline slice.
|
||||
per_d = max(1, sample_n // max(1, len(disciplines)))
|
||||
sliced: list[CureQuery] = []
|
||||
counter: dict[str, int] = defaultdict(int)
|
||||
for q in out:
|
||||
if counter[q.discipline] >= per_d:
|
||||
continue
|
||||
sliced.append(q)
|
||||
counter[q.discipline] += 1
|
||||
if len(sliced) >= sample_n:
|
||||
break
|
||||
out = sliced
|
||||
return out
|
||||
|
||||
|
||||
def _load_qrels(*, disciplines: list[str]) -> dict[str, dict[str, float]]:
|
||||
from datasets import load_dataset # noqa: PLC0415
|
||||
|
||||
out: dict[str, dict[str, float]] = defaultdict(dict)
|
||||
for discipline in disciplines:
|
||||
try:
|
||||
ds = load_dataset(path="clinia/CUREv1", name="qrels", split=discipline)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Skipping qrels for %s: %s", discipline, exc)
|
||||
continue
|
||||
for row in ds:
|
||||
qid = str(row.get("query-id") or row.get("query_id") or "")
|
||||
cid = str(row.get("corpus-id") or row.get("corpus_id") or "")
|
||||
score = row.get("score")
|
||||
if not qid or not cid or score is None:
|
||||
continue
|
||||
try:
|
||||
out[qid][cid] = float(score)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
async def _gather_with_limit(coros, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(c):
|
||||
async with sem:
|
||||
return await c
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
class CureBenchmark:
|
||||
suite: str = "medical"
|
||||
name: str = "cure"
|
||||
headline: bool = False
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("--lang", default="en", choices=("en", "es", "fr"))
|
||||
parser.add_argument("--discipline", default=None,
|
||||
help="Restrict to one discipline (default: all ingested).")
|
||||
parser.add_argument("--n", dest="sample_n", type=int, default=None)
|
||||
parser.add_argument("--concurrency", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--max-passages-per-discipline", type=int, default=None,
|
||||
help="(ingest only) cap corpus rows per discipline for smoke testing.",
|
||||
)
|
||||
# Per-upload knobs forwarded to /documents/fileupload at ingest;
|
||||
# ignored at run-time (runner reads resolved settings from the
|
||||
# union-map header).
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import DISCIPLINES, run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
disciplines=list(DISCIPLINES),
|
||||
max_per_discipline=opts.get("max_passages_per_discipline"),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
lang = opts.get("lang") or "en"
|
||||
discipline_filter = opts.get("discipline")
|
||||
sample_n = opts.get("sample_n")
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
|
||||
maps_dir = ctx.maps_dir()
|
||||
chunk_to_doc = _load_chunk_map(maps_dir)
|
||||
doc_to_corpus = _load_doc_to_corpus(maps_dir)
|
||||
ingest_settings = read_settings_header(maps_dir / "cure_corpus_map.jsonl")
|
||||
if not chunk_to_doc or not doc_to_corpus:
|
||||
raise RuntimeError(
|
||||
"CUREv1 not ingested for this suite. Run "
|
||||
"`python -m surfsense_evals ingest medical cure` first."
|
||||
)
|
||||
|
||||
# Disciplines to query are determined by the per-discipline maps
|
||||
# actually present (either user-filtered or whatever was ingested).
|
||||
ingested_disciplines = sorted({
|
||||
row_disc
|
||||
for path in maps_dir.glob("cure_corpus_map_*.jsonl")
|
||||
for row_disc in [path.stem[len("cure_corpus_map_"):]]
|
||||
})
|
||||
if discipline_filter:
|
||||
disciplines = [discipline_filter]
|
||||
else:
|
||||
disciplines = ingested_disciplines or ["dermatology"]
|
||||
|
||||
queries = _load_queries(lang=lang, disciplines=disciplines, sample_n=sample_n)
|
||||
if not queries:
|
||||
raise RuntimeError(
|
||||
f"No CUREv1 queries matched lang={lang!r} disciplines={disciplines!r}."
|
||||
)
|
||||
qrels = _load_qrels(disciplines=disciplines)
|
||||
logger.info(
|
||||
"CUREv1: %d queries / %d qrels across disciplines %s",
|
||||
len(queries),
|
||||
len(qrels),
|
||||
disciplines,
|
||||
)
|
||||
|
||||
arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
async def _ask(q: CureQuery) -> ArmResult:
|
||||
return await arm.answer(
|
||||
ArmRequest(
|
||||
question_id=f"{q.discipline}::{q.qid}",
|
||||
prompt=_PROMPT.format(query=q.text.strip()),
|
||||
)
|
||||
)
|
||||
|
||||
results: list[ArmResult] = await _gather_with_limit(
|
||||
(_ask(q) for q in queries), concurrency=concurrency
|
||||
)
|
||||
|
||||
per_query_retrieved: dict[str, list[str]] = {}
|
||||
for q, res in zip(queries, results):
|
||||
chunk_ids: list[int] = []
|
||||
seen: set[int] = set()
|
||||
for citation in res.citations:
|
||||
if citation.get("kind") != "chunk":
|
||||
continue
|
||||
cid = int(citation.get("chunk_id"))
|
||||
if cid in seen:
|
||||
continue
|
||||
chunk_ids.append(cid)
|
||||
seen.add(cid)
|
||||
corpus_ids: list[str] = []
|
||||
seen_corpus: set[str] = set()
|
||||
for cid in chunk_ids:
|
||||
doc_id = chunk_to_doc.get(cid)
|
||||
if doc_id is None:
|
||||
continue
|
||||
for corpus_id in doc_to_corpus.get(doc_id, []):
|
||||
if corpus_id in seen_corpus:
|
||||
continue
|
||||
corpus_ids.append(corpus_id)
|
||||
seen_corpus.add(corpus_id)
|
||||
per_query_retrieved[q.qid] = corpus_ids
|
||||
|
||||
scores = score_run(
|
||||
per_query_retrieved=per_query_retrieved,
|
||||
per_query_qrels=qrels,
|
||||
ks=(1, 5, 10, 32),
|
||||
ndcg_k=10,
|
||||
)
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, res in zip(queries, results):
|
||||
fh.write(
|
||||
json.dumps(
|
||||
{
|
||||
"discipline": q.discipline,
|
||||
"qid": q.qid,
|
||||
"lang": lang,
|
||||
"retrieved_corpus_ids": per_query_retrieved.get(q.qid, []),
|
||||
**res.to_jsonl(),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
metrics = scores.to_dict()
|
||||
metrics["lang"] = lang
|
||||
metrics["disciplines"] = disciplines
|
||||
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_queries": len(queries),
|
||||
"lang": lang,
|
||||
"disciplines": disciplines,
|
||||
"concurrency": concurrency,
|
||||
"provider_model": ctx.provider_model,
|
||||
"ingest_settings": ingest_settings,
|
||||
},
|
||||
)
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="CUREv1 — single-arm SurfSense retrieval",
|
||||
headline=False,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
m = latest.metrics
|
||||
recall = m.get("recall_at_k", {})
|
||||
lines: list[str] = [
|
||||
format_ingest_settings_md(latest.extra.get("ingest_settings")),
|
||||
f"- Language: {m.get('lang', '?')}",
|
||||
f"- Disciplines: {', '.join(m.get('disciplines', []) or ['?'])}",
|
||||
f"- n_queries (after qrels intersection): {m.get('n_queries', 0)}",
|
||||
]
|
||||
for k in (1, 5, 10, 32):
|
||||
v = recall.get(str(k), recall.get(k))
|
||||
if v is not None:
|
||||
lines.append(f"- Recall@{k}: {float(v):.3f}")
|
||||
lines.append(f"- MRR: {float(m.get('mrr', 0.0)):.3f}")
|
||||
lines.append(f"- nDCG@10: {float(m.get('ndcg_at_10', 0.0)):.3f}")
|
||||
return ReportSection(
|
||||
title="CUREv1 — single-arm SurfSense retrieval",
|
||||
headline=False,
|
||||
body_md="\n".join(lines),
|
||||
body_json=m,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["CureBenchmark", "CureQuery"]
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
"""MedXpertQA-MM — multimodal medical exam head-to-head (medical suite headline).
|
||||
|
||||
Source: https://huggingface.co/datasets/TsinghuaC3I/MedXpertQA
|
||||
Paper: https://arxiv.org/abs/2501.18362 (ICML 2025)
|
||||
|
||||
* MM subset: ~2,000 expert-level exam questions with diverse medical
|
||||
images (radiology, dermatology, pathology, ECGs, gross specimens,
|
||||
fundus photos) and structured patient information embedded in the
|
||||
question stem.
|
||||
* 5 answer choices per MM question (A–E).
|
||||
* USMLE / COMLEX / 17 specialty board sources; rigorously filtered
|
||||
and reviewed by physicians.
|
||||
|
||||
Real diagnostic images carry signal that text-only patient charts
|
||||
cannot (e.g. CT scans, dermoscopy), so this benchmark exercises the
|
||||
full vision RAG pipeline end-to-end against a vision-capable model
|
||||
fed the same PDF natively.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import registry as _registry
|
||||
from .runner import MedXpertQAMMBenchmark
|
||||
|
||||
_registry.register(MedXpertQAMMBenchmark())
|
||||
|
|
@ -0,0 +1,394 @@
|
|||
"""MedXpertQA-MM ingestion.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Pull ``MM/test.jsonl`` (and optionally ``MM/dev.jsonl``) plus
|
||||
``images.zip`` from
|
||||
``hf://datasets/TsinghuaC3I/MedXpertQA``. Cache under
|
||||
``<data_dir>/medical/medxpertqa/``.
|
||||
2. Extract ``images.zip`` once into ``<data_dir>/medical/medxpertqa/images/``.
|
||||
3. Render one PDF per MM question (text question + structured patient
|
||||
info embedded in the question stem + each image flowable + answer
|
||||
options). Output: ``<data_dir>/medical/medxpertqa/pdfs/<id>.pdf``.
|
||||
4. Upload each PDF to SurfSense with ``use_vision_llm=True``; persist
|
||||
``id -> document_id`` in
|
||||
``<data_dir>/medical/maps/medxpertqa_doc_map.jsonl``.
|
||||
|
||||
Both arms then receive byte-identical PDFs. The native arm sends the
|
||||
PDF directly to OpenRouter; SurfSense ingests via its own vision
|
||||
pipeline and the runner queries with ``mentioned_document_ids=[...]``
|
||||
to scope retrieval to the question's PDF.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.pdf import PdfImage, render_pdf_with_images
|
||||
from ....core.registry import RunContext
|
||||
from .prompt import format_options
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HF_REPO_ID = "TsinghuaC3I/MedXpertQA"
|
||||
HF_REPO_TYPE = "dataset"
|
||||
|
||||
|
||||
def _hf_hub_download(*args, **kwargs):
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
return hf_hub_download(*args, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MedXpertQuestion:
|
||||
qid: str # e.g. "MM-26"
|
||||
question: str # full question text (case + ask)
|
||||
options: dict[str, str] # A-E
|
||||
label: str # "A".."E"
|
||||
image_files: list[str] # filenames inside images.zip
|
||||
medical_task: str
|
||||
body_system: str
|
||||
question_type: str
|
||||
split: str # "test" or "dev"
|
||||
|
||||
|
||||
def _load_jsonl(path: Path, *, split: str) -> list[MedXpertQuestion]:
|
||||
out: list[MedXpertQuestion] = []
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
for raw_line in fh:
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
qid = str(row.get("id") or "").strip()
|
||||
question = str(row.get("question") or "").strip()
|
||||
options = row.get("options") or {}
|
||||
label = str(row.get("label") or "").strip().upper()
|
||||
if not qid or not question or not isinstance(options, dict) or not label:
|
||||
continue
|
||||
opts = {str(k).strip().upper(): str(v).strip() for k, v in options.items()}
|
||||
images = row.get("images") or []
|
||||
if not isinstance(images, list):
|
||||
images = []
|
||||
out.append(MedXpertQuestion(
|
||||
qid=qid,
|
||||
question=question,
|
||||
options=opts,
|
||||
label=label,
|
||||
image_files=[str(x).strip() for x in images if str(x).strip()],
|
||||
medical_task=str(row.get("medical_task") or "").strip(),
|
||||
body_system=str(row.get("body_system") or "").strip(),
|
||||
question_type=str(row.get("question_type") or "").strip(),
|
||||
split=split,
|
||||
))
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image archive helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_images_extracted(images_zip: Path, images_dir: Path) -> None:
|
||||
"""Extract images.zip once, tolerantly handle re-runs."""
|
||||
|
||||
marker = images_dir / ".extracted_ok"
|
||||
if marker.exists():
|
||||
return
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Extracting MedXpertQA images.zip -> %s", images_dir)
|
||||
with zipfile.ZipFile(images_zip) as zf:
|
||||
zf.extractall(images_dir)
|
||||
marker.write_text("ok\n", encoding="utf-8")
|
||||
|
||||
|
||||
def _resolve_image_path(image_filename: str, images_dir: Path) -> Path | None:
|
||||
"""Find a question's image in the (possibly nested) extract directory.
|
||||
|
||||
The zip layout sometimes nests under ``images/`` and sometimes
|
||||
flat — handle both.
|
||||
"""
|
||||
|
||||
direct = images_dir / image_filename
|
||||
if direct.exists():
|
||||
return direct
|
||||
nested = images_dir / "images" / image_filename
|
||||
if nested.exists():
|
||||
return nested
|
||||
# Last-ditch: glob recursively (slow but correct for unusual layouts).
|
||||
matches = list(images_dir.rglob(image_filename))
|
||||
return matches[0] if matches else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PDF rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _render_question_pdf(
|
||||
q: MedXpertQuestion,
|
||||
*,
|
||||
images_dir: Path,
|
||||
pdfs_dir: Path,
|
||||
) -> tuple[Path, list[str]]:
|
||||
"""Render one MedXpertQA question into a PDF.
|
||||
|
||||
Layout:
|
||||
Title: MedXpertQA — <qid> (medical_task / body_system)
|
||||
Section 1 (case): full question text
|
||||
Section 1 images: each image flowable + caption
|
||||
Section 2 (options): A) ... B) ... C) ... D) ... E) ...
|
||||
|
||||
Returns (pdf_path, missing_images) so the caller can warn on
|
||||
questions where some image files weren't found.
|
||||
"""
|
||||
|
||||
out_path = pdfs_dir / f"{q.qid}.pdf"
|
||||
images: list[PdfImage] = []
|
||||
missing: list[str] = []
|
||||
for fname in q.image_files:
|
||||
resolved = _resolve_image_path(fname, images_dir)
|
||||
if resolved is None:
|
||||
missing.append(fname)
|
||||
continue
|
||||
images.append(PdfImage(path=resolved, caption=f"Image: {fname}", max_width_in=5.5))
|
||||
|
||||
title_meta_parts = []
|
||||
if q.medical_task:
|
||||
title_meta_parts.append(q.medical_task)
|
||||
if q.body_system:
|
||||
title_meta_parts.append(q.body_system)
|
||||
if q.question_type:
|
||||
title_meta_parts.append(q.question_type)
|
||||
title_suffix = f" ({' / '.join(title_meta_parts)})" if title_meta_parts else ""
|
||||
|
||||
sections = [
|
||||
("Clinical case", q.question, images),
|
||||
("Answer choices", format_options(q.options), None),
|
||||
]
|
||||
render_pdf_with_images(
|
||||
title=f"MedXpertQA-MM {q.qid}{title_suffix}",
|
||||
sections=sections,
|
||||
output_path=out_path,
|
||||
)
|
||||
return out_path, missing
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _upload_pdfs(
|
||||
ctx: RunContext,
|
||||
pdf_paths: Iterable[Path],
|
||||
*,
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
) -> dict[str, int]:
|
||||
docs_client = ctx.documents_client()
|
||||
name_to_id: dict[str, int] = {}
|
||||
pdf_list = list(pdf_paths)
|
||||
for batch_start in range(0, len(pdf_list), batch_size):
|
||||
batch = pdf_list[batch_start:batch_start + batch_size]
|
||||
result = await docs_client.upload(
|
||||
files=batch,
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
|
||||
if all_ids:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=result.document_ids,
|
||||
timeout_s=1800.0,
|
||||
)
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=all_ids,
|
||||
)
|
||||
for s in statuses:
|
||||
name_to_id[s.title] = s.document_id
|
||||
logger.info(
|
||||
"Uploaded MedXpertQA batch %d-%d: %d new, %d duplicate",
|
||||
batch_start, batch_start + len(batch),
|
||||
len(result.document_ids), len(result.duplicate_document_ids),
|
||||
)
|
||||
return name_to_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
split: str = "test",
|
||||
max_questions: int | None = None,
|
||||
upload_batch_size: int = 8,
|
||||
skip_upload: bool = False,
|
||||
include_dev: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
"""Ingest MedXpertQA-MM into the medical suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
split : 'test' (default), 'dev', or 'both'
|
||||
Which subset to render + upload.
|
||||
max_questions : int | None
|
||||
Cap on number of questions ingested (handy for fast iteration).
|
||||
upload_batch_size : int
|
||||
PDFs per ``fileupload`` call.
|
||||
skip_upload : bool
|
||||
Render PDFs locally but don't push to SurfSense.
|
||||
include_dev : bool
|
||||
Convenience: equivalent to ``split='both'``.
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(use_vision_llm=True, processing_mode="basic")
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
images_zip_local = bench_dir / "images.zip"
|
||||
images_dir = bench_dir / "images"
|
||||
pdfs_dir = bench_dir / "pdfs"
|
||||
pdfs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_cache = bench_dir / ".hf_cache"
|
||||
hf_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Step 1: download jsonl(s)
|
||||
splits_to_load: list[str] = []
|
||||
if split == "both" or include_dev:
|
||||
splits_to_load = ["dev", "test"]
|
||||
elif split in {"dev", "test"}:
|
||||
splits_to_load = [split]
|
||||
else:
|
||||
raise ValueError(f"Unknown split {split!r}; use 'test' / 'dev' / 'both'")
|
||||
|
||||
questions: list[MedXpertQuestion] = []
|
||||
for sp in splits_to_load:
|
||||
rel = f"MM/{sp}.jsonl"
|
||||
local = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=rel,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(hf_cache),
|
||||
)
|
||||
loaded = _load_jsonl(Path(local), split=sp)
|
||||
questions.extend(loaded)
|
||||
logger.info("Loaded %d MedXpertQA-MM questions from %s split", len(loaded), sp)
|
||||
|
||||
if max_questions is not None and max_questions > 0:
|
||||
questions = questions[:max_questions]
|
||||
if not questions:
|
||||
raise RuntimeError("No MedXpertQA-MM questions loaded; check the split argument.")
|
||||
|
||||
# Step 2: download images.zip + extract once
|
||||
if not images_zip_local.exists():
|
||||
local_zip = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename="images.zip",
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(hf_cache),
|
||||
)
|
||||
# Materialise into bench_dir so the path is stable.
|
||||
try:
|
||||
from os import link as _link
|
||||
_link(local_zip, images_zip_local)
|
||||
except OSError:
|
||||
from shutil import copy2
|
||||
copy2(local_zip, images_zip_local)
|
||||
_ensure_images_extracted(images_zip_local, images_dir)
|
||||
|
||||
# Step 3: render PDFs
|
||||
pdf_paths: dict[str, Path] = {}
|
||||
missing_image_count = 0
|
||||
for i, q in enumerate(questions, start=1):
|
||||
try:
|
||||
pdf, missing = _render_question_pdf(q, images_dir=images_dir, pdfs_dir=pdfs_dir)
|
||||
pdf_paths[q.qid] = pdf
|
||||
if missing:
|
||||
missing_image_count += len(missing)
|
||||
logger.debug("qid=%s missing %d images: %s", q.qid, len(missing), missing)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to render MedXpertQA PDF for %s: %s", q.qid, exc)
|
||||
if i % 50 == 0:
|
||||
logger.info(" ... rendered %d / %d PDFs", i, len(questions))
|
||||
if missing_image_count:
|
||||
logger.warning(
|
||||
"MedXpertQA: %d image references could not be resolved on disk "
|
||||
"(rendered PDFs may be missing some images).",
|
||||
missing_image_count,
|
||||
)
|
||||
|
||||
# Step 4: upload
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("MedXpertQA: --skip-upload set; skipping SurfSense ingestion")
|
||||
else:
|
||||
logger.info("MedXpertQA upload settings: %s", settings.render_label())
|
||||
name_to_id = await _upload_pdfs(
|
||||
ctx,
|
||||
pdf_paths.values(),
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Step 5: persist manifest + questions
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
with questions_jsonl.open("w", encoding="utf-8") as fh:
|
||||
for q in questions:
|
||||
fh.write(json.dumps({
|
||||
"qid": q.qid,
|
||||
"question": q.question,
|
||||
"options": q.options,
|
||||
"label": q.label,
|
||||
"image_files": q.image_files,
|
||||
"medical_task": q.medical_task,
|
||||
"body_system": q.body_system,
|
||||
"question_type": q.question_type,
|
||||
"split": q.split,
|
||||
}) + "\n")
|
||||
logger.info("Wrote %d MedXpertQA questions to %s", len(questions), questions_jsonl)
|
||||
|
||||
map_path = ctx.maps_dir() / "medxpertqa_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
# Header line records the resolved ingest settings
|
||||
# (see core/ingest_settings.py).
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for q in questions:
|
||||
local = pdf_paths.get(q.qid)
|
||||
if local is None:
|
||||
continue
|
||||
fh.write(json.dumps({
|
||||
"qid": q.qid,
|
||||
"document_id": name_to_id.get(local.name),
|
||||
"pdf_path": str(local),
|
||||
"n_images": len(q.image_files),
|
||||
"split": q.split,
|
||||
}) + "\n")
|
||||
logger.info("Wrote MedXpertQA doc map to %s", map_path)
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["medxpertqa"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
|
||||
__all__ = ["MedXpertQuestion", "run_ingest"]
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""MedXpertQA-MM prompt.
|
||||
|
||||
Mirrors the upstream paper's evaluation prompt (Zuo et al., ICML 2025
|
||||
§3.4): present case + 5 options A-E, ask for a single letter answer.
|
||||
We also instruct the model to use the embedded images explicitly,
|
||||
since the whole point of the MM subset is that the answer depends on
|
||||
visual evidence (radiology / dermoscopy / pathology / ECG, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
ANSWER_LETTERS = ("A", "B", "C", "D", "E")
|
||||
|
||||
|
||||
_PROMPT = """\
|
||||
You are a board-certified physician. The following exam question
|
||||
includes a clinical case and one or more medical images (radiology,
|
||||
dermatology, pathology, ECG, etc.). Use BOTH the text and the images
|
||||
to choose the best answer. Do not rely on memorisation of the case;
|
||||
read the images carefully — they often determine the correct answer.
|
||||
|
||||
Case + question:
|
||||
{question}
|
||||
|
||||
Answer choices:
|
||||
{options_block}
|
||||
|
||||
Respond on a single line in the format `Answer: X` where X is one of
|
||||
A, B, C, D, or E.
|
||||
"""
|
||||
|
||||
|
||||
def format_options(options: Mapping[str, str]) -> str:
|
||||
"""Render the ``A) ... E) ...`` options block."""
|
||||
|
||||
parts: list[str] = []
|
||||
for letter in ANSWER_LETTERS:
|
||||
text = options.get(letter)
|
||||
if text is None or str(text).strip() == "":
|
||||
continue
|
||||
parts.append(f"{letter}) {str(text).strip()}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def build_prompt(question: str, options: Mapping[str, str]) -> str:
|
||||
return _PROMPT.format(
|
||||
question=question.strip(),
|
||||
options_block=format_options(options),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ANSWER_LETTERS", "build_prompt", "format_options"]
|
||||
|
|
@ -0,0 +1,681 @@
|
|||
"""MedXpertQA-MM runner — Native PDF (vision) vs SurfSense (vision RAG).
|
||||
|
||||
Headline benchmark for the medical suite.
|
||||
|
||||
* Native arm reads the rendered PDF (case + images + options) via
|
||||
OpenRouter ``chat/completions`` + the file-parser plugin.
|
||||
* SurfSense arm queries ``POST /api/v1/new_chat`` scoped via
|
||||
``mentioned_document_ids=[doc_id]`` to the same per-question PDF.
|
||||
|
||||
Operational notes:
|
||||
|
||||
* PDFs contain real images (radiology, dermoscopy, pathology, ECGs).
|
||||
Operator must pin a vision-capable model via
|
||||
``setup --provider-model anthropic/claude-sonnet-4.5`` (or similar);
|
||||
the runner emits a warning if a known text-only slug is pinned.
|
||||
* MedXpertQA tags ``medical_task`` (Diagnosis / Treatment / Basic
|
||||
Medicine) and ``body_system`` (Cardiovascular / Lymphatic / …)
|
||||
directly on every row; we slice the report by both.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, NativePdfArm, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
is_settings_header,
|
||||
)
|
||||
from ....core.metrics.comparison import (
|
||||
bootstrap_delta_ci,
|
||||
mcnemar_test,
|
||||
paired_aggregate,
|
||||
)
|
||||
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci
|
||||
from ....core.providers.openrouter_pdf import OpenRouterPdfProvider, PdfEngine
|
||||
from ....core.registry import (
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
)
|
||||
from ....core.scenarios import format_scenario_md
|
||||
from .prompt import ANSWER_LETTERS, build_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_TEXT_ONLY_HINTS = ("gpt-5.4-mini", "gpt-3.5", "text-only", "instruct-")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MXQuestion:
|
||||
qid: str
|
||||
question: str
|
||||
options: dict[str, str]
|
||||
label: str
|
||||
medical_task: str
|
||||
body_system: str
|
||||
question_type: str
|
||||
split: str
|
||||
n_images: int
|
||||
pdf_path: Path
|
||||
document_id: int | None
|
||||
|
||||
|
||||
def _load_doc_map(map_path: Path) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
|
||||
"""Read the doc map JSONL.
|
||||
|
||||
Returns ``(rows, settings)`` where ``settings`` is the
|
||||
``__settings__`` header blob (or ``{}`` for legacy maps).
|
||||
"""
|
||||
|
||||
rows: dict[str, dict[str, Any]] = {}
|
||||
settings: dict[str, Any] = {}
|
||||
with map_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
settings = dict(row["__settings__"])
|
||||
continue
|
||||
rows[str(row["qid"])] = row
|
||||
return rows, settings
|
||||
|
||||
|
||||
def _load_questions(
|
||||
questions_jsonl: Path,
|
||||
doc_map: dict[str, dict[str, Any]],
|
||||
*,
|
||||
split_filter: str | None,
|
||||
task_filter: str | None,
|
||||
body_filter: str | None,
|
||||
require_images: bool,
|
||||
sample_n: int | None,
|
||||
) -> list[MXQuestion]:
|
||||
out: list[MXQuestion] = []
|
||||
with questions_jsonl.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
qid = str(row.get("qid") or "").strip()
|
||||
if not qid:
|
||||
continue
|
||||
if split_filter and split_filter != "all" and row.get("split") != split_filter:
|
||||
continue
|
||||
if task_filter and task_filter != "all" and row.get("medical_task") != task_filter:
|
||||
continue
|
||||
if body_filter and body_filter != "all" and row.get("body_system") != body_filter:
|
||||
continue
|
||||
map_row = doc_map.get(qid)
|
||||
if map_row is None:
|
||||
logger.debug("No doc-map entry for %s; skipping", qid)
|
||||
continue
|
||||
n_images = int(map_row.get("n_images", 0))
|
||||
if require_images and n_images <= 0:
|
||||
continue
|
||||
out.append(MXQuestion(
|
||||
qid=qid,
|
||||
question=str(row.get("question") or ""),
|
||||
options={str(k).upper(): str(v) for k, v in (row.get("options") or {}).items()},
|
||||
label=str(row.get("label") or "").strip().upper(),
|
||||
medical_task=str(row.get("medical_task") or "").strip(),
|
||||
body_system=str(row.get("body_system") or "").strip(),
|
||||
question_type=str(row.get("question_type") or "").strip(),
|
||||
split=str(row.get("split") or ""),
|
||||
n_images=n_images,
|
||||
pdf_path=Path(map_row["pdf_path"]),
|
||||
document_id=map_row.get("document_id"),
|
||||
))
|
||||
out.sort(key=lambda q: (q.split, q.qid))
|
||||
if sample_n is not None and sample_n > 0:
|
||||
out = out[:sample_n]
|
||||
return out
|
||||
|
||||
|
||||
async def _gather_with_limit(coros: Iterable, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(coro):
|
||||
async with sem:
|
||||
return await coro
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
_DESCRIPTION = (
|
||||
"MedXpertQA-MM (~2,000 multimodal medical exam questions, 5 options, with images) — "
|
||||
"Native PDF (vision) vs SurfSense (vision RAG) head-to-head."
|
||||
)
|
||||
|
||||
# MedXpertQA-MM PDFs embed clinical images; vision LLM at ingest is
|
||||
# the whole point. Operators can flip ``--no-vision-llm`` to measure
|
||||
# how much we degrade without it (likely material).
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=True,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
class MedXpertQAMMBenchmark:
|
||||
"""Multimodal medical exam head-to-head."""
|
||||
|
||||
suite: str = "medical"
|
||||
name: str = "medxpertqa"
|
||||
headline: bool = True # The medical suite headline.
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--split", default="test", choices=["test", "dev", "all"],
|
||||
help="Which MedXpertQA-MM split to run (default: test).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", default="all",
|
||||
help="Filter by medical_task value (e.g. Diagnosis, Treatment, Basic Medicine).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--body-system", dest="body_filter", default="all",
|
||||
help="Filter by body_system value (e.g. Cardiovascular, Lymphatic).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--require-images", dest="require_images", action="store_true",
|
||||
help="Skip rare MM rows that ended up with zero resolvable images.",
|
||||
)
|
||||
parser.add_argument("--n", dest="sample_n", type=int, default=None,
|
||||
help="Run only the first N questions after filters apply.")
|
||||
parser.add_argument("--concurrency", type=int, default=4,
|
||||
help="Parallel question workers per arm.")
|
||||
parser.add_argument("--no-mentions", dest="no_mentions", action="store_true",
|
||||
help="SurfSense arm: skip mentioned_document_ids (unscoped retrieval).")
|
||||
parser.add_argument(
|
||||
"--pdf-engine", default="native",
|
||||
choices=[e.value for e in PdfEngine],
|
||||
help="OpenRouter file-parser engine for the native arm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-output-tokens", type=int, default=512,
|
||||
help="Cap on completion length for both arms.",
|
||||
)
|
||||
# Ingest-only knobs (forwarded by the CLI to ingest.run_ingest).
|
||||
parser.add_argument(
|
||||
"--max-questions", dest="max_questions", type=int, default=None,
|
||||
help="(ingest only) cap on number of MM questions to render + upload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-batch-size", dest="upload_batch_size", type=int, default=8,
|
||||
help="(ingest only) PDFs per fileupload call.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-upload", dest="skip_upload", action="store_true",
|
||||
help="(ingest only) render PDFs locally but don't push to SurfSense.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-dev", dest="include_dev", action="store_true",
|
||||
help="(ingest only) shorthand for --split all.",
|
||||
)
|
||||
# Per-upload knobs forwarded to /documents/fileupload at ingest;
|
||||
# ignored at run-time (runner reads the resolved settings out of
|
||||
# the doc-map manifest header).
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
split=opts.get("split") or "test",
|
||||
max_questions=opts.get("max_questions"),
|
||||
upload_batch_size=int(opts.get("upload_batch_size") or 8),
|
||||
skip_upload=bool(opts.get("skip_upload", False)),
|
||||
include_dev=bool(opts.get("include_dev", False)),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
split_filter = opts.get("split") or "test"
|
||||
task_filter = opts.get("task") or "all"
|
||||
body_filter = opts.get("body_filter") or "all"
|
||||
require_images = bool(opts.get("require_images"))
|
||||
sample_n = opts.get("sample_n")
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
no_mentions = bool(opts.get("no_mentions"))
|
||||
pdf_engine_name = opts.get("pdf_engine") or "native"
|
||||
max_output_tokens = int(opts.get("max_output_tokens") or 512)
|
||||
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
map_path = ctx.maps_dir() / "medxpertqa_doc_map.jsonl"
|
||||
if not questions_jsonl.exists() or not map_path.exists():
|
||||
raise RuntimeError(
|
||||
"MedXpertQA-MM not ingested for this suite. Run "
|
||||
"`python -m surfsense_evals ingest medical medxpertqa` first."
|
||||
)
|
||||
|
||||
doc_map, ingest_settings = _load_doc_map(map_path)
|
||||
questions = _load_questions(
|
||||
questions_jsonl, doc_map,
|
||||
split_filter=split_filter,
|
||||
task_filter=task_filter if task_filter != "all" else None,
|
||||
body_filter=body_filter if body_filter != "all" else None,
|
||||
require_images=require_images,
|
||||
sample_n=sample_n,
|
||||
)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"No MedXpertQA-MM questions matched the filters; broaden --split/--task/--body-system/--n."
|
||||
)
|
||||
logger.info("MedXpertQA-MM: scheduled %d questions", len(questions))
|
||||
|
||||
api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENROUTER_API_KEY env var is required for the native arm.")
|
||||
|
||||
# Native arm slug differs from SurfSense slug only in cost-arbitrage
|
||||
# scenario; otherwise both arms answer with provider_model.
|
||||
native_arm_model = ctx.native_arm_model
|
||||
if any(hint in native_arm_model.lower() for hint in _TEXT_ONLY_HINTS):
|
||||
if ctx.scenario == "symmetric-cheap":
|
||||
logger.info(
|
||||
"symmetric-cheap: native arm pinned to text-only %r as "
|
||||
"intended; expect it to lose on image-bearing questions "
|
||||
"(SurfSense answers from vision-extracted chunks).",
|
||||
native_arm_model,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Native arm slug %r looks text-only; image content in "
|
||||
"MedXpertQA PDFs will be ignored. Re-pin via "
|
||||
"`setup --provider-model anthropic/claude-sonnet-4.5` "
|
||||
"(or pass --native-arm-model and --scenario cost-arbitrage "
|
||||
"to make this asymmetry explicit).",
|
||||
native_arm_model,
|
||||
)
|
||||
|
||||
provider = OpenRouterPdfProvider(
|
||||
api_key=api_key,
|
||||
base_url=ctx.config.openrouter_base_url,
|
||||
model=native_arm_model,
|
||||
engine=PdfEngine(pdf_engine_name),
|
||||
)
|
||||
native_arm = NativePdfArm(provider=provider, max_output_tokens=max_output_tokens)
|
||||
surf_arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
|
||||
async def _native_one(q: MXQuestion) -> ArmResult:
|
||||
return await native_arm.answer(_make_native_request(q, max_output_tokens))
|
||||
|
||||
async def _surf_one(q: MXQuestion) -> ArmResult:
|
||||
return await surf_arm.answer(_make_surfsense_request(q, no_mentions=no_mentions))
|
||||
|
||||
native_results, surf_results = await asyncio.gather(
|
||||
_gather_with_limit((_native_one(q) for q in questions), concurrency=concurrency),
|
||||
_gather_with_limit((_surf_one(q) for q in questions), concurrency=concurrency),
|
||||
)
|
||||
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, n_res, s_res in zip(questions, native_results, surf_results, strict=False):
|
||||
meta = {
|
||||
"qid": q.qid,
|
||||
"split": q.split,
|
||||
"medical_task": q.medical_task,
|
||||
"body_system": q.body_system,
|
||||
"question_type": q.question_type,
|
||||
"n_images": q.n_images,
|
||||
"correct": q.label,
|
||||
"document_id": q.document_id,
|
||||
}
|
||||
fh.write(json.dumps({**meta, **n_res.to_jsonl()}) + "\n")
|
||||
fh.write(json.dumps({**meta, **s_res.to_jsonl()}) + "\n")
|
||||
|
||||
metrics = _compute_metrics(questions, native_results, surf_results)
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_questions": len(questions),
|
||||
"concurrency": concurrency,
|
||||
"split_filter": split_filter,
|
||||
"task_filter": task_filter,
|
||||
"body_filter": body_filter,
|
||||
"require_images": require_images,
|
||||
"no_mentions": no_mentions,
|
||||
"pdf_engine": pdf_engine_name,
|
||||
"scenario": ctx.scenario,
|
||||
"provider_model": ctx.provider_model,
|
||||
"native_arm_model": native_arm_model,
|
||||
"vision_provider_model": ctx.vision_provider_model,
|
||||
"agent_llm_id": ctx.agent_llm_id,
|
||||
"ingest_settings": ingest_settings,
|
||||
},
|
||||
)
|
||||
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps({
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
}, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="MedXpertQA-MM — Native PDF (vision) vs SurfSense (vision RAG)",
|
||||
headline=False,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
m = latest.metrics
|
||||
native = m.get("native", {})
|
||||
surf = m.get("surfsense", {})
|
||||
delta = m.get("delta", {})
|
||||
per_task = m.get("per_task", {})
|
||||
per_body = m.get("per_body_system", {})
|
||||
extra = latest.extra
|
||||
|
||||
body_lines: list[str] = []
|
||||
body_lines.append(
|
||||
f"- Sample size: {extra.get('n_questions', '?')} questions "
|
||||
f"(split: `{extra.get('split_filter', 'test')}`, "
|
||||
f"task: `{extra.get('task_filter', 'all')}`, "
|
||||
f"body: `{extra.get('body_filter', 'all')}`, "
|
||||
f"engine: `{extra.get('pdf_engine', 'native')}`)."
|
||||
)
|
||||
body_lines.append(format_scenario_md(extra))
|
||||
body_lines.append(format_ingest_settings_md(extra.get("ingest_settings")))
|
||||
body_lines.append(
|
||||
"- Native arm (OpenRouter `chat/completions` + file plugin, "
|
||||
f"`{extra.get('native_arm_model') or extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(native, indent=" "))
|
||||
body_lines.append(
|
||||
"- SurfSense arm (`POST /api/v1/new_chat`, vision RAG over chunks, "
|
||||
f"`{extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(surf, indent=" "))
|
||||
body_lines.append("- Delta (paired):")
|
||||
body_lines.append(
|
||||
f" - Accuracy: SurfSense {_pp(delta.get('accuracy_pp'))} pp "
|
||||
f"(McNemar p={_fmt(delta.get('mcnemar_p_value'), 4)}, "
|
||||
f"method={delta.get('mcnemar_method')})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Bootstrap 95% CI on delta: "
|
||||
f"[{_pp(delta.get('bootstrap_ci_low'))}pp, {_pp(delta.get('bootstrap_ci_high'))}pp]"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Cost / question: native ${_dollars(native.get('cost_micros_mean'))}, "
|
||||
f"surfsense ${_dollars(surf.get('cost_micros_mean'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('cost_micros_pct'))})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Latency p50: native {_ms_to_s(native.get('latency_ms_median'))}, "
|
||||
f"surfsense {_ms_to_s(surf.get('latency_ms_median'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('latency_ms_pct'))})"
|
||||
)
|
||||
if per_task:
|
||||
body_lines.append("- Per-medical_task split:")
|
||||
for task_name, vals in sorted(per_task.items()):
|
||||
body_lines.append(
|
||||
f" - {task_name}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
|
||||
f"(n={vals.get('n')})"
|
||||
)
|
||||
if per_body:
|
||||
body_lines.append("- Per-body_system split (top 5 by sample size):")
|
||||
top = sorted(per_body.items(), key=lambda kv: -kv[1].get("n", 0))[:5]
|
||||
for body_name, vals in top:
|
||||
body_lines.append(
|
||||
f" - {body_name}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
|
||||
f"(n={vals.get('n')})"
|
||||
)
|
||||
|
||||
return ReportSection(
|
||||
title="MedXpertQA-MM — Native PDF (vision) vs SurfSense (vision RAG)",
|
||||
headline=False,
|
||||
body_md="\n".join(body_lines),
|
||||
body_json=m,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-question helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_native_request(q: MXQuestion, max_tokens: int) -> ArmRequest:
|
||||
prompt = build_prompt(q.question, q.options)
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=prompt,
|
||||
pdf_paths=[q.pdf_path],
|
||||
options={"max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
|
||||
def _make_surfsense_request(q: MXQuestion, *, no_mentions: bool) -> ArmRequest:
|
||||
prompt = build_prompt(q.question, q.options)
|
||||
mentions: list[int] | None = None
|
||||
if not no_mentions and q.document_id is not None:
|
||||
mentions = [int(q.document_id)]
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=prompt,
|
||||
mentioned_document_ids=mentions,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_metrics(
|
||||
questions: list[MXQuestion],
|
||||
native_results: list[ArmResult],
|
||||
surf_results: list[ArmResult],
|
||||
) -> dict[str, Any]:
|
||||
native_correct: list[bool] = []
|
||||
surf_correct: list[bool] = []
|
||||
for q, n_res, s_res in zip(questions, native_results, surf_results, strict=False):
|
||||
gold = q.label
|
||||
n_ok = (n_res.answer_letter or "").upper() == gold and gold in ANSWER_LETTERS
|
||||
s_ok = (s_res.answer_letter or "").upper() == gold and gold in ANSWER_LETTERS
|
||||
native_correct.append(n_ok)
|
||||
surf_correct.append(s_ok)
|
||||
|
||||
native_costs = [float(r.cost_micros) for r in native_results]
|
||||
surf_costs = [float(r.cost_micros) for r in surf_results]
|
||||
native_lats = [float(r.latency_ms) for r in native_results]
|
||||
surf_lats = [float(r.latency_ms) for r in surf_results]
|
||||
native_in = [float(r.input_tokens) for r in native_results]
|
||||
native_out = [float(r.output_tokens) for r in native_results]
|
||||
|
||||
native_acc = accuracy_with_wilson_ci(sum(native_correct), len(native_correct))
|
||||
surf_acc = accuracy_with_wilson_ci(sum(surf_correct), len(surf_correct))
|
||||
mc = mcnemar_test(native_correct, surf_correct)
|
||||
boot = bootstrap_delta_ci(native_correct, surf_correct, n_resamples=2000)
|
||||
|
||||
native_cost_agg = paired_aggregate(native_costs)
|
||||
surf_cost_agg = paired_aggregate(surf_costs)
|
||||
native_lat_agg = paired_aggregate(native_lats)
|
||||
surf_lat_agg = paired_aggregate(surf_lats)
|
||||
|
||||
cost_pct = _safe_pct(surf_cost_agg.mean, native_cost_agg.mean)
|
||||
lat_pct = _safe_pct(surf_lat_agg.median, native_lat_agg.median)
|
||||
|
||||
per_task = _per_field(questions, native_correct, surf_correct, key=lambda q: q.medical_task or "unknown")
|
||||
per_body = _per_field(questions, native_correct, surf_correct, key=lambda q: q.body_system or "unknown")
|
||||
|
||||
return {
|
||||
"native": {
|
||||
**native_acc.to_dict(),
|
||||
"cost_micros_mean": native_cost_agg.mean,
|
||||
"cost_micros_median": native_cost_agg.median,
|
||||
"latency_ms_mean": native_lat_agg.mean,
|
||||
"latency_ms_median": native_lat_agg.median,
|
||||
"latency_ms_p95": native_lat_agg.p95,
|
||||
"input_tokens_mean": (sum(native_in) / len(native_in)) if native_in else 0.0,
|
||||
"output_tokens_mean": (sum(native_out) / len(native_out)) if native_out else 0.0,
|
||||
},
|
||||
"surfsense": {
|
||||
**surf_acc.to_dict(),
|
||||
"cost_micros_mean": surf_cost_agg.mean,
|
||||
"cost_micros_median": surf_cost_agg.median,
|
||||
"latency_ms_mean": surf_lat_agg.mean,
|
||||
"latency_ms_median": surf_lat_agg.median,
|
||||
"latency_ms_p95": surf_lat_agg.p95,
|
||||
},
|
||||
"delta": {
|
||||
"accuracy_pp": 100.0 * (surf_acc.accuracy - native_acc.accuracy),
|
||||
"mcnemar_p_value": mc.p_value,
|
||||
"mcnemar_method": mc.method,
|
||||
"mcnemar_b_native_only": mc.b,
|
||||
"mcnemar_c_surfsense_only": mc.c,
|
||||
"bootstrap_ci_low": 100.0 * boot.ci_low,
|
||||
"bootstrap_ci_high": 100.0 * boot.ci_high,
|
||||
"cost_micros_pct": cost_pct,
|
||||
"latency_ms_pct": lat_pct,
|
||||
},
|
||||
"per_task": per_task,
|
||||
"per_body_system": per_body,
|
||||
}
|
||||
|
||||
|
||||
def _per_field(
|
||||
questions: list[MXQuestion],
|
||||
native_correct: list[bool],
|
||||
surf_correct: list[bool],
|
||||
*,
|
||||
key,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
bucket: dict[str, list[tuple[bool, bool]]] = {}
|
||||
for q, n_ok, s_ok in zip(questions, native_correct, surf_correct, strict=False):
|
||||
bucket.setdefault(key(q), []).append((n_ok, s_ok))
|
||||
out: dict[str, dict[str, Any]] = {}
|
||||
for k, pairs in bucket.items():
|
||||
n_correct = [a for a, _ in pairs]
|
||||
s_correct = [b for _, b in pairs]
|
||||
out[k] = {
|
||||
"n": len(pairs),
|
||||
"native_accuracy": (sum(n_correct) / len(pairs)) if pairs else 0.0,
|
||||
"surfsense_accuracy": (sum(s_correct) / len(pairs)) if pairs else 0.0,
|
||||
"delta_accuracy_pp": (
|
||||
100.0 * (sum(s_correct) - sum(n_correct)) / len(pairs)
|
||||
if pairs else 0.0
|
||||
),
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def _safe_pct(numerator: float, denominator: float) -> float | None:
|
||||
if denominator == 0:
|
||||
return None
|
||||
return 100.0 * (numerator - denominator) / denominator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Formatters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _arm_summary_lines(d: dict[str, Any], *, indent: str) -> str:
|
||||
if not d:
|
||||
return f"{indent}(no data)"
|
||||
acc = d.get("accuracy", 0.0)
|
||||
low = d.get("ci_low", 0.0)
|
||||
high = d.get("ci_high", 0.0)
|
||||
lines = [
|
||||
f"{indent}- Accuracy: {acc * 100:.1f}% (Wilson 95% CI: {low * 100:.1f}% – {high * 100:.1f}%)",
|
||||
f"{indent}- Cost / question: ${_dollars(d.get('cost_micros_mean'))} (mean), "
|
||||
f"${_dollars(d.get('cost_micros_median'))} (median)",
|
||||
f"{indent}- Latency: p50 {_ms_to_s(d.get('latency_ms_median'))}, "
|
||||
f"p95 {_ms_to_s(d.get('latency_ms_p95'))}",
|
||||
]
|
||||
if "input_tokens_mean" in d:
|
||||
lines.append(
|
||||
f"{indent}- Mean tokens / question: in {d.get('input_tokens_mean', 0):.0f}, "
|
||||
f"out {d.get('output_tokens_mean', 0):.0f}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _dollars(micros: Any) -> str:
|
||||
if micros is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{(float(micros) / 1_000_000):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _ms_to_s(ms: Any) -> str:
|
||||
if ms is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(ms) / 1000:.1f}s"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pp(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.1f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pct_change(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.0f}%"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _fmt(value: Any, ndigits: int) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):.{ndigits}f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
__all__ = ["MedXpertQAMMBenchmark", "MXQuestion"]
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
"""MIRAGE — secondary single-arm SurfSense MCQ measurement.
|
||||
|
||||
Source: https://github.com/Teddy-XiongGZ/MIRAGE, paper
|
||||
https://aclanthology.org/2024.findings-acl.372/. 7,663 questions
|
||||
across MMLU-Med, MedQA-US, MedMCQA, PubMedQA*, BioASQ-Y/N.
|
||||
|
||||
This is a SurfSense-only measurement (not a head-to-head); native
|
||||
PDF-in-LLM doesn't apply because there is no per-question discrete
|
||||
document — the corpus is millions of biomedical snippets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .runner import MirageBenchmark
|
||||
from ....core import registry as _registry
|
||||
|
||||
_registry.register(MirageBenchmark())
|
||||
|
|
@ -0,0 +1,548 @@
|
|||
"""MIRAGE ingestion.
|
||||
|
||||
Downloads:
|
||||
|
||||
* ``benchmark.json`` (≈ 4 MB; questions for the 5 sub-tasks).
|
||||
* ``retrieved_snippets_10k.zip`` (the union of top-10k snippet ids
|
||||
retrieved by every retriever in the MedRAG paper, per task — a
|
||||
recall ceiling that avoids needing the full 23.9M-doc PubMed mirror).
|
||||
|
||||
Snippet *content* lives in the MedRAG HF mirrors
|
||||
(``MedRAG/textbooks``, ``MedRAG/pubmed``, ``MedRAG/statpearls``,
|
||||
``MedRAG/wikipedia``). We default to ``MedRAG/textbooks`` (212 MB,
|
||||
125k snippets) which is the smallest and covers the majority of
|
||||
``MedQA-US`` and the medical examination subsets. Operators can
|
||||
opt into larger corpora with ``--corpus``.
|
||||
|
||||
Each snippet is written as one markdown file then batched into
|
||||
``~5 MB`` markdown bundles for SurfSense's file upload (smaller
|
||||
than backend default ``MAX_FILE_SIZE_BYTES`` and avoids the per-call
|
||||
overhead of one HTTP request per snippet).
|
||||
|
||||
The ingestion produces two maps under ``data/medical/maps/``:
|
||||
|
||||
* ``mirage_snippet_map.jsonl`` — ``{snippet_id, document_id, batch_path}``
|
||||
* ``mirage_chunk_map.jsonl`` — ``{chunk_id, document_id, snippet_id?}``
|
||||
(best-effort; chunk text is heuristically attributed to the
|
||||
snippet it overlaps when the SurfSense chunker splits a batched
|
||||
markdown).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MIRAGE_BENCHMARK_URL = (
|
||||
"https://raw.githubusercontent.com/Teddy-XiongGZ/MIRAGE/main/benchmark.json"
|
||||
)
|
||||
# Upstream only ships ONE zip — top-10k retrievals across 5 retrievers,
|
||||
# ~16 GB. We default to skipping it (see `--skip-snippet-filter`) and
|
||||
# ingesting the chosen corpus in full; this URL is only fetched when
|
||||
# the operator explicitly opts in.
|
||||
MIRAGE_SNIPPETS_ZIP_URL = (
|
||||
"https://virginia.box.com/shared/static/cxq17th6eisl2pn04vp0x723zczlvlzc.zip"
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_CORPUS = "MedRAG/textbooks"
|
||||
_BATCH_SIZE_BYTES = 5 * 1024 * 1024
|
||||
# 2 GB safety cap. Anything larger requires --allow-large-download.
|
||||
# Set high enough that ``benchmark.json`` and small zips pass through
|
||||
# untouched but the 16 GB MIRAGE retrievals zip trips the guard.
|
||||
_LARGE_DOWNLOAD_BYTES = 2 * 1024 * 1024 * 1024
|
||||
_DOWNLOAD_RETRIES = 5
|
||||
_RETRYABLE_NET_EXC: tuple[type[BaseException], ...] = (
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnippetRow:
|
||||
snippet_id: str
|
||||
title: str
|
||||
content: str
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
title = (self.title or "").strip() or "Untitled"
|
||||
body = (self.content or "").strip()
|
||||
return f"# {title}\n\n_id: `{self.snippet_id}`_\n\n{body}\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _fetch_to_path(
|
||||
url: str,
|
||||
*,
|
||||
dest: Path,
|
||||
label: str,
|
||||
timeout_s: float = 600.0,
|
||||
allow_large_download: bool = False,
|
||||
expect_zip: bool = False,
|
||||
) -> Path:
|
||||
"""Download ``url`` to ``dest`` with retry, atomic-rename, and
|
||||
HTTP ``Range`` resume.
|
||||
|
||||
Operational properties:
|
||||
|
||||
* If ``dest`` already exists *and* (when ``expect_zip`` is True) the
|
||||
cached file is a valid ZIP, returns it immediately. A corrupt ZIP
|
||||
is removed and re-downloaded — this is the safety net for the
|
||||
`box.com truncated 16 GB zip` failure mode where the previous
|
||||
run wrote a half-completed file then exited with an exception.
|
||||
* Bytes are written to ``<dest>.partial`` and renamed only after the
|
||||
stream completes cleanly (and, for zips, only after a quick
|
||||
central-directory check). A failure mid-download leaves the
|
||||
``.partial`` file in place so the next attempt can resume from
|
||||
where it stopped via an HTTP ``Range`` header.
|
||||
* Retries on transient network errors (``RemoteProtocolError``,
|
||||
``ReadError``, ``ReadTimeout``, ``ConnectError``,
|
||||
``ConnectTimeout``) with exponential backoff, up to
|
||||
``_DOWNLOAD_RETRIES``.
|
||||
* Aborts before downloading if the ``Content-Length`` (or already-
|
||||
downloaded ``.partial`` size) is over ``_LARGE_DOWNLOAD_BYTES``
|
||||
and ``allow_large_download`` is False, to keep an operator from
|
||||
surprise-grabbing 16 GB on a slow link.
|
||||
"""
|
||||
|
||||
if dest.exists():
|
||||
if expect_zip and not _is_valid_zip(dest):
|
||||
logger.warning(
|
||||
"Cached %s at %s failed ZIP validation (size=%d B); deleting "
|
||||
"and re-downloading.",
|
||||
label,
|
||||
dest,
|
||||
dest.stat().st_size,
|
||||
)
|
||||
dest.unlink(missing_ok=True)
|
||||
else:
|
||||
logger.info("Using cached %s at %s", label, dest)
|
||||
return dest
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
partial = dest.with_suffix(dest.suffix + ".partial")
|
||||
last_exc: BaseException | None = None
|
||||
|
||||
for attempt in range(1, _DOWNLOAD_RETRIES + 1):
|
||||
existing_bytes = partial.stat().st_size if partial.exists() else 0
|
||||
headers: dict[str, str] = {}
|
||||
if existing_bytes:
|
||||
headers["Range"] = f"bytes={existing_bytes}-"
|
||||
logger.info(
|
||||
"Resuming %s from byte %d (attempt %d/%d)",
|
||||
label,
|
||||
existing_bytes,
|
||||
attempt,
|
||||
_DOWNLOAD_RETRIES,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Downloading %s from %s (attempt %d/%d)",
|
||||
label,
|
||||
url,
|
||||
attempt,
|
||||
_DOWNLOAD_RETRIES,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout_s, connect=20.0),
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
async with client.stream("GET", url, headers=headers) as response:
|
||||
if existing_bytes and response.status_code == 200:
|
||||
logger.warning(
|
||||
"Server ignored Range header for %s; restarting from 0.",
|
||||
label,
|
||||
)
|
||||
partial.unlink(missing_ok=True)
|
||||
existing_bytes = 0
|
||||
elif response.status_code == 416:
|
||||
# Range not satisfiable — the .partial is at or
|
||||
# past the end. Treat as "already downloaded";
|
||||
# validate by closing and re-opening for atomic
|
||||
# rename below.
|
||||
logger.info(
|
||||
"Server reports %s already complete (HTTP 416).",
|
||||
label,
|
||||
)
|
||||
elif response.status_code not in (200, 206):
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = _planned_total_size(response, existing_bytes)
|
||||
if (
|
||||
total_size is not None
|
||||
and total_size > _LARGE_DOWNLOAD_BYTES
|
||||
and not allow_large_download
|
||||
):
|
||||
raise _LargeDownloadAbort(label, total_size)
|
||||
|
||||
mode = "ab" if existing_bytes else "wb"
|
||||
with partial.open(mode) as fh:
|
||||
async for chunk in response.aiter_bytes(chunk_size=1 << 18):
|
||||
fh.write(chunk)
|
||||
# Optional content sanity check before promoting to dest.
|
||||
if expect_zip and not _is_valid_zip(partial):
|
||||
raise zipfile.BadZipFile(
|
||||
f"{label} downloaded to {partial} but failed central-"
|
||||
"directory check; will retry."
|
||||
)
|
||||
partial.replace(dest)
|
||||
return dest
|
||||
except _LargeDownloadAbort:
|
||||
raise
|
||||
except _RETRYABLE_NET_EXC as exc:
|
||||
last_exc = exc
|
||||
wait = min(60.0, 2.0 ** attempt)
|
||||
logger.warning(
|
||||
"Network error fetching %s (%s: %s); retrying in %.0fs.",
|
||||
label,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
except zipfile.BadZipFile as exc:
|
||||
last_exc = exc
|
||||
# Truncated body — drop the partial and retry from scratch.
|
||||
partial.unlink(missing_ok=True)
|
||||
wait = min(60.0, 2.0 ** attempt)
|
||||
logger.warning(
|
||||
"Truncated ZIP for %s; restarting from byte 0 in %.0fs.",
|
||||
label,
|
||||
wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to download {label} after {_DOWNLOAD_RETRIES} attempts: {last_exc!s}"
|
||||
)
|
||||
|
||||
|
||||
def _planned_total_size(response: httpx.Response, existing_bytes: int) -> int | None:
|
||||
"""Best-effort total size including any already-buffered .partial bytes."""
|
||||
|
||||
cl = response.headers.get("Content-Length")
|
||||
if not cl:
|
||||
return None
|
||||
try:
|
||||
remaining = int(cl)
|
||||
except ValueError:
|
||||
return None
|
||||
return existing_bytes + remaining
|
||||
|
||||
|
||||
def _is_valid_zip(path: Path) -> bool:
|
||||
"""Cheap ZIP validity check via central-directory parse."""
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(path) as zf:
|
||||
# ``namelist`` forces the central directory to be parsed.
|
||||
zf.namelist()
|
||||
return True
|
||||
except (zipfile.BadZipFile, OSError):
|
||||
return False
|
||||
|
||||
|
||||
class _LargeDownloadAbort(RuntimeError):
|
||||
"""Raised when a download exceeds the safety threshold without opt-in."""
|
||||
|
||||
def __init__(self, label: str, size_bytes: int) -> None:
|
||||
gb = size_bytes / (1024 ** 3)
|
||||
super().__init__(
|
||||
f"{label} would download ~{gb:.1f} GB, above the {_LARGE_DOWNLOAD_BYTES / (1024 ** 3):.0f} GB safety cap. "
|
||||
"Re-run with `--allow-large-download` to acknowledge, or use "
|
||||
"`--skip-snippet-filter` to bypass this download entirely and "
|
||||
"ingest the full corpus instead."
|
||||
)
|
||||
|
||||
|
||||
def _read_snippet_ids(zip_path: Path, *, tasks: list[str]) -> dict[str, set[str]]:
|
||||
"""Walk the ZIP for files whose path contains any task name.
|
||||
|
||||
Each MedRAG retriever produces one JSON file per task in the zip;
|
||||
we union all retrievers' top-K ids. The exact directory layout has
|
||||
historically been ``<retriever>/<task>.json`` mapping
|
||||
``question_id -> [snippet_id, ...]``.
|
||||
"""
|
||||
|
||||
out: dict[str, set[str]] = {t: set() for t in tasks}
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
for member in zf.namelist():
|
||||
if not member.lower().endswith(".json"):
|
||||
continue
|
||||
stem = Path(member).stem.lower()
|
||||
for task in tasks:
|
||||
if task.lower() in stem:
|
||||
try:
|
||||
with zf.open(member) as fh:
|
||||
payload = json.loads(fh.read().decode("utf-8"))
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
for ids in payload.values():
|
||||
if isinstance(ids, list):
|
||||
for sid in ids:
|
||||
if isinstance(sid, str):
|
||||
out[task].add(sid)
|
||||
elif isinstance(sid, dict) and "id" in sid:
|
||||
out[task].add(str(sid["id"]))
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def _load_corpus(
|
||||
corpus_name: str, snippet_ids: set[str] | None
|
||||
) -> Iterable[SnippetRow]:
|
||||
"""Stream rows from a MedRAG HF corpus.
|
||||
|
||||
* ``snippet_ids=None`` → yield every row (full-corpus ingestion path).
|
||||
* ``snippet_ids={...}`` → filter to the requested ids.
|
||||
|
||||
Imported lazily — ``datasets`` is a heavyweight dep.
|
||||
"""
|
||||
|
||||
if snippet_ids is not None and not snippet_ids:
|
||||
return iter(())
|
||||
from datasets import load_dataset # noqa: PLC0415
|
||||
|
||||
logger.info("Loading corpus %s (this may take a while)", corpus_name)
|
||||
ds = load_dataset(corpus_name, split="train", streaming=True)
|
||||
for row in ds:
|
||||
sid = str(row.get("id") or "")
|
||||
if snippet_ids is not None and sid not in snippet_ids:
|
||||
continue
|
||||
yield SnippetRow(
|
||||
snippet_id=sid,
|
||||
title=str(row.get("title") or ""),
|
||||
content=str(row.get("content") or row.get("contents") or ""),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batching + upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnippetBatch:
|
||||
path: Path
|
||||
snippet_ids: list[str]
|
||||
|
||||
|
||||
def _write_batches(
|
||||
snippets: Iterable[SnippetRow],
|
||||
*,
|
||||
out_dir: Path,
|
||||
batch_bytes: int = _BATCH_SIZE_BYTES,
|
||||
prefix: str = "mirage",
|
||||
) -> list[SnippetBatch]:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
batches: list[SnippetBatch] = []
|
||||
current_buffer = io.StringIO()
|
||||
current_ids: list[str] = []
|
||||
current_bytes = 0
|
||||
batch_idx = 0
|
||||
|
||||
def _flush() -> None:
|
||||
nonlocal current_buffer, current_ids, current_bytes, batch_idx
|
||||
if not current_ids:
|
||||
return
|
||||
path = out_dir / f"{prefix}_{batch_idx:04d}.md"
|
||||
path.write_text(current_buffer.getvalue(), encoding="utf-8")
|
||||
batches.append(SnippetBatch(path=path, snippet_ids=current_ids))
|
||||
batch_idx += 1
|
||||
current_buffer = io.StringIO()
|
||||
current_ids = []
|
||||
current_bytes = 0
|
||||
|
||||
for snippet in snippets:
|
||||
chunk = snippet.to_markdown() + "\n---\n\n"
|
||||
chunk_bytes = len(chunk.encode("utf-8"))
|
||||
if current_bytes + chunk_bytes > batch_bytes and current_ids:
|
||||
_flush()
|
||||
current_buffer.write(chunk)
|
||||
current_ids.append(snippet.snippet_id)
|
||||
current_bytes += chunk_bytes
|
||||
_flush()
|
||||
return batches
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
tasks: list[str] | None = None,
|
||||
corpus: str = _DEFAULT_CORPUS,
|
||||
max_snippets_per_task: int | None = None,
|
||||
skip_snippet_filter: bool = True,
|
||||
allow_large_download: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
"""Ingest a MedRAG corpus into the suite SearchSpace.
|
||||
|
||||
By default (``skip_snippet_filter=True``) we ingest the **entire**
|
||||
chosen corpus and let SurfSense's own retriever do the work. The
|
||||
upstream MIRAGE retrieval zip is ~16 GB and only useful when you
|
||||
want to pre-filter the corpus to the set of snippets some other
|
||||
retriever surfaced; for ``MedRAG/textbooks`` (212 MB / 125k snippets)
|
||||
that pre-filter is unnecessary overhead and routinely fails to
|
||||
download (box.com truncates the stream). Set
|
||||
``skip_snippet_filter=False`` (CLI: ``--use-snippet-filter``) only
|
||||
if you specifically want the upstream filter — and budget the
|
||||
16 GB zip transfer.
|
||||
"""
|
||||
|
||||
tasks = tasks or ["mmlu", "medqa", "medmcqa", "pubmedqa", "bioasq"]
|
||||
settings = settings or IngestSettings(use_vision_llm=False, processing_mode="basic")
|
||||
|
||||
bench_path = ctx.benchmark_data_dir() / "benchmark.json"
|
||||
await _fetch_to_path(MIRAGE_BENCHMARK_URL, dest=bench_path, label="MIRAGE benchmark.json")
|
||||
|
||||
if skip_snippet_filter:
|
||||
logger.info(
|
||||
"Skipping retrieved_snippets_10k.zip (skip_snippet_filter=True); "
|
||||
"ingesting entire corpus %s.",
|
||||
corpus,
|
||||
)
|
||||
snippets = list(_load_corpus(corpus, snippet_ids=None))
|
||||
else:
|
||||
zip_path = ctx.benchmark_data_dir() / "retrieved_snippets_10k.zip"
|
||||
await _fetch_to_path(
|
||||
MIRAGE_SNIPPETS_ZIP_URL,
|
||||
dest=zip_path,
|
||||
label="MIRAGE retrieved_snippets_10k.zip",
|
||||
allow_large_download=allow_large_download,
|
||||
expect_zip=True,
|
||||
)
|
||||
|
||||
by_task = _read_snippet_ids(zip_path, tasks=tasks)
|
||||
if max_snippets_per_task is not None:
|
||||
by_task = {k: set(list(v)[:max_snippets_per_task]) for k, v in by_task.items()}
|
||||
|
||||
union_ids = set().union(*by_task.values())
|
||||
logger.info(
|
||||
"MIRAGE: tasks=%s, snippet ids per task: %s, union=%d",
|
||||
tasks,
|
||||
{k: len(v) for k, v in by_task.items()},
|
||||
len(union_ids),
|
||||
)
|
||||
if not union_ids:
|
||||
raise RuntimeError(
|
||||
f"No snippet ids parsed for tasks {tasks!r} from {zip_path}. "
|
||||
"Check the zip layout (the upstream archive may have changed)."
|
||||
)
|
||||
|
||||
snippets = list(_load_corpus(corpus, snippet_ids=union_ids))
|
||||
logger.info(
|
||||
"Loaded %d / %d requested snippets from corpus %s",
|
||||
len(snippets),
|
||||
len(union_ids),
|
||||
corpus,
|
||||
)
|
||||
if not snippets:
|
||||
raise RuntimeError(
|
||||
f"Corpus {corpus} returned 0 matching rows. Either the snippet "
|
||||
"ids reference a different corpus (e.g. PubMed) or the HF mirror "
|
||||
"is unavailable. Pass --corpus to override."
|
||||
)
|
||||
|
||||
batches_dir = ctx.benchmark_data_dir() / "batches"
|
||||
batches = _write_batches(snippets, out_dir=batches_dir)
|
||||
logger.info("Wrote %d snippet batches to %s", len(batches), batches_dir)
|
||||
|
||||
docs_client = ctx.documents_client()
|
||||
upload_result = await docs_client.upload(
|
||||
files=[b.path for b in batches],
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
logger.info("MIRAGE upload settings: %s", settings.render_label())
|
||||
new_doc_ids = list(upload_result.document_ids)
|
||||
if new_doc_ids:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=new_doc_ids,
|
||||
timeout_s=3600.0,
|
||||
max_poll_s=15.0,
|
||||
)
|
||||
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=new_doc_ids + upload_result.duplicate_document_ids,
|
||||
)
|
||||
title_to_doc = {s.title: s.document_id for s in statuses}
|
||||
|
||||
snippet_map_path = ctx.maps_dir() / "mirage_snippet_map.jsonl"
|
||||
chunk_map_path = ctx.maps_dir() / "mirage_chunk_map.jsonl"
|
||||
with snippet_map_path.open("w", encoding="utf-8") as fh:
|
||||
# Header line records the ingest-time settings (see
|
||||
# core/ingest_settings.py for the protocol).
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for batch in batches:
|
||||
doc_id = title_to_doc.get(batch.path.name)
|
||||
if doc_id is None:
|
||||
logger.warning("No document_id for batch %s", batch.path.name)
|
||||
continue
|
||||
for sid in batch.snippet_ids:
|
||||
fh.write(
|
||||
json.dumps(
|
||||
{
|
||||
"snippet_id": sid,
|
||||
"document_id": doc_id,
|
||||
"batch_path": str(batch.path),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
# Best-effort chunk map. SurfSense doesn't expose snippet attribution
|
||||
# per chunk, so we just record (chunk_id -> document_id) here; the
|
||||
# MIRAGE runner only needs document_id for accuracy scoring.
|
||||
with chunk_map_path.open("w", encoding="utf-8") as fh:
|
||||
for doc_id in {b.path.name and title_to_doc.get(b.path.name) for b in batches} - {None}:
|
||||
try:
|
||||
chunks = await docs_client.list_chunks(int(doc_id))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to list chunks for doc_id=%s: %s", doc_id, exc)
|
||||
continue
|
||||
for chunk in chunks:
|
||||
fh.write(
|
||||
json.dumps({"chunk_id": chunk.id, "document_id": doc_id})
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["mirage"] = str(snippet_map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
logger.info("Wrote MIRAGE maps to %s and %s", snippet_map_path, chunk_map_path)
|
||||
|
||||
|
||||
__all__ = ["run_ingest", "SnippetRow", "SnippetBatch"]
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
"""MedRAG ``{step_by_step_thinking, answer_choice}`` MCQ prompt.
|
||||
|
||||
Mirrors the MedRAG paper's prompt format so accuracy numbers are
|
||||
comparable to the published MIRAGE leaderboard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """\
|
||||
You are a helpful medical expert. Answer the following multiple-choice
|
||||
question using the relevant medical knowledge available to you (and any
|
||||
retrieved context, if provided).
|
||||
|
||||
Respond with a JSON object on a single line:
|
||||
{{"step_by_step_thinking": "<your reasoning>", "answer_choice": "<letter>"}}
|
||||
|
||||
Question: {question}
|
||||
|
||||
Options:
|
||||
{options_block}
|
||||
"""
|
||||
|
||||
|
||||
def _options_block(options: Mapping[str, str]) -> str:
|
||||
parts: list[str] = []
|
||||
for letter in sorted(options.keys()):
|
||||
text = options.get(letter)
|
||||
if text is None or text == "":
|
||||
continue
|
||||
parts.append(f"{letter}) {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def build_prompt(question: str, options: Mapping[str, str]) -> str:
|
||||
return _PROMPT_TEMPLATE.format(
|
||||
question=question.strip(),
|
||||
options_block=_options_block(options),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["build_prompt"]
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
"""MIRAGE runner: SurfSense-only per-task accuracy.
|
||||
|
||||
The benchmark file format is one top-level dict per task (``mmlu``,
|
||||
``medqa``, ``medmcqa``, ``pubmedqa``, ``bioasq``); each task value is
|
||||
``{question_id: {question, options, answer}}``.
|
||||
|
||||
We restrict retrieval to the suite SearchSpace's full corpus (no
|
||||
``mentioned_document_ids`` — MIRAGE has no per-question ground-truth
|
||||
document; retrieval *is* the test). Accuracy is paired against the
|
||||
``answer`` letter from the dataset.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
read_settings_header,
|
||||
)
|
||||
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci, macro_accuracy
|
||||
from ....core.registry import (
|
||||
Benchmark,
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
)
|
||||
from .prompt import build_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_TASKS = ("mmlu", "medqa", "medmcqa", "pubmedqa", "bioasq")
|
||||
_DESCRIPTION = "MIRAGE (7,663 medical MCQs) — single-arm SurfSense per-task accuracy."
|
||||
|
||||
# MIRAGE corpus is text-only (textbook + abstract markdown). Vision
|
||||
# LLM at ingest is wasted compute by default; flip ``--use-vision-llm``
|
||||
# to measure cost.
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MirageQuestion:
|
||||
task: str
|
||||
qid: str
|
||||
question: str
|
||||
options: dict[str, str]
|
||||
correct: str
|
||||
|
||||
@property
|
||||
def question_id(self) -> str:
|
||||
return f"{self.task}::{self.qid}"
|
||||
|
||||
|
||||
def _load_questions(
|
||||
benchmark: dict[str, Any],
|
||||
*,
|
||||
tasks: list[str],
|
||||
sample_n: int | None,
|
||||
) -> list[MirageQuestion]:
|
||||
out: list[MirageQuestion] = []
|
||||
for task in tasks:
|
||||
rows = benchmark.get(task) or {}
|
||||
if not isinstance(rows, dict):
|
||||
continue
|
||||
for qid, raw in rows.items():
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
options = raw.get("options") or {}
|
||||
if not isinstance(options, dict):
|
||||
continue
|
||||
answer_raw = str(raw.get("answer") or "").strip()
|
||||
if not answer_raw:
|
||||
continue
|
||||
answer_letter = answer_raw[:1].upper()
|
||||
out.append(
|
||||
MirageQuestion(
|
||||
task=task,
|
||||
qid=str(qid),
|
||||
question=str(raw.get("question", "")),
|
||||
options={str(k): str(v) for k, v in options.items() if v},
|
||||
correct=answer_letter,
|
||||
)
|
||||
)
|
||||
out.sort(key=lambda q: (q.task, q.qid))
|
||||
if sample_n is not None and sample_n > 0:
|
||||
# Stratified-by-task slice so smoke runs cover every task.
|
||||
per_task = max(1, sample_n // max(1, len(tasks)))
|
||||
sliced: list[MirageQuestion] = []
|
||||
per_task_counter: dict[str, int] = {}
|
||||
for q in out:
|
||||
n = per_task_counter.get(q.task, 0)
|
||||
if n >= per_task:
|
||||
continue
|
||||
sliced.append(q)
|
||||
per_task_counter[q.task] = n + 1
|
||||
if len(sliced) >= sample_n:
|
||||
break
|
||||
out = sliced
|
||||
return out
|
||||
|
||||
|
||||
async def _gather_with_limit(coros, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(c):
|
||||
async with sem:
|
||||
return await c
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
class MirageBenchmark:
|
||||
suite: str = "medical"
|
||||
name: str = "mirage"
|
||||
headline: bool = False
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
default="all",
|
||||
choices=("all", *_TASKS),
|
||||
help="Run a single task or all (default: all).",
|
||||
)
|
||||
parser.add_argument("--n", dest="sample_n", type=int, default=None,
|
||||
help="Stratified sample size across tasks.")
|
||||
parser.add_argument("--concurrency", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--corpus", default="MedRAG/textbooks",
|
||||
help="HF MedRAG corpus to ingest from (default: MedRAG/textbooks).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-snippets-per-task", type=int, default=None,
|
||||
help="Cap the per-task ingestion to N snippets (smoke).",
|
||||
)
|
||||
# Mutually exclusive: by default we skip the upstream 16 GB
|
||||
# retrievals zip and ingest the entire corpus. Operators who
|
||||
# want the upstream pre-filter pass --use-snippet-filter (and,
|
||||
# if their corpus mismatch warrants the 16 GB transfer,
|
||||
# --allow-large-download).
|
||||
snippet_group = parser.add_mutually_exclusive_group()
|
||||
snippet_group.add_argument(
|
||||
"--use-snippet-filter", dest="use_snippet_filter", action="store_true",
|
||||
default=False,
|
||||
help="Download retrieved_snippets_10k.zip (~16 GB) and "
|
||||
"filter the corpus to those ids before ingest. "
|
||||
"Default: skip and ingest entire corpus.",
|
||||
)
|
||||
snippet_group.add_argument(
|
||||
"--skip-snippet-filter", dest="use_snippet_filter", action="store_false",
|
||||
help="(Default) Skip the 16 GB upstream zip; ingest entire corpus.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow-large-download", action="store_true", default=False,
|
||||
help="Permit downloads larger than 2 GB (e.g. retrieved_snippets_10k.zip).",
|
||||
)
|
||||
# Per-upload knobs; ignored at run-time (runner reads the
|
||||
# resolved settings out of the snippet-map manifest header).
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
corpus=str(opts.get("corpus") or "MedRAG/textbooks"),
|
||||
max_snippets_per_task=opts.get("max_snippets_per_task"),
|
||||
skip_snippet_filter=not bool(opts.get("use_snippet_filter")),
|
||||
allow_large_download=bool(opts.get("allow_large_download")),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
task_filter = opts.get("task") or "all"
|
||||
tasks = list(_TASKS) if task_filter == "all" else [task_filter]
|
||||
sample_n = opts.get("sample_n")
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
|
||||
bench_path = ctx.benchmark_data_dir() / "benchmark.json"
|
||||
if not bench_path.exists():
|
||||
raise RuntimeError(
|
||||
"MIRAGE benchmark.json missing. Run "
|
||||
"`python -m surfsense_evals ingest medical mirage` first."
|
||||
)
|
||||
benchmark = json.loads(bench_path.read_text(encoding="utf-8"))
|
||||
ingest_settings = read_settings_header(
|
||||
ctx.maps_dir() / "mirage_snippet_map.jsonl"
|
||||
)
|
||||
questions = _load_questions(benchmark, tasks=tasks, sample_n=sample_n)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
f"No MIRAGE questions matched task={task_filter!r} sample_n={sample_n!r}."
|
||||
)
|
||||
logger.info("MIRAGE: scheduled %d questions across tasks %s",
|
||||
len(questions), tasks)
|
||||
|
||||
arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
async def _ask(q: MirageQuestion) -> ArmResult:
|
||||
request = ArmRequest(
|
||||
question_id=q.question_id,
|
||||
prompt=build_prompt(q.question, q.options),
|
||||
)
|
||||
return await arm.answer(request)
|
||||
|
||||
results: list[ArmResult] = await _gather_with_limit(
|
||||
(_ask(q) for q in questions), concurrency=concurrency
|
||||
)
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, res in zip(questions, results):
|
||||
fh.write(
|
||||
json.dumps(
|
||||
{
|
||||
"task": q.task,
|
||||
"qid": q.qid,
|
||||
"correct": q.correct,
|
||||
**res.to_jsonl(),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
per_task_acc: dict[str, dict[str, Any]] = {}
|
||||
for task in tasks:
|
||||
n_correct = 0
|
||||
n_total = 0
|
||||
for q, res in zip(questions, results):
|
||||
if q.task != task:
|
||||
continue
|
||||
n_total += 1
|
||||
if (res.answer_letter or "").upper() == q.correct:
|
||||
n_correct += 1
|
||||
acc = accuracy_with_wilson_ci(n_correct, n_total)
|
||||
per_task_acc[task] = acc.to_dict()
|
||||
|
||||
macro = macro_accuracy(
|
||||
{t: accuracy_with_wilson_ci(d["n_correct"], d["n_total"]) for t, d in per_task_acc.items()}
|
||||
)
|
||||
metrics = {"per_task": per_task_acc, "macro_accuracy": macro}
|
||||
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_questions": len(questions),
|
||||
"task_filter": task_filter,
|
||||
"concurrency": concurrency,
|
||||
"provider_model": ctx.provider_model,
|
||||
"ingest_settings": ingest_settings,
|
||||
},
|
||||
)
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="MIRAGE — single-arm SurfSense per-task accuracy",
|
||||
headline=False,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
per_task = latest.metrics.get("per_task", {})
|
||||
macro = latest.metrics.get("macro_accuracy", 0.0)
|
||||
lines: list[str] = []
|
||||
lines.append(format_ingest_settings_md(latest.extra.get("ingest_settings")))
|
||||
for task in _TASKS:
|
||||
row = per_task.get(task)
|
||||
if not row:
|
||||
continue
|
||||
acc = row.get("accuracy", 0.0)
|
||||
low = row.get("ci_low", 0.0)
|
||||
high = row.get("ci_high", 0.0)
|
||||
lines.append(
|
||||
f"- {task}: {acc * 100:.1f}% "
|
||||
f"(Wilson 95% CI: {low * 100:.1f}% – {high * 100:.1f}%, "
|
||||
f"n={row.get('n_total', '?')})"
|
||||
)
|
||||
if not lines:
|
||||
lines.append("- (no per-task results)")
|
||||
lines.append(f"- Macro accuracy: {macro * 100:.2f}%")
|
||||
return ReportSection(
|
||||
title="MIRAGE — single-arm SurfSense per-task accuracy",
|
||||
headline=False,
|
||||
body_md="\n".join(lines),
|
||||
body_json=latest.metrics,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["MirageBenchmark", "MirageQuestion"]
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
"""Multimodal long-document benchmarks (PDFs with embedded images/charts/tables).
|
||||
|
||||
Distinct from the medical suite because these documents are domain-mixed
|
||||
(research reports, financials, manuals, government, brochures, papers).
|
||||
The hypothesis being tested here is *general*: does SurfSense's
|
||||
chunking-based vision RAG preserve information that lives in pixels —
|
||||
across long PDFs, across pages — versus feeding the same PDF directly
|
||||
to a vision-capable model?
|
||||
|
||||
Subpackages register themselves with ``core.registry`` on import. The
|
||||
``suites/__init__.py`` discovery walker imports them automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""MMLongBench-Doc — head-to-head Native PDF (vision) vs SurfSense (vision RAG).
|
||||
|
||||
Source: https://huggingface.co/datasets/yubo2333/MMLongBench-Doc
|
||||
Paper: https://arxiv.org/abs/2407.01523 (NeurIPS 2024 D&B Track)
|
||||
|
||||
* 135 long PDFs (avg 47 pages, multi-modal: text, images, charts, tables)
|
||||
* 1,091 expert-annotated questions
|
||||
* 33% require evidence from multiple pages
|
||||
* ~22% intentionally unanswerable (tests hallucination resistance)
|
||||
* 7 document types: research report, tutorial/workshop, academic paper,
|
||||
financial report, brochure, government, manuals
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import registry as _registry
|
||||
from .runner import MMLongBenchDocBenchmark
|
||||
|
||||
_registry.register(MMLongBenchDocBenchmark())
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
"""Format-aware grader for MMLongBench-Doc answers.
|
||||
|
||||
The dataset ships with five ``answer_format`` values per question:
|
||||
|
||||
* ``Str`` — short factoid string
|
||||
* ``Int`` — integer count / year
|
||||
* ``Float`` — decimal number (often with units stripped)
|
||||
* ``List`` — comma- or semicolon-separated bag of items
|
||||
* ``None`` — gold answer is literally "Not answerable" (hallucination probe)
|
||||
|
||||
The official MMLongBench-Doc paper grades with GPT-4 as judge. We
|
||||
implement a *deterministic* rule-based grader as the default (so two
|
||||
researchers running the same harness get the same number); an
|
||||
LLM-judge mode is exposed via ``--judge gpt5`` and routed through the
|
||||
same OpenRouter key the arms use, but is opt-in to keep cost down.
|
||||
|
||||
Returned by every grading call:
|
||||
|
||||
* ``correct: bool`` — final pass/fail used for accuracy + McNemar
|
||||
* ``f1: float`` — token-level F1 (continuous credit, useful when
|
||||
comparing arms that get *most* of a list right)
|
||||
* ``method: str`` — which path graded the row (one of
|
||||
``str_norm`` / ``int_eq`` / ``float_tol`` / ``list_set`` /
|
||||
``none_match`` / ``llm_judge``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import string
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradeResult:
|
||||
correct: bool
|
||||
f1: float
|
||||
method: str
|
||||
normalised_pred: str = ""
|
||||
normalised_gold: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Normalisation helpers (shared)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
|
||||
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.IGNORECASE)
|
||||
_WS = re.compile(r"\s+")
|
||||
_NOT_ANSWERABLE_TOKENS = {
|
||||
"not answerable",
|
||||
"cannot be answered",
|
||||
"cannot answer",
|
||||
"no answer",
|
||||
"unknown",
|
||||
"none",
|
||||
"not specified",
|
||||
"not mentioned",
|
||||
"not provided",
|
||||
"the answer is not in the document",
|
||||
}
|
||||
|
||||
# Abbreviations that should be matched literally on the lowercased
|
||||
# prediction (because normalisation strips their punctuation and
|
||||
# leaves them too short to be safe as substring tokens).
|
||||
_NOT_ANSWERABLE_LITERAL = {"n/a", "na/", "n.a.", "n a"}
|
||||
|
||||
|
||||
def _normalise_text(s: str) -> str:
|
||||
"""SQuAD-style normalisation: lowercase, drop punctuation/articles, squash whitespace."""
|
||||
|
||||
s = s.lower()
|
||||
s = s.translate(_PUNCT_TABLE)
|
||||
s = _ARTICLES.sub(" ", s)
|
||||
s = _WS.sub(" ", s).strip()
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-format graders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _grade_str(pred: str, gold: str) -> GradeResult:
|
||||
p = _normalise_text(pred)
|
||||
g = _normalise_text(gold)
|
||||
if not p:
|
||||
return GradeResult(False, 0.0, "str_norm", p, g)
|
||||
if p == g:
|
||||
return GradeResult(True, 1.0, "str_norm", p, g)
|
||||
# Substring match in either direction = correct (handles the common
|
||||
# "model emits a fuller sentence containing the gold" case).
|
||||
if g and (g in p or p in g):
|
||||
return GradeResult(True, _f1_tokens(p, g), "str_norm", p, g)
|
||||
return GradeResult(False, _f1_tokens(p, g), "str_norm", p, g)
|
||||
|
||||
|
||||
_INT_RE = re.compile(r"-?\d[\d,]*")
|
||||
|
||||
|
||||
def _grade_int(pred: str, gold: str) -> GradeResult:
|
||||
g_match = _INT_RE.search(gold)
|
||||
if g_match is None:
|
||||
return _grade_str(pred, gold)
|
||||
g_val = int(g_match.group(0).replace(",", ""))
|
||||
p_match = _INT_RE.search(pred)
|
||||
if p_match is None:
|
||||
return GradeResult(False, 0.0, "int_eq", str(p_match), str(g_val))
|
||||
p_val = int(p_match.group(0).replace(",", ""))
|
||||
return GradeResult(p_val == g_val, 1.0 if p_val == g_val else 0.0,
|
||||
"int_eq", str(p_val), str(g_val))
|
||||
|
||||
|
||||
_FLOAT_RE = re.compile(r"-?\d+(?:[.,]\d+)?")
|
||||
|
||||
|
||||
def _grade_float(pred: str, gold: str, *, rel_tol: float = 1e-2) -> GradeResult:
|
||||
g_match = _FLOAT_RE.search(gold)
|
||||
if g_match is None:
|
||||
return _grade_str(pred, gold)
|
||||
g_val = float(g_match.group(0).replace(",", "."))
|
||||
p_match = _FLOAT_RE.search(pred)
|
||||
if p_match is None:
|
||||
return GradeResult(False, 0.0, "float_tol", "", str(g_val))
|
||||
p_val = float(p_match.group(0).replace(",", "."))
|
||||
# Tolerance: 1% relative or 0.01 absolute, whichever is looser.
|
||||
abs_diff = abs(p_val - g_val)
|
||||
tol = max(abs(g_val) * rel_tol, 0.01)
|
||||
ok = abs_diff <= tol
|
||||
return GradeResult(ok, 1.0 if ok else 0.0, "float_tol", str(p_val), str(g_val))
|
||||
|
||||
|
||||
_LIST_SPLIT = re.compile(r"[;,\n]")
|
||||
|
||||
|
||||
def _grade_list(pred: str, gold: str) -> GradeResult:
|
||||
g_items = {_normalise_text(x) for x in _LIST_SPLIT.split(gold) if x.strip()}
|
||||
p_items = {_normalise_text(x) for x in _LIST_SPLIT.split(pred) if x.strip()}
|
||||
if not g_items:
|
||||
return _grade_str(pred, gold)
|
||||
inter = g_items & p_items
|
||||
if not inter:
|
||||
return GradeResult(False, 0.0, "list_set",
|
||||
", ".join(sorted(p_items)),
|
||||
", ".join(sorted(g_items)))
|
||||
precision = len(inter) / len(p_items) if p_items else 0.0
|
||||
recall = len(inter) / len(g_items)
|
||||
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return GradeResult(f1 >= 0.999, f1, "list_set",
|
||||
", ".join(sorted(p_items)),
|
||||
", ".join(sorted(g_items)))
|
||||
|
||||
|
||||
def _grade_none(pred: str, gold: str) -> GradeResult:
|
||||
"""Gold == 'Not answerable'. The arm earns credit if its prediction
|
||||
expresses inability to answer.
|
||||
|
||||
Two passes:
|
||||
|
||||
1. Literal-substring check on the lowercased+stripped pred for
|
||||
ambiguous abbreviations like ``n/a`` (since normalisation
|
||||
strips the punctuation and would over-match).
|
||||
2. Word-boundary substring check on the normalised pred for the
|
||||
multi-word phrases (``cannot answer``, ``not specified`` etc.).
|
||||
"""
|
||||
|
||||
raw_lower = (pred or "").strip().lower()
|
||||
p = _normalise_text(pred)
|
||||
expressed_unknown = False
|
||||
|
||||
# Pass 1: literal abbreviation hits on the raw lowercased text.
|
||||
if any(lit in raw_lower for lit in _NOT_ANSWERABLE_LITERAL):
|
||||
expressed_unknown = True
|
||||
|
||||
# Pass 2: word-boundary check on normalised tokens.
|
||||
if not expressed_unknown:
|
||||
p_padded = f" {p} "
|
||||
for tok_raw in _NOT_ANSWERABLE_TOKENS:
|
||||
tok = _normalise_text(tok_raw)
|
||||
if not tok or len(tok) < 3:
|
||||
continue
|
||||
if f" {tok} " in p_padded:
|
||||
expressed_unknown = True
|
||||
break
|
||||
return GradeResult(
|
||||
expressed_unknown, 1.0 if expressed_unknown else 0.0,
|
||||
"none_match", p, _normalise_text(gold),
|
||||
)
|
||||
|
||||
|
||||
def _f1_tokens(pred: str, gold: str) -> float:
|
||||
p_tok = pred.split()
|
||||
g_tok = gold.split()
|
||||
if not p_tok or not g_tok:
|
||||
return 0.0
|
||||
common = Counter(p_tok) & Counter(g_tok)
|
||||
overlap = sum(common.values())
|
||||
if overlap == 0:
|
||||
return 0.0
|
||||
precision = overlap / len(p_tok)
|
||||
recall = overlap / len(g_tok)
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_FORMAT_DISPATCH = {
|
||||
"str": _grade_str,
|
||||
"int": _grade_int,
|
||||
"float": _grade_float,
|
||||
"list": _grade_list,
|
||||
"none": _grade_none,
|
||||
}
|
||||
|
||||
|
||||
def grade(*, pred: str, gold: str, answer_format: str) -> GradeResult:
|
||||
"""Grade a single (prediction, gold) pair.
|
||||
|
||||
``answer_format`` is the dataset's ``answer_format`` column value.
|
||||
Unknown / blank values fall through to string grading.
|
||||
"""
|
||||
|
||||
fmt = (answer_format or "").strip().lower()
|
||||
fn = _FORMAT_DISPATCH.get(fmt, _grade_str)
|
||||
return fn(pred or "", gold or "")
|
||||
|
||||
|
||||
__all__ = ["GradeResult", "grade"]
|
||||
|
|
@ -0,0 +1,365 @@
|
|||
"""MMLongBench-Doc ingestion.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Pull the questions parquet from
|
||||
``hf://datasets/yubo2333/MMLongBench-Doc/data/`` and cache locally.
|
||||
2. Resolve the unique set of ``doc_id`` referenced by questions, and
|
||||
download each PDF from
|
||||
``hf://datasets/yubo2333/MMLongBench-Doc/documents/<doc_id>``.
|
||||
``huggingface_hub.hf_hub_download`` is resumable + content-hash
|
||||
verifying; we cache PDFs under ``<data_dir>/multimodal_doc/mmlongbench/pdfs/``.
|
||||
3. Upload every PDF to SurfSense via ``DocumentsClient.upload`` with
|
||||
``use_vision_llm=True`` so SurfSense's Pillow + LiteLLM vision
|
||||
pipeline extracts captions / OCR for embedded images, charts, and
|
||||
tables.
|
||||
4. Wait for ``processed`` status and persist
|
||||
``doc_id -> document_id`` in
|
||||
``<data_dir>/multimodal_doc/maps/mmlongbench_doc_map.jsonl``.
|
||||
|
||||
By default we ingest **all** 135 PDFs (~660 MB, totally manageable).
|
||||
Operators can scope to a subset with ``--max-docs N`` if iterating on
|
||||
a slow vision pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HF_REPO_ID = "yubo2333/MMLongBench-Doc"
|
||||
HF_REPO_TYPE = "dataset"
|
||||
|
||||
# Lazy import: huggingface_hub + pyarrow are heavyweight; keep the
|
||||
# benchmark module importable on machines that have only the core
|
||||
# install (e.g. CI lint jobs).
|
||||
def _hf_hub_download(*args, **kwargs):
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
return hf_hub_download(*args, **kwargs)
|
||||
|
||||
|
||||
def _list_repo_files() -> list[str]:
|
||||
from huggingface_hub import list_repo_files
|
||||
|
||||
return list_repo_files(repo_id=HF_REPO_ID, repo_type=HF_REPO_TYPE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question parquet -> Python rows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMLongBenchQuestion:
|
||||
doc_id: str # filename inside the documents/ folder
|
||||
doc_type: str
|
||||
question: str
|
||||
answer: str
|
||||
answer_format: str # Str / Int / Float / List / None
|
||||
evidence_pages: list[int]
|
||||
evidence_sources: list[str]
|
||||
|
||||
|
||||
def _load_questions_from_parquet(parquet_path: Path) -> list[MMLongBenchQuestion]:
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
table = pq.read_table(parquet_path)
|
||||
rows = table.to_pylist()
|
||||
out: list[MMLongBenchQuestion] = []
|
||||
for row in rows:
|
||||
doc_id = str(row.get("doc_id") or "").strip()
|
||||
if not doc_id:
|
||||
continue
|
||||
question = str(row.get("question") or "").strip()
|
||||
if not question:
|
||||
continue
|
||||
out.append(
|
||||
MMLongBenchQuestion(
|
||||
doc_id=doc_id,
|
||||
doc_type=str(row.get("doc_type") or "").strip(),
|
||||
question=question,
|
||||
answer=str(row.get("answer") or "").strip(),
|
||||
answer_format=str(row.get("answer_format") or "").strip(),
|
||||
evidence_pages=_parse_int_list(row.get("evidence_pages")),
|
||||
evidence_sources=_parse_str_list(row.get("evidence_sources")),
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _parse_int_list(raw) -> list[int]:
|
||||
if raw is None:
|
||||
return []
|
||||
if isinstance(raw, list):
|
||||
out = []
|
||||
for x in raw:
|
||||
try:
|
||||
out.append(int(x))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return out
|
||||
text = str(raw).strip().strip("[]")
|
||||
if not text:
|
||||
return []
|
||||
out: list[int] = []
|
||||
for tok in text.split(","):
|
||||
tok = tok.strip().strip("'\"")
|
||||
if tok.isdigit():
|
||||
out.append(int(tok))
|
||||
return out
|
||||
|
||||
|
||||
def _parse_str_list(raw) -> list[str]:
|
||||
if raw is None:
|
||||
return []
|
||||
if isinstance(raw, list):
|
||||
return [str(x).strip().strip("'\"") for x in raw if str(x).strip()]
|
||||
text = str(raw).strip().strip("[]")
|
||||
if not text:
|
||||
return []
|
||||
return [tok.strip().strip("'\"") for tok in text.split(",") if tok.strip()]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _download_questions_parquet(cache_dir: Path) -> Path:
|
||||
"""Download every parquet under ``data/`` and concatenate.
|
||||
|
||||
The HF dataset usually publishes a single ``train`` split, but we
|
||||
enumerate to be robust to repo restructuring.
|
||||
"""
|
||||
|
||||
parquet_paths: list[Path] = []
|
||||
files = _list_repo_files()
|
||||
data_files = [f for f in files if f.startswith("data/") and f.endswith(".parquet")]
|
||||
if not data_files:
|
||||
raise RuntimeError(
|
||||
f"No parquet files found under data/ in {HF_REPO_ID}; "
|
||||
f"upstream repo may have been restructured."
|
||||
)
|
||||
for rel in sorted(data_files):
|
||||
local = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=rel,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(cache_dir),
|
||||
)
|
||||
parquet_paths.append(Path(local))
|
||||
logger.info("Cached MMLongBench parquet shard %s -> %s", rel, local)
|
||||
return parquet_paths[0] if len(parquet_paths) == 1 else _merge_parquets(parquet_paths, cache_dir)
|
||||
|
||||
|
||||
def _merge_parquets(paths: list[Path], cache_dir: Path) -> Path:
|
||||
"""Combine multiple parquet shards into one (rare branch, but correct)."""
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
tables = [pq.read_table(p) for p in paths]
|
||||
merged = pa.concat_tables(tables, promote_options="default")
|
||||
out = cache_dir / "merged_questions.parquet"
|
||||
pq.write_table(merged, out)
|
||||
return out
|
||||
|
||||
|
||||
def _download_pdf(doc_id: str, cache_dir: Path, pdfs_dir: Path) -> Path:
|
||||
"""Download a single PDF (resumable via huggingface_hub cache)."""
|
||||
|
||||
rel = f"documents/{doc_id}"
|
||||
local = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=rel,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(cache_dir),
|
||||
)
|
||||
# Materialise to a stable path inside our data/ tree so the runner
|
||||
# has a deterministic location regardless of HF cache internals.
|
||||
dest = pdfs_dir / doc_id
|
||||
if not dest.exists() or dest.stat().st_size != Path(local).stat().st_size:
|
||||
# Use a hardlink when possible (cheap), fall back to copy.
|
||||
try:
|
||||
if dest.exists():
|
||||
dest.unlink()
|
||||
os.link(local, dest)
|
||||
except OSError:
|
||||
from shutil import copy2
|
||||
|
||||
copy2(local, dest)
|
||||
return dest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _upload_pdfs(
|
||||
ctx: RunContext,
|
||||
pdf_paths: Iterable[Path],
|
||||
*,
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
) -> dict[str, int]:
|
||||
"""Upload PDFs in batches, return ``filename -> document_id`` map."""
|
||||
|
||||
docs_client = ctx.documents_client()
|
||||
name_to_id: dict[str, int] = {}
|
||||
pdf_list = list(pdf_paths)
|
||||
for batch_start in range(0, len(pdf_list), batch_size):
|
||||
batch = pdf_list[batch_start:batch_start + batch_size]
|
||||
result = await docs_client.upload(
|
||||
files=batch,
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
|
||||
if all_ids:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=result.document_ids, # only newly added need polling
|
||||
timeout_s=1800.0, # vision pipeline is slow on long PDFs
|
||||
)
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=all_ids,
|
||||
)
|
||||
for s in statuses:
|
||||
name_to_id[s.title] = s.document_id
|
||||
logger.info(
|
||||
"Uploaded MMLongBench batch %d-%d: %d new, %d duplicate",
|
||||
batch_start, batch_start + len(batch),
|
||||
len(result.document_ids), len(result.duplicate_document_ids),
|
||||
)
|
||||
return name_to_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
max_docs: int | None = None,
|
||||
upload_batch_size: int = 8,
|
||||
skip_upload: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
"""Ingest MMLongBench-Doc into the multimodal_doc suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_docs : int | None
|
||||
Cap the number of PDFs to download + upload. ``None`` = all 135.
|
||||
Useful when iterating on the runner without paying for the full
|
||||
vision pipeline pass each time.
|
||||
upload_batch_size : int
|
||||
How many PDFs to send per ``fileupload`` call. Smaller batches
|
||||
recover faster from individual failures; larger batches reduce
|
||||
round-trip overhead.
|
||||
skip_upload : bool
|
||||
Download + cache PDFs locally but skip SurfSense ingestion.
|
||||
Useful for testing the native arm in isolation.
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(use_vision_llm=True, processing_mode="basic")
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
pdfs_dir = bench_dir / "pdfs"
|
||||
pdfs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_cache = bench_dir / ".hf_cache"
|
||||
hf_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Step 1: questions
|
||||
parquet_path = _download_questions_parquet(hf_cache)
|
||||
questions = _load_questions_from_parquet(parquet_path)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"MMLongBench-Doc parquet contains no parseable questions. "
|
||||
"Upstream may have changed schema."
|
||||
)
|
||||
|
||||
# Persist a copy alongside the PDFs so the runner has one place to read.
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
with questions_jsonl.open("w", encoding="utf-8") as fh:
|
||||
for q in questions:
|
||||
fh.write(json.dumps({
|
||||
"doc_id": q.doc_id,
|
||||
"doc_type": q.doc_type,
|
||||
"question": q.question,
|
||||
"answer": q.answer,
|
||||
"answer_format": q.answer_format,
|
||||
"evidence_pages": q.evidence_pages,
|
||||
"evidence_sources": q.evidence_sources,
|
||||
}) + "\n")
|
||||
logger.info("Wrote %d MMLongBench questions to %s", len(questions), questions_jsonl)
|
||||
|
||||
# Step 2: download unique PDFs
|
||||
unique_doc_ids = sorted({q.doc_id for q in questions})
|
||||
if max_docs is not None and max_docs > 0:
|
||||
unique_doc_ids = unique_doc_ids[:max_docs]
|
||||
logger.info("MMLongBench: downloading %d unique PDFs", len(unique_doc_ids))
|
||||
|
||||
pdf_paths: dict[str, Path] = {}
|
||||
for i, doc_id in enumerate(unique_doc_ids, start=1):
|
||||
try:
|
||||
pdf_paths[doc_id] = _download_pdf(doc_id, hf_cache, pdfs_dir)
|
||||
if i % 10 == 0:
|
||||
logger.info(" ... %d / %d PDFs cached", i, len(unique_doc_ids))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to download MMLongBench PDF %s: %s", doc_id, exc)
|
||||
|
||||
# Step 3: upload to SurfSense
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("MMLongBench: --skip-upload set; skipping SurfSense ingestion")
|
||||
else:
|
||||
logger.info("MMLongBench upload settings: %s", settings.render_label())
|
||||
name_to_id = await _upload_pdfs(
|
||||
ctx,
|
||||
pdf_paths.values(),
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Step 4: persist doc_id -> document_id manifest
|
||||
map_path = ctx.maps_dir() / "mmlongbench_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
# Header line records the resolved ingest settings
|
||||
# (see core/ingest_settings.py).
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for doc_id in unique_doc_ids:
|
||||
local = pdf_paths.get(doc_id)
|
||||
if local is None:
|
||||
continue
|
||||
fh.write(json.dumps({
|
||||
"doc_id": doc_id,
|
||||
"document_id": name_to_id.get(local.name),
|
||||
"pdf_path": str(local),
|
||||
"n_questions": sum(1 for q in questions if q.doc_id == doc_id),
|
||||
}) + "\n")
|
||||
logger.info("Wrote MMLongBench doc map to %s", map_path)
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["mmlongbench"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
|
||||
__all__ = ["MMLongBenchQuestion", "run_ingest"]
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
"""MMLongBench-Doc prompt template.
|
||||
|
||||
Both arms get the same prompt — only the document delivery channel
|
||||
differs (native PDF embedded in the OpenRouter request vs SurfSense
|
||||
RAG retrieval). The format hint in the prompt mirrors what the
|
||||
upstream paper uses so the grader's regex can reliably extract the
|
||||
answer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-format hint blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FORMAT_HINTS: dict[str, str] = {
|
||||
"str": (
|
||||
"Respond with the answer as a short phrase, no full sentence. "
|
||||
"Format your final line as `Answer: <text>`."
|
||||
),
|
||||
"int": (
|
||||
"Respond with a single integer only. "
|
||||
"Format your final line as `Answer: <integer>`."
|
||||
),
|
||||
"float": (
|
||||
"Respond with a single decimal number only (no units). "
|
||||
"Format your final line as `Answer: <number>`."
|
||||
),
|
||||
"list": (
|
||||
"Respond with a comma-separated list of items, no extra text. "
|
||||
"Format your final line as `Answer: item1, item2, item3`."
|
||||
),
|
||||
"none": (
|
||||
"If the answer cannot be determined from the document, say so explicitly. "
|
||||
"Format your final line as `Answer: Not answerable`."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
_PROMPT = """\
|
||||
You are a document-understanding assistant. Use ONLY the provided
|
||||
document to answer the question. The document may contain text,
|
||||
tables, charts, figures, and images. If the answer is in a chart or
|
||||
image, read it carefully. Do not use external knowledge.
|
||||
|
||||
Question: {question}
|
||||
|
||||
{format_hint}
|
||||
"""
|
||||
|
||||
|
||||
def build_prompt(question: str, *, answer_format: str) -> str:
|
||||
"""Assemble the full prompt for one MMLongBench question."""
|
||||
|
||||
fmt = (answer_format or "str").strip().lower()
|
||||
hint = _FORMAT_HINTS.get(fmt, _FORMAT_HINTS["str"])
|
||||
return _PROMPT.format(question=question.strip(), format_hint=hint)
|
||||
|
||||
|
||||
__all__ = ["build_prompt"]
|
||||
|
|
@ -0,0 +1,704 @@
|
|||
"""MMLongBench-Doc runner — head-to-head Native PDF (vision) vs SurfSense (vision RAG).
|
||||
|
||||
Differences from a typical MCQ head-to-head:
|
||||
|
||||
* Open-ended answers (Str / Int / Float / List / Not-answerable) — uses
|
||||
``extract_freeform_answer`` instead of ``extract_answer_letter``.
|
||||
* Format-aware grader (see ``.grader``) returns both binary correctness
|
||||
(for accuracy / McNemar) and continuous F1 (for nuanced reporting).
|
||||
* Native arm requires a vision-capable model — we don't enforce this
|
||||
in code (operator's choice via ``setup --provider-model``) but we
|
||||
emit a warning if the pinned slug looks text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, NativePdfArm, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
is_settings_header,
|
||||
)
|
||||
from ....core.metrics.comparison import (
|
||||
bootstrap_delta_ci,
|
||||
mcnemar_test,
|
||||
paired_aggregate,
|
||||
)
|
||||
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci
|
||||
from ....core.parse.freeform_answer import extract_freeform_answer
|
||||
from ....core.providers.openrouter_pdf import OpenRouterPdfProvider, PdfEngine
|
||||
from ....core.registry import (
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
)
|
||||
from ....core.scenarios import format_scenario_md
|
||||
from .grader import GradeResult, grade
|
||||
from .prompt import build_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question + map row shapes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMLBQuestion:
|
||||
qid: str # synthesised from doc_id + index
|
||||
doc_id: str # filename inside the documents/ folder
|
||||
doc_type: str
|
||||
question: str
|
||||
gold_answer: str
|
||||
answer_format: str
|
||||
evidence_pages: list[int]
|
||||
evidence_sources: list[str]
|
||||
pdf_path: Path
|
||||
document_id: int | None # SurfSense doc id (None if upload skipped)
|
||||
|
||||
|
||||
def _load_doc_map(map_path: Path) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
|
||||
"""Read the doc map JSONL.
|
||||
|
||||
Returns ``(rows, settings)`` where ``settings`` is the
|
||||
``__settings__`` header blob (or ``{}`` for legacy maps).
|
||||
"""
|
||||
|
||||
rows: dict[str, dict[str, Any]] = {}
|
||||
settings: dict[str, Any] = {}
|
||||
with map_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
settings = dict(row["__settings__"])
|
||||
continue
|
||||
rows[str(row["doc_id"])] = row
|
||||
return rows, settings
|
||||
|
||||
|
||||
def _load_questions(
|
||||
questions_jsonl: Path,
|
||||
doc_map: dict[str, dict[str, Any]],
|
||||
*,
|
||||
doc_filter: list[str] | None,
|
||||
format_filter: str | None,
|
||||
sample_n: int | None,
|
||||
skip_unanswerable: bool,
|
||||
) -> list[MMLBQuestion]:
|
||||
out: list[MMLBQuestion] = []
|
||||
per_doc_counter: dict[str, int] = {}
|
||||
with questions_jsonl.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
doc_id = str(row.get("doc_id") or "").strip()
|
||||
if not doc_id:
|
||||
continue
|
||||
if doc_filter and doc_id not in doc_filter:
|
||||
continue
|
||||
map_row = doc_map.get(doc_id)
|
||||
if map_row is None:
|
||||
logger.debug("No doc-map entry for %s; skipping", doc_id)
|
||||
continue
|
||||
answer_format = str(row.get("answer_format") or "").strip().lower()
|
||||
if format_filter and format_filter != "all" and format_filter != answer_format:
|
||||
continue
|
||||
gold = str(row.get("answer") or "").strip()
|
||||
if skip_unanswerable and answer_format == "none":
|
||||
continue
|
||||
idx = per_doc_counter.get(doc_id, 0)
|
||||
per_doc_counter[doc_id] = idx + 1
|
||||
out.append(MMLBQuestion(
|
||||
qid=f"{doc_id}::Q{idx:03d}",
|
||||
doc_id=doc_id,
|
||||
doc_type=str(row.get("doc_type") or "").strip(),
|
||||
question=str(row.get("question") or "").strip(),
|
||||
gold_answer=gold,
|
||||
answer_format=answer_format,
|
||||
evidence_pages=list(row.get("evidence_pages") or []),
|
||||
evidence_sources=list(row.get("evidence_sources") or []),
|
||||
pdf_path=Path(map_row["pdf_path"]),
|
||||
document_id=map_row.get("document_id"),
|
||||
))
|
||||
out.sort(key=lambda q: (q.doc_id, q.qid))
|
||||
if sample_n is not None and sample_n > 0:
|
||||
out = out[:sample_n]
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bounded concurrency helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _gather_with_limit(coros: Iterable, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(coro):
|
||||
async with sem:
|
||||
return await coro
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_DESCRIPTION = (
|
||||
"MMLongBench-Doc (135 long PDFs, 1,091 multimodal questions) — "
|
||||
"Native PDF (vision) vs SurfSense (vision RAG) head-to-head."
|
||||
)
|
||||
|
||||
|
||||
_TEXT_ONLY_HINTS = ("gpt-5.4-mini", "gpt-3.5", "text-only", "instruct-")
|
||||
|
||||
# MMLongBench-Doc PDFs are long documents with figures, charts, and
|
||||
# tables. Vision LLM at ingest is the whole point; flip --no-vision-llm
|
||||
# to measure how much SurfSense degrades on real document images.
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=True,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
class MMLongBenchDocBenchmark:
|
||||
"""Long-document multimodal RAG vs native vision."""
|
||||
|
||||
suite: str = "multimodal_doc"
|
||||
name: str = "mmlongbench"
|
||||
headline: bool = True
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--docs",
|
||||
default=None,
|
||||
help="Comma-separated doc_ids (filenames) to run (default: all).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
default="all",
|
||||
choices=["all", "str", "int", "float", "list", "none"],
|
||||
help="Filter to one answer format. 'none' = unanswerable probes only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", dest="sample_n", type=int, default=None,
|
||||
help="Run only the first N questions after filters apply.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-unanswerable", dest="skip_unanswerable", action="store_true",
|
||||
help="Drop ~22%% unanswerable questions (use to compare against baselines that don't include them).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrency", type=int, default=4,
|
||||
help="Parallel question workers per arm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-mentions", dest="no_mentions", action="store_true",
|
||||
help="SurfSense arm: skip mentioned_document_ids (unscoped retrieval).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdf-engine", default="native",
|
||||
choices=[e.value for e in PdfEngine],
|
||||
help="OpenRouter file-parser engine for the native arm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-output-tokens", type=int, default=512,
|
||||
help="Cap on completion length for both arms.",
|
||||
)
|
||||
# Ingest-only knobs (forwarded by the CLI to ingest.run_ingest).
|
||||
parser.add_argument(
|
||||
"--max-docs", dest="max_docs", type=int, default=None,
|
||||
help="(ingest only) cap on number of unique PDFs to download + upload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-batch-size", dest="upload_batch_size", type=int, default=8,
|
||||
help="(ingest only) PDFs per fileupload call.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-upload", dest="skip_upload", action="store_true",
|
||||
help="(ingest only) cache PDFs locally but don't push to SurfSense.",
|
||||
)
|
||||
# Per-upload knobs forwarded to /documents/fileupload at ingest;
|
||||
# ignored at run-time (runner reads the resolved settings out of
|
||||
# the doc-map manifest header).
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
max_docs=opts.get("max_docs"),
|
||||
upload_batch_size=int(opts.get("upload_batch_size") or 8),
|
||||
skip_upload=bool(opts.get("skip_upload", False)),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
docs_raw: str | None = opts.get("docs")
|
||||
doc_filter = [d.strip() for d in docs_raw.split(",")] if docs_raw else None
|
||||
format_filter = opts.get("format") or "all"
|
||||
sample_n = opts.get("sample_n")
|
||||
skip_unanswerable = bool(opts.get("skip_unanswerable"))
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
no_mentions = bool(opts.get("no_mentions"))
|
||||
pdf_engine_name = opts.get("pdf_engine") or "native"
|
||||
max_output_tokens = int(opts.get("max_output_tokens") or 512)
|
||||
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
map_path = ctx.maps_dir() / "mmlongbench_doc_map.jsonl"
|
||||
if not questions_jsonl.exists() or not map_path.exists():
|
||||
raise RuntimeError(
|
||||
"MMLongBench-Doc not ingested for this suite. Run "
|
||||
"`python -m surfsense_evals ingest multimodal_doc mmlongbench` first."
|
||||
)
|
||||
|
||||
doc_map, ingest_settings = _load_doc_map(map_path)
|
||||
questions = _load_questions(
|
||||
questions_jsonl, doc_map,
|
||||
doc_filter=doc_filter,
|
||||
format_filter=None if format_filter == "all" else format_filter,
|
||||
sample_n=sample_n,
|
||||
skip_unanswerable=skip_unanswerable,
|
||||
)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"No MMLongBench questions matched the filters; broaden --docs/--format/--n."
|
||||
)
|
||||
logger.info("MMLongBench-Doc: scheduled %d questions", len(questions))
|
||||
|
||||
api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError(
|
||||
"OPENROUTER_API_KEY env var is required for the native arm."
|
||||
)
|
||||
|
||||
# Native arm slug differs from SurfSense slug only in cost-arbitrage
|
||||
# scenario; otherwise both arms answer with provider_model.
|
||||
native_arm_model = ctx.native_arm_model
|
||||
if any(hint in native_arm_model.lower() for hint in _TEXT_ONLY_HINTS):
|
||||
if ctx.scenario == "symmetric-cheap":
|
||||
logger.info(
|
||||
"symmetric-cheap: native arm pinned to text-only %r as "
|
||||
"intended; expect it to lose on image-bearing pages "
|
||||
"(SurfSense answers from vision-extracted chunks).",
|
||||
native_arm_model,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Native arm slug %r looks text-only; image content in "
|
||||
"PDFs will be ignored. Re-pin via "
|
||||
"`setup --provider-model anthropic/claude-sonnet-4.5` "
|
||||
"(or pass --native-arm-model and --scenario cost-arbitrage "
|
||||
"to make this asymmetry explicit).",
|
||||
native_arm_model,
|
||||
)
|
||||
|
||||
provider = OpenRouterPdfProvider(
|
||||
api_key=api_key,
|
||||
base_url=ctx.config.openrouter_base_url,
|
||||
model=native_arm_model,
|
||||
engine=PdfEngine(pdf_engine_name),
|
||||
)
|
||||
native_arm = NativePdfArm(provider=provider, max_output_tokens=max_output_tokens)
|
||||
surf_arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
|
||||
async def _native_one(q: MMLBQuestion) -> ArmResult:
|
||||
return await native_arm.answer(_make_native_request(q, max_output_tokens))
|
||||
|
||||
async def _surf_one(q: MMLBQuestion) -> ArmResult:
|
||||
return await surf_arm.answer(_make_surfsense_request(q, no_mentions=no_mentions))
|
||||
|
||||
native_results, surf_results = await asyncio.gather(
|
||||
_gather_with_limit((_native_one(q) for q in questions), concurrency=concurrency),
|
||||
_gather_with_limit((_surf_one(q) for q in questions), concurrency=concurrency),
|
||||
)
|
||||
|
||||
native_grades = [_grade_one(q, r) for q, r in zip(questions, native_results, strict=False)]
|
||||
surf_grades = [_grade_one(q, r) for q, r in zip(questions, surf_results, strict=False)]
|
||||
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, n_res, s_res, n_g, s_g in zip(
|
||||
questions, native_results, surf_results, native_grades, surf_grades, strict=False
|
||||
):
|
||||
meta = {
|
||||
"qid": q.qid,
|
||||
"doc_id": q.doc_id,
|
||||
"doc_type": q.doc_type,
|
||||
"answer_format": q.answer_format,
|
||||
"gold": q.gold_answer,
|
||||
"evidence_pages": q.evidence_pages,
|
||||
"evidence_sources": q.evidence_sources,
|
||||
"document_id": q.document_id,
|
||||
}
|
||||
fh.write(json.dumps({
|
||||
**meta,
|
||||
**n_res.to_jsonl(),
|
||||
"graded": _grade_to_jsonl(n_g),
|
||||
}) + "\n")
|
||||
fh.write(json.dumps({
|
||||
**meta,
|
||||
**s_res.to_jsonl(),
|
||||
"graded": _grade_to_jsonl(s_g),
|
||||
}) + "\n")
|
||||
|
||||
metrics = _compute_metrics(questions, native_results, surf_results, native_grades, surf_grades)
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_questions": len(questions),
|
||||
"concurrency": concurrency,
|
||||
"format_filter": format_filter,
|
||||
"skip_unanswerable": skip_unanswerable,
|
||||
"no_mentions": no_mentions,
|
||||
"pdf_engine": pdf_engine_name,
|
||||
"scenario": ctx.scenario,
|
||||
"provider_model": ctx.provider_model,
|
||||
"native_arm_model": native_arm_model,
|
||||
"vision_provider_model": ctx.vision_provider_model,
|
||||
"agent_llm_id": ctx.agent_llm_id,
|
||||
"ingest_settings": ingest_settings,
|
||||
},
|
||||
)
|
||||
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps({
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
}, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="MMLongBench-Doc — Native PDF (vision) vs SurfSense (vision RAG)",
|
||||
headline=True,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
m = latest.metrics
|
||||
native = m.get("native", {})
|
||||
surf = m.get("surfsense", {})
|
||||
delta = m.get("delta", {})
|
||||
per_format = m.get("per_format", {})
|
||||
extra = latest.extra
|
||||
|
||||
body_lines: list[str] = []
|
||||
body_lines.append(
|
||||
f"- Sample size: {extra.get('n_questions', '?')} questions "
|
||||
f"(format filter: `{extra.get('format_filter', 'all')}`, "
|
||||
f"skip-unanswerable: `{extra.get('skip_unanswerable', False)}`, "
|
||||
f"engine: `{extra.get('pdf_engine', 'native')}`)."
|
||||
)
|
||||
body_lines.append(format_scenario_md(extra))
|
||||
body_lines.append(format_ingest_settings_md(extra.get("ingest_settings")))
|
||||
body_lines.append(
|
||||
"- Native arm (OpenRouter `chat/completions` + file plugin, "
|
||||
f"`{extra.get('native_arm_model') or extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(native, indent=" "))
|
||||
body_lines.append(
|
||||
"- SurfSense arm (`POST /api/v1/new_chat`, vision RAG over chunks, "
|
||||
f"`{extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(surf, indent=" "))
|
||||
body_lines.append("- Delta (paired):")
|
||||
body_lines.append(
|
||||
f" - Accuracy: SurfSense {_pp(delta.get('accuracy_pp'))} pp "
|
||||
f"(McNemar p={_fmt(delta.get('mcnemar_p_value'), 4)}, "
|
||||
f"method={delta.get('mcnemar_method')})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - F1 (mean): SurfSense {_pp(delta.get('f1_pp'))} pp"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Bootstrap 95% CI on accuracy delta: "
|
||||
f"[{_pp(delta.get('bootstrap_ci_low'))}pp, {_pp(delta.get('bootstrap_ci_high'))}pp]"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Cost / question: native ${_dollars(native.get('cost_micros_mean'))}, "
|
||||
f"surfsense ${_dollars(surf.get('cost_micros_mean'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('cost_micros_pct'))})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Latency p50: native {_ms_to_s(native.get('latency_ms_median'))}, "
|
||||
f"surfsense {_ms_to_s(surf.get('latency_ms_median'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('latency_ms_pct'))})"
|
||||
)
|
||||
if per_format:
|
||||
body_lines.append("- Per-format split (accuracy delta in pp):")
|
||||
for fmt, vals in sorted(per_format.items()):
|
||||
body_lines.append(
|
||||
f" - {fmt}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
|
||||
f"(n={vals.get('n')}, native acc={vals.get('native_accuracy', 0)*100:.1f}%, "
|
||||
f"surf acc={vals.get('surfsense_accuracy', 0)*100:.1f}%)"
|
||||
)
|
||||
|
||||
return ReportSection(
|
||||
title="MMLongBench-Doc — Native PDF (vision) vs SurfSense (vision RAG)",
|
||||
headline=True,
|
||||
body_md="\n".join(body_lines),
|
||||
body_json=m,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-question helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_native_request(q: MMLBQuestion, max_tokens: int) -> ArmRequest:
|
||||
prompt = build_prompt(q.question, answer_format=q.answer_format)
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=prompt,
|
||||
pdf_paths=[q.pdf_path],
|
||||
options={"max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
|
||||
def _make_surfsense_request(q: MMLBQuestion, *, no_mentions: bool) -> ArmRequest:
|
||||
prompt = build_prompt(q.question, answer_format=q.answer_format)
|
||||
mentions: list[int] | None = None
|
||||
if not no_mentions and q.document_id is not None:
|
||||
mentions = [int(q.document_id)]
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=prompt,
|
||||
mentioned_document_ids=mentions,
|
||||
)
|
||||
|
||||
|
||||
def _grade_one(q: MMLBQuestion, result: ArmResult) -> GradeResult:
|
||||
pred_text = extract_freeform_answer(result.raw_text or "")
|
||||
return grade(pred=pred_text, gold=q.gold_answer, answer_format=q.answer_format)
|
||||
|
||||
|
||||
def _grade_to_jsonl(g: GradeResult) -> dict[str, Any]:
|
||||
return {
|
||||
"correct": g.correct,
|
||||
"f1": g.f1,
|
||||
"method": g.method,
|
||||
"normalised_pred": g.normalised_pred,
|
||||
"normalised_gold": g.normalised_gold,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics aggregation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_metrics(
|
||||
questions: list[MMLBQuestion],
|
||||
native_results: list[ArmResult],
|
||||
surf_results: list[ArmResult],
|
||||
native_grades: list[GradeResult],
|
||||
surf_grades: list[GradeResult],
|
||||
) -> dict[str, Any]:
|
||||
native_correct = [g.correct for g in native_grades]
|
||||
surf_correct = [g.correct for g in surf_grades]
|
||||
native_f1 = [g.f1 for g in native_grades]
|
||||
surf_f1 = [g.f1 for g in surf_grades]
|
||||
|
||||
native_costs = [float(r.cost_micros) for r in native_results]
|
||||
surf_costs = [float(r.cost_micros) for r in surf_results]
|
||||
native_latencies = [float(r.latency_ms) for r in native_results]
|
||||
surf_latencies = [float(r.latency_ms) for r in surf_results]
|
||||
native_in_tokens = [float(r.input_tokens) for r in native_results]
|
||||
native_out_tokens = [float(r.output_tokens) for r in native_results]
|
||||
|
||||
native_acc = accuracy_with_wilson_ci(sum(native_correct), len(native_correct))
|
||||
surf_acc = accuracy_with_wilson_ci(sum(surf_correct), len(surf_correct))
|
||||
mc = mcnemar_test(native_correct, surf_correct)
|
||||
boot = bootstrap_delta_ci(native_correct, surf_correct, n_resamples=2000)
|
||||
|
||||
native_cost_agg = paired_aggregate(native_costs)
|
||||
surf_cost_agg = paired_aggregate(surf_costs)
|
||||
native_latency_agg = paired_aggregate(native_latencies)
|
||||
surf_latency_agg = paired_aggregate(surf_latencies)
|
||||
|
||||
cost_pct = _safe_pct(surf_cost_agg.mean, native_cost_agg.mean)
|
||||
latency_pct = _safe_pct(surf_latency_agg.median, native_latency_agg.median)
|
||||
|
||||
per_format_pairs: dict[str, list[tuple[bool, bool]]] = {}
|
||||
for q, n_ok, s_ok in zip(questions, native_correct, surf_correct, strict=False):
|
||||
per_format_pairs.setdefault(q.answer_format or "unknown", []).append((n_ok, s_ok))
|
||||
|
||||
per_format: dict[str, dict[str, Any]] = {}
|
||||
for fmt, pairs in per_format_pairs.items():
|
||||
n_correct = [a for a, _ in pairs]
|
||||
s_correct = [b for _, b in pairs]
|
||||
per_format[fmt] = {
|
||||
"n": len(pairs),
|
||||
"native_accuracy": (sum(n_correct) / len(pairs)) if pairs else 0.0,
|
||||
"surfsense_accuracy": (sum(s_correct) / len(pairs)) if pairs else 0.0,
|
||||
"delta_accuracy_pp": (
|
||||
100.0 * (sum(s_correct) - sum(n_correct)) / len(pairs)
|
||||
if pairs else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
native_f1_mean = sum(native_f1) / len(native_f1) if native_f1 else 0.0
|
||||
surf_f1_mean = sum(surf_f1) / len(surf_f1) if surf_f1 else 0.0
|
||||
|
||||
return {
|
||||
"native": {
|
||||
**native_acc.to_dict(),
|
||||
"f1_mean": native_f1_mean,
|
||||
"cost_micros_mean": native_cost_agg.mean,
|
||||
"cost_micros_median": native_cost_agg.median,
|
||||
"latency_ms_mean": native_latency_agg.mean,
|
||||
"latency_ms_median": native_latency_agg.median,
|
||||
"latency_ms_p95": native_latency_agg.p95,
|
||||
"input_tokens_mean": (sum(native_in_tokens) / len(native_in_tokens)) if native_in_tokens else 0.0,
|
||||
"output_tokens_mean": (sum(native_out_tokens) / len(native_out_tokens)) if native_out_tokens else 0.0,
|
||||
},
|
||||
"surfsense": {
|
||||
**surf_acc.to_dict(),
|
||||
"f1_mean": surf_f1_mean,
|
||||
"cost_micros_mean": surf_cost_agg.mean,
|
||||
"cost_micros_median": surf_cost_agg.median,
|
||||
"latency_ms_mean": surf_latency_agg.mean,
|
||||
"latency_ms_median": surf_latency_agg.median,
|
||||
"latency_ms_p95": surf_latency_agg.p95,
|
||||
},
|
||||
"delta": {
|
||||
"accuracy_pp": 100.0 * (surf_acc.accuracy - native_acc.accuracy),
|
||||
"f1_pp": 100.0 * (surf_f1_mean - native_f1_mean),
|
||||
"mcnemar_p_value": mc.p_value,
|
||||
"mcnemar_method": mc.method,
|
||||
"mcnemar_b_native_only": mc.b,
|
||||
"mcnemar_c_surfsense_only": mc.c,
|
||||
"bootstrap_ci_low": 100.0 * boot.ci_low,
|
||||
"bootstrap_ci_high": 100.0 * boot.ci_high,
|
||||
"cost_micros_pct": cost_pct,
|
||||
"latency_ms_pct": latency_pct,
|
||||
},
|
||||
"per_format": per_format,
|
||||
}
|
||||
|
||||
|
||||
def _safe_pct(numerator: float, denominator: float) -> float | None:
|
||||
if denominator == 0:
|
||||
return None
|
||||
return 100.0 * (numerator - denominator) / denominator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tiny formatting helpers used by report_section
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _arm_summary_lines(d: dict[str, Any], *, indent: str) -> str:
|
||||
if not d:
|
||||
return f"{indent}(no data)"
|
||||
acc = d.get("accuracy", 0.0)
|
||||
low = d.get("ci_low", 0.0)
|
||||
high = d.get("ci_high", 0.0)
|
||||
f1 = d.get("f1_mean", 0.0)
|
||||
lines = [
|
||||
f"{indent}- Accuracy: {acc * 100:.1f}% (Wilson 95% CI: {low * 100:.1f}% – {high * 100:.1f}%)",
|
||||
f"{indent}- F1 (token-level mean): {f1 * 100:.1f}%",
|
||||
f"{indent}- Cost / question: ${_dollars(d.get('cost_micros_mean'))} (mean), "
|
||||
f"${_dollars(d.get('cost_micros_median'))} (median)",
|
||||
f"{indent}- Latency: p50 {_ms_to_s(d.get('latency_ms_median'))}, "
|
||||
f"p95 {_ms_to_s(d.get('latency_ms_p95'))}",
|
||||
]
|
||||
if "input_tokens_mean" in d:
|
||||
lines.append(
|
||||
f"{indent}- Mean tokens / question: in {d.get('input_tokens_mean', 0):.0f}, "
|
||||
f"out {d.get('output_tokens_mean', 0):.0f}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _dollars(micros: Any) -> str:
|
||||
if micros is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{(float(micros) / 1_000_000):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _ms_to_s(ms: Any) -> str:
|
||||
if ms is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(ms) / 1000:.1f}s"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pp(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.1f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pct_change(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.0f}%"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _fmt(value: Any, ndigits: int) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):.{ndigits}f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
__all__ = ["MMLBQuestion", "MMLongBenchDocBenchmark"]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""Research / multi-document RAG benchmarks.
|
||||
|
||||
Distinct from ``multimodal_doc`` (PDF-bound) and ``medical`` (one
|
||||
question = one source PDF). Benchmarks here put *retrieval and
|
||||
reasoning across many documents* in the critical path — the regime
|
||||
where SurfSense's chunk-level RAG should shine versus "pour the
|
||||
entire document into the LLM" or "ask the LLM cold".
|
||||
|
||||
* ``frames`` (google/frames-benchmark) — 824 multi-hop Wikipedia
|
||||
questions; tests bare-LLM vs SurfSense over a shared ~330-doc
|
||||
corpus.
|
||||
* ``crag`` (facebookresearch/CRAG, KDD Cup 2024) — 2,706 web QA
|
||||
pairs with 5 pre-retrieved HTML pages each; tests bare-LLM vs
|
||||
long-context-stuffed LLM vs SurfSense over the question's 5
|
||||
scoped pages — the closest comparison to a competing RAG product.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
"""CRAG — Comprehensive RAG Benchmark (Yang et al., Meta, KDD Cup 2024).
|
||||
|
||||
Source: https://github.com/facebookresearch/CRAG (Tasks 1, 2, and 3)
|
||||
Paper: https://arxiv.org/abs/2406.04744
|
||||
|
||||
This package registers two siblings:
|
||||
|
||||
* ``crag`` — Tasks 1 & 2: 5 candidate pages per question.
|
||||
* ``crag_t3`` — Task 3: 50 candidate pages per question. The
|
||||
long-context arm is capped to the top-5 (the realistic "naive
|
||||
RAG = pick top-K results" baseline); SurfSense retrieves over
|
||||
all 50, where its rerank becomes the entire contribution.
|
||||
|
||||
Both share the grader, prompt, runner, and report code; only the
|
||||
ingest path differs (single bz2 vs 4-part tar.bz2 streamed).
|
||||
|
||||
CRAG ships ~2,706 factual QA pairs, each paired with **5 full HTML
|
||||
pages** retrieved as the top-5 of a real web search at ``query_time``
|
||||
(50 in Task 3).
|
||||
The benchmark spans 5 domains (finance, music, movie, sports, open)
|
||||
and 8 question types (simple, comparison, aggregation, set, multi-hop,
|
||||
post-processing, false_premise, simple_w_condition) — heads/torsos/
|
||||
tails of entity popularity — and an explicit static→real-time
|
||||
freshness axis.
|
||||
|
||||
Why CRAG demonstrates SurfSense more clearly than FRAMES
|
||||
--------------------------------------------------------
|
||||
FRAMES tested SurfSense vs. *no retrieval at all* — a fair "naive
|
||||
prompting" baseline (the published 40.8% number) but not a competing
|
||||
RAG product. CRAG enables a three-way comparison:
|
||||
|
||||
* ``bare_llm`` — chat completion with the question only. CRAG
|
||||
paper: ≤34% accuracy ("LLM cold").
|
||||
* ``long_context`` — stuff all 5 extracted page texts straight into
|
||||
the prompt (the "naive RAG" / "straightforward RAG" arm in the
|
||||
paper). Published baseline: ~44%.
|
||||
* ``surfsense`` — POST ``/api/v1/new_chat`` with retrieval scoped
|
||||
to the question's 5 ingested pages (``mentioned_document_ids``).
|
||||
|
||||
So the headline becomes "SurfSense vs. context-stuffed long-context
|
||||
LLM, both fed the same 5 pages" — a head-to-head against the simplest
|
||||
realistic RAG strategy, not against an unarmed model.
|
||||
|
||||
Scoring follows the CRAG paper: each prediction is graded as
|
||||
**correct** (+1), **missing/I-don't-know** (0), or **incorrect** (-1),
|
||||
and the headline metric is the *Truthfulness Score*:
|
||||
``(#correct - #incorrect) / total`` — penalising hallucinations
|
||||
relative to refusals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import registry as _registry
|
||||
from .runner import CragBenchmark, CragTask3Benchmark
|
||||
|
||||
_registry.register(CragBenchmark())
|
||||
_registry.register(CragTask3Benchmark())
|
||||
|
|
@ -0,0 +1,335 @@
|
|||
"""CRAG dataset loader — download ``crag_task_1_and_2_dev_v4.jsonl.bz2`` and parse.
|
||||
|
||||
The CRAG repo (``facebookresearch/CRAG``) ships Tasks 1 & 2 as a
|
||||
single bzip2-compressed JSONL on GitHub raw. Each row carries:
|
||||
|
||||
* ``interaction_id`` — opaque per-question id (we keep verbatim)
|
||||
* ``query_time`` — wall clock of the original web search
|
||||
* ``domain`` — finance | music | movie | sports | open
|
||||
* ``question_type`` — simple | comparison | aggregation | set |
|
||||
multi-hop | post-processing | false_premise |
|
||||
simple_w_condition
|
||||
* ``static_or_dynamic`` — static | slow-changing | fast-changing | real-time
|
||||
* ``query`` — the question
|
||||
* ``answer`` — gold short answer
|
||||
* ``alt_ans`` — list[str] of alternative valid answers
|
||||
(paraphrases / synonyms / unit variants)
|
||||
* ``split`` — 0 = validation, 1 = public test
|
||||
* ``popularity`` — head | torso | tail (KG questions); empty for web
|
||||
* ``search_results`` — list of up to 5 ``{page_name, page_url,
|
||||
page_snippet, page_result, page_last_modified}``;
|
||||
``page_result`` is full HTML.
|
||||
|
||||
We materialise this into ``CragQuestion`` objects keeping ``pages`` as
|
||||
a list of ``CragPage`` so downstream ingest can save each as its own
|
||||
file and SurfSense can dedupe on filename.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bz2
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import urllib.request
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tasks 1 & 2 share the same JSONL on the public CRAG repo.
|
||||
CRAG_TASK_1_2_URL = (
|
||||
"https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/"
|
||||
"crag_task_1_and_2_dev_v4.jsonl.bz2"
|
||||
)
|
||||
CRAG_TASK_1_2_FILENAME = "crag_task_1_and_2_dev_v4.jsonl.bz2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question / page dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragPage:
|
||||
"""One of the up-to-5 pre-retrieved web pages for a CRAG question."""
|
||||
|
||||
page_name: str
|
||||
page_url: str
|
||||
page_snippet: str
|
||||
page_html: str
|
||||
page_last_modified: str | None = None
|
||||
|
||||
@property
|
||||
def url_hash(self) -> str:
|
||||
"""Stable 12-hex digest of the page URL for filename keys.
|
||||
|
||||
We can't use the raw URL as a filename (slashes, query strings,
|
||||
unicode), and we *do* want collision-safety across the whole
|
||||
ingest sample. ``sha1[:12]`` gives us 48 bits of namespace
|
||||
which is overkill for a corpus capped at a few thousand pages.
|
||||
"""
|
||||
|
||||
return hashlib.sha1(self.page_url.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragQuestion:
|
||||
"""One row of CRAG (Tasks 1 & 2)."""
|
||||
|
||||
qid: str # synthesised "C00000".."C02705"
|
||||
interaction_id: str
|
||||
query_time: str
|
||||
query: str
|
||||
gold_answer: str
|
||||
alt_answers: list[str]
|
||||
domain: str
|
||||
question_type: str
|
||||
static_or_dynamic: str
|
||||
popularity: str # may be "" for web-sourced questions
|
||||
split: int # 0=validation, 1=public_test
|
||||
raw_index: int # row index in the source JSONL
|
||||
pages: list[CragPage] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"qid": self.qid,
|
||||
"interaction_id": self.interaction_id,
|
||||
"query_time": self.query_time,
|
||||
"query": self.query,
|
||||
"gold_answer": self.gold_answer,
|
||||
"alt_answers": list(self.alt_answers),
|
||||
"domain": self.domain,
|
||||
"question_type": self.question_type,
|
||||
"static_or_dynamic": self.static_or_dynamic,
|
||||
"popularity": self.popularity,
|
||||
"split": self.split,
|
||||
"raw_index": self.raw_index,
|
||||
"n_pages": len(self.pages),
|
||||
"page_urls": [p.page_url for p in self.pages],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download + decompress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def download_task_1_2(cache_dir: Path) -> Path:
|
||||
"""Download the bz2 archive into ``cache_dir`` (skip if cached).
|
||||
|
||||
Returns the path to the local ``.jsonl.bz2``. We use stdlib
|
||||
``urllib`` rather than ``httpx`` to keep the download synchronous
|
||||
and trivially resumable (re-running the function is a no-op once
|
||||
the file is on disk and non-empty).
|
||||
"""
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = cache_dir / CRAG_TASK_1_2_FILENAME
|
||||
if dest.exists() and dest.stat().st_size > 0:
|
||||
logger.debug("CRAG bz2 already cached at %s", dest)
|
||||
return dest
|
||||
|
||||
logger.info("Downloading CRAG (Tasks 1 & 2) from %s ...", CRAG_TASK_1_2_URL)
|
||||
tmp = dest.with_suffix(dest.suffix + ".part")
|
||||
req = urllib.request.Request(
|
||||
CRAG_TASK_1_2_URL,
|
||||
headers={"User-Agent": "SurfSense-Evals/0.1 (CRAG dataset fetch)"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=600) as response, tmp.open("wb") as fh:
|
||||
chunk = response.read(1 << 20)
|
||||
while chunk:
|
||||
fh.write(chunk)
|
||||
chunk = response.read(1 << 20)
|
||||
tmp.replace(dest)
|
||||
logger.info("CRAG bz2 downloaded: %s (%.1f MiB)", dest, dest.stat().st_size / 1024 / 1024)
|
||||
return dest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parse
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_pages(raw_search_results: Any) -> list[CragPage]:
|
||||
if not isinstance(raw_search_results, list):
|
||||
return []
|
||||
pages: list[CragPage] = []
|
||||
for entry in raw_search_results:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
url = str(entry.get("page_url") or "").strip()
|
||||
html = str(entry.get("page_result") or "")
|
||||
if not url or not html.strip():
|
||||
# No URL or empty HTML => useless for retrieval.
|
||||
continue
|
||||
pages.append(CragPage(
|
||||
page_name=str(entry.get("page_name") or "").strip(),
|
||||
page_url=url,
|
||||
page_snippet=str(entry.get("page_snippet") or "").strip(),
|
||||
page_html=html,
|
||||
page_last_modified=(
|
||||
str(entry.get("page_last_modified")).strip()
|
||||
if entry.get("page_last_modified") else None
|
||||
),
|
||||
))
|
||||
return pages
|
||||
|
||||
|
||||
def _parse_alt_answers(raw: Any) -> list[str]:
|
||||
if isinstance(raw, list):
|
||||
return [str(x).strip() for x in raw if str(x).strip()]
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return [raw.strip()]
|
||||
return []
|
||||
|
||||
|
||||
def iter_questions(jsonl_bz2_path: Path) -> list[CragQuestion]:
|
||||
"""Stream-decompress + parse the CRAG JSONL into ``CragQuestion`` objects.
|
||||
|
||||
The bz2 expansion ratio is ~10x and the decompressed file is
|
||||
multi-GB; we therefore decompress *line by line* via
|
||||
``bz2.open(..., "rt")``. Each row is a single (potentially very
|
||||
large, due to embedded HTML) JSON object. We keep the entire row
|
||||
in memory because we materialise the pages to disk immediately
|
||||
after parsing in the ingest pipeline — the runner never holds
|
||||
more than the current sample's worth of HTML.
|
||||
"""
|
||||
|
||||
out: list[CragQuestion] = []
|
||||
with bz2.open(jsonl_bz2_path, mode="rt", encoding="utf-8") as fh:
|
||||
for raw_idx, line in enumerate(fh):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning("Skipping malformed CRAG row %d: %s", raw_idx, exc)
|
||||
continue
|
||||
query = str(row.get("query") or "").strip()
|
||||
answer = str(row.get("answer") or "").strip()
|
||||
if not query or not answer:
|
||||
logger.debug("Skipping CRAG row %d with missing query/answer", raw_idx)
|
||||
continue
|
||||
interaction_id = str(row.get("interaction_id") or "").strip()
|
||||
pages = _parse_pages(row.get("search_results"))
|
||||
out.append(CragQuestion(
|
||||
qid=f"C{raw_idx:05d}",
|
||||
interaction_id=interaction_id,
|
||||
query_time=str(row.get("query_time") or "").strip(),
|
||||
query=query,
|
||||
gold_answer=answer,
|
||||
alt_answers=_parse_alt_answers(row.get("alt_ans")),
|
||||
domain=str(row.get("domain") or "").strip().lower(),
|
||||
question_type=str(row.get("question_type") or "").strip().lower(),
|
||||
static_or_dynamic=str(row.get("static_or_dynamic") or "").strip().lower(),
|
||||
popularity=str(row.get("popularity") or "").strip().lower(),
|
||||
split=int(row.get("split") or 0),
|
||||
raw_index=raw_idx,
|
||||
pages=pages,
|
||||
))
|
||||
return out
|
||||
|
||||
|
||||
def stratified_sample(
|
||||
questions: list[CragQuestion],
|
||||
*,
|
||||
n: int,
|
||||
seed: int = 17,
|
||||
) -> list[CragQuestion]:
|
||||
"""Take ``n`` questions that roughly preserve the domain × question-type mix.
|
||||
|
||||
CRAG is only ~2.7k rows so naive head-of-list sampling badly
|
||||
over-weights ``finance`` (because the dataset isn't shuffled by
|
||||
domain). We bucket on ``(domain, question_type)`` and round-robin
|
||||
pick from each bucket until we hit ``n`` — this gives every
|
||||
bucket a fair shot and keeps the sample composition stable across
|
||||
re-runs (deterministic via the seeded shuffle inside each bucket).
|
||||
"""
|
||||
|
||||
if n <= 0 or n >= len(questions):
|
||||
return list(questions)
|
||||
import random
|
||||
|
||||
rng = random.Random(seed)
|
||||
buckets: dict[tuple[str, str], list[CragQuestion]] = {}
|
||||
for q in questions:
|
||||
buckets.setdefault((q.domain, q.question_type), []).append(q)
|
||||
for items in buckets.values():
|
||||
rng.shuffle(items)
|
||||
|
||||
keys = sorted(buckets.keys())
|
||||
chosen: list[CragQuestion] = []
|
||||
cursor = 0
|
||||
while len(chosen) < n and any(buckets[k] for k in keys):
|
||||
key = keys[cursor % len(keys)]
|
||||
cursor += 1
|
||||
if buckets[key]:
|
||||
chosen.append(buckets[key].pop())
|
||||
chosen.sort(key=lambda q: q.raw_index)
|
||||
return chosen
|
||||
|
||||
|
||||
def write_questions_jsonl(questions: list[CragQuestion], dest: Path) -> None:
|
||||
"""Persist a parsed copy (without page HTML) under the benchmark data dir."""
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
with dest.open("w", encoding="utf-8") as fh:
|
||||
for q in questions:
|
||||
fh.write(json.dumps(q.to_dict()) + "\n")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reading the lightweight questions.jsonl back
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_questions_jsonl(path: Path) -> list[dict[str, Any]]:
|
||||
"""Re-load the lightweight (no-HTML) questions JSONL from disk."""
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
if not path.exists():
|
||||
return out
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
out.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience: decompress a snippet to memory for tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def decompress_to_memory(jsonl_bz2_path: Path) -> io.StringIO:
|
||||
"""For tests / one-off scripts: read the whole bz2 into a StringIO.
|
||||
|
||||
Avoids leaking gigabytes; use ``iter_questions`` in production.
|
||||
"""
|
||||
|
||||
with bz2.open(jsonl_bz2_path, mode="rb") as fh:
|
||||
return io.StringIO(fh.read().decode("utf-8"))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CRAG_TASK_1_2_FILENAME",
|
||||
"CRAG_TASK_1_2_URL",
|
||||
"CragPage",
|
||||
"CragQuestion",
|
||||
"decompress_to_memory",
|
||||
"download_task_1_2",
|
||||
"iter_questions",
|
||||
"load_questions_jsonl",
|
||||
"stratified_sample",
|
||||
"write_questions_jsonl",
|
||||
]
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
"""CRAG Task 3 dataset loader — 4-part tar.bz2 → streaming JSONL.
|
||||
|
||||
Task 3 ships ~7 GB of compressed data split into 4 parts on GitHub:
|
||||
|
||||
crag_task_3_dev_v4.tar.bz2.part1 (≈2 GB)
|
||||
crag_task_3_dev_v4.tar.bz2.part2 (≈2 GB)
|
||||
crag_task_3_dev_v4.tar.bz2.part3 (≈2 GB)
|
||||
crag_task_3_dev_v4.tar.bz2.part4 (≈1.3 GB)
|
||||
|
||||
Concatenated, they form a tar archive containing a single JSONL file.
|
||||
Decompressed, that JSONL is on the order of 30-50 GB because each row
|
||||
embeds 50 full HTML pages (vs 5 in Tasks 1 & 2).
|
||||
|
||||
Materialising the JSONL would blow the disk budget (we have ~50 GB
|
||||
free at the time of writing), so we stream the whole thing instead:
|
||||
|
||||
1. Download parts (idempotent; ``scripts/download_crag_task3.py``).
|
||||
2. Concat them into a virtual file via ``_MultiPartReader``.
|
||||
3. Wrap in ``bz2.BZ2File`` for on-the-fly decompression.
|
||||
4. Wrap in ``tarfile.open(fileobj=..., mode="r|")`` for streaming
|
||||
tar member iteration.
|
||||
5. For the JSONL member inside, ``tar.extractfile()`` returns a
|
||||
binary file-like; we iterate lines and yield parsed dicts.
|
||||
|
||||
The caller can ``break`` out as soon as they have enough samples —
|
||||
nothing past the consumed point is decompressed.
|
||||
|
||||
Schema is identical to Tasks 1 & 2 (see ``dataset.py``); only
|
||||
``search_results`` is bigger (50 entries instead of 5).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bz2
|
||||
import json
|
||||
import logging
|
||||
import tarfile
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import IO
|
||||
|
||||
from .dataset import (
|
||||
CragPage,
|
||||
CragQuestion,
|
||||
_parse_alt_answers,
|
||||
_parse_pages,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CRAG_TASK_3_PART_URLS: tuple[str, ...] = tuple(
|
||||
"https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/"
|
||||
f"crag_task_3_dev_v4.tar.bz2.part{i}"
|
||||
for i in (1, 2, 3, 4)
|
||||
)
|
||||
CRAG_TASK_3_PART_NAMES: tuple[str, ...] = tuple(
|
||||
f"crag_task_3_dev_v4.tar.bz2.part{i}" for i in (1, 2, 3, 4)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-part virtual file (concatenates N files transparently)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _MultiPartReader:
|
||||
"""Read N files end-to-end as if they were one big file.
|
||||
|
||||
Implements just enough of the file protocol for ``bz2.BZ2File``
|
||||
to consume it: ``read(n)``, ``readable()``, ``close()``.
|
||||
Doesn't implement ``seek`` — the bz2 + tarfile streaming path
|
||||
is forward-only, which is what we want here.
|
||||
"""
|
||||
|
||||
def __init__(self, paths: list[Path]) -> None:
|
||||
if not paths:
|
||||
raise ValueError("_MultiPartReader needs at least one path")
|
||||
for p in paths:
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(p)
|
||||
self._paths = list(paths)
|
||||
self._idx = 0
|
||||
self._fh: IO[bytes] | None = self._paths[0].open("rb")
|
||||
self._closed = False
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
if self._closed:
|
||||
raise ValueError("read of closed _MultiPartReader")
|
||||
if n is None or n < 0:
|
||||
chunks: list[bytes] = []
|
||||
while self._fh is not None:
|
||||
chunks.append(self._fh.read())
|
||||
self._advance()
|
||||
return b"".join(chunks)
|
||||
out: list[bytes] = []
|
||||
remaining = n
|
||||
while remaining > 0 and self._fh is not None:
|
||||
chunk = self._fh.read(remaining)
|
||||
if not chunk:
|
||||
self._advance()
|
||||
continue
|
||||
out.append(chunk)
|
||||
remaining -= len(chunk)
|
||||
return b"".join(out)
|
||||
|
||||
def _advance(self) -> None:
|
||||
if self._fh is not None:
|
||||
self._fh.close()
|
||||
self._fh = None
|
||||
self._idx += 1
|
||||
if self._idx < len(self._paths):
|
||||
self._fh = self._paths[self._idx].open("rb")
|
||||
|
||||
def readable(self) -> bool:
|
||||
return not self._closed
|
||||
|
||||
def close(self) -> None:
|
||||
if self._fh is not None:
|
||||
self._fh.close()
|
||||
self._fh = None
|
||||
self._closed = True
|
||||
|
||||
def __enter__(self) -> _MultiPartReader:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[no-untyped-def]
|
||||
self.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stream the JSONL inside the tar.bz2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _is_jsonl_member(name: str) -> bool:
|
||||
return name.endswith(".jsonl") or name.endswith(".jsonl.txt")
|
||||
|
||||
|
||||
def iter_questions_task3(
|
||||
parts_dir: Path,
|
||||
*,
|
||||
max_questions: int | None = None,
|
||||
) -> list[CragQuestion]:
|
||||
"""Stream-parse Task 3 rows into ``CragQuestion`` objects.
|
||||
|
||||
The Task 3 archive ships its 2,706 questions sharded across
|
||||
multiple JSONL files inside the tar (e.g.
|
||||
``crag_task_3_dev_v4_0.jsonl``, ``..._1.jsonl``, …). We iterate
|
||||
members in-stream, parse every JSONL one we encounter, and stop
|
||||
as soon as ``max_questions`` is reached — at which point we
|
||||
don't decompress any further members.
|
||||
|
||||
For a typical n=50 sample at ~3 MB per row we touch ~150 MB of
|
||||
decompressed JSONL — almost always inside the first shard.
|
||||
"""
|
||||
|
||||
parts = [parts_dir / name for name in CRAG_TASK_3_PART_NAMES]
|
||||
multi = _MultiPartReader(parts)
|
||||
bz = bz2.BZ2File(multi, mode="rb")
|
||||
tar = tarfile.open(fileobj=bz, mode="r|")
|
||||
out: list[CragQuestion] = []
|
||||
raw_idx = 0
|
||||
found_jsonl = False
|
||||
try:
|
||||
for member in tar:
|
||||
if not member.isfile() or not _is_jsonl_member(member.name):
|
||||
continue
|
||||
found_jsonl = True
|
||||
logger.info(
|
||||
"CRAG Task 3: streaming JSONL shard %s (size: %d bytes)",
|
||||
member.name, member.size,
|
||||
)
|
||||
fh = tar.extractfile(member)
|
||||
if fh is None:
|
||||
logger.warning("tar.extractfile returned None for %s; skipping", member.name)
|
||||
continue
|
||||
try:
|
||||
for raw_line in fh:
|
||||
line = raw_line.decode("utf-8", errors="replace").strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning(
|
||||
"Skipping malformed CRAG Task 3 row %d in %s: %s",
|
||||
raw_idx, member.name, exc,
|
||||
)
|
||||
raw_idx += 1
|
||||
continue
|
||||
query = str(row.get("query") or "").strip()
|
||||
answer = str(row.get("answer") or "").strip()
|
||||
if not query or not answer:
|
||||
raw_idx += 1
|
||||
continue
|
||||
out.append(CragQuestion(
|
||||
qid=f"T3_{raw_idx:05d}",
|
||||
interaction_id=str(row.get("interaction_id") or "").strip(),
|
||||
query_time=str(row.get("query_time") or "").strip(),
|
||||
query=query,
|
||||
gold_answer=answer,
|
||||
alt_answers=_parse_alt_answers(row.get("alt_ans")),
|
||||
domain=str(row.get("domain") or "").strip().lower(),
|
||||
question_type=str(row.get("question_type") or "").strip().lower(),
|
||||
static_or_dynamic=str(row.get("static_or_dynamic") or "").strip().lower(),
|
||||
popularity=str(row.get("popularity") or "").strip().lower(),
|
||||
split=int(row.get("split") or 0),
|
||||
raw_index=raw_idx,
|
||||
pages=_parse_pages(row.get("search_results")),
|
||||
))
|
||||
raw_idx += 1
|
||||
if max_questions is not None and len(out) >= max_questions:
|
||||
return out
|
||||
finally:
|
||||
try:
|
||||
fh.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
if not found_jsonl:
|
||||
raise RuntimeError(
|
||||
"No JSONL member found inside Task 3 tar.bz2 archive; "
|
||||
"schema may have changed upstream."
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
tar.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
try:
|
||||
bz.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
try:
|
||||
multi.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
return out
|
||||
|
||||
|
||||
def parts_present(parts_dir: Path) -> bool:
|
||||
"""``True`` iff all 4 parts exist on disk and are non-empty."""
|
||||
|
||||
for name in CRAG_TASK_3_PART_NAMES:
|
||||
p = parts_dir / name
|
||||
if not p.exists() or p.stat().st_size == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Re-exports for convenience
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CRAG_TASK_3_PART_NAMES",
|
||||
"CRAG_TASK_3_PART_URLS",
|
||||
"CragPage",
|
||||
"CragQuestion",
|
||||
"iter_questions_task3",
|
||||
"parts_present",
|
||||
]
|
||||
|
|
@ -0,0 +1,540 @@
|
|||
"""CRAG 3-class grader: ``correct`` (+1) / ``missing`` (0) / ``incorrect`` (-1).
|
||||
|
||||
The CRAG paper's headline metric is the **Truthfulness Score**:
|
||||
|
||||
score = (#correct - #incorrect) / total
|
||||
|
||||
which rewards calibrated abstention — refusing to answer is neutral
|
||||
(0), guessing wrong is negative (-1). Grading is therefore a 3-class
|
||||
problem rather than the 2-class accuracy used for FRAMES.
|
||||
|
||||
Pipeline per (pred, gold, alt_ans, question_type):
|
||||
|
||||
1. Detect refusal first (``Answer: I don't know`` / "I don't know" /
|
||||
"no information") → ``missing`` (deterministic, never billed).
|
||||
2. ``false_premise`` questions: gold is canonically "the question
|
||||
contains a false premise" — reward any answer that flags the
|
||||
false premise (substring "false premise" / "incorrect premise" /
|
||||
"no such") as correct.
|
||||
3. Run the FRAMES-style deterministic shortcut (exact / numeric /
|
||||
substring) on ``pred`` against ``gold ∪ alt_ans``. Hit → correct.
|
||||
4. Fall through to the LLM judge (if configured), which returns one
|
||||
of ``{correct, missing, incorrect}`` — verbatim CRAG protocol.
|
||||
5. No judge configured → record ``incorrect`` (pessimistic but at
|
||||
least monotone with the deterministic grader).
|
||||
|
||||
The judge is throttled by an asyncio.Semaphore so it doesn't outrun
|
||||
the OpenRouter rate limit; the pre-judge deterministic pass keeps
|
||||
the bill bounded (most easy "Beyoncé"-vs-"Beyoncé Knowles" cases
|
||||
short-circuit before we burn judge tokens).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import string
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from ....core.providers.openrouter_chat import OpenRouterChatProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GradeClass = Literal["correct", "missing", "incorrect"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragGradeResult:
|
||||
"""One graded (pred, gold) pair under CRAG's 3-class rubric."""
|
||||
|
||||
grade: GradeClass
|
||||
score: int # +1 / 0 / -1
|
||||
method: str # exact, numeric, substring, refusal,
|
||||
# false_premise_correct, false_premise_miss,
|
||||
# llm_judge, lexical_miss, ...
|
||||
normalised_pred: str = ""
|
||||
normalised_gold: str = ""
|
||||
judge_rationale: str = ""
|
||||
|
||||
@property
|
||||
def correct(self) -> bool:
|
||||
return self.grade == "correct"
|
||||
|
||||
@property
|
||||
def missing(self) -> bool:
|
||||
return self.grade == "missing"
|
||||
|
||||
@property
|
||||
def incorrect(self) -> bool:
|
||||
return self.grade == "incorrect"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"grade": self.grade,
|
||||
"score": self.score,
|
||||
"method": self.method,
|
||||
"normalised_pred": self.normalised_pred,
|
||||
"normalised_gold": self.normalised_gold,
|
||||
"judge_rationale": self.judge_rationale,
|
||||
}
|
||||
|
||||
|
||||
def _grade_to_score(grade: GradeClass) -> int:
|
||||
return {"correct": 1, "missing": 0, "incorrect": -1}[grade]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
|
||||
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.IGNORECASE)
|
||||
_WS = re.compile(r"\s+")
|
||||
|
||||
|
||||
def _normalise(s: str) -> str:
|
||||
s = (s or "").lower()
|
||||
s = s.translate(_PUNCT_TABLE)
|
||||
s = _ARTICLES.sub(" ", s)
|
||||
s = _WS.sub(" ", s).strip()
|
||||
return s
|
||||
|
||||
|
||||
_WORD_NUMBERS = {
|
||||
"zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
|
||||
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11,
|
||||
"twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15, "sixteen": 16,
|
||||
"seventeen": 17, "eighteen": 18, "nineteen": 19, "twenty": 20,
|
||||
}
|
||||
|
||||
_NUMERIC_RE = re.compile(r"-?\d+(?:[.,]\d+)?")
|
||||
|
||||
|
||||
def _maybe_number(s: str) -> float | None:
|
||||
"""Extract a single numeric value from raw lowercased text."""
|
||||
|
||||
raw = (s or "").strip().lower()
|
||||
if not raw:
|
||||
return None
|
||||
match = _NUMERIC_RE.search(raw)
|
||||
if match:
|
||||
try:
|
||||
return float(match.group(0).replace(",", ""))
|
||||
except ValueError:
|
||||
pass
|
||||
for tok in _normalise(s).split():
|
||||
if tok in _WORD_NUMBERS:
|
||||
return float(_WORD_NUMBERS[tok])
|
||||
return None
|
||||
|
||||
|
||||
def _whole_word_substring(haystack: str, needle: str) -> bool:
|
||||
if not needle:
|
||||
return False
|
||||
return f" {needle} " in f" {haystack} "
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Refusal detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_REFUSAL_PATTERNS = [
|
||||
re.compile(r"\bi\s+don'?t\s+know\b", re.IGNORECASE),
|
||||
re.compile(r"\bi\s+do\s+not\s+know\b", re.IGNORECASE),
|
||||
re.compile(r"\bnot\s+enough\s+information\b", re.IGNORECASE),
|
||||
re.compile(r"\binsufficient\s+information\b", re.IGNORECASE),
|
||||
re.compile(r"\bcannot\s+(?:be\s+)?(?:answered|determined)\b", re.IGNORECASE),
|
||||
re.compile(r"\bunable\s+to\s+(?:answer|determine)\b", re.IGNORECASE),
|
||||
re.compile(r"\bno\s+(?:information|data|evidence)\b", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _is_refusal(pred: str) -> bool:
|
||||
"""Cheap deterministic check for "I don't know" -shaped responses."""
|
||||
|
||||
if not pred or not pred.strip():
|
||||
return True # empty answer is a de facto refusal
|
||||
return any(p.search(pred) for p in _REFUSAL_PATTERNS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# False-premise handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_FALSE_PREMISE_PATTERNS = [
|
||||
re.compile(r"false\s+premise", re.IGNORECASE),
|
||||
re.compile(r"incorrect\s+premise", re.IGNORECASE),
|
||||
re.compile(r"premise\s+(?:is|of)\s+the\s+question", re.IGNORECASE),
|
||||
re.compile(r"\bno\s+such\b", re.IGNORECASE),
|
||||
re.compile(r"never\s+(?:happened|occurred|existed)", re.IGNORECASE),
|
||||
re.compile(r"\bdid\s+not\s+(?:happen|occur|exist)\b", re.IGNORECASE),
|
||||
re.compile(r"\bdoes\s+not\s+exist\b", re.IGNORECASE),
|
||||
re.compile(r"is\s+not\s+(?:true|correct|accurate)", re.IGNORECASE),
|
||||
re.compile(r"\bisn'?t\s+(?:true|correct|accurate)\b", re.IGNORECASE),
|
||||
re.compile(r"\binvalid\s+(?:premise|question|assumption)\b", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _flags_false_premise(pred: str) -> bool:
|
||||
return any(p.search(pred) for p in _FALSE_PREMISE_PATTERNS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deterministic grader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def grade_deterministic(
|
||||
*,
|
||||
pred: str,
|
||||
gold: str,
|
||||
alt_answers: Sequence[str] = (),
|
||||
question_type: str = "",
|
||||
) -> CragGradeResult:
|
||||
"""Try to grade without the LLM judge. Returns a final result.
|
||||
|
||||
Always returns *some* result — the caller checks ``method`` to
|
||||
decide whether the LLM judge should overturn it. ``lexical_miss``
|
||||
and ``false_premise_unclear`` are the two methods that trigger the
|
||||
judge fallback.
|
||||
"""
|
||||
|
||||
qtype = (question_type or "").lower()
|
||||
n_pred = _normalise(pred)
|
||||
n_gold = _normalise(gold)
|
||||
|
||||
if _is_refusal(pred):
|
||||
# CRAG protocol: refusal is *missing* (0), even on false-premise
|
||||
# questions where one might argue refusal == correct. We
|
||||
# follow the paper's grading literally.
|
||||
return CragGradeResult(
|
||||
grade="missing",
|
||||
score=0,
|
||||
method="refusal",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
|
||||
# Empty-gold guard (shouldn't happen, but defensively):
|
||||
if not n_gold:
|
||||
return CragGradeResult(
|
||||
grade="incorrect",
|
||||
score=-1,
|
||||
method="empty_gold",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
|
||||
# False-premise questions: gold is typically "the question contains
|
||||
# a false premise" / "no such X" / similar. Any answer that
|
||||
# explicitly flags the false premise is correct.
|
||||
if qtype == "false_premise":
|
||||
if _flags_false_premise(pred):
|
||||
return CragGradeResult(
|
||||
grade="correct",
|
||||
score=1,
|
||||
method="false_premise_flagged",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
# If the model commits to *any* concrete answer on a false-
|
||||
# premise question without flagging the premise, it is wrong.
|
||||
# But we don't classify ourselves — let the judge decide on
|
||||
# the off chance the gold itself is e.g. "no" and the pred
|
||||
# is "no" without explicit "false premise" wording.
|
||||
return CragGradeResult(
|
||||
grade="incorrect",
|
||||
score=-1,
|
||||
method="false_premise_unclear",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
|
||||
# All non-false-premise questions: try the standard chain against
|
||||
# gold and each alt answer. First match wins.
|
||||
candidates = [gold, *list(alt_answers)]
|
||||
for candidate in candidates:
|
||||
if not candidate or not str(candidate).strip():
|
||||
continue
|
||||
cand_norm = _normalise(candidate)
|
||||
if not cand_norm:
|
||||
continue
|
||||
if n_pred == cand_norm:
|
||||
return CragGradeResult(
|
||||
grade="correct", score=1, method="exact",
|
||||
normalised_pred=n_pred, normalised_gold=cand_norm,
|
||||
)
|
||||
p_num = _maybe_number(pred)
|
||||
c_num = _maybe_number(candidate)
|
||||
if p_num is not None and c_num is not None:
|
||||
# Pure 1% relative tolerance for CRAG (currency, counts,
|
||||
# ratios). Unlike FRAMES (which uses a 0.5 absolute floor
|
||||
# for year-shaped answers), CRAG's numeric questions are
|
||||
# often small-value (stock prices, percentages) where a
|
||||
# 0.5 floor would let "$2.05" match "$2.17". The judge is
|
||||
# the safety net for borderline rounding cases.
|
||||
tol = abs(c_num) * 0.01
|
||||
if abs(p_num - c_num) <= tol:
|
||||
return CragGradeResult(
|
||||
grade="correct", score=1, method="numeric",
|
||||
normalised_pred=n_pred, normalised_gold=cand_norm,
|
||||
)
|
||||
# Numeric question with different numbers — keep looking
|
||||
# at other candidates rather than declaring miss now;
|
||||
# alt answers may include word forms that pass.
|
||||
if _whole_word_substring(n_pred, cand_norm):
|
||||
return CragGradeResult(
|
||||
grade="correct", score=1, method="substring",
|
||||
normalised_pred=n_pred, normalised_gold=cand_norm,
|
||||
)
|
||||
if _whole_word_substring(cand_norm, n_pred) and len(n_pred) >= 3:
|
||||
return CragGradeResult(
|
||||
grade="correct", score=1, method="substring_reverse",
|
||||
normalised_pred=n_pred, normalised_gold=cand_norm,
|
||||
)
|
||||
|
||||
return CragGradeResult(
|
||||
grade="incorrect",
|
||||
score=-1,
|
||||
method="lexical_miss",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM-as-judge (3-class)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_JUDGE_SYSTEM = (
|
||||
"You are an impartial grader for short-answer factual questions, "
|
||||
"following the CRAG benchmark rubric. Given a question, the gold "
|
||||
"answer (and any alternative valid answers), and a model's "
|
||||
"prediction, classify the prediction into exactly one of three "
|
||||
"categories:\n\n"
|
||||
"* \"correct\" — the prediction expresses the same factual "
|
||||
"content as the gold answer (paraphrasing OK; numbers as words "
|
||||
"OK; partial-but-correct names OK; non-contradictory extra "
|
||||
"detail OK).\n"
|
||||
"* \"missing\" — the prediction explicitly refuses, says \"I "
|
||||
"don't know\", says there is insufficient information, or hedges "
|
||||
"without committing.\n"
|
||||
"* \"incorrect\" — the prediction commits to a fact that is "
|
||||
"different from the gold answer, or fails to flag a false "
|
||||
"premise when the question contains one.\n\n"
|
||||
"Special case: if the question contains a false premise and the "
|
||||
"gold answer says so, then a prediction that flags the false "
|
||||
"premise is \"correct\".\n\n"
|
||||
"Respond with ONLY a JSON object on a single line:\n"
|
||||
'{\"grade\": \"correct\"|\"missing\"|\"incorrect\", \"rationale\": \"<one short sentence>\"}'
|
||||
)
|
||||
|
||||
|
||||
_JUDGE_TEMPLATE = """\
|
||||
Question: {question}
|
||||
Question type: {question_type}
|
||||
Gold answer: {gold}
|
||||
{alt_block}Model prediction: {pred}
|
||||
|
||||
Decide whether the prediction is correct, missing, or incorrect.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragJudgeConfig:
|
||||
api_key: str
|
||||
model: str = "anthropic/claude-sonnet-4.5"
|
||||
base_url: str = "https://openrouter.ai/api/v1"
|
||||
max_tokens: int = 200
|
||||
concurrency: int = 4
|
||||
|
||||
|
||||
class CragLlmJudge:
|
||||
"""Async LLM judge over OpenRouter chat completions, 3-class output."""
|
||||
|
||||
def __init__(self, *, config: CragJudgeConfig) -> None:
|
||||
self._config = config
|
||||
self._provider = OpenRouterChatProvider(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
model=config.model,
|
||||
)
|
||||
self._sem = asyncio.Semaphore(max(1, config.concurrency))
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._config.model
|
||||
|
||||
async def judge(
|
||||
self,
|
||||
*,
|
||||
question: str,
|
||||
gold: str,
|
||||
alt_answers: Sequence[str],
|
||||
pred: str,
|
||||
question_type: str = "",
|
||||
) -> tuple[GradeClass, str]:
|
||||
"""Return ``(grade, rationale)``. Errors return incorrect + reason."""
|
||||
|
||||
alt_block = ""
|
||||
if alt_answers:
|
||||
alt_lines = "\n".join(f" - {a}" for a in alt_answers if a)
|
||||
if alt_lines:
|
||||
alt_block = f"Alternative valid answers:\n{alt_lines}\n"
|
||||
prompt = _JUDGE_TEMPLATE.format(
|
||||
question=question,
|
||||
question_type=question_type or "unknown",
|
||||
gold=gold,
|
||||
alt_block=alt_block,
|
||||
pred=pred,
|
||||
)
|
||||
try:
|
||||
async with self._sem:
|
||||
response = await self._provider.complete(
|
||||
prompt=prompt,
|
||||
system_prompt=_JUDGE_SYSTEM,
|
||||
max_tokens=self._config.max_tokens,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return "incorrect", f"judge_error: {type(exc).__name__}: {exc}"
|
||||
return _parse_judge_response(response.text)
|
||||
|
||||
|
||||
def _parse_judge_response(text: str) -> tuple[GradeClass, str]:
|
||||
"""Parse the judge reply into a 3-class label + rationale."""
|
||||
|
||||
if not text or not text.strip():
|
||||
return "incorrect", "judge_returned_empty"
|
||||
match = re.search(r"\{[^{}]*\}", text, flags=re.DOTALL)
|
||||
candidate = match.group(0) if match else text
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
lowered = text.strip().lower()
|
||||
if "correct" in lowered and "incorrect" not in lowered:
|
||||
return "correct", "yes (parser_fallback)"
|
||||
if "missing" in lowered or "i don" in lowered:
|
||||
return "missing", "missing (parser_fallback)"
|
||||
return "incorrect", f"unparseable_judge_response: {text[:200]}"
|
||||
raw_grade = str(data.get("grade") or "").strip().lower()
|
||||
rationale = str(data.get("rationale", "")).strip()[:280]
|
||||
if raw_grade in {"correct", "missing", "incorrect"}:
|
||||
return raw_grade, rationale # type: ignore[return-value]
|
||||
return "incorrect", f"unknown_grade={raw_grade!r}; {rationale}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Combined grader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Methods that should *not* trigger the LLM judge — the deterministic
|
||||
# verdict is conclusive (refusal, exact match, numeric mismatch, etc.).
|
||||
_TERMINAL_METHODS = frozenset({
|
||||
"refusal",
|
||||
"exact",
|
||||
"numeric",
|
||||
"substring",
|
||||
"substring_reverse",
|
||||
"false_premise_flagged",
|
||||
"empty_gold",
|
||||
})
|
||||
|
||||
|
||||
async def grade_with_judge(
|
||||
*,
|
||||
pred: str,
|
||||
gold: str,
|
||||
alt_answers: Sequence[str],
|
||||
question: str,
|
||||
question_type: str,
|
||||
judge: CragLlmJudge | None,
|
||||
) -> CragGradeResult:
|
||||
"""One row → deterministic shortcut → optional LLM judge fallback."""
|
||||
|
||||
det = grade_deterministic(
|
||||
pred=pred,
|
||||
gold=gold,
|
||||
alt_answers=alt_answers,
|
||||
question_type=question_type,
|
||||
)
|
||||
if det.method in _TERMINAL_METHODS:
|
||||
return det
|
||||
if judge is None:
|
||||
return det # ``lexical_miss`` / ``false_premise_unclear`` → keep as-is
|
||||
grade, rationale = await judge.judge(
|
||||
question=question,
|
||||
gold=gold,
|
||||
alt_answers=alt_answers,
|
||||
pred=pred,
|
||||
question_type=question_type,
|
||||
)
|
||||
return CragGradeResult(
|
||||
grade=grade,
|
||||
score=_grade_to_score(grade),
|
||||
method="llm_judge",
|
||||
normalised_pred=det.normalised_pred,
|
||||
normalised_gold=det.normalised_gold,
|
||||
judge_rationale=rationale,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragGradeRow:
|
||||
"""One row to grade. Mirrors the FRAMES grader's tuple but typed."""
|
||||
|
||||
qid: str
|
||||
question: str
|
||||
gold: str
|
||||
alt_answers: list[str]
|
||||
pred: str
|
||||
question_type: str = ""
|
||||
|
||||
|
||||
async def grade_many(
|
||||
*,
|
||||
rows: Sequence[CragGradeRow],
|
||||
judge: CragLlmJudge | None,
|
||||
) -> list[CragGradeResult]:
|
||||
"""Grade every row concurrently. Judge enforces its own concurrency cap."""
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
coros = [
|
||||
grade_with_judge(
|
||||
pred=r.pred,
|
||||
gold=r.gold,
|
||||
alt_answers=r.alt_answers,
|
||||
question=r.question,
|
||||
question_type=r.question_type,
|
||||
judge=judge,
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
return list(await asyncio.gather(*coros))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CragGradeResult",
|
||||
"CragGradeRow",
|
||||
"CragJudgeConfig",
|
||||
"CragLlmJudge",
|
||||
"GradeClass",
|
||||
"grade_deterministic",
|
||||
"grade_many",
|
||||
"grade_with_judge",
|
||||
]
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
"""HTML → markdown for CRAG pages, with boilerplate removal.
|
||||
|
||||
Each CRAG page is a *full* HTML document (nav, ads, recommended-for-
|
||||
you, footer, ...). Without removing that boilerplate, retrieval over
|
||||
the chunks would surface menu items and "subscribe to our newsletter"
|
||||
boxes instead of the actual page content. We use ``trafilatura``,
|
||||
which is purpose-built for main-content extraction (the same library
|
||||
Common Crawl downstream pipelines use). It outputs clean prose with
|
||||
section headers, lists, and tables preserved.
|
||||
|
||||
Extraction policy:
|
||||
1. ``trafilatura.extract`` with ``output_format="markdown"`` — main
|
||||
content only, headers preserved, tables kept.
|
||||
2. If extraction fails or returns < 200 chars (paywalled / JS-only
|
||||
page / extraction confused), fall back to a plain stdlib
|
||||
``HTMLParser`` that strips tags and collapses whitespace. Some
|
||||
text is better than no text — SurfSense's chunker handles noisy
|
||||
prose.
|
||||
|
||||
We *intentionally* keep the page name and URL as visible H1 / link
|
||||
metadata so the SurfSense chunker preserves doc identity at the top of
|
||||
the first chunk (mirrors what we do for FRAMES Wikipedia pages).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from html.parser import HTMLParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_MIN_TRAFILATURA_LENGTH = 200
|
||||
_MAX_OUTPUT_CHARS = 200_000 # cap to keep upload payloads sane
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""Outcome of converting one HTML blob to plain markdown."""
|
||||
|
||||
text: str
|
||||
method: str # "trafilatura" | "fallback_strip" | "empty"
|
||||
n_chars: int
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return self.n_chars > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trafilatura wrapper (lazy import so tests / small scripts don't pay)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _trafilatura_extract(html_text: str, *, url: str) -> str | None:
|
||||
try:
|
||||
import trafilatura
|
||||
except ImportError: # pragma: no cover - dependency is required
|
||||
logger.warning("trafilatura not installed; falling back to strip-tags only")
|
||||
return None
|
||||
try:
|
||||
text = trafilatura.extract(
|
||||
html_text,
|
||||
url=url or None,
|
||||
output_format="markdown",
|
||||
include_links=False,
|
||||
include_images=False,
|
||||
include_tables=True,
|
||||
favor_recall=True,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - trafilatura raises a zoo
|
||||
logger.debug("trafilatura.extract crashed for %s: %s", url, exc)
|
||||
return None
|
||||
if not text:
|
||||
return None
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stdlib fallback: strip HTML tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StripHTMLParser(HTMLParser):
|
||||
"""Collect text content, treating block tags as paragraph breaks.
|
||||
|
||||
We deliberately drop ``<script>``, ``<style>``, ``<nav>``,
|
||||
``<header>``, ``<footer>``, and ``<aside>`` content — these are
|
||||
almost always boilerplate and they are the dominant source of
|
||||
noise SurfSense ends up retrieving against if not removed.
|
||||
"""
|
||||
|
||||
_SKIP_TAGS = frozenset({"script", "style", "nav", "header", "footer", "aside", "svg"})
|
||||
_BLOCK_TAGS = frozenset({
|
||||
"p", "div", "section", "article", "li", "ul", "ol",
|
||||
"h1", "h2", "h3", "h4", "h5", "h6", "br", "tr",
|
||||
"td", "th", "table", "blockquote", "pre",
|
||||
})
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(convert_charrefs=True)
|
||||
self._buffer: list[str] = []
|
||||
self._skip_depth: int = 0
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list) -> None: # noqa: ARG002
|
||||
if tag in self._SKIP_TAGS:
|
||||
self._skip_depth += 1
|
||||
if tag in self._BLOCK_TAGS:
|
||||
self._buffer.append("\n")
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag in self._SKIP_TAGS and self._skip_depth > 0:
|
||||
self._skip_depth -= 1
|
||||
if tag in self._BLOCK_TAGS:
|
||||
self._buffer.append("\n")
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._skip_depth:
|
||||
return
|
||||
self._buffer.append(data)
|
||||
|
||||
def get_text(self) -> str:
|
||||
text = "".join(self._buffer)
|
||||
# Decode any leftover entities and collapse whitespace.
|
||||
text = html.unescape(text)
|
||||
text = re.sub(r"[ \t]+", " ", text)
|
||||
text = re.sub(r"\n[ \t]+", "\n", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _strip_tags(html_text: str) -> str:
|
||||
parser = _StripHTMLParser()
|
||||
try:
|
||||
parser.feed(html_text)
|
||||
except Exception as exc: # noqa: BLE001 - HTMLParser is fragile on garbage input
|
||||
logger.debug("HTMLParser failed; using regex strip: %s", exc)
|
||||
no_tags = re.sub(r"<[^>]+>", " ", html_text)
|
||||
return re.sub(r"\s+", " ", html.unescape(no_tags)).strip()
|
||||
return parser.get_text()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_main_content(
|
||||
html_text: str,
|
||||
*,
|
||||
url: str = "",
|
||||
page_name: str = "",
|
||||
last_modified: str | None = None,
|
||||
) -> ExtractionResult:
|
||||
"""Convert one HTML blob into clean markdown for ingest.
|
||||
|
||||
The returned ``text`` is prefixed with a small metadata header
|
||||
(``# {page_name}\\n\\nSource: {url}\\n``) so that:
|
||||
|
||||
* SurfSense's chunker has a stable doc-identity anchor at the top
|
||||
of the first chunk (matches what we do for FRAMES Wikipedia).
|
||||
* The retrieval-augmented arm sees the URL inline, which the LLM
|
||||
can surface as a citation if the prompt asks for one.
|
||||
"""
|
||||
|
||||
body = ""
|
||||
method = "empty"
|
||||
if html_text and html_text.strip():
|
||||
body = _trafilatura_extract(html_text, url=url) or ""
|
||||
if body and len(body) >= _MIN_TRAFILATURA_LENGTH:
|
||||
method = "trafilatura"
|
||||
else:
|
||||
stripped = _strip_tags(html_text)
|
||||
# Prefer trafilatura output even if short, but only if it
|
||||
# contained any prose at all — empty trafilatura fall-through
|
||||
# to the stripped form.
|
||||
if body and stripped and len(stripped) > len(body) * 1.5:
|
||||
body = stripped
|
||||
method = "fallback_strip"
|
||||
elif not body and stripped:
|
||||
body = stripped
|
||||
method = "fallback_strip"
|
||||
elif body:
|
||||
method = "trafilatura"
|
||||
|
||||
body = body.strip()
|
||||
if len(body) > _MAX_OUTPUT_CHARS:
|
||||
body = body[:_MAX_OUTPUT_CHARS] + "\n\n[...truncated...]"
|
||||
|
||||
if not body:
|
||||
return ExtractionResult(text="", method="empty", n_chars=0)
|
||||
|
||||
title_line = (page_name or url or "Untitled").strip()
|
||||
header_lines = [f"# {title_line}"]
|
||||
if url:
|
||||
header_lines.append(f"Source: {url}")
|
||||
if last_modified:
|
||||
header_lines.append(f"Last modified: {last_modified}")
|
||||
final = "\n".join(header_lines) + "\n\n" + body + "\n"
|
||||
return ExtractionResult(text=final, method=method, n_chars=len(final))
|
||||
|
||||
|
||||
__all__ = ["ExtractionResult", "extract_main_content"]
|
||||
|
|
@ -0,0 +1,447 @@
|
|||
"""CRAG ingestion: download → extract → upload → per-question doc map.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Download ``crag_task_1_and_2_dev_v4.jsonl.bz2`` from
|
||||
``facebookresearch/CRAG`` (skip if cached).
|
||||
2. Stream-parse into ``CragQuestion`` objects.
|
||||
3. Optionally cap to ``--n-questions N`` (and *stratified* sample
|
||||
across ``(domain, question_type)`` so the smoke / partial run
|
||||
isn't dominated by ``finance`` or ``simple``).
|
||||
4. For each question, extract the 5 web pages to clean markdown via
|
||||
``trafilatura`` and write them to
|
||||
``<bench_dir>/pages/<qid>__<page_idx>__<url_hash>.md``. The
|
||||
filename is unique across the whole sample (so SurfSense's
|
||||
``(filename, search_space)`` dedup never collides between
|
||||
questions) and round-trippable (the ``<qid>__`` prefix lets the
|
||||
ingest infer doc-membership at the title level even before we
|
||||
land on a stable status response).
|
||||
5. Upload all extracted pages to SurfSense in batches with text-only
|
||||
ETL (``use_vision_llm=False, processing_mode="basic"``) — these
|
||||
are extracted plaintext, no images involved.
|
||||
6. Persist a doc map at
|
||||
``<suite_data>/maps/crag_doc_map.jsonl`` with one row per question:
|
||||
|
||||
{"qid": "C00042",
|
||||
"interaction_id": "<uuid>",
|
||||
"question": "<text>",
|
||||
"gold_answer": "<text>",
|
||||
"alt_answers": [...],
|
||||
"domain": "...", "question_type": "...",
|
||||
"static_or_dynamic": "...", "popularity": "...",
|
||||
"query_time": "...",
|
||||
"page_filenames": ["C00042__0__abc123.md", ...],
|
||||
"document_ids": [42101, 42102, ...],
|
||||
"missing_pages": [...] # filenames whose upload failed
|
||||
}
|
||||
|
||||
The runner uses ``document_ids`` to scope SurfSense retrieval to
|
||||
exactly the 5 pages of the question (matches CRAG protocol — the
|
||||
benchmark explicitly hands over its own retrieved pages).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.clients.documents import (
|
||||
DocumentProcessingFailed,
|
||||
DocumentProcessingTimeout,
|
||||
)
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
from .dataset import (
|
||||
CragPage,
|
||||
CragQuestion,
|
||||
download_task_1_2,
|
||||
iter_questions,
|
||||
stratified_sample,
|
||||
write_questions_jsonl,
|
||||
)
|
||||
from .html_extract import extract_main_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_FILENAME_SAFE = re.compile(r"[^A-Za-z0-9._\-]+")
|
||||
|
||||
|
||||
def _page_filename(qid: str, page_idx: int, page: CragPage) -> str:
|
||||
"""Filesystem-safe, globally unique markdown filename for a CRAG page.
|
||||
|
||||
Format: ``<qid>__<idx>__<url_hash>.md``. Both the qid (``C00042``)
|
||||
and the URL-hash (``[:12]``) are alphanumeric so we don't need to
|
||||
sanitise them, but we strip anything else just in case.
|
||||
"""
|
||||
|
||||
qid_safe = _FILENAME_SAFE.sub("_", qid)
|
||||
return f"{qid_safe}__{page_idx:02d}__{page.url_hash}.md"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _IngestStats:
|
||||
n_questions: int
|
||||
n_pages_total: int
|
||||
n_pages_extracted: int
|
||||
n_pages_empty: int
|
||||
n_uploaded: int
|
||||
n_existing: int
|
||||
bench_dir: Path
|
||||
map_path: Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Page extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _materialise_pages(
|
||||
questions: list[CragQuestion],
|
||||
*,
|
||||
pages_dir: Path,
|
||||
overwrite: bool = False,
|
||||
) -> tuple[dict[str, list[str]], dict[str, str]]:
|
||||
"""Extract every page in every question to ``pages_dir`` as markdown.
|
||||
|
||||
Returns:
|
||||
* ``qid -> [filename, filename, ...]`` (in page order, only
|
||||
successful extractions)
|
||||
* ``filename -> source_url`` for diagnostics
|
||||
|
||||
Empty extractions (paywall / JS / parse-fail with no fallback
|
||||
output) are skipped — better to retrieve from 4 pages than feed
|
||||
SurfSense's chunker an empty file.
|
||||
"""
|
||||
|
||||
pages_dir.mkdir(parents=True, exist_ok=True)
|
||||
qid_to_files: dict[str, list[str]] = {}
|
||||
file_to_url: dict[str, str] = {}
|
||||
method_counts: dict[str, int] = {}
|
||||
n_empty = 0
|
||||
|
||||
for q in questions:
|
||||
names: list[str] = []
|
||||
for idx, page in enumerate(q.pages):
|
||||
filename = _page_filename(q.qid, idx, page)
|
||||
dest = pages_dir / filename
|
||||
if dest.exists() and dest.stat().st_size > 0 and not overwrite:
|
||||
method_counts["cache_hit"] = method_counts.get("cache_hit", 0) + 1
|
||||
names.append(filename)
|
||||
file_to_url[filename] = page.page_url
|
||||
continue
|
||||
result = extract_main_content(
|
||||
page.page_html,
|
||||
url=page.page_url,
|
||||
page_name=page.page_name,
|
||||
last_modified=page.page_last_modified,
|
||||
)
|
||||
method_counts[result.method] = method_counts.get(result.method, 0) + 1
|
||||
if not result.ok:
|
||||
n_empty += 1
|
||||
continue
|
||||
dest.write_text(result.text, encoding="utf-8")
|
||||
names.append(filename)
|
||||
file_to_url[filename] = page.page_url
|
||||
qid_to_files[q.qid] = names
|
||||
|
||||
logger.info(
|
||||
"CRAG page extraction: %s; empty=%d, total_files=%d across %d questions",
|
||||
method_counts, n_empty, len(file_to_url), len(qid_to_files),
|
||||
)
|
||||
return qid_to_files, file_to_url
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _upload_pages(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
pages_dir: Path,
|
||||
filenames: list[str],
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
) -> dict[str, int]:
|
||||
"""Upload ``filenames`` (already on disk under ``pages_dir``) and return name → doc_id."""
|
||||
|
||||
if not filenames:
|
||||
return {}
|
||||
docs_client = ctx.documents_client()
|
||||
name_to_id: dict[str, int] = {}
|
||||
paths = [pages_dir / fn for fn in filenames if (pages_dir / fn).exists()]
|
||||
|
||||
for batch_start in range(0, len(paths), batch_size):
|
||||
batch = paths[batch_start : batch_start + batch_size]
|
||||
result = await docs_client.upload(
|
||||
files=batch,
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
|
||||
if result.document_ids:
|
||||
try:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=result.document_ids,
|
||||
timeout_s=900.0,
|
||||
)
|
||||
except (DocumentProcessingFailed, DocumentProcessingTimeout) as exc:
|
||||
logger.warning("CRAG batch processing issue: %s", exc)
|
||||
if all_ids:
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=all_ids,
|
||||
)
|
||||
for s in statuses:
|
||||
stem = Path(s.title).stem if s.title.endswith(".md") else s.title
|
||||
name_to_id[stem] = s.document_id
|
||||
name_to_id[s.title] = s.document_id
|
||||
if not s.title.endswith(".md"):
|
||||
name_to_id[f"{s.title}.md"] = s.document_id
|
||||
logger.info(
|
||||
"CRAG upload batch %d-%d: %d new, %d duplicate",
|
||||
batch_start, batch_start + len(batch),
|
||||
len(result.document_ids), len(result.duplicate_document_ids),
|
||||
)
|
||||
return name_to_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Doc map writer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_question_doc_ids(
|
||||
questions: list[CragQuestion],
|
||||
qid_to_files: dict[str, list[str]],
|
||||
name_to_id: dict[str, int],
|
||||
) -> list[dict[str, Any]]:
|
||||
rows: list[dict[str, Any]] = []
|
||||
for q in questions:
|
||||
filenames = qid_to_files.get(q.qid, [])
|
||||
doc_ids: list[int] = []
|
||||
missing: list[str] = []
|
||||
for fn in filenames:
|
||||
stem = Path(fn).stem
|
||||
doc_id = name_to_id.get(stem) or name_to_id.get(fn)
|
||||
if doc_id is not None and doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
else:
|
||||
missing.append(fn)
|
||||
rows.append({
|
||||
"qid": q.qid,
|
||||
"interaction_id": q.interaction_id,
|
||||
"raw_index": q.raw_index,
|
||||
"question": q.query,
|
||||
"gold_answer": q.gold_answer,
|
||||
"alt_answers": list(q.alt_answers),
|
||||
"domain": q.domain,
|
||||
"question_type": q.question_type,
|
||||
"static_or_dynamic": q.static_or_dynamic,
|
||||
"popularity": q.popularity,
|
||||
"query_time": q.query_time,
|
||||
"split": q.split,
|
||||
"page_filenames": filenames,
|
||||
"document_ids": doc_ids,
|
||||
"missing_pages": missing,
|
||||
"n_pages": len(filenames),
|
||||
})
|
||||
return rows
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
n_questions: int | None = None,
|
||||
upload_batch_size: int = 16,
|
||||
skip_upload: bool = False,
|
||||
overwrite_extract: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
sample_seed: int = 17,
|
||||
) -> None:
|
||||
"""Ingest the CRAG benchmark (Tasks 1 & 2) into the research suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_questions
|
||||
Cap on the number of CRAG questions to materialise.
|
||||
``None`` = all 2,706 (~13,500 pages — large; smoke runs
|
||||
should pass 10-20 and full runs ~200).
|
||||
upload_batch_size
|
||||
Markdown files per ``/documents/fileupload`` call.
|
||||
skip_upload
|
||||
Extract + cache markdown locally but don't push to SurfSense
|
||||
(useful for debugging the extraction step).
|
||||
overwrite_extract
|
||||
Re-run trafilatura even when a cached markdown file exists.
|
||||
Default False so re-running ingest is idempotent.
|
||||
settings
|
||||
Override per-upload knobs. CRAG defaults to text-only basic
|
||||
ETL — these are *extracted* plaintext, no images.
|
||||
sample_seed
|
||||
RNG seed for ``stratified_sample``. Pin this for reproducibility.
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
pages_dir = bench_dir / "pages"
|
||||
raw_cache = bench_dir / ".raw_cache"
|
||||
raw_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
bz2_path = download_task_1_2(raw_cache)
|
||||
logger.info("CRAG: parsing %s ...", bz2_path.name)
|
||||
all_questions = iter_questions(bz2_path)
|
||||
if not all_questions:
|
||||
raise RuntimeError(
|
||||
"CRAG JSONL contained no parseable rows; upstream may have changed schema."
|
||||
)
|
||||
logger.info("CRAG: parsed %d total questions", len(all_questions))
|
||||
|
||||
if n_questions is not None and n_questions > 0:
|
||||
questions = stratified_sample(all_questions, n=n_questions, seed=sample_seed)
|
||||
logger.info(
|
||||
"CRAG: stratified sample of %d questions across %d (domain, qtype) buckets",
|
||||
len(questions),
|
||||
len({(q.domain, q.question_type) for q in questions}),
|
||||
)
|
||||
else:
|
||||
questions = all_questions
|
||||
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
write_questions_jsonl(questions, questions_jsonl)
|
||||
|
||||
n_pages_total = sum(len(q.pages) for q in questions)
|
||||
logger.info(
|
||||
"CRAG: extracting up to %d pages across %d questions ...",
|
||||
n_pages_total, len(questions),
|
||||
)
|
||||
qid_to_files, file_to_url = _materialise_pages(
|
||||
questions, pages_dir=pages_dir, overwrite=overwrite_extract,
|
||||
)
|
||||
n_pages_extracted = sum(len(v) for v in qid_to_files.values())
|
||||
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("CRAG: --skip-upload; skipping SurfSense ingestion")
|
||||
else:
|
||||
all_filenames = sorted({fn for fns in qid_to_files.values() for fn in fns})
|
||||
logger.info("CRAG: uploading %d unique pages ...", len(all_filenames))
|
||||
name_to_id = await _upload_pages(
|
||||
ctx,
|
||||
pages_dir=pages_dir,
|
||||
filenames=all_filenames,
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
doc_rows = _resolve_question_doc_ids(questions, qid_to_files, name_to_id)
|
||||
map_path = ctx.maps_dir() / "crag_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for row in doc_rows:
|
||||
fh.write(json.dumps(row) + "\n")
|
||||
logger.info("Wrote CRAG doc map to %s (%d rows)", map_path, len(doc_rows))
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["crag"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
stats = _IngestStats(
|
||||
n_questions=len(questions),
|
||||
n_pages_total=n_pages_total,
|
||||
n_pages_extracted=n_pages_extracted,
|
||||
n_pages_empty=n_pages_total - n_pages_extracted,
|
||||
n_uploaded=len(name_to_id),
|
||||
n_existing=0,
|
||||
bench_dir=bench_dir,
|
||||
map_path=map_path,
|
||||
)
|
||||
logger.info("CRAG ingest done: %s", stats)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# For runner: read extracted page text back from disk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_page_markdown(bench_dir: Path, filename: str) -> str | None:
|
||||
"""Return the on-disk markdown body for a previously-extracted page.
|
||||
|
||||
Used by the long-context runner arm to assemble the prompt at
|
||||
inference time — we don't keep all 5×N pages in memory between
|
||||
ingest and run.
|
||||
"""
|
||||
|
||||
path = bench_dir / "pages" / filename
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
async def _retry_upload_idempotent( # noqa: D401 - hidden helper
|
||||
ctx: RunContext,
|
||||
*,
|
||||
pages_dir: Path,
|
||||
filenames: list[str],
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
max_attempts: int = 2,
|
||||
) -> dict[str, int]:
|
||||
"""Future-proofing hook (unused today): retry the ingest upload pass."""
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return await _upload_pages(
|
||||
ctx,
|
||||
pages_dir=pages_dir,
|
||||
filenames=filenames,
|
||||
batch_size=batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_exc = exc
|
||||
logger.warning("CRAG upload attempt %d failed: %s", attempt + 1, exc)
|
||||
await asyncio.sleep(2.0 * (attempt + 1))
|
||||
if last_exc is not None:
|
||||
raise last_exc
|
||||
return {}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_IngestStats",
|
||||
"_materialise_pages",
|
||||
"_page_filename",
|
||||
"_resolve_question_doc_ids",
|
||||
"_upload_pages",
|
||||
"read_page_markdown",
|
||||
"run_ingest",
|
||||
]
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
"""CRAG Task 3 ingestion: 4-part download → streaming JSONL → upload.
|
||||
|
||||
Same flow as ``ingest.run_ingest`` for Tasks 1 & 2 (extract HTML →
|
||||
upload markdown → resolve doc_ids → write doc map), but:
|
||||
|
||||
* Source: 4 .tar.bz2 parts streamed via ``dataset_task3``.
|
||||
* Page count: 50 per question instead of 5 — the whole point of
|
||||
Task 3 (the long-context arm now structurally has to choose what
|
||||
to keep, while SurfSense's retrieval becomes mandatory).
|
||||
* Stratified sampling re-uses the Task 1 helper since the question
|
||||
schema is identical.
|
||||
|
||||
Doc map lands at ``<suite_data>/maps/crag_t3_doc_map.jsonl`` with the
|
||||
same row shape as Task 1's map (so the runner only needs to know
|
||||
which file to load; everything else is shared).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
from .dataset import stratified_sample, write_questions_jsonl
|
||||
from .dataset_task3 import (
|
||||
CRAG_TASK_3_PART_NAMES,
|
||||
iter_questions_task3,
|
||||
parts_present,
|
||||
)
|
||||
from .ingest import (
|
||||
_IngestStats,
|
||||
_materialise_pages,
|
||||
_resolve_question_doc_ids,
|
||||
_upload_pages,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_INSTRUCTIONS_TO_DOWNLOAD = (
|
||||
"Run `python scripts/download_crag_task3.py` first to fetch the "
|
||||
"4 tar.bz2 parts (~7 GB total) into "
|
||||
"data/research/crag_t3/.raw_cache/. The downloader is idempotent "
|
||||
"and parallel."
|
||||
)
|
||||
|
||||
|
||||
async def run_ingest_task3(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
n_questions: int | None = None,
|
||||
upload_batch_size: int = 16,
|
||||
skip_upload: bool = False,
|
||||
overwrite_extract: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
sample_seed: int = 17,
|
||||
parse_cap: int | None = None,
|
||||
) -> None:
|
||||
"""Ingest CRAG Task 3 (50 pages per question) into the research suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_questions
|
||||
Cap on the post-stratified-sample question count. ``None`` =
|
||||
"use whatever ``parse_cap`` produced". For real runs aim for
|
||||
50 (~2,500 pages) — n=200 (10k pages) is doable but slow.
|
||||
parse_cap
|
||||
Hard cap on how many rows we *parse* from the streaming
|
||||
archive before stratified sampling. Defaults to
|
||||
``max(400, 6*n_questions)`` — enough to cover all (domain,
|
||||
question_type) buckets ~5x but small enough to fit in the
|
||||
first shard or two (each shard is ≈5 GB decompressed and
|
||||
holds ~300 rows; bz2 throughput is ~50 MB/s). Lowering this
|
||||
is the only knob that bounds streaming cost since we can
|
||||
``break`` out of the JSONL stream early without decompressing
|
||||
the rest of the ~50 GB archive body.
|
||||
upload_batch_size
|
||||
Markdown files per ``/documents/fileupload`` call.
|
||||
skip_upload
|
||||
Extract markdown locally, don't push to SurfSense.
|
||||
overwrite_extract
|
||||
Re-run trafilatura even when a cached markdown is present.
|
||||
settings
|
||||
Per-upload knobs override (default: text-only basic ETL).
|
||||
sample_seed
|
||||
RNG seed for stratified sampling (deterministic).
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
pages_dir = bench_dir / "pages"
|
||||
raw_cache = bench_dir / ".raw_cache"
|
||||
raw_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not parts_present(raw_cache):
|
||||
missing = [
|
||||
n for n in CRAG_TASK_3_PART_NAMES
|
||||
if not (raw_cache / n).exists()
|
||||
]
|
||||
raise RuntimeError(
|
||||
f"CRAG Task 3 parts missing from {raw_cache}: {missing}. "
|
||||
f"{_INSTRUCTIONS_TO_DOWNLOAD}"
|
||||
)
|
||||
|
||||
# 1. Stream-parse (capped). For n=50 we don't need the full 2,706
|
||||
# rows — just enough that the stratified sampler can balance.
|
||||
# Each tar shard ~5 GB / ~300 rows / ~2 min decompress, so
|
||||
# 400-500 rows = shard 0 + a slice of shard 1 ≈ 3-4 min.
|
||||
parse_cap = parse_cap or (
|
||||
max(400, 6 * (n_questions or 50)) if n_questions else None
|
||||
)
|
||||
logger.info(
|
||||
"CRAG Task 3: streaming JSONL (parse_cap=%s) ...",
|
||||
parse_cap if parse_cap else "no-cap",
|
||||
)
|
||||
all_questions = iter_questions_task3(raw_cache, max_questions=parse_cap)
|
||||
logger.info("CRAG Task 3: parsed %d rows", len(all_questions))
|
||||
|
||||
if not all_questions:
|
||||
raise RuntimeError("CRAG Task 3 streaming returned 0 rows; check archive integrity.")
|
||||
|
||||
if n_questions is not None and n_questions > 0:
|
||||
questions = stratified_sample(all_questions, n=n_questions, seed=sample_seed)
|
||||
logger.info(
|
||||
"CRAG Task 3: stratified sample of %d questions across %d (domain, qtype) buckets",
|
||||
len(questions),
|
||||
len({(q.domain, q.question_type) for q in questions}),
|
||||
)
|
||||
else:
|
||||
questions = all_questions
|
||||
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
write_questions_jsonl(questions, questions_jsonl)
|
||||
|
||||
n_pages_total = sum(len(q.pages) for q in questions)
|
||||
logger.info(
|
||||
"CRAG Task 3: extracting up to %d pages across %d questions ...",
|
||||
n_pages_total, len(questions),
|
||||
)
|
||||
qid_to_files, _file_to_url = _materialise_pages(
|
||||
questions, pages_dir=pages_dir, overwrite=overwrite_extract,
|
||||
)
|
||||
n_pages_extracted = sum(len(v) for v in qid_to_files.values())
|
||||
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("CRAG Task 3: --skip-upload; skipping SurfSense ingestion")
|
||||
else:
|
||||
all_filenames = sorted({fn for fns in qid_to_files.values() for fn in fns})
|
||||
logger.info("CRAG Task 3: uploading %d unique pages ...", len(all_filenames))
|
||||
name_to_id = await _upload_pages(
|
||||
ctx,
|
||||
pages_dir=pages_dir,
|
||||
filenames=all_filenames,
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
doc_rows = _resolve_question_doc_ids(questions, qid_to_files, name_to_id)
|
||||
map_path = ctx.maps_dir() / "crag_t3_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for row in doc_rows:
|
||||
fh.write(json.dumps(row) + "\n")
|
||||
logger.info("Wrote CRAG Task 3 doc map to %s (%d rows)", map_path, len(doc_rows))
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["crag_t3"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
stats = _IngestStats(
|
||||
n_questions=len(questions),
|
||||
n_pages_total=n_pages_total,
|
||||
n_pages_extracted=n_pages_extracted,
|
||||
n_pages_empty=n_pages_total - n_pages_extracted,
|
||||
n_uploaded=len(name_to_id),
|
||||
n_existing=0,
|
||||
bench_dir=bench_dir,
|
||||
map_path=map_path,
|
||||
)
|
||||
logger.info("CRAG Task 3 ingest done: %s", stats)
|
||||
|
||||
|
||||
__all__ = ["run_ingest_task3"]
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
"""CRAG prompt templates for the three competing arms.
|
||||
|
||||
The CRAG paper grades each prediction as one of:
|
||||
|
||||
* **correct** — answer matches gold (with paraphrasing tolerance)
|
||||
* **missing** — model refuses or says "I don't know"
|
||||
* **incorrect** — model commits to a wrong answer (hallucination)
|
||||
|
||||
The truthfulness score `(correct - incorrect) / total` rewards
|
||||
calibrated abstention, so the prompts below explicitly *invite* the
|
||||
model to refuse when it isn't confident — otherwise the bare-LLM arm
|
||||
gets penalised twice (no docs *and* a no-refusal prompt) and the
|
||||
comparison stops being fair to the LLM-only baseline.
|
||||
|
||||
Three templates, byte-identical instructions:
|
||||
|
||||
* ``build_bare_prompt(q)`` — question-only.
|
||||
* ``build_long_context_prompt(q, contexts)`` — question + concatenated
|
||||
page extracts, all stuffed into the user message. Mirrors the
|
||||
paper's "straightforward RAG" baseline.
|
||||
* ``build_surfsense_prompt(q)`` — question + a hint that retrieval
|
||||
over the question's 5 ingested pages is available; the SurfSense
|
||||
agent itself owns the retrieval step.
|
||||
|
||||
The ``Answer:`` line at the end is parsed by ``extract_freeform_answer``
|
||||
in the runner, so the format is mandatory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
_BASE_INSTRUCTIONS = (
|
||||
"You are a careful question-answering assistant. The question is a "
|
||||
"real-world factual question that may be about finance, music, "
|
||||
"movies, sports, or any other domain.\n\n"
|
||||
"Important rules:\n"
|
||||
"1. If the question contains a false premise (an assumption that "
|
||||
"is factually wrong), say so explicitly in your final answer "
|
||||
"rather than answering as if the premise were true.\n"
|
||||
"2. If you are not confident in an answer, prefer saying \"I don't "
|
||||
"know\" over guessing. A wrong commit is penalised more than a "
|
||||
"refusal.\n"
|
||||
"3. Keep the final answer short — a name, a number, a date, a "
|
||||
"phrase. Do not repeat the question.\n\n"
|
||||
"Format your final line EXACTLY as:\n"
|
||||
"Answer: <short answer>\n"
|
||||
"If you don't know, write `Answer: I don't know`."
|
||||
)
|
||||
|
||||
|
||||
_BARE_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
Question: {question}
|
||||
Question time: {query_time}
|
||||
"""
|
||||
|
||||
|
||||
_SURFSENSE_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
You have access to a search index of up to 5 web pages that were
|
||||
retrieved for this question. Use the retrieval tool to look up any
|
||||
facts you are not confident about. The pages may be partially or fully
|
||||
relevant; some may contradict each other (prefer the more authoritative
|
||||
or more recent source).
|
||||
|
||||
Question: {question}
|
||||
Question time: {query_time}
|
||||
"""
|
||||
|
||||
|
||||
_LONG_CONTEXT_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
You are given the full text of {n_contexts} web pages that were
|
||||
retrieved for this question. Read all of them, then answer. The
|
||||
pages may be partially or fully relevant; some may contradict each
|
||||
other (prefer the more authoritative or more recent source).
|
||||
|
||||
{contexts}
|
||||
|
||||
Question: {question}
|
||||
Question time: {query_time}
|
||||
"""
|
||||
|
||||
|
||||
def build_bare_prompt(question: str, *, query_time: str = "") -> str:
|
||||
"""Prompt for the no-retrieval baseline arm."""
|
||||
|
||||
return _BARE_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
question=question.strip(),
|
||||
query_time=query_time.strip() or "unknown",
|
||||
)
|
||||
|
||||
|
||||
def build_surfsense_prompt(question: str, *, query_time: str = "") -> str:
|
||||
"""Prompt for the SurfSense arm (agent does retrieval itself)."""
|
||||
|
||||
return _SURFSENSE_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
question=question.strip(),
|
||||
query_time=query_time.strip() or "unknown",
|
||||
)
|
||||
|
||||
|
||||
def build_long_context_prompt(
|
||||
question: str,
|
||||
*,
|
||||
contexts: list[tuple[str, str]],
|
||||
query_time: str = "",
|
||||
per_page_char_cap: int = 12_000,
|
||||
) -> str:
|
||||
"""Prompt for the "stuff all pages into the prompt" baseline.
|
||||
|
||||
``contexts`` is a list of ``(page_title_or_url, page_text)`` pairs.
|
||||
Each page is truncated at ``per_page_char_cap`` (default 12k chars
|
||||
≈ 3k tokens) so a 5-page CRAG question fits well under any
|
||||
modern long-context window with room for the question + reasoning.
|
||||
"""
|
||||
|
||||
blocks: list[str] = []
|
||||
for idx, (title, text) in enumerate(contexts, start=1):
|
||||
body = (text or "").strip()
|
||||
if len(body) > per_page_char_cap:
|
||||
body = body[:per_page_char_cap].rstrip() + "\n[...truncated...]"
|
||||
title_clean = (title or f"page_{idx}").strip().replace("\n", " ")
|
||||
blocks.append(
|
||||
f"--- PAGE {idx}: {title_clean} ---\n{body}\n"
|
||||
)
|
||||
contexts_block = "\n".join(blocks) if blocks else "(no pages retrieved)"
|
||||
return _LONG_CONTEXT_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
n_contexts=len(contexts),
|
||||
contexts=contexts_block,
|
||||
question=question.strip(),
|
||||
query_time=query_time.strip() or "unknown",
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_bare_prompt",
|
||||
"build_long_context_prompt",
|
||||
"build_surfsense_prompt",
|
||||
]
|
||||
1053
surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py
Normal file
1053
surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,29 @@
|
|||
"""FRAMES — multi-hop Wikipedia retrieval & reasoning (google/frames-benchmark).
|
||||
|
||||
Source: https://huggingface.co/datasets/google/frames-benchmark
|
||||
Paper: https://arxiv.org/abs/2409.12941 (Krishna et al., 2024)
|
||||
|
||||
* 824 multi-hop questions, each requiring 2-15 Wikipedia articles
|
||||
* 5 reasoning types: numerical, tabular, multiple constraints,
|
||||
temporal, post-processing
|
||||
* Published Gemini-Pro-1.5 baselines:
|
||||
- Naive prompting (no retrieval): 40.8%
|
||||
- BM25, top-4: 47.4%
|
||||
- Multi-step retrieval & reasoning: 66.0%
|
||||
- Oracle retrieval (gold articles): 72.9%
|
||||
|
||||
This is the benchmark that *finally* puts SurfSense's strongest claim
|
||||
on trial: cross-document iterative retrieval. The harness ingests
|
||||
every Wikipedia article referenced by any question in the run sample
|
||||
into a single SearchSpace; SurfSense answers without
|
||||
``mentioned_document_ids`` so its agent has to actually retrieve.
|
||||
The bare-LLM arm answers from the prompt only (the published 40.8%
|
||||
baseline number).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import registry as _registry
|
||||
from .runner import FramesBenchmark
|
||||
|
||||
_registry.register(FramesBenchmark())
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
"""FRAMES dataset loader — download ``test.tsv`` from HF and parse rows.
|
||||
|
||||
The HF repo (``google/frames-benchmark``) ships a single tab-separated
|
||||
file at ``test.tsv`` (824 rows). Columns of interest for us:
|
||||
|
||||
* unnamed first column → row index (``id`` we synthesise as ``Q000``..)
|
||||
* ``Prompt`` → the question (free-text, often multi-clause)
|
||||
* ``Answer`` → gold answer (short string: name, number, year, ...)
|
||||
* ``wikipedia_link_1`` ... ``wikipedia_link_11+`` → sparse per-question
|
||||
link cells (we ignore in favour of the consolidated column below).
|
||||
* ``reasoning_types`` → pipe-separated tags (``"Numerical reasoning |
|
||||
Tabular reasoning | Multiple constraints"``)
|
||||
* ``wiki_links`` → Python-list literal of every URL the question relies
|
||||
on, e.g. ``"['https://en.wikipedia.org/wiki/...', '...']"``
|
||||
|
||||
We use ``wiki_links`` (already deduplicated per row) and
|
||||
``ast.literal_eval`` to materialise it. The legacy
|
||||
``wikipedia_link_*`` columns are kept around only so a curious
|
||||
operator can compare cell-vs-list if upstream ever drift apart.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HF_REPO_ID = "google/frames-benchmark"
|
||||
HF_REPO_TYPE = "dataset"
|
||||
HF_TEST_FILE = "test.tsv"
|
||||
|
||||
|
||||
def _hf_hub_download(*args: Any, **kwargs: Any) -> str:
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
return hf_hub_download(*args, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class FramesQuestion:
|
||||
"""One row of FRAMES (post-parse)."""
|
||||
|
||||
qid: str # synthesised "Q000" .. "Q823"
|
||||
question: str
|
||||
gold_answer: str
|
||||
wiki_urls: list[str] # deduped, in original order
|
||||
reasoning_types: list[str] # split on "|"
|
||||
raw_index: int # row index from the TSV (for debugging)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"qid": self.qid,
|
||||
"question": self.question,
|
||||
"gold_answer": self.gold_answer,
|
||||
"wiki_urls": list(self.wiki_urls),
|
||||
"reasoning_types": list(self.reasoning_types),
|
||||
"raw_index": self.raw_index,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download + parse
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def download_test_tsv(cache_dir: Path) -> Path:
|
||||
"""Resumable download of ``test.tsv`` via ``huggingface_hub``."""
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
local = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=HF_TEST_FILE,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(cache_dir),
|
||||
)
|
||||
return Path(local)
|
||||
|
||||
|
||||
def _parse_wiki_links(raw: Any) -> list[str]:
|
||||
"""Convert the ``wiki_links`` cell (Python list literal) to ``list[str]``."""
|
||||
|
||||
if not raw:
|
||||
return []
|
||||
if isinstance(raw, list):
|
||||
return [str(x).strip() for x in raw if str(x).strip()]
|
||||
text = str(raw).strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
parsed = ast.literal_eval(text)
|
||||
except (SyntaxError, ValueError):
|
||||
# Fall back: maybe it's a comma-separated string with no quotes.
|
||||
return [tok.strip() for tok in text.strip("[]").split(",") if tok.strip()]
|
||||
if isinstance(parsed, (list, tuple)):
|
||||
return [str(x).strip() for x in parsed if str(x).strip()]
|
||||
return [str(parsed).strip()]
|
||||
|
||||
|
||||
def _parse_reasoning_types(raw: Any) -> list[str]:
|
||||
if not raw:
|
||||
return []
|
||||
text = str(raw).strip()
|
||||
if not text:
|
||||
return []
|
||||
return [tok.strip() for tok in text.split("|") if tok.strip()]
|
||||
|
||||
|
||||
def load_questions(tsv_path: Path) -> list[FramesQuestion]:
|
||||
"""Read FRAMES rows from disk into ``FramesQuestion`` objects.
|
||||
|
||||
Uses pandas for robust TSV parsing (tabs inside quoted strings are
|
||||
rare in this dataset but pandas handles them; the stdlib ``csv``
|
||||
module is fine too if pandas ever becomes a problem). We pin
|
||||
``index_col=0`` because the upstream TSV uses the first unnamed
|
||||
column as the row index.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_csv(tsv_path, sep="\t", index_col=0, keep_default_na=False)
|
||||
out: list[FramesQuestion] = []
|
||||
for raw_idx, row in df.iterrows():
|
||||
prompt = str(row.get("Prompt") or "").strip()
|
||||
answer = str(row.get("Answer") or "").strip()
|
||||
if not prompt or not answer:
|
||||
logger.debug("Skipping FRAMES row %s with missing Prompt/Answer", raw_idx)
|
||||
continue
|
||||
urls = _parse_wiki_links(row.get("wiki_links"))
|
||||
if not urls:
|
||||
# Fall back to the per-cell ``wikipedia_link_*`` columns.
|
||||
urls = []
|
||||
for col in row.index:
|
||||
if col.startswith("wikipedia_link"):
|
||||
val = str(row.get(col) or "").strip()
|
||||
if val and val not in urls:
|
||||
urls.append(val)
|
||||
reasoning = _parse_reasoning_types(row.get("reasoning_types"))
|
||||
out.append(FramesQuestion(
|
||||
qid=f"Q{int(raw_idx):03d}",
|
||||
question=prompt,
|
||||
gold_answer=answer,
|
||||
wiki_urls=urls,
|
||||
reasoning_types=reasoning,
|
||||
raw_index=int(raw_idx),
|
||||
))
|
||||
return out
|
||||
|
||||
|
||||
def write_questions_jsonl(questions: list[FramesQuestion], dest: Path) -> None:
|
||||
"""Persist a parsed copy under the benchmark data dir."""
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
with dest.open("w", encoding="utf-8") as fh:
|
||||
for q in questions:
|
||||
fh.write(json.dumps(q.to_dict()) + "\n")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FramesQuestion",
|
||||
"download_test_tsv",
|
||||
"load_questions",
|
||||
"write_questions_jsonl",
|
||||
]
|
||||
|
|
@ -0,0 +1,341 @@
|
|||
"""FRAMES grader: deterministic shortcut + LLM-as-judge fallback.
|
||||
|
||||
FRAMES gold answers are short factoids (a name, a year, an ordinal,
|
||||
a count). The published paper uses an LLM judge for grading, citing
|
||||
the long tail of paraphrasing ("Jane Ballou" vs "Mrs. Ballou (Jane)";
|
||||
"5" vs "five"; "London, England" vs "London"). We replicate that
|
||||
faithfully *but* avoid burning judge tokens on the obvious cases.
|
||||
|
||||
Pipeline per (pred, gold):
|
||||
|
||||
1. Normalise both sides (SQuAD-style).
|
||||
2. If normalised pred == normalised gold → CORRECT (``method=exact``).
|
||||
3. Numeric path: if both extract to a single number and the values
|
||||
match within 1% relative tolerance → CORRECT (``method=numeric``).
|
||||
4. Substring path: if normalised gold appears as a *whole-word phrase*
|
||||
inside normalised pred (or vice versa) → CORRECT
|
||||
(``method=substring``).
|
||||
5. Otherwise → call the LLM judge if a judge is wired; the judge
|
||||
returns yes/no with a one-line rationale.
|
||||
6. If no judge is configured, fall through to ``False``
|
||||
(``method=lexical_miss``).
|
||||
|
||||
The judge is called *concurrently* across the run via a semaphore (so
|
||||
it doesn't outrun the upstream rate limit). Cached on
|
||||
``(arm, qid)`` so re-running ``report`` doesn't re-judge.
|
||||
|
||||
Returned shape mirrors ``mmlongbench.grader.GradeResult`` to keep
|
||||
report writers uniform across benchmarks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import string
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from ....core.providers.openrouter_chat import OpenRouterChatProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradeResult:
|
||||
"""Shape mirrors mmlongbench.grader.GradeResult for report uniformity."""
|
||||
|
||||
correct: bool
|
||||
f1: float
|
||||
method: str
|
||||
normalised_pred: str = ""
|
||||
normalised_gold: str = ""
|
||||
judge_rationale: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"correct": self.correct,
|
||||
"f1": self.f1,
|
||||
"method": self.method,
|
||||
"normalised_pred": self.normalised_pred,
|
||||
"normalised_gold": self.normalised_gold,
|
||||
"judge_rationale": self.judge_rationale,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
|
||||
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.IGNORECASE)
|
||||
_WS = re.compile(r"\s+")
|
||||
|
||||
|
||||
def _normalise(s: str) -> str:
|
||||
s = (s or "").lower()
|
||||
s = s.translate(_PUNCT_TABLE)
|
||||
s = _ARTICLES.sub(" ", s)
|
||||
s = _WS.sub(" ", s).strip()
|
||||
return s
|
||||
|
||||
|
||||
_WORD_NUMBERS = {
|
||||
"zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
|
||||
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11,
|
||||
"twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15, "sixteen": 16,
|
||||
"seventeen": 17, "eighteen": 18, "nineteen": 19, "twenty": 20,
|
||||
}
|
||||
|
||||
_NUMERIC_RE = re.compile(r"-?\d+(?:[.,]\d+)?")
|
||||
|
||||
|
||||
def _maybe_number(s: str) -> float | None:
|
||||
"""Extract a single numeric value, recognising digit and word forms.
|
||||
|
||||
Operates on the lowercased *raw* text (rather than the
|
||||
punctuation-stripped normalisation) so that thousands separators
|
||||
like ``1,234`` are preserved through the regex and parsed
|
||||
correctly. We only fall back to ``_normalise`` for the word-number
|
||||
pass, which doesn't care about punctuation.
|
||||
"""
|
||||
|
||||
raw = (s or "").strip().lower()
|
||||
if not raw:
|
||||
return None
|
||||
match = _NUMERIC_RE.search(raw)
|
||||
if match:
|
||||
try:
|
||||
return float(match.group(0).replace(",", ""))
|
||||
except ValueError:
|
||||
pass
|
||||
for tok in _normalise(s).split():
|
||||
if tok in _WORD_NUMBERS:
|
||||
return float(_WORD_NUMBERS[tok])
|
||||
return None
|
||||
|
||||
|
||||
def _whole_word_substring(haystack: str, needle: str) -> bool:
|
||||
"""Is ``needle`` present as a whole-word phrase in ``haystack``?"""
|
||||
|
||||
if not needle:
|
||||
return False
|
||||
pad_h = f" {haystack} "
|
||||
pad_n = f" {needle} "
|
||||
return pad_n in pad_h
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deterministic shortcut
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def grade_deterministic(*, pred: str, gold: str) -> GradeResult:
|
||||
"""Try to grade without the LLM judge. Returns a final-result object.
|
||||
|
||||
A ``False`` result with ``method == "lexical_miss"`` is the signal
|
||||
to the caller that the LLM judge should be consulted (if available).
|
||||
"""
|
||||
|
||||
if not (pred or "").strip():
|
||||
return GradeResult(False, 0.0, "empty_pred", "", _normalise(gold))
|
||||
|
||||
p = _normalise(pred)
|
||||
g = _normalise(gold)
|
||||
if not g:
|
||||
# Defensively: gold should never be empty; if it is, we can't grade.
|
||||
return GradeResult(False, 0.0, "empty_gold", p, g)
|
||||
|
||||
if p == g:
|
||||
return GradeResult(True, 1.0, "exact", p, g)
|
||||
|
||||
p_num = _maybe_number(pred)
|
||||
g_num = _maybe_number(gold)
|
||||
if p_num is not None and g_num is not None:
|
||||
# 1% relative tolerance, 0.5 absolute floor (handles year-ish answers).
|
||||
tol = max(abs(g_num) * 0.01, 0.5)
|
||||
if abs(p_num - g_num) <= tol:
|
||||
return GradeResult(True, 1.0, "numeric", p, g)
|
||||
return GradeResult(False, 0.0, "numeric_miss", p, g)
|
||||
|
||||
if _whole_word_substring(p, g):
|
||||
return GradeResult(True, 1.0, "substring", p, g)
|
||||
if _whole_word_substring(g, p) and len(p) >= 3:
|
||||
# Be conservative the other direction — only credit if pred is
|
||||
# at least 3 normalised chars (avoids "John" matching gold
|
||||
# "John F. Kennedy" as correct).
|
||||
return GradeResult(True, 1.0, "substring_reverse", p, g)
|
||||
|
||||
return GradeResult(False, 0.0, "lexical_miss", p, g)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM-as-judge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_JUDGE_SYSTEM = (
|
||||
"You are an impartial grader for short-answer factual questions. "
|
||||
"Given a question, the gold answer, and a model's prediction, "
|
||||
"decide whether the prediction is correct. The prediction is "
|
||||
"correct if it expresses the same factual content as the gold "
|
||||
"answer, allowing for paraphrasing, surface-level differences "
|
||||
"(numbers as words, names with/without titles), and additional "
|
||||
"non-contradictory detail. The prediction is incorrect if it "
|
||||
"expresses a different fact, omits the central answer, or hedges "
|
||||
"without committing.\n\n"
|
||||
"Respond with ONLY a JSON object on a single line:\n"
|
||||
'{\"correct\": true|false, \"rationale\": \"<one short sentence>\"}'
|
||||
)
|
||||
|
||||
|
||||
_JUDGE_TEMPLATE = """\
|
||||
Question: {question}
|
||||
Gold answer: {gold}
|
||||
Model prediction: {pred}
|
||||
|
||||
Decide whether the prediction is correct.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class JudgeConfig:
|
||||
"""Configuration handed to ``LlmJudge`` at construction time."""
|
||||
|
||||
api_key: str
|
||||
model: str = "anthropic/claude-sonnet-4.5"
|
||||
base_url: str = "https://openrouter.ai/api/v1"
|
||||
max_tokens: int = 200
|
||||
concurrency: int = 4
|
||||
|
||||
|
||||
class LlmJudge:
|
||||
"""Async LLM judge over OpenRouter chat completions."""
|
||||
|
||||
def __init__(self, *, config: JudgeConfig) -> None:
|
||||
self._config = config
|
||||
self._provider = OpenRouterChatProvider(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
model=config.model,
|
||||
)
|
||||
self._sem = asyncio.Semaphore(max(1, config.concurrency))
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._config.model
|
||||
|
||||
async def judge(
|
||||
self,
|
||||
*,
|
||||
question: str,
|
||||
gold: str,
|
||||
pred: str,
|
||||
) -> tuple[bool, str]:
|
||||
"""Return ``(is_correct, rationale)``. Errors return False + reason."""
|
||||
|
||||
prompt = _JUDGE_TEMPLATE.format(question=question, gold=gold, pred=pred)
|
||||
try:
|
||||
async with self._sem:
|
||||
response = await self._provider.complete(
|
||||
prompt=prompt,
|
||||
system_prompt=_JUDGE_SYSTEM,
|
||||
max_tokens=self._config.max_tokens,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return False, f"judge_error: {type(exc).__name__}: {exc}"
|
||||
return _parse_judge_response(response.text)
|
||||
|
||||
|
||||
def _parse_judge_response(text: str) -> tuple[bool, str]:
|
||||
"""Pull ``correct`` + ``rationale`` out of the judge's reply."""
|
||||
|
||||
if not text or not text.strip():
|
||||
return False, "judge_returned_empty"
|
||||
# Accept JSON anywhere in the message; some models prepend prose.
|
||||
match = re.search(r"\{[^{}]*\}", text, flags=re.DOTALL)
|
||||
candidate = match.group(0) if match else text
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Fallback: yes/no parsing.
|
||||
lowered = text.strip().lower()
|
||||
if lowered.startswith("yes") or "correct: yes" in lowered or '"correct": true' in lowered:
|
||||
return True, "yes (parser_fallback)"
|
||||
if lowered.startswith("no") or "correct: no" in lowered or '"correct": false' in lowered:
|
||||
return False, "no (parser_fallback)"
|
||||
return False, f"unparseable_judge_response: {text[:200]}"
|
||||
correct = bool(data.get("correct"))
|
||||
rationale = str(data.get("rationale", "")).strip()[:280]
|
||||
return correct, rationale
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Combined grader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def grade_with_judge(
|
||||
*,
|
||||
pred: str,
|
||||
gold: str,
|
||||
question: str,
|
||||
judge: LlmJudge | None,
|
||||
) -> GradeResult:
|
||||
"""Grade one row: deterministic shortcut → optional LLM judge fallback."""
|
||||
|
||||
det = grade_deterministic(pred=pred, gold=gold)
|
||||
if det.correct or det.method != "lexical_miss":
|
||||
return det
|
||||
if judge is None:
|
||||
return det
|
||||
is_correct, rationale = await judge.judge(question=question, gold=gold, pred=pred)
|
||||
return GradeResult(
|
||||
correct=is_correct,
|
||||
f1=1.0 if is_correct else 0.0,
|
||||
method="llm_judge",
|
||||
normalised_pred=det.normalised_pred,
|
||||
normalised_gold=det.normalised_gold,
|
||||
judge_rationale=rationale,
|
||||
)
|
||||
|
||||
|
||||
async def grade_many(
|
||||
*,
|
||||
rows: Sequence[tuple[str, str, str, str]],
|
||||
judge: LlmJudge | None,
|
||||
) -> list[GradeResult]:
|
||||
"""Grade ``[(qid, question, gold, pred), ...]`` concurrently.
|
||||
|
||||
The judge already enforces its own concurrency cap; this just
|
||||
schedules everything via ``asyncio.gather``. ``qid`` is unused
|
||||
inside the grader but threaded through so callers can correlate
|
||||
results back to their rows.
|
||||
"""
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
coros = [
|
||||
grade_with_judge(pred=p, gold=g, question=q, judge=judge)
|
||||
for _qid, q, g, p in rows
|
||||
]
|
||||
return list(await asyncio.gather(*coros))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GradeResult",
|
||||
"JudgeConfig",
|
||||
"LlmJudge",
|
||||
"grade_deterministic",
|
||||
"grade_many",
|
||||
"grade_with_judge",
|
||||
]
|
||||
|
|
@ -0,0 +1,341 @@
|
|||
"""FRAMES ingestion: download → fetch Wikipedia → upload markdown.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Download ``test.tsv`` from ``hf://datasets/google/frames-benchmark``.
|
||||
2. Parse rows into ``FramesQuestion`` objects.
|
||||
3. Optionally cap to the first ``--max-questions N`` so a smoke run
|
||||
doesn't trigger a 1k-article fetch.
|
||||
4. Build the **deduplicated** set of Wikipedia URLs across the chosen
|
||||
sample (questions share many articles — Q1 and Q42 might both
|
||||
reference ``James_A._Garfield``).
|
||||
5. Fetch each unique article via ``WikiFetcher`` (polite 2 RPS) into
|
||||
``<bench_dir>/wiki/<title>.md``.
|
||||
6. Upload the resulting markdown files to SurfSense in batches with
|
||||
``use_vision_llm=False, processing_mode="basic"`` (text-only — no
|
||||
reason to pay vision LLM costs on Wikipedia plaintext).
|
||||
7. Persist a doc map at
|
||||
``<suite_data>/maps/frames_doc_map.jsonl`` with one row per question
|
||||
listing its ``document_ids`` (so the runner *could* scope retrieval
|
||||
if requested, though by default we don't — see ``runner.py``).
|
||||
|
||||
The doc map row shape:
|
||||
|
||||
{"qid": "Q000",
|
||||
"wiki_titles": ["President of the United States", "James Buchanan", ...],
|
||||
"document_ids": [123, 124, ...],
|
||||
"missing_titles": []}
|
||||
|
||||
We resolve titles → SurfSense document_ids via the post-upload
|
||||
``DocumentStatus.title`` field. SurfSense's title is the uploaded
|
||||
filename (without extension), so we round-trip via
|
||||
``cache_filename_for_title`` to match.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.clients.documents import (
|
||||
DocumentProcessingFailed,
|
||||
DocumentProcessingTimeout,
|
||||
)
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
from .dataset import (
|
||||
download_test_tsv,
|
||||
load_questions,
|
||||
write_questions_jsonl,
|
||||
)
|
||||
from .wiki_fetch import (
|
||||
WikiArticle,
|
||||
WikiFetcher,
|
||||
cache_filename_for_title,
|
||||
title_from_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _IngestStats:
|
||||
n_questions: int
|
||||
n_unique_urls: int
|
||||
n_fetched: int
|
||||
n_cached_hits: int
|
||||
n_missing: int
|
||||
n_uploaded: int
|
||||
n_existing: int
|
||||
bench_dir: Path
|
||||
map_path: Path
|
||||
|
||||
|
||||
async def _fetch_articles(
|
||||
fetcher: WikiFetcher,
|
||||
urls: list[str],
|
||||
) -> tuple[dict[str, WikiArticle], list[str]]:
|
||||
"""Fetch each URL serially (the WikiFetcher's rate-limiter serialises anyway).
|
||||
|
||||
Returns ``(url -> WikiArticle, missing_urls)``. Missing means
|
||||
Wikipedia reported the title doesn't exist, the URL was non-wiki,
|
||||
or the API returned an empty extract.
|
||||
"""
|
||||
|
||||
fetched: dict[str, WikiArticle] = {}
|
||||
missing: list[str] = []
|
||||
n_total = len(urls)
|
||||
for i, url in enumerate(urls, start=1):
|
||||
try:
|
||||
article = await fetcher.fetch(url)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FRAMES wiki fetch %s failed: %s", url, exc)
|
||||
missing.append(url)
|
||||
continue
|
||||
if article is None:
|
||||
missing.append(url)
|
||||
continue
|
||||
fetched[url] = article
|
||||
if i % 25 == 0 or i == n_total:
|
||||
logger.info(" ... fetched %d / %d Wikipedia articles", i, n_total)
|
||||
return fetched, missing
|
||||
|
||||
|
||||
async def _upload_markdowns(
|
||||
ctx: RunContext,
|
||||
articles: list[WikiArticle],
|
||||
*,
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
) -> dict[str, int]:
|
||||
"""Upload deduplicated markdown files. Returns ``filename -> document_id``.
|
||||
|
||||
SurfSense dedupes uploads on ``(filename, search_space_id)``, so
|
||||
re-running ingest after a crash is idempotent — duplicates land in
|
||||
``duplicate_document_ids`` and we still recover their ids via the
|
||||
status endpoint.
|
||||
"""
|
||||
|
||||
if not articles:
|
||||
return {}
|
||||
docs_client = ctx.documents_client()
|
||||
name_to_id: dict[str, int] = {}
|
||||
paths = [a.markdown_path for a in articles]
|
||||
for batch_start in range(0, len(paths), batch_size):
|
||||
batch = paths[batch_start : batch_start + batch_size]
|
||||
result = await docs_client.upload(
|
||||
files=batch,
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
|
||||
if result.document_ids:
|
||||
try:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=result.document_ids,
|
||||
timeout_s=900.0,
|
||||
)
|
||||
except (DocumentProcessingFailed, DocumentProcessingTimeout) as exc:
|
||||
logger.warning("FRAMES batch processing issue: %s", exc)
|
||||
if all_ids:
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=all_ids,
|
||||
)
|
||||
for s in statuses:
|
||||
# SurfSense stores the uploaded filename as ``title`` (no extension).
|
||||
stem = Path(s.title).stem if s.title.endswith(".md") else s.title
|
||||
name_to_id[stem] = s.document_id
|
||||
name_to_id[s.title] = s.document_id
|
||||
logger.info(
|
||||
"FRAMES upload batch %d-%d: %d new, %d duplicate",
|
||||
batch_start, batch_start + len(batch),
|
||||
len(result.document_ids), len(result.duplicate_document_ids),
|
||||
)
|
||||
return name_to_id
|
||||
|
||||
|
||||
def _resolve_question_doc_ids(
|
||||
questions: list[Any],
|
||||
fetched: dict[str, WikiArticle],
|
||||
name_to_id: dict[str, int],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""For each question, list the document_ids of its (fetched) wiki articles."""
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for q in questions:
|
||||
doc_ids: list[int] = []
|
||||
titles: list[str] = []
|
||||
missing: list[str] = []
|
||||
for url in q.wiki_urls:
|
||||
article = fetched.get(url)
|
||||
if article is None:
|
||||
missing.append(url)
|
||||
continue
|
||||
titles.append(article.title)
|
||||
stem = Path(cache_filename_for_title(article.title)).stem
|
||||
doc_id = name_to_id.get(stem) or name_to_id.get(article.markdown_path.name)
|
||||
if doc_id is not None and doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
rows.append({
|
||||
"qid": q.qid,
|
||||
"raw_index": q.raw_index,
|
||||
"n_wiki_urls": len(q.wiki_urls),
|
||||
"wiki_titles": titles,
|
||||
"document_ids": doc_ids,
|
||||
"missing_urls": missing,
|
||||
})
|
||||
return rows
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
max_questions: int | None = None,
|
||||
upload_batch_size: int = 16,
|
||||
skip_upload: bool = False,
|
||||
fetch_rate_limit_rps: float = 2.0,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
"""Ingest the FRAMES benchmark into the research suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_questions : int | None
|
||||
Cap on the number of FRAMES questions to materialise. ``None`` =
|
||||
all 824 (≈300+ unique articles). Smoke runs should pass 5-10.
|
||||
upload_batch_size : int
|
||||
Markdown files per ``/documents/fileupload`` call. Larger
|
||||
batches reduce round-trip overhead; smaller batches recover
|
||||
faster from individual processing failures.
|
||||
skip_upload : bool
|
||||
Fetch + cache Wikipedia articles locally but don't push to
|
||||
SurfSense. Useful for debugging the fetcher in isolation.
|
||||
fetch_rate_limit_rps : float
|
||||
Maximum requests-per-second to the Wikipedia API. Default 2.0
|
||||
is a polite ceiling; raise cautiously.
|
||||
settings : IngestSettings | None
|
||||
Override per-upload knobs. FRAMES defaults to text-only
|
||||
(no vision LLM, basic mode) — the corpus is plain wikitext.
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
wiki_cache = bench_dir / "wiki"
|
||||
wiki_cache.mkdir(parents=True, exist_ok=True)
|
||||
hf_cache = bench_dir / ".hf_cache"
|
||||
hf_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. Download + parse questions.
|
||||
tsv_path = download_test_tsv(hf_cache)
|
||||
questions = load_questions(tsv_path)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"FRAMES test.tsv contained no parseable rows; upstream may "
|
||||
"have changed schema."
|
||||
)
|
||||
logger.info("FRAMES: parsed %d questions from %s", len(questions), tsv_path.name)
|
||||
if max_questions is not None and max_questions > 0:
|
||||
questions = questions[:max_questions]
|
||||
logger.info("FRAMES: capped to first %d questions", len(questions))
|
||||
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
write_questions_jsonl(questions, questions_jsonl)
|
||||
|
||||
# 2. Build deduplicated URL set (preserving first-seen order).
|
||||
seen_urls: dict[str, None] = {}
|
||||
for q in questions:
|
||||
for url in q.wiki_urls:
|
||||
seen_urls.setdefault(url, None)
|
||||
unique_urls = list(seen_urls.keys())
|
||||
logger.info(
|
||||
"FRAMES: %d unique Wikipedia URLs across %d questions",
|
||||
len(unique_urls), len(questions),
|
||||
)
|
||||
|
||||
# 3. Fetch (with cache).
|
||||
fetcher = WikiFetcher(cache_dir=wiki_cache, rate_limit_rps=fetch_rate_limit_rps)
|
||||
n_cached = sum(
|
||||
1 for url in unique_urls
|
||||
if (wiki_cache / cache_filename_for_title(_safe_title(url))).exists()
|
||||
)
|
||||
fetched, missing_urls = await _fetch_articles(fetcher, unique_urls)
|
||||
logger.info(
|
||||
"FRAMES: fetched=%d, cache_hits=%d, missing=%d",
|
||||
len(fetched), n_cached, len(missing_urls),
|
||||
)
|
||||
|
||||
# 4. Upload to SurfSense (deduped by filename).
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("FRAMES: --skip-upload; skipping SurfSense ingestion")
|
||||
else:
|
||||
unique_articles = list({a.markdown_path: a for a in fetched.values()}.values())
|
||||
name_to_id = await _upload_markdowns(
|
||||
ctx,
|
||||
unique_articles,
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# 5. Persist per-question doc map.
|
||||
doc_rows = _resolve_question_doc_ids(questions, fetched, name_to_id)
|
||||
|
||||
map_path = ctx.maps_dir() / "frames_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for row in doc_rows:
|
||||
fh.write(json.dumps(row) + "\n")
|
||||
logger.info("Wrote FRAMES doc map to %s (%d rows)", map_path, len(doc_rows))
|
||||
|
||||
# 6. Update suite state.
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["frames"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
stats = _IngestStats(
|
||||
n_questions=len(questions),
|
||||
n_unique_urls=len(unique_urls),
|
||||
n_fetched=len(fetched),
|
||||
n_cached_hits=n_cached,
|
||||
n_missing=len(missing_urls),
|
||||
n_uploaded=len(name_to_id),
|
||||
n_existing=0,
|
||||
bench_dir=bench_dir,
|
||||
map_path=map_path,
|
||||
)
|
||||
logger.info("FRAMES ingest done: %s", stats)
|
||||
|
||||
|
||||
def _safe_title(url: str) -> str:
|
||||
"""Pre-cache title resolution; returns ``""`` on bad URL."""
|
||||
|
||||
try:
|
||||
return title_from_url(url)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
__all__ = ["run_ingest"]
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
"""FRAMES prompt templates.
|
||||
|
||||
Two templates: one for the bare-LLM arm (no retrieval), one for
|
||||
SurfSense (the agent retrieves; we mostly just instruct it on
|
||||
output format). Both arms must use byte-identical *content* for the
|
||||
question itself so the head-to-head is fair — the wrappers diverge
|
||||
only in framing.
|
||||
|
||||
Format expectations (mirrors the FRAMES paper, section 4):
|
||||
|
||||
* Short factual answer — names, dates, numbers, ordinals
|
||||
* No extra explanation in the final line; we anchor on
|
||||
``Answer: <text>`` for deterministic extraction
|
||||
* Free-text reasoning is *allowed* before the final ``Answer:`` line —
|
||||
multi-hop questions often benefit from it. We just don't grade it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
_BASE_INSTRUCTIONS = (
|
||||
"You are a careful question-answering assistant. The question may "
|
||||
"require combining facts from multiple sources, doing arithmetic, "
|
||||
"or reasoning about dates. Think step by step if needed, then give "
|
||||
"the final answer.\n\n"
|
||||
"Format your final line EXACTLY as:\n"
|
||||
"Answer: <short answer>\n\n"
|
||||
"The answer should be as short as possible — a name, a number, a "
|
||||
"date, a single phrase. Do not repeat the question. Do not include "
|
||||
"punctuation at the end unless it is part of the answer."
|
||||
)
|
||||
|
||||
|
||||
_BARE_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
Question: {question}
|
||||
"""
|
||||
|
||||
|
||||
_SURFSENSE_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
You have access to a Wikipedia knowledge base via retrieval. Use it
|
||||
to look up any facts you are not confident about. The corpus contains
|
||||
the Wikipedia articles needed to answer this question, but you must
|
||||
retrieve them yourself — they are not pre-selected.
|
||||
|
||||
Question: {question}
|
||||
"""
|
||||
|
||||
|
||||
def build_bare_prompt(question: str) -> str:
|
||||
"""Prompt for the no-retrieval baseline arm."""
|
||||
|
||||
return _BARE_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
question=question.strip(),
|
||||
)
|
||||
|
||||
|
||||
def build_surfsense_prompt(question: str) -> str:
|
||||
"""Prompt for the SurfSense arm (retrieval-augmented)."""
|
||||
|
||||
return _SURFSENSE_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
question=question.strip(),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["build_bare_prompt", "build_surfsense_prompt"]
|
||||
|
|
@ -0,0 +1,686 @@
|
|||
"""FRAMES runner — Bare LLM (no retrieval) vs SurfSense (multi-hop RAG).
|
||||
|
||||
Two arms run paired on every question in the sample:
|
||||
|
||||
1. ``BareLlmArm`` — OpenRouter chat completion with the question only.
|
||||
Reproduces the published "naive prompting" baseline (40.8% on
|
||||
Gemini-Pro-1.5).
|
||||
2. ``SurfSenseArm`` — POST ``/api/v1/new_chat`` with **no**
|
||||
``mentioned_document_ids`` so the agent retrieves over the entire
|
||||
ingested Wikipedia corpus. This is the "multi-step retrieval &
|
||||
reasoning" cell in the FRAMES paper.
|
||||
|
||||
Open-ended grading: deterministic shortcut + optional LLM-as-judge
|
||||
(``--no-judge`` to disable). Cost / latency / token aggregates are
|
||||
collected per arm. Paired stats (McNemar, bootstrap CI) for the
|
||||
accuracy delta. Per-reasoning-type breakdown to surface where one
|
||||
arm beats the other (numerical vs temporal vs multi-constraint, ...).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, BareLlmArm, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
is_settings_header,
|
||||
)
|
||||
from ....core.metrics.comparison import (
|
||||
bootstrap_delta_ci,
|
||||
mcnemar_test,
|
||||
paired_aggregate,
|
||||
)
|
||||
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci
|
||||
from ....core.parse.freeform_answer import extract_freeform_answer
|
||||
from ....core.providers.openrouter_chat import OpenRouterChatProvider
|
||||
from ....core.registry import ReportSection, RunArtifact, RunContext
|
||||
from ....core.scenarios import format_scenario_md
|
||||
from .grader import GradeResult, JudgeConfig, LlmJudge, grade_many
|
||||
from .prompt import build_bare_prompt, build_surfsense_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class FramesRunnerQuestion:
|
||||
qid: str
|
||||
raw_index: int
|
||||
question: str
|
||||
gold_answer: str
|
||||
reasoning_types: list[str]
|
||||
document_ids: list[int] # subset of corpus relevant to this Q (may be empty)
|
||||
n_wiki_urls: int
|
||||
missing_urls: list[str]
|
||||
|
||||
|
||||
def _load_doc_map(map_path: Path) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
|
||||
rows: dict[str, dict[str, Any]] = {}
|
||||
settings: dict[str, Any] = {}
|
||||
with map_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
settings = dict(row["__settings__"])
|
||||
continue
|
||||
rows[str(row["qid"])] = row
|
||||
return rows, settings
|
||||
|
||||
|
||||
def _load_questions(
|
||||
questions_jsonl: Path,
|
||||
doc_map: dict[str, dict[str, Any]],
|
||||
*,
|
||||
sample_n: int | None,
|
||||
reasoning_filter: str | None,
|
||||
) -> list[FramesRunnerQuestion]:
|
||||
out: list[FramesRunnerQuestion] = []
|
||||
with questions_jsonl.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
qid = str(row.get("qid") or "").strip()
|
||||
if not qid:
|
||||
continue
|
||||
map_row = doc_map.get(qid, {})
|
||||
reasoning = list(row.get("reasoning_types") or [])
|
||||
if reasoning_filter and reasoning_filter not in [r.lower() for r in reasoning]:
|
||||
continue
|
||||
out.append(FramesRunnerQuestion(
|
||||
qid=qid,
|
||||
raw_index=int(row.get("raw_index") or 0),
|
||||
question=str(row.get("question") or "").strip(),
|
||||
gold_answer=str(row.get("gold_answer") or "").strip(),
|
||||
reasoning_types=reasoning,
|
||||
document_ids=list(map_row.get("document_ids") or []),
|
||||
n_wiki_urls=int(map_row.get("n_wiki_urls") or 0),
|
||||
missing_urls=list(map_row.get("missing_urls") or []),
|
||||
))
|
||||
out.sort(key=lambda q: q.raw_index)
|
||||
if sample_n is not None and sample_n > 0:
|
||||
out = out[:sample_n]
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bounded concurrency helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _gather_with_limit(coros: Iterable, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(coro):
|
||||
async with sem:
|
||||
return await coro
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_DESCRIPTION = (
|
||||
"FRAMES (824 multi-hop Wikipedia questions, 5 reasoning types) — "
|
||||
"Bare LLM (no retrieval) vs SurfSense (multi-step RAG over the "
|
||||
"Wikipedia corpus). Tests cross-document retrieval + reasoning."
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
class FramesBenchmark:
|
||||
"""Multi-hop Wikipedia RAG vs naive prompting."""
|
||||
|
||||
suite: str = "research"
|
||||
name: str = "frames"
|
||||
headline: bool = True
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--n", dest="sample_n", type=int, default=None,
|
||||
help="Run only the first N questions after filters (default: all 824).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reasoning",
|
||||
dest="reasoning_filter",
|
||||
default=None,
|
||||
help=(
|
||||
"Filter to questions tagged with this reasoning type "
|
||||
"(e.g. 'numerical reasoning', 'temporal reasoning'). "
|
||||
"Case-insensitive substring against the upstream tags."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrency", type=int, default=4,
|
||||
help="Parallel question workers per arm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scope-mentions", dest="scope_mentions", action="store_true",
|
||||
help=(
|
||||
"SurfSense arm: scope retrieval to the per-question "
|
||||
"document_ids (oracle-retrieval upper bound). Default "
|
||||
"is full-corpus retrieval (the realistic FRAMES setting)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-output-tokens", type=int, default=512,
|
||||
help="Cap on completion length for both arms.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-judge", dest="no_judge", action="store_true",
|
||||
help=(
|
||||
"Disable LLM-as-judge fallback grading; use only the "
|
||||
"deterministic grader (faster but more pessimistic)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--judge-model",
|
||||
dest="judge_model",
|
||||
default="anthropic/claude-sonnet-4.5",
|
||||
help="OpenRouter slug for the LLM judge (default: claude-sonnet-4.5).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--judge-concurrency",
|
||||
dest="judge_concurrency",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Parallel judge calls (default: 4).",
|
||||
)
|
||||
# Ingest-only knobs.
|
||||
parser.add_argument(
|
||||
"--max-questions", dest="max_questions", type=int, default=None,
|
||||
help="(ingest only) cap on number of questions to materialise + ingest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-batch-size", dest="upload_batch_size", type=int, default=16,
|
||||
help="(ingest only) markdown files per fileupload call.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-upload", dest="skip_upload", action="store_true",
|
||||
help="(ingest only) cache wiki articles locally but don't push to SurfSense.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fetch-rps", dest="fetch_rate_limit_rps", type=float, default=2.0,
|
||||
help="(ingest only) max requests/second to the Wikipedia API.",
|
||||
)
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
max_questions=opts.get("max_questions"),
|
||||
upload_batch_size=int(opts.get("upload_batch_size") or 16),
|
||||
skip_upload=bool(opts.get("skip_upload", False)),
|
||||
fetch_rate_limit_rps=float(opts.get("fetch_rate_limit_rps") or 2.0),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
sample_n = opts.get("sample_n")
|
||||
reasoning_filter = opts.get("reasoning_filter")
|
||||
if reasoning_filter:
|
||||
reasoning_filter = reasoning_filter.strip().lower() or None
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
scope_mentions = bool(opts.get("scope_mentions"))
|
||||
max_output_tokens = int(opts.get("max_output_tokens") or 512)
|
||||
no_judge = bool(opts.get("no_judge"))
|
||||
judge_model = str(opts.get("judge_model") or "anthropic/claude-sonnet-4.5")
|
||||
judge_concurrency = int(opts.get("judge_concurrency") or 4)
|
||||
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
map_path = ctx.maps_dir() / "frames_doc_map.jsonl"
|
||||
if not questions_jsonl.exists() or not map_path.exists():
|
||||
raise RuntimeError(
|
||||
"FRAMES not ingested for this suite. Run "
|
||||
"`python -m surfsense_evals ingest research frames` first."
|
||||
)
|
||||
|
||||
doc_map, ingest_settings = _load_doc_map(map_path)
|
||||
questions = _load_questions(
|
||||
questions_jsonl, doc_map,
|
||||
sample_n=sample_n,
|
||||
reasoning_filter=reasoning_filter,
|
||||
)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"No FRAMES questions matched the filters; broaden --reasoning/--n."
|
||||
)
|
||||
logger.info("FRAMES: scheduled %d questions", len(questions))
|
||||
|
||||
api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError(
|
||||
"OPENROUTER_API_KEY env var is required for the bare-LLM arm."
|
||||
)
|
||||
|
||||
bare_provider = OpenRouterChatProvider(
|
||||
api_key=api_key,
|
||||
base_url=ctx.config.openrouter_base_url,
|
||||
model=ctx.native_arm_model,
|
||||
)
|
||||
bare_arm = BareLlmArm(
|
||||
provider=bare_provider,
|
||||
max_output_tokens=max_output_tokens,
|
||||
)
|
||||
surf_arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
judge: LlmJudge | None = None
|
||||
if not no_judge:
|
||||
judge = LlmJudge(config=JudgeConfig(
|
||||
api_key=api_key,
|
||||
model=judge_model,
|
||||
base_url=ctx.config.openrouter_base_url,
|
||||
concurrency=judge_concurrency,
|
||||
))
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
|
||||
async def _bare_one(q: FramesRunnerQuestion) -> ArmResult:
|
||||
return await bare_arm.answer(_make_bare_request(q, max_output_tokens))
|
||||
|
||||
async def _surf_one(q: FramesRunnerQuestion) -> ArmResult:
|
||||
return await surf_arm.answer(
|
||||
_make_surfsense_request(q, scope_mentions=scope_mentions)
|
||||
)
|
||||
|
||||
bare_results, surf_results = await asyncio.gather(
|
||||
_gather_with_limit((_bare_one(q) for q in questions), concurrency=concurrency),
|
||||
_gather_with_limit((_surf_one(q) for q in questions), concurrency=concurrency),
|
||||
)
|
||||
|
||||
bare_grades = await _grade_results(questions, bare_results, judge=judge)
|
||||
surf_grades = await _grade_results(questions, surf_results, judge=judge)
|
||||
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, b_res, s_res, b_g, s_g in zip(
|
||||
questions, bare_results, surf_results, bare_grades, surf_grades, strict=False
|
||||
):
|
||||
meta = {
|
||||
"qid": q.qid,
|
||||
"raw_index": q.raw_index,
|
||||
"reasoning_types": q.reasoning_types,
|
||||
"n_wiki_urls": q.n_wiki_urls,
|
||||
"n_resolved_doc_ids": len(q.document_ids),
|
||||
"n_missing_urls": len(q.missing_urls),
|
||||
"gold": q.gold_answer,
|
||||
}
|
||||
fh.write(json.dumps({
|
||||
**meta,
|
||||
**b_res.to_jsonl(),
|
||||
"graded": b_g.to_dict(),
|
||||
}) + "\n")
|
||||
fh.write(json.dumps({
|
||||
**meta,
|
||||
**s_res.to_jsonl(),
|
||||
"graded": s_g.to_dict(),
|
||||
}) + "\n")
|
||||
|
||||
metrics = _compute_metrics(questions, bare_results, surf_results, bare_grades, surf_grades)
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_questions": len(questions),
|
||||
"concurrency": concurrency,
|
||||
"reasoning_filter": reasoning_filter,
|
||||
"scope_mentions": scope_mentions,
|
||||
"no_judge": no_judge,
|
||||
"judge_model": judge_model if not no_judge else None,
|
||||
"scenario": ctx.scenario,
|
||||
"provider_model": ctx.provider_model,
|
||||
"native_arm_model": ctx.native_arm_model,
|
||||
"vision_provider_model": ctx.vision_provider_model,
|
||||
"agent_llm_id": ctx.agent_llm_id,
|
||||
"ingest_settings": ingest_settings,
|
||||
"bare_arm_label": "bare_llm",
|
||||
},
|
||||
)
|
||||
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps({
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
}, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="FRAMES — Bare LLM vs SurfSense (multi-hop Wikipedia RAG)",
|
||||
headline=True,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
m = latest.metrics
|
||||
bare = m.get("bare", {})
|
||||
surf = m.get("surfsense", {})
|
||||
delta = m.get("delta", {})
|
||||
per_reasoning = m.get("per_reasoning", {})
|
||||
extra = latest.extra
|
||||
|
||||
body_lines: list[str] = []
|
||||
body_lines.append(
|
||||
f"- Sample size: {extra.get('n_questions', '?')} questions "
|
||||
f"(reasoning filter: `{extra.get('reasoning_filter') or 'none'}`, "
|
||||
f"scope-mentions: `{extra.get('scope_mentions', False)}`, "
|
||||
f"judge: `{extra.get('judge_model') or 'deterministic-only'}`)."
|
||||
)
|
||||
body_lines.append(format_scenario_md(extra))
|
||||
body_lines.append(format_ingest_settings_md(extra.get("ingest_settings")))
|
||||
body_lines.append(
|
||||
"- Bare LLM arm (OpenRouter chat, no retrieval, "
|
||||
f"`{extra.get('native_arm_model') or extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(bare, indent=" "))
|
||||
body_lines.append(
|
||||
"- SurfSense arm (`POST /api/v1/new_chat`, multi-step RAG, "
|
||||
f"`{extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(surf, indent=" "))
|
||||
body_lines.append("- Delta (paired):")
|
||||
body_lines.append(
|
||||
f" - Accuracy: SurfSense {_pp(delta.get('accuracy_pp'))} pp "
|
||||
f"(McNemar p={_fmt(delta.get('mcnemar_p_value'), 4)}, "
|
||||
f"method={delta.get('mcnemar_method')})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Bootstrap 95% CI on accuracy delta: "
|
||||
f"[{_pp(delta.get('bootstrap_ci_low'))}pp, {_pp(delta.get('bootstrap_ci_high'))}pp]"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Cost / question: bare ${_dollars(bare.get('cost_micros_mean'))}, "
|
||||
f"surfsense ${_dollars(surf.get('cost_micros_mean'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('cost_micros_pct'))})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Latency p50: bare {_ms_to_s(bare.get('latency_ms_median'))}, "
|
||||
f"surfsense {_ms_to_s(surf.get('latency_ms_median'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('latency_ms_pct'))})"
|
||||
)
|
||||
if per_reasoning:
|
||||
body_lines.append("- Per-reasoning-type split (accuracy delta in pp):")
|
||||
for tag, vals in sorted(per_reasoning.items()):
|
||||
body_lines.append(
|
||||
f" - {tag}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
|
||||
f"(n={vals.get('n')}, bare acc={vals.get('bare_accuracy', 0)*100:.1f}%, "
|
||||
f"surf acc={vals.get('surfsense_accuracy', 0)*100:.1f}%)"
|
||||
)
|
||||
|
||||
return ReportSection(
|
||||
title="FRAMES — Bare LLM vs SurfSense (multi-hop Wikipedia RAG)",
|
||||
headline=True,
|
||||
body_md="\n".join(body_lines),
|
||||
body_json=m,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-question helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bare_request(q: FramesRunnerQuestion, max_tokens: int) -> ArmRequest:
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=build_bare_prompt(q.question),
|
||||
options={"max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
|
||||
def _make_surfsense_request(q: FramesRunnerQuestion, *, scope_mentions: bool) -> ArmRequest:
|
||||
mentions: list[int] | None = None
|
||||
if scope_mentions and q.document_ids:
|
||||
mentions = list(q.document_ids)
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=build_surfsense_prompt(q.question),
|
||||
mentioned_document_ids=mentions,
|
||||
)
|
||||
|
||||
|
||||
async def _grade_results(
|
||||
questions: list[FramesRunnerQuestion],
|
||||
results: list[ArmResult],
|
||||
*,
|
||||
judge: LlmJudge | None,
|
||||
) -> list[GradeResult]:
|
||||
rows: list[tuple[str, str, str, str]] = []
|
||||
for q, r in zip(questions, results, strict=False):
|
||||
pred = extract_freeform_answer(r.raw_text or "")
|
||||
rows.append((q.qid, q.question, q.gold_answer, pred))
|
||||
return await grade_many(rows=rows, judge=judge)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics aggregation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_metrics(
|
||||
questions: list[FramesRunnerQuestion],
|
||||
bare_results: list[ArmResult],
|
||||
surf_results: list[ArmResult],
|
||||
bare_grades: list[GradeResult],
|
||||
surf_grades: list[GradeResult],
|
||||
) -> dict[str, Any]:
|
||||
bare_correct = [g.correct for g in bare_grades]
|
||||
surf_correct = [g.correct for g in surf_grades]
|
||||
|
||||
bare_costs = [float(r.cost_micros) for r in bare_results]
|
||||
surf_costs = [float(r.cost_micros) for r in surf_results]
|
||||
bare_latencies = [float(r.latency_ms) for r in bare_results]
|
||||
surf_latencies = [float(r.latency_ms) for r in surf_results]
|
||||
bare_in_tokens = [float(r.input_tokens) for r in bare_results]
|
||||
bare_out_tokens = [float(r.output_tokens) for r in bare_results]
|
||||
|
||||
bare_acc = accuracy_with_wilson_ci(sum(bare_correct), len(bare_correct))
|
||||
surf_acc = accuracy_with_wilson_ci(sum(surf_correct), len(surf_correct))
|
||||
mc = mcnemar_test(bare_correct, surf_correct)
|
||||
boot = bootstrap_delta_ci(bare_correct, surf_correct, n_resamples=2000)
|
||||
|
||||
bare_cost_agg = paired_aggregate(bare_costs)
|
||||
surf_cost_agg = paired_aggregate(surf_costs)
|
||||
bare_latency_agg = paired_aggregate(bare_latencies)
|
||||
surf_latency_agg = paired_aggregate(surf_latencies)
|
||||
cost_pct = _safe_pct(surf_cost_agg.mean, bare_cost_agg.mean)
|
||||
latency_pct = _safe_pct(surf_latency_agg.median, bare_latency_agg.median)
|
||||
|
||||
# Per-reasoning-type breakdown. Each question may carry multiple
|
||||
# reasoning tags; we count it under each tag (so totals don't
|
||||
# equal len(questions) — the reader is expected to look at the
|
||||
# per-tag ``n``).
|
||||
per_reasoning_pairs: dict[str, list[tuple[bool, bool]]] = {}
|
||||
for q, b_ok, s_ok in zip(questions, bare_correct, surf_correct, strict=False):
|
||||
tags = q.reasoning_types or ["(untagged)"]
|
||||
for tag in tags:
|
||||
per_reasoning_pairs.setdefault(tag, []).append((b_ok, s_ok))
|
||||
|
||||
per_reasoning: dict[str, dict[str, Any]] = {}
|
||||
for tag, pairs in per_reasoning_pairs.items():
|
||||
b_correct = [a for a, _ in pairs]
|
||||
s_correct = [b for _, b in pairs]
|
||||
per_reasoning[tag] = {
|
||||
"n": len(pairs),
|
||||
"bare_accuracy": (sum(b_correct) / len(pairs)) if pairs else 0.0,
|
||||
"surfsense_accuracy": (sum(s_correct) / len(pairs)) if pairs else 0.0,
|
||||
"delta_accuracy_pp": (
|
||||
100.0 * (sum(s_correct) - sum(b_correct)) / len(pairs)
|
||||
if pairs else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
grader_methods = {
|
||||
"bare": _count_methods(bare_grades),
|
||||
"surfsense": _count_methods(surf_grades),
|
||||
}
|
||||
|
||||
return {
|
||||
"bare": {
|
||||
**bare_acc.to_dict(),
|
||||
"cost_micros_mean": bare_cost_agg.mean,
|
||||
"cost_micros_median": bare_cost_agg.median,
|
||||
"latency_ms_mean": bare_latency_agg.mean,
|
||||
"latency_ms_median": bare_latency_agg.median,
|
||||
"latency_ms_p95": bare_latency_agg.p95,
|
||||
"input_tokens_mean": (sum(bare_in_tokens) / len(bare_in_tokens)) if bare_in_tokens else 0.0,
|
||||
"output_tokens_mean": (sum(bare_out_tokens) / len(bare_out_tokens)) if bare_out_tokens else 0.0,
|
||||
},
|
||||
"surfsense": {
|
||||
**surf_acc.to_dict(),
|
||||
"cost_micros_mean": surf_cost_agg.mean,
|
||||
"cost_micros_median": surf_cost_agg.median,
|
||||
"latency_ms_mean": surf_latency_agg.mean,
|
||||
"latency_ms_median": surf_latency_agg.median,
|
||||
"latency_ms_p95": surf_latency_agg.p95,
|
||||
},
|
||||
"delta": {
|
||||
"accuracy_pp": 100.0 * (surf_acc.accuracy - bare_acc.accuracy),
|
||||
"mcnemar_p_value": mc.p_value,
|
||||
"mcnemar_method": mc.method,
|
||||
"mcnemar_b_bare_only": mc.b,
|
||||
"mcnemar_c_surfsense_only": mc.c,
|
||||
"bootstrap_ci_low": 100.0 * boot.ci_low,
|
||||
"bootstrap_ci_high": 100.0 * boot.ci_high,
|
||||
"cost_micros_pct": cost_pct,
|
||||
"latency_ms_pct": latency_pct,
|
||||
},
|
||||
"per_reasoning": per_reasoning,
|
||||
"grader_methods": grader_methods,
|
||||
}
|
||||
|
||||
|
||||
def _count_methods(grades: list[GradeResult]) -> dict[str, int]:
|
||||
out: dict[str, int] = {}
|
||||
for g in grades:
|
||||
out[g.method] = out.get(g.method, 0) + 1
|
||||
return out
|
||||
|
||||
|
||||
def _safe_pct(numerator: float, denominator: float) -> float | None:
|
||||
if denominator == 0:
|
||||
return None
|
||||
return 100.0 * (numerator - denominator) / denominator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tiny formatting helpers used by report_section
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _arm_summary_lines(d: dict[str, Any], *, indent: str) -> str:
|
||||
if not d:
|
||||
return f"{indent}(no data)"
|
||||
acc = d.get("accuracy", 0.0)
|
||||
low = d.get("ci_low", 0.0)
|
||||
high = d.get("ci_high", 0.0)
|
||||
lines = [
|
||||
f"{indent}- Accuracy: {acc * 100:.1f}% (Wilson 95% CI: {low * 100:.1f}% – {high * 100:.1f}%)",
|
||||
f"{indent}- Cost / question: ${_dollars(d.get('cost_micros_mean'))} (mean), "
|
||||
f"${_dollars(d.get('cost_micros_median'))} (median)",
|
||||
f"{indent}- Latency: p50 {_ms_to_s(d.get('latency_ms_median'))}, "
|
||||
f"p95 {_ms_to_s(d.get('latency_ms_p95'))}",
|
||||
]
|
||||
if "input_tokens_mean" in d:
|
||||
lines.append(
|
||||
f"{indent}- Mean tokens / question: in {d.get('input_tokens_mean', 0):.0f}, "
|
||||
f"out {d.get('output_tokens_mean', 0):.0f}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _dollars(micros: Any) -> str:
|
||||
if micros is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{(float(micros) / 1_000_000):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _ms_to_s(ms: Any) -> str:
|
||||
if ms is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(ms) / 1000:.1f}s"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pp(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.1f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pct_change(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.0f}%"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _fmt(value: Any, ndigits: int) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):.{ndigits}f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
__all__ = ["FramesBenchmark", "FramesRunnerQuestion"]
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
"""Wikipedia article fetcher → plain-text markdown, with disk cache.
|
||||
|
||||
We hit the MediaWiki action API for *plain text* extracts:
|
||||
|
||||
GET https://en.wikipedia.org/w/api.php
|
||||
?action=query&prop=extracts&explaintext=true
|
||||
&redirects=1&titles=<Title>&format=json&formatversion=2
|
||||
|
||||
This avoids HTML→markdown conversion (and its many edge cases). The
|
||||
``explaintext=true`` mode strips infoboxes / templates / wikilinks
|
||||
and returns clean section-headered prose, which is exactly what we
|
||||
want SurfSense to chunk + embed. We prepend ``# <Title>\n\n`` so the
|
||||
markdown has a visible H1 (helps SurfSense's chunker preserve doc
|
||||
identity at the top of the first chunk).
|
||||
|
||||
Caching: every fetched article lands in
|
||||
``<bench_dir>/wiki/<sanitised-title>.md`` and is reused on subsequent
|
||||
runs. The cache key is the URL-decoded title (e.g.
|
||||
``Charlotte_Brontë`` regardless of source URL casing or
|
||||
percent-encoding).
|
||||
|
||||
Politeness: 2 RPS rate limit + a descriptive User-Agent (Wikimedia
|
||||
asks for one). We don't parallelise above 2 RPS — this is a courtesy
|
||||
to Wikipedia and only ~300 articles for the n=100 sample.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
WIKI_API = "https://en.wikipedia.org/w/api.php"
|
||||
USER_AGENT = (
|
||||
"SurfSense-Evals/0.1 (https://github.com/MODSetter/SurfSense; "
|
||||
"research-benchmark fetch; respects 2 RPS rate limit)"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WikiArticle:
|
||||
"""One fetched article + metadata."""
|
||||
|
||||
title: str # canonical title returned by MW (post-redirect)
|
||||
source_url: str # the URL we were asked to fetch
|
||||
markdown_path: Path # where the cached body lives on disk
|
||||
n_chars: int # length of the body (post-prepend H1)
|
||||
redirected_from: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Title <-> URL helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_WIKI_PATH_RE = re.compile(r"^/wiki/(?P<title>[^?#]+)$")
|
||||
|
||||
|
||||
def title_from_url(url: str) -> str:
|
||||
"""Pull the page title out of a wiki URL.
|
||||
|
||||
``https://en.wikipedia.org/wiki/Charlotte_Bront%C3%AB`` → ``Charlotte Brontë``.
|
||||
Spaces are preserved (the API accepts spaces and underscores
|
||||
interchangeably; we use spaces to keep cache filenames human-readable).
|
||||
"""
|
||||
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if parsed.netloc and "wikipedia.org" not in parsed.netloc:
|
||||
raise ValueError(f"Not a Wikipedia URL: {url!r}")
|
||||
match = _WIKI_PATH_RE.match(parsed.path)
|
||||
if not match:
|
||||
raise ValueError(f"Unrecognised wiki path: {parsed.path!r}")
|
||||
raw_title = urllib.parse.unquote(match.group("title"))
|
||||
# MW treats underscores and spaces as equivalent; spaces are friendlier.
|
||||
return raw_title.replace("_", " ").strip()
|
||||
|
||||
|
||||
_FILENAME_SAFE = re.compile(r"[^A-Za-z0-9._\- ]")
|
||||
|
||||
|
||||
def cache_filename_for_title(title: str) -> str:
|
||||
"""Map a title to a filesystem-safe filename.
|
||||
|
||||
Replaces every non-(alnum / ``._- `` / space) character with ``_``.
|
||||
Title collisions are rare (FRAMES only has English Wikipedia titles)
|
||||
and a final ``hash(title)[:8]`` would obscure the otherwise-readable
|
||||
filenames; we accept the (vanishingly small) collision risk.
|
||||
"""
|
||||
|
||||
safe = _FILENAME_SAFE.sub("_", title)
|
||||
safe = safe.strip().replace(" ", "_")
|
||||
return f"{safe}.md"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async fetcher with rate limiting + retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WikiFetcher:
|
||||
"""Polite fetch + disk cache + redirect handling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cache_dir: Path,
|
||||
rate_limit_rps: float = 2.0,
|
||||
timeout_s: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
self._cache_dir = Path(cache_dir)
|
||||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._min_interval = 1.0 / max(rate_limit_rps, 0.1)
|
||||
self._last_request_at = 0.0
|
||||
self._rate_lock = asyncio.Lock()
|
||||
self._timeout = httpx.Timeout(timeout_s, connect=10.0)
|
||||
self._max_retries = max_retries
|
||||
|
||||
async def _throttle(self) -> None:
|
||||
async with self._rate_lock:
|
||||
now = asyncio.get_event_loop().time()
|
||||
wait = self._last_request_at + self._min_interval - now
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
self._last_request_at = asyncio.get_event_loop().time()
|
||||
|
||||
async def fetch(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
http: httpx.AsyncClient | None = None,
|
||||
) -> WikiArticle | None:
|
||||
"""Fetch one article. Returns ``None`` only if MW reports the title is missing.
|
||||
|
||||
Raises on transport errors after retries. Caller decides
|
||||
whether to abort the whole ingest or continue with the
|
||||
successfully-fetched subset.
|
||||
"""
|
||||
|
||||
try:
|
||||
title = title_from_url(url)
|
||||
except ValueError as exc:
|
||||
logger.warning("Skipping non-wiki URL %s: %s", url, exc)
|
||||
return None
|
||||
|
||||
cache_path = self._cache_dir / cache_filename_for_title(title)
|
||||
if cache_path.exists() and cache_path.stat().st_size > 0:
|
||||
return WikiArticle(
|
||||
title=title,
|
||||
source_url=url,
|
||||
markdown_path=cache_path,
|
||||
n_chars=cache_path.stat().st_size,
|
||||
)
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
await self._throttle()
|
||||
payload = await self._fetch_extract(title, http=http)
|
||||
break
|
||||
except (httpx.HTTPError, RuntimeError) as exc:
|
||||
last_exc = exc
|
||||
wait = 1.0 * (2 ** attempt)
|
||||
logger.warning(
|
||||
"wiki fetch %r attempt %d failed: %s; retry in %.1fs",
|
||||
title, attempt + 1, exc, wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
page = payload.get("page") or {}
|
||||
if not page or page.get("missing"):
|
||||
logger.warning("Wikipedia reports missing page for %r (url=%s)", title, url)
|
||||
return None
|
||||
|
||||
canonical_title = str(page.get("title") or title).strip()
|
||||
body = str(page.get("extract") or "").strip()
|
||||
if not body:
|
||||
logger.warning("Wikipedia returned empty extract for %r", title)
|
||||
return None
|
||||
markdown = f"# {canonical_title}\n\n{body}\n"
|
||||
cache_path.write_text(markdown, encoding="utf-8")
|
||||
return WikiArticle(
|
||||
title=canonical_title,
|
||||
source_url=url,
|
||||
markdown_path=cache_path,
|
||||
n_chars=len(markdown),
|
||||
redirected_from=title if canonical_title != title else None,
|
||||
)
|
||||
|
||||
async def _fetch_extract(
|
||||
self,
|
||||
title: str,
|
||||
*,
|
||||
http: httpx.AsyncClient | None,
|
||||
) -> dict:
|
||||
"""One MW API call. Returns ``{'page': {...}}`` (formatversion=2)."""
|
||||
|
||||
params = {
|
||||
"action": "query",
|
||||
"prop": "extracts",
|
||||
"explaintext": "true",
|
||||
"redirects": "1",
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
"titles": title,
|
||||
}
|
||||
headers = {"User-Agent": USER_AGENT, "Accept": "application/json"}
|
||||
if http is not None:
|
||||
response = await http.get(WIKI_API, params=params, headers=headers, timeout=self._timeout)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.get(WIKI_API, params=params, headers=headers, timeout=self._timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if "error" in data:
|
||||
raise RuntimeError(f"MediaWiki API error: {data['error']!r}")
|
||||
pages = (data.get("query") or {}).get("pages") or []
|
||||
if not pages:
|
||||
return {"page": {}}
|
||||
return {"page": pages[0]}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WIKI_API",
|
||||
"USER_AGENT",
|
||||
"WikiArticle",
|
||||
"WikiFetcher",
|
||||
"cache_filename_for_title",
|
||||
"title_from_url",
|
||||
]
|
||||
1
surfsense_evals/tests/__init__.py
Normal file
1
surfsense_evals/tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
34
surfsense_evals/tests/conftest.py
Normal file
34
surfsense_evals/tests/conftest.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""Shared pytest fixtures for surfsense-evals."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.core.config import Config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_env(monkeypatch, tmp_path: Path) -> Path:
|
||||
"""Isolate env vars + filesystem state per test.
|
||||
|
||||
Wipes every ``SURFSENSE_*`` / ``OPENROUTER_*`` / ``EVAL_*`` var so a
|
||||
test that wants a specific credential mode can ``monkeypatch.setenv``
|
||||
just what it needs without leakage from the caller's shell.
|
||||
"""
|
||||
|
||||
for key in list(os.environ):
|
||||
if key.startswith(("SURFSENSE_", "OPENROUTER_", "EVAL_")):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
monkeypatch.setenv("EVAL_DATA_DIR", str(tmp_path / "data"))
|
||||
monkeypatch.setenv("EVAL_REPORTS_DIR", str(tmp_path / "reports"))
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_config(tmp_env: Path) -> Config: # noqa: ARG001
|
||||
from surfsense_evals.core.config import load_config
|
||||
|
||||
return load_config()
|
||||
1
surfsense_evals/tests/core/__init__.py
Normal file
1
surfsense_evals/tests/core/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
95
surfsense_evals/tests/core/test_auth.py
Normal file
95
surfsense_evals/tests/core/test_auth.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Auth credential resolution + 401 refresh hook."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from surfsense_evals.core.auth import (
|
||||
CredentialError,
|
||||
acquire_token,
|
||||
client_with_auth,
|
||||
)
|
||||
from surfsense_evals.core.config import Config
|
||||
|
||||
|
||||
def _make_config(**overrides) -> Config:
|
||||
base = {
|
||||
"surfsense_api_base": "http://test",
|
||||
"openrouter_api_key": None,
|
||||
"openrouter_base_url": "https://openrouter.ai/api/v1",
|
||||
"surfsense_jwt": None,
|
||||
"surfsense_refresh_token": None,
|
||||
"surfsense_user_email": None,
|
||||
"surfsense_user_password": None,
|
||||
"data_dir": None,
|
||||
"reports_dir": None,
|
||||
}
|
||||
base.update(overrides)
|
||||
# Path objects required by Config; tests don't touch the FS.
|
||||
from pathlib import Path
|
||||
|
||||
base["data_dir"] = base["data_dir"] or Path("/tmp/eval_test_data")
|
||||
base["reports_dir"] = base["reports_dir"] or Path("/tmp/eval_test_reports")
|
||||
return Config(**base)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_token_jwt_mode_short_circuits():
|
||||
config = _make_config(surfsense_jwt="abc", surfsense_refresh_token="ref")
|
||||
bundle = await acquire_token(config)
|
||||
assert bundle.access_token == "abc"
|
||||
assert bundle.refresh_token == "ref"
|
||||
assert bundle.mode == "jwt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_acquire_token_local_mode_posts_form():
|
||||
respx.post("http://test/auth/jwt/login").mock(
|
||||
return_value=httpx.Response(
|
||||
200, json={"access_token": "T", "refresh_token": "R", "token_type": "bearer"}
|
||||
)
|
||||
)
|
||||
config = _make_config(
|
||||
surfsense_user_email="u@example.com", surfsense_user_password="pw"
|
||||
)
|
||||
bundle = await acquire_token(config)
|
||||
assert bundle.access_token == "T"
|
||||
assert bundle.refresh_token == "R"
|
||||
assert bundle.mode == "local"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_token_no_credentials():
|
||||
config = _make_config()
|
||||
with pytest.raises(CredentialError) as exc:
|
||||
await acquire_token(config)
|
||||
assert "SURFSENSE_USER_EMAIL" in str(exc.value)
|
||||
assert "SURFSENSE_JWT" in str(exc.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_client_with_auth_refreshes_on_401():
|
||||
config = _make_config(surfsense_jwt="old", surfsense_refresh_token="ref")
|
||||
bundle = await acquire_token(config)
|
||||
|
||||
respx.post("http://test/auth/jwt/refresh").mock(
|
||||
return_value=httpx.Response(200, json={"access_token": "new", "refresh_token": "ref2"})
|
||||
)
|
||||
# First call returns 401; the retry (post-refresh) returns 200.
|
||||
respx.get("http://test/api/v1/searchspaces").mock(
|
||||
side_effect=[
|
||||
httpx.Response(401, json={"detail": "expired"}),
|
||||
httpx.Response(200, json=[]),
|
||||
]
|
||||
)
|
||||
|
||||
async with client_with_auth(config, bundle) as client:
|
||||
response = await client.get("http://test/api/v1/searchspaces")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert bundle.access_token == "new"
|
||||
assert bundle.refresh_token == "ref2"
|
||||
262
surfsense_evals/tests/core/test_clients.py
Normal file
262
surfsense_evals/tests/core/test_clients.py
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
"""respx-mocked tests for the SurfSense HTTP clients."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from surfsense_evals.core.clients import (
|
||||
DocumentsClient,
|
||||
NewChatClient,
|
||||
SearchSpaceClient,
|
||||
)
|
||||
from surfsense_evals.core.clients.new_chat import ThreadBusyError
|
||||
|
||||
_BASE = "http://test"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http() -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(base_url=_BASE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchSpaceClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_create_search_space_returns_row(respx_mock, http):
|
||||
respx_mock.post("/api/v1/searchspaces").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"id": 99,
|
||||
"name": "eval-medical-2026",
|
||||
"description": None,
|
||||
"user_id": "user-x",
|
||||
"citations_enabled": True,
|
||||
"qna_custom_instructions": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
client = SearchSpaceClient(http, _BASE)
|
||||
row = await client.create("eval-medical-2026")
|
||||
assert row.id == 99
|
||||
assert row.name == "eval-medical-2026"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_delete_search_space_idempotent_on_404(respx_mock, http):
|
||||
respx_mock.delete("/api/v1/searchspaces/42").mock(
|
||||
return_value=httpx.Response(404, json={"detail": "gone"})
|
||||
)
|
||||
client = SearchSpaceClient(http, _BASE)
|
||||
await client.delete(42) # must not raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_set_llm_preferences_partial_update(respx_mock, http):
|
||||
route = respx_mock.put("/api/v1/search-spaces/42/llm-preferences").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"agent_llm_id": -10042,
|
||||
"document_summary_llm_id": None,
|
||||
"image_generation_config_id": None,
|
||||
"vision_llm_config_id": None,
|
||||
"agent_llm": {
|
||||
"id": -10042,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-sonnet-4.5",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
client = SearchSpaceClient(http, _BASE)
|
||||
prefs = await client.set_llm_preferences(42, agent_llm_id=-10042)
|
||||
assert prefs.agent_llm_id == -10042
|
||||
assert prefs.agent_llm["provider"] == "OPENROUTER"
|
||||
sent_body = json.loads(route.calls[-1].request.content)
|
||||
assert sent_body == {"agent_llm_id": -10042}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentsClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_documents_status_parses_state(respx_mock, http):
|
||||
respx_mock.get("/api/v1/documents/status").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"items": [
|
||||
{"id": 1, "title": "a.pdf", "document_type": "FILE",
|
||||
"status": {"state": "ready", "reason": None}},
|
||||
{"id": 2, "title": "b.pdf", "document_type": "FILE",
|
||||
"status": {"state": "failed", "reason": "ETL boom"}},
|
||||
]
|
||||
},
|
||||
)
|
||||
)
|
||||
client = DocumentsClient(http, _BASE)
|
||||
statuses = await client.get_status(search_space_id=1, document_ids=[1, 2])
|
||||
assert {s.document_id for s in statuses} == {1, 2}
|
||||
assert {s.is_ready for s in statuses} == {True, False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_documents_upload_returns_payload(respx_mock, http, tmp_path: Path):
|
||||
f1 = tmp_path / "a.pdf"
|
||||
f1.write_bytes(b"%PDF-1.4 small")
|
||||
respx_mock.post("/api/v1/documents/fileupload").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"message": "Files uploaded",
|
||||
"document_ids": [101],
|
||||
"duplicate_document_ids": [],
|
||||
"total_files": 1,
|
||||
"pending_files": 1,
|
||||
"skipped_duplicates": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
client = DocumentsClient(http, _BASE)
|
||||
result = await client.upload(files=[f1], search_space_id=7)
|
||||
assert result.document_ids == [101]
|
||||
assert result.pending_files == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_documents_list_chunks_paginated(respx_mock, http):
|
||||
respx_mock.get("/api/v1/documents/5/chunks").mock(
|
||||
side_effect=[
|
||||
httpx.Response(200, json={
|
||||
"items": [{"id": 1, "content": "a"}, {"id": 2, "content": "b"}],
|
||||
"total": 3, "page": 0, "page_size": 2, "has_more": True,
|
||||
}),
|
||||
httpx.Response(200, json={
|
||||
"items": [{"id": 3, "content": "c"}],
|
||||
"total": 3, "page": 1, "page_size": 2, "has_more": False,
|
||||
}),
|
||||
]
|
||||
)
|
||||
client = DocumentsClient(http, _BASE)
|
||||
rows = await client.list_chunks(5, page_size=2)
|
||||
assert [r.id for r in rows] == [1, 2, 3]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NewChatClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_create_thread_returns_id(respx_mock, http):
|
||||
respx_mock.post("/api/v1/threads").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"id": 555,
|
||||
"title": "eval",
|
||||
"archived": False,
|
||||
"visibility": "PRIVATE",
|
||||
"search_space_id": 1,
|
||||
"messages": [],
|
||||
"created_at": "2026-05-11T00:00:00Z",
|
||||
"updated_at": "2026-05-11T00:00:00Z",
|
||||
},
|
||||
)
|
||||
)
|
||||
client = NewChatClient(http, _BASE)
|
||||
tid = await client.create_thread(search_space_id=1)
|
||||
assert tid == 555
|
||||
|
||||
|
||||
def _sse_body(events: list[dict]) -> bytes:
|
||||
parts = []
|
||||
for ev in events:
|
||||
parts.append(f"data: {json.dumps(ev)}\n\n")
|
||||
parts.append("data: [DONE]\n\n")
|
||||
return "".join(parts).encode("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_ask_accumulates_text_deltas(respx_mock, http):
|
||||
body = _sse_body([
|
||||
{"type": "start", "messageId": "m1"},
|
||||
{"type": "text-start", "id": "t1"},
|
||||
{"type": "text-delta", "id": "t1", "delta": "Answer "},
|
||||
{"type": "text-delta", "id": "t1", "delta": "is "},
|
||||
{"type": "text-delta", "id": "t1", "delta": "B [citation:42]."},
|
||||
{"type": "text-end", "id": "t1"},
|
||||
{"type": "finish"},
|
||||
])
|
||||
respx_mock.post("/api/v1/new_chat").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
content=body,
|
||||
headers={"Content-Type": "text/event-stream"},
|
||||
)
|
||||
)
|
||||
client = NewChatClient(http, _BASE)
|
||||
answer = await client.ask(
|
||||
thread_id=1, search_space_id=2, user_query="What is the answer?"
|
||||
)
|
||||
assert answer.text == "Answer is B [citation:42]."
|
||||
assert answer.finished_normally is True
|
||||
assert any(c["chunk_id"] == 42 for c in answer.citations)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_ask_409_thread_busy_retries(respx_mock, http):
|
||||
body = _sse_body([
|
||||
{"type": "text-delta", "id": "t1", "delta": "ok"},
|
||||
{"type": "finish"},
|
||||
])
|
||||
busy = httpx.Response(
|
||||
409,
|
||||
json={"detail": {"errorCode": "THREAD_BUSY", "message": "busy"}},
|
||||
headers={"Retry-After": "1"},
|
||||
)
|
||||
success = httpx.Response(
|
||||
200, content=body, headers={"Content-Type": "text/event-stream"}
|
||||
)
|
||||
respx_mock.post("/api/v1/new_chat").mock(side_effect=[busy, success])
|
||||
client = NewChatClient(http, _BASE)
|
||||
answer = await client.ask(
|
||||
thread_id=1, search_space_id=2, user_query="hi", max_busy_retries=2
|
||||
)
|
||||
assert answer.text == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_ask_409_exhausts_retries(respx_mock, http):
|
||||
busy = httpx.Response(
|
||||
409,
|
||||
json={"detail": {"errorCode": "TURN_CANCELLING", "message": "wait"}},
|
||||
headers={"Retry-After": "1"},
|
||||
)
|
||||
respx_mock.post("/api/v1/new_chat").mock(return_value=busy)
|
||||
client = NewChatClient(http, _BASE)
|
||||
with pytest.raises(ThreadBusyError):
|
||||
await client.ask(
|
||||
thread_id=1, search_space_id=2, user_query="hi", max_busy_retries=1
|
||||
)
|
||||
160
surfsense_evals/tests/core/test_config.py
Normal file
160
surfsense_evals/tests/core/test_config.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
"""Tests for env loading + state.json read/write."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from surfsense_evals.core.config import (
|
||||
DEFAULT_SCENARIO,
|
||||
SCENARIOS,
|
||||
SuiteState,
|
||||
clear_suite_state,
|
||||
get_suite_state,
|
||||
load_config,
|
||||
set_suite_state,
|
||||
)
|
||||
|
||||
|
||||
def test_load_config_defaults_to_localhost(tmp_env): # noqa: ARG001
|
||||
config = load_config()
|
||||
assert config.surfsense_api_base == "http://localhost:8000"
|
||||
assert config.has_jwt_mode() is False
|
||||
assert config.has_local_mode() is False
|
||||
assert config.credential_mode() == "none"
|
||||
|
||||
|
||||
def test_load_config_picks_up_jwt_env(tmp_env, monkeypatch): # noqa: ARG001
|
||||
monkeypatch.setenv("SURFSENSE_JWT", "tok")
|
||||
config = load_config()
|
||||
assert config.credential_mode() == "jwt"
|
||||
|
||||
|
||||
def test_load_config_picks_up_local_env(tmp_env, monkeypatch): # noqa: ARG001
|
||||
monkeypatch.setenv("SURFSENSE_USER_EMAIL", "u@x.com")
|
||||
monkeypatch.setenv("SURFSENSE_USER_PASSWORD", "pw")
|
||||
config = load_config()
|
||||
assert config.credential_mode() == "local"
|
||||
|
||||
|
||||
def test_state_roundtrip_per_suite(tmp_env): # noqa: ARG001
|
||||
config = load_config()
|
||||
assert get_suite_state(config, "medical") is None
|
||||
state = SuiteState(
|
||||
search_space_id=1,
|
||||
agent_llm_id=-10042,
|
||||
provider_model="anthropic/claude-sonnet-4.5",
|
||||
created_at="2026-05-11T20-30-00Z",
|
||||
)
|
||||
set_suite_state(config, "medical", state)
|
||||
legal = SuiteState(
|
||||
search_space_id=2,
|
||||
agent_llm_id=-1,
|
||||
provider_model="openai/gpt-5",
|
||||
created_at="2026-05-11T21-00-00Z",
|
||||
)
|
||||
set_suite_state(config, "legal", legal)
|
||||
|
||||
fetched = get_suite_state(config, "medical")
|
||||
assert fetched.search_space_id == 1
|
||||
assert fetched.provider_model == "anthropic/claude-sonnet-4.5"
|
||||
|
||||
# Other suite untouched after teardown.
|
||||
cleared = clear_suite_state(config, "medical")
|
||||
assert cleared is True
|
||||
assert get_suite_state(config, "medical") is None
|
||||
assert get_suite_state(config, "legal").search_space_id == 2
|
||||
|
||||
raw = json.loads(config.state_path.read_text(encoding="utf-8"))
|
||||
assert "medical" not in raw["suites"]
|
||||
assert "legal" in raw["suites"]
|
||||
|
||||
|
||||
def test_paths_are_per_suite(tmp_env): # noqa: ARG001
|
||||
config = load_config()
|
||||
a = config.suite_data_dir("medical")
|
||||
b = config.suite_data_dir("legal")
|
||||
assert a != b
|
||||
assert config.suite_reports_dir("medical").parent == config.reports_dir
|
||||
assert config.suite_runs_dir("medical").name == "runs"
|
||||
assert config.suite_maps_dir("medical").name == "maps"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario state — back-compat + new fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_legacy_state_back_compat_defaults_to_head_to_head():
|
||||
"""state.json files written before scenarios shipped must still load.
|
||||
|
||||
Missing ``scenario`` / ``vision_*`` / ``native_arm_model`` keys all
|
||||
default to ``head-to-head`` / ``None`` so old setups keep working
|
||||
after upgrade — the runner's behaviour exactly mirrors the legacy
|
||||
one (both arms answer with ``provider_model``).
|
||||
"""
|
||||
|
||||
legacy = {
|
||||
"search_space_id": 7,
|
||||
"agent_llm_id": -123,
|
||||
"provider_model": "anthropic/claude-sonnet-4.5",
|
||||
"created_at": "2026-05-11T20-30-00Z",
|
||||
"ingestion_maps": {},
|
||||
}
|
||||
state = SuiteState.from_dict(legacy)
|
||||
assert state.scenario == DEFAULT_SCENARIO == "head-to-head"
|
||||
assert state.vision_llm_config_id is None
|
||||
assert state.vision_provider_model is None
|
||||
assert state.native_arm_model is None
|
||||
# The native arm should still answer with the same slug as SurfSense.
|
||||
assert state.effective_native_arm_model == state.provider_model
|
||||
|
||||
|
||||
def test_unknown_scenario_falls_back_to_default():
|
||||
"""Garbage scenario in state.json → default, not crash.
|
||||
|
||||
Defensive: we'd rather a stale state file render with the safe
|
||||
head-to-head behaviour than break the whole run with a KeyError.
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"search_space_id": 1,
|
||||
"agent_llm_id": -1,
|
||||
"provider_model": "openai/gpt-5",
|
||||
"scenario": "unknown-scenario-name",
|
||||
}
|
||||
state = SuiteState.from_dict(payload)
|
||||
assert state.scenario == DEFAULT_SCENARIO
|
||||
|
||||
|
||||
def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG001
|
||||
config = load_config()
|
||||
state = SuiteState(
|
||||
search_space_id=42,
|
||||
agent_llm_id=-1,
|
||||
provider_model="openai/gpt-5.4-mini",
|
||||
created_at="2026-05-11T20-30-00Z",
|
||||
scenario="cost-arbitrage",
|
||||
vision_llm_config_id=-101,
|
||||
vision_provider_model="anthropic/claude-sonnet-4.5",
|
||||
native_arm_model="anthropic/claude-sonnet-4.5",
|
||||
)
|
||||
set_suite_state(config, "medical", state)
|
||||
|
||||
fetched = get_suite_state(config, "medical")
|
||||
assert fetched.scenario == "cost-arbitrage"
|
||||
assert fetched.vision_llm_config_id == -101
|
||||
assert fetched.vision_provider_model == "anthropic/claude-sonnet-4.5"
|
||||
assert fetched.native_arm_model == "anthropic/claude-sonnet-4.5"
|
||||
# Cost arbitrage's whole point: native arm slug != surfsense slug.
|
||||
assert fetched.effective_native_arm_model != fetched.provider_model
|
||||
assert fetched.effective_native_arm_model == "anthropic/claude-sonnet-4.5"
|
||||
|
||||
raw = json.loads(config.state_path.read_text(encoding="utf-8"))
|
||||
assert raw["suites"]["medical"]["scenario"] == "cost-arbitrage"
|
||||
|
||||
|
||||
def test_scenario_constants_are_stable():
|
||||
"""Pin the public scenario list; runners + tests key off these strings."""
|
||||
|
||||
assert SCENARIOS == ("head-to-head", "symmetric-cheap", "cost-arbitrage")
|
||||
assert DEFAULT_SCENARIO == "head-to-head"
|
||||
269
surfsense_evals/tests/core/test_ingest_settings.py
Normal file
269
surfsense_evals/tests/core/test_ingest_settings.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
"""Unit tests for ``surfsense_evals.core.ingest_settings``.
|
||||
|
||||
Covers:
|
||||
|
||||
* ``IngestSettings.merge`` honours operator overrides and falls back
|
||||
to per-benchmark defaults when the operator is silent.
|
||||
* ``add_ingest_settings_args`` exposes the three flag pairs and
|
||||
argparse defaults of ``None`` correctly distinguish "not passed"
|
||||
from "explicitly false".
|
||||
* ``settings_header_line`` / ``read_settings_header`` round-trip
|
||||
through a JSONL file.
|
||||
* ``read_settings_header`` is fault-tolerant: missing files, missing
|
||||
header, malformed JSON.
|
||||
* ``format_ingest_settings_md`` produces a stable Markdown bullet.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.core.ingest_settings import (
|
||||
PROCESSING_MODES,
|
||||
SETTINGS_HEADER_KEY,
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
is_settings_header,
|
||||
read_settings_header,
|
||||
settings_header_line,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# IngestSettings.merge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMerge:
|
||||
def test_silent_operator_uses_defaults(self) -> None:
|
||||
defaults = IngestSettings(use_vision_llm=True, processing_mode="basic", should_summarize=True)
|
||||
merged = IngestSettings.merge(defaults, {})
|
||||
assert merged == defaults
|
||||
|
||||
def test_explicit_false_overrides_default_true(self) -> None:
|
||||
defaults = IngestSettings(use_vision_llm=True)
|
||||
merged = IngestSettings.merge(
|
||||
defaults, {"use_vision_llm": False}
|
||||
)
|
||||
assert merged.use_vision_llm is False
|
||||
|
||||
def test_explicit_true_overrides_default_false(self) -> None:
|
||||
defaults = IngestSettings(use_vision_llm=False)
|
||||
merged = IngestSettings.merge(
|
||||
defaults, {"use_vision_llm": True}
|
||||
)
|
||||
assert merged.use_vision_llm is True
|
||||
|
||||
def test_none_means_silent(self) -> None:
|
||||
# Argparse with BooleanOptionalAction yields None when the
|
||||
# operator passed neither --use-vision-llm nor --no-vision-llm.
|
||||
defaults = IngestSettings(use_vision_llm=True)
|
||||
merged = IngestSettings.merge(
|
||||
defaults, {"use_vision_llm": None}
|
||||
)
|
||||
assert merged.use_vision_llm is True
|
||||
|
||||
def test_processing_mode_override(self) -> None:
|
||||
defaults = IngestSettings(processing_mode="basic")
|
||||
merged = IngestSettings.merge(
|
||||
defaults, {"processing_mode": "premium"}
|
||||
)
|
||||
assert merged.processing_mode == "premium"
|
||||
|
||||
def test_processing_mode_invalid_raises(self) -> None:
|
||||
defaults = IngestSettings(processing_mode="basic")
|
||||
with pytest.raises(ValueError, match="Invalid processing_mode"):
|
||||
IngestSettings.merge(defaults, {"processing_mode": "exotic"})
|
||||
|
||||
def test_processing_mode_blank_falls_back(self) -> None:
|
||||
defaults = IngestSettings(processing_mode="basic")
|
||||
merged = IngestSettings.merge(defaults, {"processing_mode": ""})
|
||||
assert merged.processing_mode == "basic"
|
||||
|
||||
def test_string_truthy_coerced(self) -> None:
|
||||
defaults = IngestSettings(use_vision_llm=False)
|
||||
merged = IngestSettings.merge(defaults, {"use_vision_llm": "yes"})
|
||||
assert merged.use_vision_llm is True
|
||||
|
||||
def test_string_falsy_coerced(self) -> None:
|
||||
defaults = IngestSettings(use_vision_llm=True)
|
||||
merged = IngestSettings.merge(defaults, {"use_vision_llm": "false"})
|
||||
assert merged.use_vision_llm is False
|
||||
|
||||
def test_other_keys_ignored(self) -> None:
|
||||
# Benchmarks pass the whole opts dict; merge must tolerate
|
||||
# unrelated keys without crashing.
|
||||
defaults = IngestSettings(use_vision_llm=True, processing_mode="basic")
|
||||
merged = IngestSettings.merge(
|
||||
defaults,
|
||||
{
|
||||
"use_vision_llm": False,
|
||||
"concurrency": 4,
|
||||
"task_filter": "all",
|
||||
"no_mentions": True,
|
||||
},
|
||||
)
|
||||
assert merged.use_vision_llm is False
|
||||
assert merged.processing_mode == "basic"
|
||||
|
||||
def test_to_dict_round_trips(self) -> None:
|
||||
s = IngestSettings(use_vision_llm=True, processing_mode="premium", should_summarize=False)
|
||||
d = s.to_dict()
|
||||
assert d == {
|
||||
"use_vision_llm": True,
|
||||
"processing_mode": "premium",
|
||||
"should_summarize": False,
|
||||
}
|
||||
|
||||
def test_render_label_format(self) -> None:
|
||||
s = IngestSettings(use_vision_llm=True, processing_mode="premium", should_summarize=True)
|
||||
assert s.render_label() == "vision=on, mode=premium, summarize=on"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# add_ingest_settings_args
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAddArgs:
|
||||
@pytest.fixture
|
||||
def parser(self) -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser()
|
||||
add_ingest_settings_args(
|
||||
p,
|
||||
defaults=IngestSettings(
|
||||
use_vision_llm=False, processing_mode="basic", should_summarize=False
|
||||
),
|
||||
)
|
||||
return p
|
||||
|
||||
def test_silent_invocation_yields_none(self, parser: argparse.ArgumentParser) -> None:
|
||||
args = parser.parse_args([])
|
||||
assert args.use_vision_llm is None
|
||||
assert args.processing_mode is None
|
||||
assert args.should_summarize is None
|
||||
|
||||
def test_use_vision_llm_flag(self, parser: argparse.ArgumentParser) -> None:
|
||||
args = parser.parse_args(["--use-vision-llm"])
|
||||
assert args.use_vision_llm is True
|
||||
|
||||
def test_no_vision_llm_flag(self, parser: argparse.ArgumentParser) -> None:
|
||||
args = parser.parse_args(["--no-vision-llm"])
|
||||
assert args.use_vision_llm is False
|
||||
|
||||
def test_processing_mode_choices(self, parser: argparse.ArgumentParser) -> None:
|
||||
for mode in PROCESSING_MODES:
|
||||
args = parser.parse_args(["--processing-mode", mode])
|
||||
assert args.processing_mode == mode
|
||||
|
||||
def test_processing_mode_rejects_unknown(
|
||||
self, parser: argparse.ArgumentParser
|
||||
) -> None:
|
||||
with pytest.raises(SystemExit):
|
||||
parser.parse_args(["--processing-mode", "exotic"])
|
||||
|
||||
def test_summarize_flag_pair(self, parser: argparse.ArgumentParser) -> None:
|
||||
on = parser.parse_args(["--should-summarize"])
|
||||
assert on.should_summarize is True
|
||||
off = parser.parse_args(["--no-summarize"])
|
||||
assert off.should_summarize is False
|
||||
|
||||
def test_vision_flags_mutually_exclusive(
|
||||
self, parser: argparse.ArgumentParser
|
||||
) -> None:
|
||||
with pytest.raises(SystemExit):
|
||||
parser.parse_args(["--use-vision-llm", "--no-vision-llm"])
|
||||
|
||||
def test_full_pipeline(self, parser: argparse.ArgumentParser) -> None:
|
||||
# Operator passes flags + defaults are reasonable. Merge
|
||||
# should yield exactly what they asked for.
|
||||
args = parser.parse_args(
|
||||
["--use-vision-llm", "--processing-mode", "premium"]
|
||||
)
|
||||
defaults = IngestSettings(
|
||||
use_vision_llm=False, processing_mode="basic", should_summarize=False
|
||||
)
|
||||
merged = IngestSettings.merge(defaults, vars(args))
|
||||
assert merged == IngestSettings(
|
||||
use_vision_llm=True, processing_mode="premium", should_summarize=False
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Header round-trip + read_settings_header fault tolerance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeader:
|
||||
def test_header_line_round_trip(self, tmp_path: Path) -> None:
|
||||
s = IngestSettings(use_vision_llm=True, processing_mode="premium")
|
||||
path = tmp_path / "map.jsonl"
|
||||
with path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(settings_header_line(s) + "\n")
|
||||
fh.write(json.dumps({"case_id": "x", "document_id": 1}) + "\n")
|
||||
loaded = read_settings_header(path)
|
||||
assert loaded == s.to_dict()
|
||||
|
||||
def test_is_settings_header_recognises(self) -> None:
|
||||
assert is_settings_header({SETTINGS_HEADER_KEY: {}})
|
||||
assert not is_settings_header({"case_id": "x"})
|
||||
|
||||
def test_missing_file_returns_empty(self, tmp_path: Path) -> None:
|
||||
assert read_settings_header(tmp_path / "does_not_exist.jsonl") == {}
|
||||
|
||||
def test_empty_file_returns_empty(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "empty.jsonl"
|
||||
path.write_text("", encoding="utf-8")
|
||||
assert read_settings_header(path) == {}
|
||||
|
||||
def test_no_header_returns_empty(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "legacy.jsonl"
|
||||
with path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps({"case_id": "x", "document_id": 1}) + "\n")
|
||||
fh.write(json.dumps({"case_id": "y", "document_id": 2}) + "\n")
|
||||
assert read_settings_header(path) == {}
|
||||
|
||||
def test_malformed_json_returns_empty(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "broken.jsonl"
|
||||
path.write_text("not json\n", encoding="utf-8")
|
||||
assert read_settings_header(path) == {}
|
||||
|
||||
def test_skips_blank_first_lines(self, tmp_path: Path) -> None:
|
||||
s = IngestSettings(use_vision_llm=True)
|
||||
path = tmp_path / "padded.jsonl"
|
||||
with path.open("w", encoding="utf-8") as fh:
|
||||
fh.write("\n\n")
|
||||
fh.write(settings_header_line(s) + "\n")
|
||||
assert read_settings_header(path) == s.to_dict()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# format_ingest_settings_md
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatMd:
|
||||
def test_full_settings(self) -> None:
|
||||
out = format_ingest_settings_md(
|
||||
{"use_vision_llm": True, "processing_mode": "premium", "should_summarize": True}
|
||||
)
|
||||
assert "vision_llm=`on`" in out
|
||||
assert "processing_mode=`premium`" in out
|
||||
assert "summarize=`on`" in out
|
||||
|
||||
def test_default_off(self) -> None:
|
||||
out = format_ingest_settings_md(
|
||||
{"use_vision_llm": False, "processing_mode": "basic", "should_summarize": False}
|
||||
)
|
||||
assert "vision_llm=`off`" in out
|
||||
assert "processing_mode=`basic`" in out
|
||||
assert "summarize=`off`" in out
|
||||
|
||||
def test_missing_returns_re_ingest_hint(self) -> None:
|
||||
# Empty dict + None + non-mapping should all degrade gracefully.
|
||||
for raw in [None, {}, "not-a-mapping"]:
|
||||
assert "(not recorded" in format_ingest_settings_md(raw)
|
||||
153
surfsense_evals/tests/core/test_metrics.py
Normal file
153
surfsense_evals/tests/core/test_metrics.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""Metric correctness — Wilson, McNemar, retrieval scores."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.core.metrics import (
|
||||
accuracy_with_wilson_ci,
|
||||
bootstrap_delta_ci,
|
||||
mcnemar_test,
|
||||
mrr,
|
||||
ndcg_at_k,
|
||||
recall_at_k,
|
||||
score_run,
|
||||
wilson_ci,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wilson
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,n,low,high",
|
||||
[
|
||||
(80, 100, 0.7111, 0.8666), # cross-checked vs statsmodels.proportion_confint(method='wilson')
|
||||
(50, 100, 0.4038, 0.5962),
|
||||
(0, 0, 0.0, 1.0),
|
||||
(0, 10, 0.0, 0.2775),
|
||||
(10, 10, 0.7225, 1.0),
|
||||
],
|
||||
)
|
||||
def test_wilson_ci_known_values(k, n, low, high):
|
||||
result_low, result_high = wilson_ci(k, n)
|
||||
assert math.isclose(result_low, low, abs_tol=5e-4), (k, n, result_low, low)
|
||||
assert math.isclose(result_high, high, abs_tol=5e-4), (k, n, result_high, high)
|
||||
|
||||
|
||||
def test_accuracy_with_wilson_ci_object():
|
||||
res = accuracy_with_wilson_ci(70, 100)
|
||||
assert res.accuracy == 0.7
|
||||
assert 0.0 < res.ci_low < res.ci_high < 1.0
|
||||
|
||||
|
||||
def test_invalid_inputs_raise():
|
||||
with pytest.raises(ValueError):
|
||||
accuracy_with_wilson_ci(-1, 10)
|
||||
with pytest.raises(ValueError):
|
||||
accuracy_with_wilson_ci(11, 10)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# McNemar
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_mcnemar_degenerate_returns_p_value_one():
|
||||
a = [True, True, False, False]
|
||||
b = [True, True, False, False]
|
||||
res = mcnemar_test(a, b)
|
||||
assert res.b == 0 and res.c == 0
|
||||
assert res.p_value == 1.0
|
||||
assert res.method == "degenerate"
|
||||
|
||||
|
||||
def test_mcnemar_exact_branch_strong_signal():
|
||||
"""B = 0, C = 10 → exact two-sided binomial p == 2 * (1/2)**10."""
|
||||
|
||||
a = [True] * 10 + [False] * 10
|
||||
b = [True] * 10 + [True] * 10 # surfsense beats native on the 10 native-wrong
|
||||
res = mcnemar_test(a, b)
|
||||
assert res.b == 0
|
||||
assert res.c == 10
|
||||
assert res.method == "exact"
|
||||
expected = 2 * (0.5 ** 10)
|
||||
assert math.isclose(res.p_value, expected, rel_tol=1e-9)
|
||||
|
||||
|
||||
def test_mcnemar_chi_square_approx_for_large_discordant():
|
||||
# Construct b=15, c=5 with continuity-corrected chi^2 = (|10|-1)^2/20 = 4.05.
|
||||
a = [True] * 15 + [False] * 5 + [True] * 30 + [False] * 30
|
||||
b = [False] * 15 + [True] * 5 + [True] * 30 + [False] * 30
|
||||
res = mcnemar_test(a, b)
|
||||
assert res.method == "chi2_cc"
|
||||
assert res.b == 15 and res.c == 5
|
||||
assert math.isclose(res.statistic, ((abs(15 - 5) - 1) ** 2) / 20.0, rel_tol=1e-9)
|
||||
# p ≈ chi2.sf(4.05, df=1) ≈ 0.04417
|
||||
assert 0.04 < res.p_value < 0.05
|
||||
|
||||
|
||||
def test_mcnemar_length_mismatch():
|
||||
with pytest.raises(ValueError):
|
||||
mcnemar_test([True], [True, False])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bootstrap_delta_ci_shape_and_determinism():
|
||||
a = [True, True, False, True, False, False, True, True]
|
||||
b = [True, True, True, True, True, False, True, False]
|
||||
res1 = bootstrap_delta_ci(a, b, n_resamples=500, random_state=42)
|
||||
res2 = bootstrap_delta_ci(a, b, n_resamples=500, random_state=42)
|
||||
assert res1.delta == res2.delta
|
||||
assert res1.ci_low == res2.ci_low
|
||||
assert res1.ci_high == res2.ci_high
|
||||
assert res1.ci_low <= res1.delta <= res1.ci_high
|
||||
assert res1.n_resamples == 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Retrieval
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_recall_at_k():
|
||||
retrieved = ["a", "b", "c", "d"]
|
||||
relevant = ["b", "d", "z"]
|
||||
assert recall_at_k(retrieved, relevant, k=2) == pytest.approx(1 / 3)
|
||||
assert recall_at_k(retrieved, relevant, k=4) == pytest.approx(2 / 3)
|
||||
|
||||
|
||||
def test_mrr():
|
||||
assert mrr(["a", "b", "c"], ["c"]) == pytest.approx(1 / 3)
|
||||
assert mrr(["x", "y"], ["z"]) == 0.0
|
||||
|
||||
|
||||
def test_ndcg_at_k_perfect_order():
|
||||
qrels = {"a": 2, "b": 1}
|
||||
assert ndcg_at_k(["a", "b"], qrels, k=2) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_ndcg_at_k_irrelevant_first():
|
||||
qrels = {"a": 2, "b": 1}
|
||||
# Wrong order should still be > 0 but < 1
|
||||
val = ndcg_at_k(["c", "a", "b"], qrels, k=3)
|
||||
assert 0 < val < 1
|
||||
|
||||
|
||||
def test_score_run_aggregates_across_queries():
|
||||
scores = score_run(
|
||||
per_query_retrieved={"q1": ["a", "b"], "q2": ["x", "y", "z"]},
|
||||
per_query_qrels={"q1": {"a": 1}, "q2": {"z": 2}},
|
||||
ks=(1, 5),
|
||||
ndcg_k=5,
|
||||
)
|
||||
assert scores.n_queries == 2
|
||||
assert scores.recall_at_k[1] == pytest.approx((1 + 0) / 2) # q1 hits @1, q2 doesn't
|
||||
assert scores.mrr == pytest.approx((1.0 + 1 / 3) / 2)
|
||||
27
surfsense_evals/tests/core/test_parse_answer_letter.py
Normal file
27
surfsense_evals/tests/core/test_parse_answer_letter.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Tests for the MCQ answer-letter extractor."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.core.parse import extract_answer_letter
|
||||
from surfsense_evals.core.parse.answer_letter import AnswerLetterResult
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,expected_letter,expected_strategy",
|
||||
[
|
||||
('```json\n{"step_by_step_thinking": "...", "answer_choice": "B"}\n```', "B", "json_envelope"),
|
||||
('Reasoning... {"step_by_step_thinking": "x", "answer_choice": "C"}', "C", "json_envelope"),
|
||||
("Long reasoning.\nAnswer: D", "D", "answer_line"),
|
||||
("The correct answer is (A).", "A", "answer_line"),
|
||||
("Final answer: e", "E", "answer_line"),
|
||||
("Long reasoning.\n\nB", "B", "bare_letter"),
|
||||
("Long reasoning.\n\n(C).", "C", "bare_letter"),
|
||||
("", None, "none"),
|
||||
("Just narrative without an answer.", None, "none"),
|
||||
],
|
||||
)
|
||||
def test_extract_answer_letter(text, expected_letter, expected_strategy):
|
||||
result = extract_answer_letter(text)
|
||||
assert result == AnswerLetterResult(expected_letter, expected_strategy)
|
||||
108
surfsense_evals/tests/core/test_parse_citations.py
Normal file
108
surfsense_evals/tests/core/test_parse_citations.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""Parity tests for the citation regex.
|
||||
|
||||
Each row mirrors a case from the canonical TS reference at
|
||||
``surfsense_web/lib/citations/citation-parser.ts``. If a future PR
|
||||
loosens or tightens the TS regex, these tests will start failing;
|
||||
that's the explicit signal to re-port the change.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.core.parse import (
|
||||
CITATION_REGEX,
|
||||
ChunkCitation,
|
||||
UrlCitation,
|
||||
parse_citations,
|
||||
)
|
||||
|
||||
PARITY_TABLE = [
|
||||
# (input, expected number of matches, expected first-token kind/value)
|
||||
("Plain text with no citation.", 0, None),
|
||||
(
|
||||
"The patient has fever [citation:42] and cough.",
|
||||
1,
|
||||
ChunkCitation(chunk_id=42, is_docs_chunk=False),
|
||||
),
|
||||
(
|
||||
"Negative chunk ids work [citation:-7].",
|
||||
1,
|
||||
ChunkCitation(chunk_id=-7, is_docs_chunk=False),
|
||||
),
|
||||
(
|
||||
"doc-prefix [citation:doc-12].",
|
||||
1,
|
||||
ChunkCitation(chunk_id=12, is_docs_chunk=True),
|
||||
),
|
||||
(
|
||||
"Multi id [citation:1, doc-2, -3].",
|
||||
3,
|
||||
ChunkCitation(chunk_id=1, is_docs_chunk=False),
|
||||
),
|
||||
(
|
||||
"URL form [citation:https://x.com/a].",
|
||||
1,
|
||||
UrlCitation(url="https://x.com/a"),
|
||||
),
|
||||
(
|
||||
"Chinese brackets【citation:5】.",
|
||||
1,
|
||||
ChunkCitation(chunk_id=5, is_docs_chunk=False),
|
||||
),
|
||||
(
|
||||
"ZWSP-decorated [\u200bcitation:9\u200b].",
|
||||
1,
|
||||
ChunkCitation(chunk_id=9, is_docs_chunk=False),
|
||||
),
|
||||
(
|
||||
"Whitespace [citation: doc-100 ] tolerated.",
|
||||
1,
|
||||
ChunkCitation(chunk_id=100, is_docs_chunk=True),
|
||||
),
|
||||
(
|
||||
# The TS regex's URL char class excludes ']', so a trailing
|
||||
# bracket isn't swallowed.
|
||||
"Two URLs [citation:https://a.io] and [citation:https://b.io].",
|
||||
2,
|
||||
UrlCitation(url="https://a.io"),
|
||||
),
|
||||
(
|
||||
# Garbled form should match nothing.
|
||||
"Citation-like but wrong [citation:].",
|
||||
0,
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text,n_expected,first", PARITY_TABLE)
|
||||
def test_citation_regex_parity(text: str, n_expected: int, first):
|
||||
tokens = parse_citations(text)
|
||||
assert len(tokens) == n_expected, (text, tokens)
|
||||
if first is not None:
|
||||
assert tokens[0] == first, (text, tokens)
|
||||
|
||||
|
||||
def test_regex_pattern_matches_ts_source():
|
||||
"""Sanity: the compiled pattern carries the exact alternatives the TS source does."""
|
||||
|
||||
pattern = CITATION_REGEX.pattern
|
||||
assert "https?://" in pattern
|
||||
assert "urlcite" in pattern
|
||||
assert "doc-" in pattern
|
||||
assert "\u200B" in pattern
|
||||
assert "【" in pattern and "】" in pattern
|
||||
|
||||
|
||||
def test_url_map_resolution():
|
||||
text = "Inline placeholder [citation:urlcite0]."
|
||||
tokens = parse_citations(text, url_map={"urlcite0": "https://resolved.example/x"})
|
||||
assert tokens == [UrlCitation(url="https://resolved.example/x")]
|
||||
|
||||
|
||||
def test_url_map_missing_key_drops_token():
|
||||
"""Missing urlcite resolution returns no token (TS behaviour)."""
|
||||
|
||||
text = "[citation:urlcite99]"
|
||||
assert parse_citations(text, url_map={}) == []
|
||||
73
surfsense_evals/tests/core/test_parse_freeform_answer.py
Normal file
73
surfsense_evals/tests/core/test_parse_freeform_answer.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Tests for ``surfsense_evals.core.parse.freeform_answer``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.core.parse.freeform_answer import extract_freeform_answer
|
||||
|
||||
|
||||
class TestExtractFreeformAnswer:
|
||||
def test_empty_string_returns_empty(self) -> None:
|
||||
assert extract_freeform_answer("") == ""
|
||||
assert extract_freeform_answer(" \n\n ") == ""
|
||||
|
||||
def test_simple_answer_marker(self) -> None:
|
||||
assert extract_freeform_answer("Answer: 42") == "42"
|
||||
|
||||
def test_final_answer_marker(self) -> None:
|
||||
assert extract_freeform_answer("Final answer: Paris") == "Paris"
|
||||
|
||||
def test_the_answer_is_marker(self) -> None:
|
||||
assert extract_freeform_answer("The answer is: not answerable") == "not answerable"
|
||||
|
||||
def test_multiline_picks_last_answer_marker(self) -> None:
|
||||
text = "Let me think...\nAnswer: 5\nAnswer: 7\n"
|
||||
assert extract_freeform_answer(text) == "7"
|
||||
|
||||
def test_falls_back_to_last_nonempty_line(self) -> None:
|
||||
text = "Some thinking here.\n\n42"
|
||||
assert extract_freeform_answer(text) == "42"
|
||||
|
||||
def test_strips_quotes(self) -> None:
|
||||
assert extract_freeform_answer('Answer: "Paris"') == "Paris"
|
||||
assert extract_freeform_answer("Answer: 'Paris'") == "Paris"
|
||||
|
||||
def test_strips_backticks(self) -> None:
|
||||
assert extract_freeform_answer("Answer: `42`") == "42"
|
||||
|
||||
def test_uses_fenced_block_when_no_marker(self) -> None:
|
||||
text = "Here's my response:\n```\nfinal value\n```\n"
|
||||
assert extract_freeform_answer(text) == "final value"
|
||||
|
||||
def test_case_insensitive_markers(self) -> None:
|
||||
assert extract_freeform_answer("ANSWER: yes") == "yes"
|
||||
assert extract_freeform_answer("answer: no") == "no"
|
||||
|
||||
@pytest.mark.parametrize("text,expected", [
|
||||
("Answer: 1, 2, 3", "1, 2, 3"),
|
||||
("Answer: 3.14", "3.14"),
|
||||
("Answer: spaced ", "spaced"),
|
||||
])
|
||||
def test_various_payloads(self, text: str, expected: str) -> None:
|
||||
assert extract_freeform_answer(text) == expected
|
||||
|
||||
def test_inline_answer_after_thinking_trace(self) -> None:
|
||||
# Agent replies sometimes glue their thinking onto the same
|
||||
# line as the final "Answer: ..." marker (no newline before it).
|
||||
# The line-anchored regex misses this; the inline fallback
|
||||
# should still extract the right value.
|
||||
text = (
|
||||
"Need the Charlotte Bronte book title/year and the rank "
|
||||
"for a 128-foot NYC building.Answer: 128th"
|
||||
)
|
||||
assert extract_freeform_answer(text) == "128th"
|
||||
|
||||
def test_inline_picks_last_inline_answer(self) -> None:
|
||||
text = "Thought: maybe Answer: 5 is right? Actually Answer: 7."
|
||||
assert extract_freeform_answer(text) == "7."
|
||||
|
||||
def test_inline_does_not_override_proper_marker(self) -> None:
|
||||
# When a clean line-anchored "Answer: ..." exists, that wins.
|
||||
text = "Some preamble.Answer: 99\nAnswer: 42"
|
||||
assert extract_freeform_answer(text) == "42"
|
||||
84
surfsense_evals/tests/core/test_parse_sse.py
Normal file
84
surfsense_evals/tests/core/test_parse_sse.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""Tests for the SSE consumer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.core.parse import iter_sse_events
|
||||
|
||||
|
||||
async def _alist(it):
|
||||
out = []
|
||||
async for x in it:
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
async def _astream(lines):
|
||||
for line in lines:
|
||||
yield line
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_data_frame():
|
||||
events = await _alist(
|
||||
iter_sse_events(_astream([
|
||||
'data: {"type": "text-delta", "delta": "hi"}',
|
||||
"",
|
||||
'data: {"type": "finish"}',
|
||||
"",
|
||||
]))
|
||||
)
|
||||
assert [e.data for e in events] == [
|
||||
'{"type": "text-delta", "delta": "hi"}',
|
||||
'{"type": "finish"}',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_done_sentinel_passes_through():
|
||||
events = await _alist(
|
||||
iter_sse_events(_astream([
|
||||
"data: [DONE]",
|
||||
"",
|
||||
]))
|
||||
)
|
||||
assert [e.data for e in events] == ["[DONE]"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiline_data_joins_with_newline():
|
||||
events = await _alist(
|
||||
iter_sse_events(_astream([
|
||||
"data: line1",
|
||||
"data: line2",
|
||||
"",
|
||||
]))
|
||||
)
|
||||
assert events[0].data == "line1\nline2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comments_and_other_fields_ignored():
|
||||
events = await _alist(
|
||||
iter_sse_events(_astream([
|
||||
": heartbeat",
|
||||
"event: foo",
|
||||
"id: 123",
|
||||
"data: payload",
|
||||
"",
|
||||
]))
|
||||
)
|
||||
assert [e.data for e in events] == ["payload"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_missing_trailing_blank():
|
||||
"""Some servers omit the final blank line; the consumer should still emit."""
|
||||
|
||||
events = await _alist(
|
||||
iter_sse_events(_astream([
|
||||
"data: only-one",
|
||||
]))
|
||||
)
|
||||
assert [e.data for e in events] == ["only-one"]
|
||||
51
surfsense_evals/tests/core/test_pdf_render.py
Normal file
51
surfsense_evals/tests/core/test_pdf_render.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""Smoke tests for PDF rendering.
|
||||
|
||||
We don't pull a full PDF parser into the test deps; the assertions
|
||||
are bytes-level (``%PDF`` magic, deterministic CreationDate scrub).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from surfsense_evals.core.pdf import render_pdf, render_text_files_to_pdf
|
||||
|
||||
|
||||
def test_render_pdf_writes_pdf_with_magic(tmp_path: Path):
|
||||
out = tmp_path / "out.pdf"
|
||||
rendered = render_pdf(
|
||||
title="Test",
|
||||
sections=[("intro", "Hello world."), ("body", "Line one.\nLine two.")],
|
||||
output_path=out,
|
||||
)
|
||||
assert rendered.path == out
|
||||
assert out.exists()
|
||||
assert out.read_bytes().startswith(b"%PDF-")
|
||||
|
||||
|
||||
def test_render_pdf_deterministic_dates(tmp_path: Path):
|
||||
out_a = tmp_path / "a.pdf"
|
||||
out_b = tmp_path / "b.pdf"
|
||||
sections = [("only", "deterministic body content")]
|
||||
render_pdf(title="Det", sections=sections, output_path=out_a)
|
||||
render_pdf(title="Det", sections=sections, output_path=out_b)
|
||||
# CreationDate / ModDate are scrubbed to a fixed value, so the two
|
||||
# files should compare equal (modulo any other internal randomness
|
||||
# — reportlab's basic outputs are deterministic given fixed inputs).
|
||||
assert out_a.read_bytes() == out_b.read_bytes()
|
||||
|
||||
|
||||
def test_render_text_files_uses_filename_as_section(tmp_path: Path):
|
||||
files_dir = tmp_path / "src"
|
||||
files_dir.mkdir()
|
||||
(files_dir / "admission_note.txt").write_text("history of present illness", encoding="utf-8")
|
||||
(files_dir / "labs.txt").write_text("Na 138, K 4.0", encoding="utf-8")
|
||||
out = tmp_path / "case.pdf"
|
||||
rendered = render_text_files_to_pdf(
|
||||
title="Case 1",
|
||||
files=[files_dir / "admission_note.txt", files_dir / "labs.txt"],
|
||||
output_path=out,
|
||||
)
|
||||
assert out.exists()
|
||||
# We don't decode the PDF; the n_chars estimate should reflect both inputs.
|
||||
assert rendered.n_chars >= len("history of present illness") + len("Na 138, K 4.0")
|
||||
73
surfsense_evals/tests/core/test_pdf_render_with_images.py
Normal file
73
surfsense_evals/tests/core/test_pdf_render_with_images.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Tests for ``render_pdf_with_images`` — covers image embedding +
|
||||
deterministic byte output, mirroring ``test_pdf_render.py`` for the
|
||||
text-only path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.core.pdf import PdfImage, render_pdf_with_images
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tiny_png(tmp_path: Path) -> Path:
|
||||
"""Generate a real 4x4 PNG via Pillow — embeds cleanly in reportlab.
|
||||
|
||||
Hand-crafted PNG headers tend to fail PIL's strict decoder, so we
|
||||
delegate to Pillow which is already a transitive dep of reportlab.
|
||||
"""
|
||||
|
||||
from PIL import Image as PILImage
|
||||
|
||||
p = tmp_path / "pixel.png"
|
||||
PILImage.new("RGB", (4, 4), color=(128, 128, 128)).save(p, format="PNG")
|
||||
return p
|
||||
|
||||
|
||||
class TestRenderPdfWithImages:
|
||||
def test_renders_pdf_with_no_images(self, tmp_path: Path) -> None:
|
||||
out = tmp_path / "out.pdf"
|
||||
rendered = render_pdf_with_images(
|
||||
title="Test",
|
||||
sections=[("Heading", "Body text here.", None)],
|
||||
output_path=out,
|
||||
)
|
||||
assert rendered.path == out
|
||||
assert out.exists()
|
||||
assert out.read_bytes().startswith(b"%PDF-")
|
||||
|
||||
def test_renders_pdf_with_one_image(self, tmp_path: Path, tiny_png: Path) -> None:
|
||||
out = tmp_path / "out.pdf"
|
||||
render_pdf_with_images(
|
||||
title="Test",
|
||||
sections=[("Case", "Body text.", [PdfImage(path=tiny_png, caption="A pixel")])],
|
||||
output_path=out,
|
||||
)
|
||||
assert out.exists()
|
||||
assert out.stat().st_size > 200 # not empty
|
||||
|
||||
def test_deterministic_bytes(self, tmp_path: Path, tiny_png: Path) -> None:
|
||||
out_a = tmp_path / "a.pdf"
|
||||
out_b = tmp_path / "b.pdf"
|
||||
sections = [
|
||||
("Case", "Some text.", [PdfImage(path=tiny_png, caption="cap")]),
|
||||
("Options", "A) one\nB) two", None),
|
||||
]
|
||||
render_pdf_with_images(title="Test", sections=sections, output_path=out_a)
|
||||
render_pdf_with_images(title="Test", sections=sections, output_path=out_b)
|
||||
assert out_a.read_bytes() == out_b.read_bytes()
|
||||
|
||||
def test_skips_invalid_image_silently(self, tmp_path: Path) -> None:
|
||||
"""A bad image path should not abort the whole PDF render."""
|
||||
|
||||
out = tmp_path / "out.pdf"
|
||||
render_pdf_with_images(
|
||||
title="Test",
|
||||
sections=[("Case", "Text", [PdfImage(path=tmp_path / "nope.jpg", caption="x")])],
|
||||
output_path=out,
|
||||
)
|
||||
assert out.exists()
|
||||
assert out.read_bytes().startswith(b"%PDF-")
|
||||
121
surfsense_evals/tests/core/test_provider_openrouter.py
Normal file
121
surfsense_evals/tests/core/test_provider_openrouter.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""respx-mocked tests for the OpenRouter PDF provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from surfsense_evals.core.providers.openrouter_pdf import (
|
||||
OpenRouterPdfProvider,
|
||||
PdfEngine,
|
||||
)
|
||||
|
||||
_BASE = "https://openrouter.test"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tiny_pdf(tmp_path: Path) -> Path:
|
||||
p = tmp_path / "case.pdf"
|
||||
p.write_bytes(b"%PDF-1.4 minimal content")
|
||||
return p
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_payload_shape_matches_openrouter_docs(respx_mock, tiny_pdf: Path):
|
||||
captured = {}
|
||||
|
||||
def _capture(request):
|
||||
captured["body"] = json.loads(request.content)
|
||||
captured["headers"] = dict(request.headers)
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"choices": [{
|
||||
"message": {"content": "Answer: B"},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15, "cost": 0.0001},
|
||||
},
|
||||
)
|
||||
|
||||
respx_mock.post("/chat/completions").mock(side_effect=_capture)
|
||||
|
||||
provider = OpenRouterPdfProvider(
|
||||
api_key="sk-or-test",
|
||||
base_url=_BASE,
|
||||
model="anthropic/claude-sonnet-4.5",
|
||||
engine=PdfEngine.NATIVE,
|
||||
)
|
||||
response = await provider.complete(prompt="What is the diagnosis?", pdf_path=tiny_pdf)
|
||||
body = captured["body"]
|
||||
assert body["model"] == "anthropic/claude-sonnet-4.5"
|
||||
assert body["plugins"] == [{"id": "file-parser", "pdf": {"engine": "native"}}]
|
||||
user = body["messages"][-1]
|
||||
assert user["role"] == "user"
|
||||
file_part = user["content"][0]
|
||||
assert file_part["type"] == "file"
|
||||
assert file_part["file"]["filename"] == tiny_pdf.name
|
||||
assert file_part["file"]["file_data"].startswith("data:application/pdf;base64,")
|
||||
assert (
|
||||
base64.b64decode(file_part["file"]["file_data"].split(",", 1)[1])
|
||||
== tiny_pdf.read_bytes() # noqa: ASYNC240 — test fixture, sync read is fine
|
||||
)
|
||||
assert user["content"][1] == {"type": "text", "text": "What is the diagnosis?"}
|
||||
assert captured["headers"]["authorization"] == "Bearer sk-or-test"
|
||||
assert captured["headers"].get("x-title") == "SurfSense-evals"
|
||||
|
||||
assert response.text == "Answer: B"
|
||||
assert response.input_tokens == 10
|
||||
assert response.output_tokens == 5
|
||||
assert response.total_tokens == 15
|
||||
# cost 0.0001 USD == 100 micros
|
||||
assert response.cost_micros == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_chat_array_content_concatenates(respx_mock, tiny_pdf: Path):
|
||||
respx_mock.post("/chat/completions").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello "},
|
||||
{"type": "text", "text": "world"},
|
||||
{"type": "image_url", "image_url": "ignored"},
|
||||
]
|
||||
}
|
||||
}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||
},
|
||||
)
|
||||
)
|
||||
provider = OpenRouterPdfProvider(
|
||||
api_key="sk-or-test", base_url=_BASE, model="x/y"
|
||||
)
|
||||
response = await provider.complete(prompt="hi", pdf_path=tiny_pdf)
|
||||
assert response.text == "Hello world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(base_url=_BASE)
|
||||
async def test_provider_raises_on_4xx(respx_mock, tiny_pdf: Path):
|
||||
respx_mock.post("/chat/completions").mock(
|
||||
return_value=httpx.Response(429, json={"error": {"message": "rate limited"}})
|
||||
)
|
||||
provider = OpenRouterPdfProvider(api_key="sk-or-test", base_url=_BASE, model="x/y")
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await provider.complete(prompt="hi", pdf_path=tiny_pdf)
|
||||
|
||||
|
||||
def test_missing_api_key_raises():
|
||||
with pytest.raises(ValueError):
|
||||
OpenRouterPdfProvider(api_key="", base_url=_BASE, model="x/y")
|
||||
58
surfsense_evals/tests/core/test_registry.py
Normal file
58
surfsense_evals/tests/core/test_registry.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Registry + auto-discovery tests.
|
||||
|
||||
* Auto-discovery skips packages starting with ``_`` (so test fixtures
|
||||
don't leak into the production catalogue).
|
||||
* Manually importing a ``_demo`` benchmark fires its ``register(...)``
|
||||
call and the CLI sees it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
||||
from surfsense_evals.core import registry
|
||||
|
||||
|
||||
def _force_register_demo() -> None:
|
||||
"""Import (or reload) the demo module so its ``register(...)`` runs.
|
||||
|
||||
On a fresh interpreter, ``import_module`` triggers package
|
||||
initialization. After the first call though, the module is cached
|
||||
in ``sys.modules`` and a second ``import_module`` is a no-op — so
|
||||
if a previous test already unregistered the entry, we have to
|
||||
``reload`` to re-execute the module body.
|
||||
"""
|
||||
|
||||
module = importlib.import_module("surfsense_evals.suites._demo.hello")
|
||||
if ("_demo", "hello") not in registry.snapshot():
|
||||
importlib.reload(module)
|
||||
|
||||
|
||||
def test_auto_discovery_skips_underscore_prefixed_subpackages():
|
||||
from surfsense_evals.suites import discover_suites
|
||||
|
||||
discovered = discover_suites()
|
||||
assert all(not part.startswith("_") for full in discovered for part in full.split("."))
|
||||
# The medical suite's headline benchmark must always discover.
|
||||
assert any(name.endswith(".medical.medxpertqa") for name in discovered)
|
||||
|
||||
|
||||
def test_demo_benchmark_registers_on_explicit_import():
|
||||
_force_register_demo()
|
||||
bench = registry.get("_demo", "hello")
|
||||
assert bench is not None
|
||||
assert bench.name == "hello"
|
||||
assert bench.headline is False
|
||||
# Cleanup so the test is idempotent under repeated runs.
|
||||
registry.unregister("_demo", "hello")
|
||||
|
||||
|
||||
def test_register_unregister_roundtrip():
|
||||
# Make sure no stale entry from a prior test in the session.
|
||||
if ("_demo", "hello") in registry.snapshot():
|
||||
registry.unregister("_demo", "hello")
|
||||
snapshot_before = dict(registry.snapshot())
|
||||
_force_register_demo()
|
||||
assert ("_demo", "hello") in registry.snapshot()
|
||||
registry.unregister("_demo", "hello")
|
||||
assert dict(registry.snapshot()) == snapshot_before
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue