mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-27 19:25:15 +02:00
chore: evals
This commit is contained in:
parent
2402b730fa
commit
3737118050
122 changed files with 22598 additions and 13 deletions
10
surfsense_evals/src/surfsense_evals/__init__.py
Normal file
10
surfsense_evals/src/surfsense_evals/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""SurfSense Evals — domain-agnostic eval harness.
|
||||
|
||||
Public entry-point is the ``surfsense_evals`` CLI (``python -m surfsense_evals``).
|
||||
Programmatic embedding is a non-goal for now; everything goes through the CLI
|
||||
+ filesystem outputs (state.json, raw run JSONL, summary.md/json reports).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__version__ = "0.1.0"
|
||||
13
surfsense_evals/src/surfsense_evals/__main__.py
Normal file
13
surfsense_evals/src/surfsense_evals/__main__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Module entry point: ``python -m surfsense_evals ...``.
|
||||
|
||||
Delegates to ``core.cli.main``. ``core.cli`` lazily imports
|
||||
``surfsense_evals.suites`` so every benchmark gets a chance to register
|
||||
before argparse builds its subcommand groups.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from surfsense_evals.core.cli import main
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
raise SystemExit(main())
|
||||
8
surfsense_evals/src/surfsense_evals/core/__init__.py
Normal file
8
surfsense_evals/src/surfsense_evals/core/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""Domain-agnostic infrastructure shared by every suite.
|
||||
|
||||
Nothing under ``core/`` knows or cares about a specific evaluation domain.
|
||||
Suites live under ``surfsense_evals.suites.<domain>.<benchmark>`` and
|
||||
register themselves with ``core.registry`` on import.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
44
surfsense_evals/src/surfsense_evals/core/arms/__init__.py
Normal file
44
surfsense_evals/src/surfsense_evals/core/arms/__init__.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Arm protocol + concrete arms shared across suites.
|
||||
|
||||
Concrete arms (``NativePdfArm``, ``SurfSenseArm``, ``BareLlmArm``) are
|
||||
imported lazily via ``__getattr__`` so consumers that only need the
|
||||
protocol — e.g. the registry's ``Arm`` re-export — don't transitively
|
||||
pull in ``httpx`` providers or the SurfSense client unless they
|
||||
actually use those arms.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import Arm, ArmRequest, ArmResult
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .bare_llm import BareLlmArm
|
||||
from .native_pdf import NativePdfArm
|
||||
from .surfsense import SurfSenseArm
|
||||
|
||||
__all__ = [
|
||||
"Arm",
|
||||
"ArmRequest",
|
||||
"ArmResult",
|
||||
"BareLlmArm",
|
||||
"NativePdfArm",
|
||||
"SurfSenseArm",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str): # PEP 562
|
||||
if name == "NativePdfArm":
|
||||
from .native_pdf import NativePdfArm
|
||||
|
||||
return NativePdfArm
|
||||
if name == "SurfSenseArm":
|
||||
from .surfsense import SurfSenseArm
|
||||
|
||||
return SurfSenseArm
|
||||
if name == "BareLlmArm":
|
||||
from .bare_llm import BareLlmArm
|
||||
|
||||
return BareLlmArm
|
||||
raise AttributeError(f"module 'surfsense_evals.core.arms' has no attribute {name!r}")
|
||||
100
surfsense_evals/src/surfsense_evals/core/arms/bare_llm.py
Normal file
100
surfsense_evals/src/surfsense_evals/core/arms/bare_llm.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""Bare-LLM arm: chat completion with prompt-only input, no retrieval.
|
||||
|
||||
Pairs with ``SurfSenseArm`` for any benchmark that wants to measure
|
||||
"how much does the model already know without RAG?". For factuality /
|
||||
multi-hop benchmarks (FRAMES, MuSiQue, …) this produces the published
|
||||
"naive prompting" baseline — e.g. FRAMES's 40.8% on Gemini-Pro-1.5.
|
||||
|
||||
Symmetric with ``NativePdfArm`` in shape, but the request carries no
|
||||
``pdf_paths``: the prompt itself is the only input the model gets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..providers.openrouter_chat import OpenRouterChatProvider
|
||||
from .base import Arm, ArmRequest, ArmResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BareLlmArm(Arm):
|
||||
"""``Arm`` implementation backed by ``OpenRouterChatProvider``.
|
||||
|
||||
``name`` defaults to ``"bare_llm"`` but is overridable per-instance.
|
||||
Suites that want two distinct OpenRouter chat arms (e.g. CRAG's
|
||||
``bare_llm`` vs ``long_context`` — both backed by chat-completions
|
||||
but exercising different prompt strategies) instantiate twice with
|
||||
different names so the metrics aggregator can keep them separate.
|
||||
"""
|
||||
|
||||
name: str = "bare_llm"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: OpenRouterChatProvider,
|
||||
max_output_tokens: int | None = 1024,
|
||||
system_prompt: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._max_output = max_output_tokens
|
||||
self._system_prompt = system_prompt
|
||||
if name:
|
||||
self.name = name
|
||||
|
||||
@classmethod
|
||||
def from_env(
|
||||
cls,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_output_tokens: int | None = 1024,
|
||||
system_prompt: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> BareLlmArm:
|
||||
provider = OpenRouterChatProvider(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
)
|
||||
return cls(
|
||||
provider=provider,
|
||||
max_output_tokens=max_output_tokens,
|
||||
system_prompt=system_prompt,
|
||||
name=name,
|
||||
)
|
||||
|
||||
async def answer(self, request: ArmRequest) -> ArmResult:
|
||||
try:
|
||||
response = await self._provider.complete(
|
||||
prompt=request.prompt,
|
||||
system_prompt=self._system_prompt,
|
||||
max_tokens=self._max_output,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text="",
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text=response.text,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cost_micros=response.cost_micros,
|
||||
latency_ms=response.latency_ms,
|
||||
extra={
|
||||
"model": self._provider.model,
|
||||
"finish_reason": response.finish_reason,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["BareLlmArm"]
|
||||
93
surfsense_evals/src/surfsense_evals/core/arms/base.py
Normal file
93
surfsense_evals/src/surfsense_evals/core/arms/base.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Arm protocol + the value types every arm exchanges with a runner.
|
||||
|
||||
An ``Arm`` is "one way to answer one question". Two ship in this PR:
|
||||
|
||||
* ``NativePdfArm`` — drop the PDF straight into an OpenRouter
|
||||
chat-completions request with ``plugins=[{file-parser, engine:
|
||||
native}]``. Used for the head-to-head "is the model good enough on
|
||||
its own?" measurement.
|
||||
* ``SurfSenseArm`` — POST ``/api/v1/new_chat`` with the question
|
||||
scoped to the relevant ``mentioned_document_ids``; consume the SSE
|
||||
stream and parse citations.
|
||||
|
||||
Both implement the same protocol so a benchmark runner only sees
|
||||
``Arm.answer(request) -> ArmResult``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArmRequest:
|
||||
"""One arm-call worth of input.
|
||||
|
||||
* ``question_id`` is opaque — used for logging and joining results.
|
||||
* ``prompt`` is the fully-formatted text the arm should send. The
|
||||
runner is responsible for prompt construction so head-to-head
|
||||
comparisons use byte-identical text.
|
||||
* ``pdf_paths`` is the per-question source PDFs (used by
|
||||
``NativePdfArm``). Empty for retrieval-only / corpus-wide
|
||||
benchmarks.
|
||||
* ``mentioned_document_ids`` is the SurfSense document scoping list
|
||||
(used by ``SurfSenseArm``). When ``None`` SurfSense retrieves
|
||||
across the whole search space.
|
||||
* ``options`` is a free-form bag of arm-specific overrides
|
||||
(e.g. SurfSense's ``disabled_tools``).
|
||||
"""
|
||||
|
||||
question_id: str
|
||||
prompt: str
|
||||
pdf_paths: list[Path] = field(default_factory=list)
|
||||
mentioned_document_ids: list[int] | None = None
|
||||
options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArmResult:
|
||||
"""Outcome of one ``Arm.answer`` invocation."""
|
||||
|
||||
arm: str
|
||||
question_id: str
|
||||
raw_text: str
|
||||
answer_letter: str | None = None
|
||||
citations: list[dict[str, Any]] = field(default_factory=list)
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cost_micros: int = 0
|
||||
latency_ms: int = 0
|
||||
error: str | None = None
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
def to_jsonl(self) -> dict[str, Any]:
|
||||
"""Stable dict shape for ``data/<suite>/runs/<ts>/<bench>_raw.jsonl``."""
|
||||
|
||||
return {
|
||||
"arm": self.arm,
|
||||
"question_id": self.question_id,
|
||||
"answer_letter": self.answer_letter,
|
||||
"raw_text": self.raw_text,
|
||||
"citations": self.citations,
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cost_micros": self.cost_micros,
|
||||
"latency_ms": self.latency_ms,
|
||||
"error": self.error,
|
||||
"extra": self.extra,
|
||||
}
|
||||
|
||||
|
||||
class Arm(Protocol):
|
||||
"""One concrete way to answer questions for a given run."""
|
||||
|
||||
name: str
|
||||
|
||||
async def answer(self, request: ArmRequest) -> ArmResult: # pragma: no cover - protocol
|
||||
...
|
||||
104
surfsense_evals/src/surfsense_evals/core/arms/native_pdf.py
Normal file
104
surfsense_evals/src/surfsense_evals/core/arms/native_pdf.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Native-PDF arm: drop the PDF straight into OpenRouter chat-completions.
|
||||
|
||||
Generic across suites — a benchmark just supplies the prompt and the
|
||||
single PDF path. Multi-PDF questions concatenate in the runner before
|
||||
calling this arm so each ``answer`` invocation feeds the model exactly
|
||||
one ``data:application/pdf;base64,...`` block (matches the human
|
||||
"drag-and-drop one PDF into Claude" intent).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..parse.answer_letter import extract_answer_letter
|
||||
from ..providers.openrouter_pdf import OpenRouterPdfProvider, PdfEngine
|
||||
from .base import Arm, ArmRequest, ArmResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NativePdfArm(Arm):
|
||||
"""``Arm`` implementation backed by ``OpenRouterPdfProvider``."""
|
||||
|
||||
name: str = "native_pdf"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: OpenRouterPdfProvider,
|
||||
max_output_tokens: int | None = 1024,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._max_output = max_output_tokens
|
||||
|
||||
@classmethod
|
||||
def from_env(
|
||||
cls,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str,
|
||||
engine: PdfEngine = PdfEngine.NATIVE,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_output_tokens: int | None = 1024,
|
||||
) -> NativePdfArm:
|
||||
provider = OpenRouterPdfProvider(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
engine=engine,
|
||||
)
|
||||
return cls(provider=provider, max_output_tokens=max_output_tokens)
|
||||
|
||||
async def answer(self, request: ArmRequest) -> ArmResult:
|
||||
if not request.pdf_paths:
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text="",
|
||||
error="native_pdf arm requires at least one pdf_path",
|
||||
)
|
||||
if len(request.pdf_paths) > 1:
|
||||
# The plan calls out one-PDF-per-question so the head-to-head
|
||||
# is fair; runners are responsible for upstream concatenation.
|
||||
logger.debug(
|
||||
"qid=%s native_pdf got %d pdfs; using first only",
|
||||
request.question_id,
|
||||
len(request.pdf_paths),
|
||||
)
|
||||
pdf = request.pdf_paths[0]
|
||||
try:
|
||||
response = await self._provider.complete(
|
||||
prompt=request.prompt,
|
||||
pdf_path=pdf,
|
||||
max_tokens=self._max_output,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text="",
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
|
||||
letter = extract_answer_letter(response.text)
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text=response.text,
|
||||
answer_letter=letter.letter,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cost_micros=response.cost_micros,
|
||||
latency_ms=response.latency_ms,
|
||||
extra={
|
||||
"model": self._provider.model,
|
||||
"engine": self._provider.engine.value,
|
||||
"answer_letter_strategy": letter.strategy,
|
||||
"finish_reason": response.finish_reason,
|
||||
"pdf_filename": pdf.name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["NativePdfArm"]
|
||||
104
surfsense_evals/src/surfsense_evals/core/arms/surfsense.py
Normal file
104
surfsense_evals/src/surfsense_evals/core/arms/surfsense.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""SurfSense arm: per-question fresh thread + ``/api/v1/new_chat`` stream.
|
||||
|
||||
For every question:
|
||||
|
||||
* Create a fresh ``NewChatThread`` on the suite's pinned SearchSpace.
|
||||
This sidesteps the per-thread ``THREAD_BUSY`` 409 (a single thread
|
||||
serialises turns, see ``surfsense_backend/app/routes/new_chat_routes.py:191-220``).
|
||||
* POST ``/api/v1/new_chat`` with the prompt and the per-question
|
||||
``mentioned_document_ids`` (``surfsense_backend/app/schemas/new_chat.py:241-243``).
|
||||
* Consume the SSE stream via ``NewChatClient.ask`` which accumulates
|
||||
text deltas and returns ``StreamedAnswer``.
|
||||
* Optionally delete the thread (default ON for ephemeral runs).
|
||||
|
||||
Citations are parsed from the streamed assistant text via the
|
||||
canonical regex port; chunk ids are returned in ``ArmResult.citations``
|
||||
for the runner to map back to corpus ids.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..clients import NewChatClient
|
||||
from ..parse.answer_letter import extract_answer_letter
|
||||
from .base import Arm, ArmRequest, ArmResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SurfSenseArm(Arm):
|
||||
"""``Arm`` implementation backed by ``NewChatClient``."""
|
||||
|
||||
name: str = "surfsense"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: NewChatClient,
|
||||
search_space_id: int,
|
||||
ephemeral_threads: bool = True,
|
||||
thread_title_prefix: str = "eval",
|
||||
) -> None:
|
||||
self._client = client
|
||||
self._search_space_id = search_space_id
|
||||
self._ephemeral = ephemeral_threads
|
||||
self._title_prefix = thread_title_prefix
|
||||
|
||||
async def answer(self, request: ArmRequest) -> ArmResult:
|
||||
thread_id: int | None = None
|
||||
try:
|
||||
thread_id = await self._client.create_thread(
|
||||
search_space_id=self._search_space_id,
|
||||
title=f"{self._title_prefix}:{request.question_id}",
|
||||
)
|
||||
answer = await self._client.ask(
|
||||
thread_id=thread_id,
|
||||
search_space_id=self._search_space_id,
|
||||
user_query=request.prompt,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
disabled_tools=request.options.get("disabled_tools"),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text="",
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
extra={"thread_id": thread_id},
|
||||
)
|
||||
finally:
|
||||
if self._ephemeral and thread_id is not None:
|
||||
try:
|
||||
await self._client.delete_thread(thread_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug(
|
||||
"Failed to delete thread %s: %s", thread_id, exc
|
||||
)
|
||||
|
||||
letter = extract_answer_letter(answer.text)
|
||||
return ArmResult(
|
||||
arm=self.name,
|
||||
question_id=request.question_id,
|
||||
raw_text=answer.text,
|
||||
answer_letter=letter.letter,
|
||||
citations=answer.citations,
|
||||
latency_ms=answer.latency_ms,
|
||||
# SurfSense doesn't surface input/output token counts in the
|
||||
# SSE stream today; leaving the cost / token fields at 0
|
||||
# documents that gap. Estimating from the raw text would
|
||||
# bias the comparison against the SurfSense arm.
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"search_space_id": self._search_space_id,
|
||||
"answer_letter_strategy": letter.strategy,
|
||||
"user_message_id": answer.user_message_id,
|
||||
"assistant_message_id": answer.assistant_message_id,
|
||||
"finished_normally": answer.finished_normally,
|
||||
"n_raw_events": len(answer.raw_events),
|
||||
"n_mentioned_documents": len(request.mentioned_document_ids or []),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["SurfSenseArm"]
|
||||
273
surfsense_evals/src/surfsense_evals/core/auth.py
Normal file
273
surfsense_evals/src/surfsense_evals/core/auth.py
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
"""Dual-mode credential resolver + httpx client factory with 401 auto-refresh.
|
||||
|
||||
SurfSense supports ``AUTH_TYPE=LOCAL`` (email + password) and
|
||||
``AUTH_TYPE=GOOGLE`` (Google OAuth → frontend stores JWT in ``localStorage``).
|
||||
There is no headless equivalent of the Google flow, so the harness handles
|
||||
both modes by treating the JWT as the universal credential:
|
||||
|
||||
* **LOCAL**: harness POSTs form-encoded ``username`` + ``password`` to
|
||||
``/auth/jwt/login``, reads ``{access_token, refresh_token}``.
|
||||
* **GOOGLE / pre-issued JWT**: operator pastes their existing JWT (and
|
||||
optionally refresh token) into ``SURFSENSE_JWT`` /
|
||||
``SURFSENSE_REFRESH_TOKEN``; harness skips login.
|
||||
|
||||
Either way ``client_with_auth`` returns one shared
|
||||
``httpx.AsyncClient`` with ``Authorization: Bearer <jwt>`` set and an
|
||||
event hook that, on a 401 with a refresh token in scope, calls
|
||||
``POST /auth/jwt/refresh`` and retries the original request once. JWT
|
||||
lifetime defaults to one day backend-side, so this matters for long
|
||||
MIRAGE runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialError(RuntimeError):
|
||||
"""Raised when no credential mode is configured."""
|
||||
|
||||
|
||||
_NO_CREDENTIALS_MESSAGE = (
|
||||
"No SurfSense credentials configured. Set ONE of:\n"
|
||||
" (LOCAL) SURFSENSE_USER_EMAIL + SURFSENSE_USER_PASSWORD\n"
|
||||
" (GOOGLE) SURFSENSE_JWT (and optionally SURFSENSE_REFRESH_TOKEN)\n"
|
||||
"For GOOGLE: log in to SurfSense in your browser, open DevTools → "
|
||||
"Application → Local Storage → copy `surfsense_bearer_token` and "
|
||||
"`surfsense_refresh_token` into those env vars."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBundle:
|
||||
"""Mutable token state — refresh hook updates ``access_token`` in place."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str | None = None
|
||||
# ``mode`` is informational only ("local" or "jwt"); used in error messages.
|
||||
mode: str = "jwt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token acquisition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def acquire_token(config: Config, *, http: httpx.AsyncClient | None = None) -> TokenBundle:
|
||||
"""Resolve credentials → ``TokenBundle``.
|
||||
|
||||
Precedence:
|
||||
|
||||
1. ``SURFSENSE_JWT`` set → use it directly. Refresh token captured if
|
||||
supplied.
|
||||
2. ``SURFSENSE_USER_EMAIL`` + ``SURFSENSE_USER_PASSWORD`` set →
|
||||
form-encoded POST to ``/auth/jwt/login``.
|
||||
3. Neither → raise ``CredentialError``.
|
||||
|
||||
The optional ``http`` argument lets tests inject a mocked client; if
|
||||
omitted a one-shot client is created for the login call only.
|
||||
"""
|
||||
|
||||
if config.has_jwt_mode():
|
||||
return TokenBundle(
|
||||
access_token=config.surfsense_jwt or "",
|
||||
refresh_token=config.surfsense_refresh_token,
|
||||
mode="jwt",
|
||||
)
|
||||
|
||||
if config.has_local_mode():
|
||||
async def _login(client: httpx.AsyncClient) -> TokenBundle:
|
||||
response = await client.post(
|
||||
f"{config.surfsense_api_base}/auth/jwt/login",
|
||||
data={
|
||||
"username": config.surfsense_user_email,
|
||||
"password": config.surfsense_user_password,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise CredentialError(
|
||||
f"LOCAL login failed (HTTP {response.status_code}): "
|
||||
f"{_safe_text(response)}"
|
||||
)
|
||||
payload = response.json()
|
||||
access = payload.get("access_token")
|
||||
if not access:
|
||||
raise CredentialError(
|
||||
f"LOCAL login response missing access_token: {payload!r}"
|
||||
)
|
||||
return TokenBundle(
|
||||
access_token=access,
|
||||
refresh_token=payload.get("refresh_token") or None,
|
||||
mode="local",
|
||||
)
|
||||
|
||||
if http is not None:
|
||||
return await _login(http)
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, connect=10.0)) as client:
|
||||
return await _login(client)
|
||||
|
||||
raise CredentialError(_NO_CREDENTIALS_MESSAGE)
|
||||
|
||||
|
||||
def _safe_text(response: httpx.Response, *, limit: int = 200) -> str:
|
||||
body = response.text or ""
|
||||
if len(body) > limit:
|
||||
return body[:limit] + "…"
|
||||
return body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# httpx client + 401 auto-refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AuthState:
|
||||
"""Shared mutable holder closed over by the auth event hook.
|
||||
|
||||
Kept private so callers can't accidentally mutate the access token
|
||||
out-of-band; ``client_with_auth`` returns the client directly.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, tokens: TokenBundle) -> None:
|
||||
self.config = config
|
||||
self.tokens = tokens
|
||||
self._refresh_in_flight: bool = False
|
||||
|
||||
|
||||
def _build_auth_request(state: _AuthState, request: httpx.Request) -> None:
|
||||
"""Stamp the current bearer onto ``request`` (request-event hook)."""
|
||||
|
||||
request.headers["Authorization"] = f"Bearer {state.tokens.access_token}"
|
||||
|
||||
|
||||
async def _refresh_access_token(
|
||||
state: _AuthState, transport: httpx.AsyncBaseTransport | None = None
|
||||
) -> bool:
|
||||
"""POST ``/auth/jwt/refresh`` with the current refresh token.
|
||||
|
||||
Returns ``True`` on success and updates ``state.tokens`` in place.
|
||||
Returns ``False`` if no refresh token is configured or the call fails.
|
||||
Recursive 401s are avoided by using a *new* client without the auth
|
||||
hook.
|
||||
"""
|
||||
|
||||
refresh = state.tokens.refresh_token
|
||||
if not refresh:
|
||||
return False
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(15.0, connect=5.0),
|
||||
transport=transport,
|
||||
) as inner:
|
||||
response = await inner.post(
|
||||
f"{state.config.surfsense_api_base}/auth/jwt/refresh",
|
||||
json={"refresh_token": refresh},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
logger.warning("Token refresh transport error: %s", exc)
|
||||
return False
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"Token refresh rejected (HTTP %s): %s",
|
||||
response.status_code,
|
||||
_safe_text(response),
|
||||
)
|
||||
return False
|
||||
payload = response.json()
|
||||
new_access = payload.get("access_token")
|
||||
if not new_access:
|
||||
logger.warning("Refresh response missing access_token: %r", payload)
|
||||
return False
|
||||
state.tokens.access_token = new_access
|
||||
new_refresh = payload.get("refresh_token")
|
||||
if new_refresh:
|
||||
state.tokens.refresh_token = new_refresh
|
||||
return True
|
||||
|
||||
|
||||
def client_with_auth(
|
||||
config: Config,
|
||||
tokens: TokenBundle,
|
||||
*,
|
||||
timeout: float = 60.0,
|
||||
transport: httpx.AsyncBaseTransport | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Build a single shared ``httpx.AsyncClient`` for the SurfSense API.
|
||||
|
||||
* Stamps ``Authorization: Bearer <jwt>`` on every outgoing request.
|
||||
* On any 401 response, attempts a single refresh (if a refresh token
|
||||
is configured) and retries the original request once. The retry
|
||||
uses a fresh stamping of the bearer header, so a successful
|
||||
refresh transparently unblocks long runs.
|
||||
* The retry is best-effort — repeated 401s after a refresh attempt
|
||||
are surfaced to the caller so they can re-auth manually.
|
||||
|
||||
Pass ``base_url`` to scope a sub-client (e.g. tests). The default
|
||||
keeps full URLs in calling code, which makes route-spec citations in
|
||||
the codebase easier to grep.
|
||||
"""
|
||||
|
||||
state = _AuthState(config, tokens)
|
||||
|
||||
async def _request_hook(request: httpx.Request) -> None:
|
||||
_build_auth_request(state, request)
|
||||
|
||||
# ``send`` is overridden in ``_AuthAwareClient`` to retry once on 401
|
||||
# after refreshing the bearer. httpx's response event-hook can't
|
||||
# *replace* a response, so we need a subclass to do the replay.
|
||||
client = _AuthAwareClient(
|
||||
state=state,
|
||||
transport=transport,
|
||||
timeout=httpx.Timeout(timeout, connect=10.0),
|
||||
base_url=base_url or "",
|
||||
event_hooks={"request": [_request_hook]},
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class _AuthAwareClient(httpx.AsyncClient):
|
||||
"""``AsyncClient`` that retries once on 401 after refreshing the token."""
|
||||
|
||||
def __init__(self, *, state: _AuthState, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._auth_state = state
|
||||
|
||||
async def send( # type: ignore[override]
|
||||
self, request: httpx.Request, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
response = await super().send(request, **kwargs)
|
||||
if response.status_code != 401:
|
||||
return response
|
||||
# Don't refresh while a refresh is itself in flight.
|
||||
if self._auth_state._refresh_in_flight:
|
||||
return response
|
||||
self._auth_state._refresh_in_flight = True
|
||||
try:
|
||||
refreshed = await _refresh_access_token(self._auth_state)
|
||||
finally:
|
||||
self._auth_state._refresh_in_flight = False
|
||||
if not refreshed:
|
||||
return response
|
||||
# Re-stamp and replay once. ``request`` is reusable.
|
||||
await response.aclose()
|
||||
request.headers["Authorization"] = f"Bearer {self._auth_state.tokens.access_token}"
|
||||
return await super().send(request, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CredentialError",
|
||||
"TokenBundle",
|
||||
"acquire_token",
|
||||
"client_with_auth",
|
||||
]
|
||||
790
surfsense_evals/src/surfsense_evals/core/cli.py
Normal file
790
surfsense_evals/src/surfsense_evals/core/cli.py
Normal file
|
|
@ -0,0 +1,790 @@
|
|||
"""Argparse CLI for ``python -m surfsense_evals``.
|
||||
|
||||
Subcommands:
|
||||
|
||||
* ``setup --suite <name> --provider-model <slug> [--agent-llm-id <int>]``
|
||||
* ``teardown --suite <name>``
|
||||
* ``models list [--provider openrouter] [--grep <s>]``
|
||||
* ``suites list``
|
||||
* ``benchmarks list [--suite <name>]``
|
||||
* ``ingest <suite> <benchmark> [benchmark flags]``
|
||||
* ``run <suite> <benchmark> [benchmark flags]``
|
||||
* ``report --suite <name> [--benchmark <name>]``
|
||||
|
||||
The ``ingest`` / ``run`` subparsers are built dynamically from the
|
||||
registry — adding a new benchmark only requires registering it; the
|
||||
CLI surface comes for free. ``add_run_args`` lets each benchmark
|
||||
publish its own flags.
|
||||
|
||||
Design choices worth flagging:
|
||||
|
||||
* ``setup`` rejects ``agent_llm_id == 0`` (Auto / LiteLLM router) so
|
||||
per-question accuracy is reproducible.
|
||||
* ``setup`` validates that the picked LLM config has
|
||||
``provider == "OPENROUTER"`` and ``model_name == --provider-model``
|
||||
before declaring success — both arms of the head-to-head must hit
|
||||
the same OpenRouter slug.
|
||||
* Lifecycle state is keyed by suite, so ``setup --suite legal`` does
|
||||
not touch ``medical``'s SearchSpace, and vice versa.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import sys
|
||||
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
# Windows' legacy console (cp1252) crashes when Rich tries to write characters
|
||||
# outside the active codepage (e.g. '->', em-dashes, box-drawing). Force UTF-8
|
||||
# on stdout/stderr and disable Rich's legacy_windows render path so the file
|
||||
# stream is used directly. Modern Windows (>=10, VS Code terminal, Windows
|
||||
# Terminal, PowerShell, cmd) all interpret ANSI escapes natively.
|
||||
if sys.platform == "win32":
|
||||
for _stream in (sys.stdout, sys.stderr):
|
||||
try:
|
||||
_stream.reconfigure(encoding="utf-8", errors="replace")
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
|
||||
from . import registry
|
||||
from .auth import CredentialError, acquire_token, client_with_auth
|
||||
from .clients import SearchSpaceClient
|
||||
from .clients.search_space import LlmPreferences
|
||||
from .config import (
|
||||
DEFAULT_SCENARIO,
|
||||
SCENARIOS,
|
||||
Config,
|
||||
SuiteState,
|
||||
clear_suite_state,
|
||||
get_suite_state,
|
||||
load_config,
|
||||
set_suite_state,
|
||||
utc_iso_timestamp,
|
||||
)
|
||||
from .vision_llm import VisionConfigError, resolve_vision_llm
|
||||
|
||||
logger = logging.getLogger("surfsense_evals")
|
||||
console = Console(legacy_windows=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _discover_suites() -> list[str]:
|
||||
"""Trigger ``register(...)`` for every benchmark.
|
||||
|
||||
Imported lazily so ``models list`` (which doesn't need any
|
||||
benchmark) still runs fast.
|
||||
"""
|
||||
|
||||
from surfsense_evals.suites import discover_suites
|
||||
|
||||
return discover_suites()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global LLM config fetcher (used by setup + models list)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlmConfigEntry:
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
model_name: str
|
||||
raw: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> LlmConfigEntry:
|
||||
return cls(
|
||||
id=int(payload["id"]),
|
||||
name=str(payload.get("name", "")),
|
||||
provider=str(payload.get("provider", "")).upper(),
|
||||
model_name=str(payload.get("model_name", "")),
|
||||
raw=payload,
|
||||
)
|
||||
|
||||
|
||||
async def _list_global_llm_configs(http: httpx.AsyncClient, base: str) -> list[LlmConfigEntry]:
|
||||
response = await http.get(
|
||||
f"{base}/api/v1/global-new-llm-configs",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if not isinstance(payload, list):
|
||||
raise RuntimeError(f"Unexpected /global-new-llm-configs payload: {payload!r}")
|
||||
return [LlmConfigEntry.from_payload(item) for item in payload]
|
||||
|
||||
|
||||
def _resolve_openrouter_id(
|
||||
candidates: list[LlmConfigEntry],
|
||||
provider_model: str,
|
||||
*,
|
||||
explicit_id: int | None,
|
||||
) -> int:
|
||||
"""Resolve the SurfSense LLM id for ``provider_model``.
|
||||
|
||||
Behaviour:
|
||||
|
||||
* If ``explicit_id`` is given: return it directly. The caller is
|
||||
then expected to GET-validate that the row's
|
||||
``provider == "OPENROUTER"`` and ``model_name`` matches the slug.
|
||||
That branch supports positive BYOK ``NewLLMConfig`` rows whose
|
||||
slugs may overlap with global OpenRouter virtuals.
|
||||
* Otherwise: filter to ``provider == "OPENROUTER"`` and
|
||||
``model_name == provider_model``. Expect exactly one match —
|
||||
raise with a friendly message otherwise.
|
||||
"""
|
||||
|
||||
if explicit_id is not None:
|
||||
return explicit_id
|
||||
|
||||
matches = [
|
||||
c for c in candidates if c.provider == "OPENROUTER" and c.model_name == provider_model
|
||||
]
|
||||
if not matches:
|
||||
sample = ", ".join(
|
||||
f"{c.model_name} (id={c.id})" for c in candidates if c.provider == "OPENROUTER"
|
||||
)[:600]
|
||||
raise RuntimeError(
|
||||
f"No OpenRouter config found for slug '{provider_model}'. "
|
||||
"Make sure `openrouter_integration.enabled: true` in "
|
||||
"global_llm_config.yaml and that the Celery worker has "
|
||||
"finished its first refresh (the catalogue is fetched at "
|
||||
"Celery startup per `app/celery_app.py`). "
|
||||
f"Available OpenRouter slugs (sample): {sample or '<none>'}.\n"
|
||||
"Browse with: python -m surfsense_evals models list --grep <substring>"
|
||||
)
|
||||
if len(matches) > 1:
|
||||
listing = "\n".join(f" id={c.id} name={c.name!r}" for c in matches)
|
||||
raise RuntimeError(
|
||||
f"Multiple OpenRouter configs for slug '{provider_model}':\n{listing}\n"
|
||||
"Pass --agent-llm-id <id> to disambiguate."
|
||||
)
|
||||
return matches[0].id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subcommand implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _cmd_setup(args: argparse.Namespace) -> int:
|
||||
suite = args.suite
|
||||
provider_model: str = args.provider_model
|
||||
explicit_id: int | None = args.agent_llm_id
|
||||
scenario: str = args.scenario
|
||||
vision_llm_slug: str | None = args.vision_llm
|
||||
native_arm_model: str | None = args.native_arm_model
|
||||
skip_vision_setup: bool = args.no_vision_llm_setup
|
||||
|
||||
if explicit_id == 0:
|
||||
console.print(
|
||||
"[red]agent_llm_id == 0 (Auto / LiteLLM router) is not allowed — "
|
||||
"results would not be reproducible.[/red]"
|
||||
)
|
||||
return 2
|
||||
|
||||
if scenario not in SCENARIOS:
|
||||
console.print(
|
||||
f"[red]Unknown scenario {scenario!r}. Pick one of: "
|
||||
f"{', '.join(SCENARIOS)}[/red]"
|
||||
)
|
||||
return 2
|
||||
|
||||
# Scenario-specific validation. Each branch documents WHY the rule
|
||||
# exists so the operator's mental model matches what the runner does.
|
||||
if scenario == "cost-arbitrage":
|
||||
if not native_arm_model:
|
||||
console.print(
|
||||
"[red]--scenario cost-arbitrage requires --native-arm-model "
|
||||
"<vision-capable slug>.[/red] The native arm needs a vision "
|
||||
"model to fairly answer image-bearing questions; SurfSense "
|
||||
"answers from already-extracted text via --provider-model."
|
||||
)
|
||||
return 2
|
||||
if native_arm_model == provider_model:
|
||||
console.print(
|
||||
"[yellow]--native-arm-model equals --provider-model in "
|
||||
"cost-arbitrage; that's degenerate (same as head-to-head). "
|
||||
"Pick a different slug or switch to --scenario head-to-head.[/yellow]"
|
||||
)
|
||||
elif scenario in ("head-to-head", "symmetric-cheap"):
|
||||
if native_arm_model:
|
||||
console.print(
|
||||
f"[yellow]--native-arm-model is ignored for --scenario {scenario} "
|
||||
f"(both arms answer with --provider-model={provider_model!r}).[/yellow]"
|
||||
)
|
||||
native_arm_model = None # don't persist a stale value
|
||||
|
||||
config = load_config()
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
|
||||
async with client_with_auth(config, token) as http:
|
||||
candidates = await _list_global_llm_configs(http, config.surfsense_api_base)
|
||||
|
||||
try:
|
||||
agent_llm_id = _resolve_openrouter_id(
|
||||
candidates, provider_model, explicit_id=explicit_id
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
|
||||
ss_client = SearchSpaceClient(http, config.surfsense_api_base)
|
||||
existing = get_suite_state(config, suite)
|
||||
if existing is not None:
|
||||
try:
|
||||
row = await ss_client.get(existing.search_space_id)
|
||||
console.print(
|
||||
f"Reusing existing SearchSpace [cyan]{row.name}[/cyan] "
|
||||
f"(id={row.id}) for suite [bold]{suite}[/bold]."
|
||||
)
|
||||
search_space_id = row.id
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 404:
|
||||
console.print(
|
||||
f"[yellow]state.json pointed at SearchSpace id={existing.search_space_id} "
|
||||
f"but backend returned 404; creating a fresh one.[/yellow]"
|
||||
)
|
||||
existing = None
|
||||
else:
|
||||
raise
|
||||
if existing is None:
|
||||
ss_name = f"eval-{suite}-{utc_iso_timestamp()}"
|
||||
row = await ss_client.create(
|
||||
ss_name, description=f"surfsense-evals lifecycle ({suite})"
|
||||
)
|
||||
console.print(
|
||||
f"Created SearchSpace [cyan]{row.name}[/cyan] (id={row.id}) "
|
||||
f"for suite [bold]{suite}[/bold]."
|
||||
)
|
||||
search_space_id = row.id
|
||||
|
||||
# Resolve + attach the vision LLM config (unless explicitly skipped).
|
||||
# Asymmetric scenarios make the vision LLM at ingest a hard
|
||||
# requirement — without it, SurfSense's chunks have no image
|
||||
# content and the entire framing collapses.
|
||||
vision_required = scenario in ("symmetric-cheap", "cost-arbitrage")
|
||||
vision_config_id: int | None = None
|
||||
vision_provider_model: str | None = None
|
||||
if not skip_vision_setup and (vision_required or vision_llm_slug is not None):
|
||||
try:
|
||||
vision_candidates = await ss_client.list_global_vision_llm_configs()
|
||||
resolved = resolve_vision_llm(
|
||||
vision_candidates, explicit_slug=vision_llm_slug
|
||||
)
|
||||
except VisionConfigError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
vision_config_id = resolved.config_id
|
||||
vision_provider_model = resolved.provider_model
|
||||
console.print(
|
||||
f"Vision LLM at ingest: [cyan]{vision_provider_model}[/cyan] "
|
||||
f"(id={vision_config_id}, selected_via={resolved.selected_via})."
|
||||
)
|
||||
|
||||
pref_kwargs: dict[str, Any] = {"agent_llm_id": agent_llm_id}
|
||||
if vision_config_id is not None:
|
||||
pref_kwargs["vision_llm_config_id"] = vision_config_id
|
||||
|
||||
await ss_client.set_llm_preferences(search_space_id, **pref_kwargs)
|
||||
prefs = await ss_client.get_llm_preferences(search_space_id)
|
||||
if not _validate_pin(prefs, provider_model):
|
||||
agent = prefs.agent_llm or {}
|
||||
console.print(
|
||||
f"[red]LLM pin validation FAILED.[/red] After PUT, "
|
||||
f"agent_llm.provider={agent.get('provider')!r}, "
|
||||
f"model_name={agent.get('model_name')!r}; expected "
|
||||
f"provider=OPENROUTER, model_name={provider_model!r}."
|
||||
)
|
||||
return 2
|
||||
if vision_config_id is not None and prefs.vision_llm_config_id != vision_config_id:
|
||||
console.print(
|
||||
f"[red]Vision LLM pin validation FAILED.[/red] After PUT, "
|
||||
f"vision_llm_config_id={prefs.vision_llm_config_id!r}; "
|
||||
f"expected {vision_config_id!r}."
|
||||
)
|
||||
return 2
|
||||
|
||||
suite_state = SuiteState(
|
||||
search_space_id=search_space_id,
|
||||
agent_llm_id=agent_llm_id,
|
||||
provider_model=provider_model,
|
||||
created_at=utc_iso_timestamp(),
|
||||
ingestion_maps=existing.ingestion_maps if existing else {},
|
||||
scenario=scenario,
|
||||
vision_llm_config_id=vision_config_id,
|
||||
vision_provider_model=vision_provider_model,
|
||||
native_arm_model=native_arm_model,
|
||||
)
|
||||
set_suite_state(config, suite, suite_state)
|
||||
|
||||
summary_bits = [
|
||||
f"suite={suite!r}",
|
||||
f"scenario={scenario!r}",
|
||||
f"search_space_id={suite_state.search_space_id}",
|
||||
f"agent_llm_id={suite_state.agent_llm_id}",
|
||||
f"provider_model={suite_state.provider_model!r}",
|
||||
]
|
||||
if suite_state.vision_provider_model:
|
||||
summary_bits.append(f"vision_provider_model={suite_state.vision_provider_model!r}")
|
||||
if suite_state.native_arm_model:
|
||||
summary_bits.append(f"native_arm_model={suite_state.native_arm_model!r}")
|
||||
console.print(f"[green]setup OK[/green] {' '.join(summary_bits)}")
|
||||
return 0
|
||||
|
||||
|
||||
def _validate_pin(prefs: LlmPreferences, provider_model: str) -> bool:
|
||||
agent = prefs.agent_llm or {}
|
||||
return (
|
||||
str(agent.get("provider", "")).upper() == "OPENROUTER"
|
||||
and str(agent.get("model_name", "")) == provider_model
|
||||
)
|
||||
|
||||
|
||||
async def _cmd_teardown(args: argparse.Namespace) -> int:
|
||||
suite = args.suite
|
||||
config = load_config()
|
||||
state = get_suite_state(config, suite)
|
||||
if state is None:
|
||||
console.print(f"[yellow]No state for suite {suite!r}; nothing to tear down.[/yellow]")
|
||||
return 0
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
async with client_with_auth(config, token) as http:
|
||||
ss_client = SearchSpaceClient(http, config.surfsense_api_base)
|
||||
try:
|
||||
await ss_client.delete(state.search_space_id)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
console.print(
|
||||
f"[yellow]DELETE failed (HTTP {exc.response.status_code}); "
|
||||
"clearing state.json anyway.[/yellow]"
|
||||
)
|
||||
clear_suite_state(config, suite)
|
||||
console.print(
|
||||
f"[green]teardown OK[/green] suite={suite!r} "
|
||||
f"(SearchSpace soft-deleted, state.json slot cleared)."
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
async def _cmd_models_list(args: argparse.Namespace) -> int:
|
||||
config = load_config()
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
async with client_with_auth(config, token) as http:
|
||||
entries = await _list_global_llm_configs(http, config.surfsense_api_base)
|
||||
grep = (args.grep or "").lower()
|
||||
provider_filter = (args.provider or "").upper()
|
||||
rows: list[LlmConfigEntry] = []
|
||||
for e in entries:
|
||||
if provider_filter and e.provider != provider_filter:
|
||||
continue
|
||||
if grep and grep not in e.model_name.lower() and grep not in e.name.lower():
|
||||
continue
|
||||
rows.append(e)
|
||||
table = Table(
|
||||
title=f"Global LLM configs ({len(rows)} of {len(entries)})",
|
||||
show_lines=False,
|
||||
)
|
||||
table.add_column("id", justify="right", style="cyan")
|
||||
table.add_column("provider", style="magenta")
|
||||
table.add_column("model_name", style="green")
|
||||
table.add_column("name")
|
||||
for e in sorted(rows, key=lambda x: (x.provider, x.model_name)):
|
||||
table.add_row(str(e.id), e.provider, e.model_name, e.name)
|
||||
console.print(table)
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_suites_list(_args: argparse.Namespace) -> int:
|
||||
_discover_suites()
|
||||
suites = registry.list_suites()
|
||||
if not suites:
|
||||
console.print(
|
||||
"[yellow]No suites registered. Drop a benchmark under "
|
||||
"src/surfsense_evals/suites/<domain>/<benchmark>/.[/yellow]"
|
||||
)
|
||||
return 0
|
||||
table = Table(title=f"Registered suites ({len(suites)})")
|
||||
table.add_column("suite", style="bold")
|
||||
table.add_column("benchmarks", style="green")
|
||||
for suite in suites:
|
||||
names = [b.name for b in registry.list_benchmarks(suite)]
|
||||
table.add_row(suite, ", ".join(names) or "<none>")
|
||||
console.print(table)
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_benchmarks_list(args: argparse.Namespace) -> int:
|
||||
_discover_suites()
|
||||
benchmarks = registry.list_benchmarks(args.suite)
|
||||
if not benchmarks:
|
||||
console.print("[yellow]No benchmarks registered.[/yellow]")
|
||||
return 0
|
||||
table = Table(title=f"Benchmarks ({len(benchmarks)})")
|
||||
table.add_column("suite", style="bold")
|
||||
table.add_column("name", style="cyan")
|
||||
table.add_column("headline", justify="center")
|
||||
table.add_column("description")
|
||||
for b in benchmarks:
|
||||
table.add_row(
|
||||
b.suite,
|
||||
b.name,
|
||||
"yes" if b.headline else "no",
|
||||
getattr(b, "description", ""),
|
||||
)
|
||||
console.print(table)
|
||||
return 0
|
||||
|
||||
|
||||
async def _cmd_ingest(args: argparse.Namespace) -> int:
|
||||
benchmark = registry.get(args.suite, args.benchmark)
|
||||
config = load_config()
|
||||
state = get_suite_state(config, args.suite)
|
||||
if state is None:
|
||||
console.print(
|
||||
f"[red]No setup for suite {args.suite!r}. Run "
|
||||
f"`python -m surfsense_evals setup --suite {args.suite} "
|
||||
f"--provider-model <slug>` first.[/red]"
|
||||
)
|
||||
return 2
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
|
||||
# Forward parsed CLI flags into ingest() so a benchmark can honour
|
||||
# its own flags (e.g. MIRAGE's --skip-snippet-filter / --corpus).
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if k not in {"_func", "_async", "command", "subcommand", "suite", "benchmark", "log_level"}
|
||||
}
|
||||
async with client_with_auth(config, token) as http:
|
||||
ctx = registry.RunContext(
|
||||
suite=args.suite,
|
||||
benchmark=args.benchmark,
|
||||
config=config,
|
||||
suite_state=state,
|
||||
http=http,
|
||||
)
|
||||
await benchmark.ingest(ctx, **extra_kwargs)
|
||||
console.print(f"[green]ingest OK[/green] {args.suite}/{args.benchmark}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _cmd_run(args: argparse.Namespace) -> int:
|
||||
benchmark = registry.get(args.suite, args.benchmark)
|
||||
config = load_config()
|
||||
state = get_suite_state(config, args.suite)
|
||||
if state is None:
|
||||
console.print(
|
||||
f"[red]No setup for suite {args.suite!r}. Run "
|
||||
f"`python -m surfsense_evals setup --suite {args.suite} "
|
||||
f"--provider-model <slug>` first.[/red]"
|
||||
)
|
||||
return 2
|
||||
try:
|
||||
token = await acquire_token(config)
|
||||
except CredentialError as exc:
|
||||
console.print(f"[red]{exc}[/red]")
|
||||
return 2
|
||||
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if k not in {"_func", "_async", "command", "subcommand", "suite", "benchmark", "log_level"}
|
||||
}
|
||||
async with client_with_auth(config, token) as http:
|
||||
ctx = registry.RunContext(
|
||||
suite=args.suite,
|
||||
benchmark=args.benchmark,
|
||||
config=config,
|
||||
suite_state=state,
|
||||
http=http,
|
||||
)
|
||||
artifact = await benchmark.run(ctx, **extra_kwargs)
|
||||
|
||||
console.print(
|
||||
f"[green]run OK[/green] {args.suite}/{args.benchmark} → "
|
||||
f"{artifact.raw_path}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
async def _cmd_report(args: argparse.Namespace) -> int:
|
||||
from .report import write_report
|
||||
|
||||
benchmark_filter = args.benchmark
|
||||
config = load_config()
|
||||
state = get_suite_state(config, args.suite)
|
||||
if state is None:
|
||||
console.print(f"[red]No setup for suite {args.suite!r}.[/red]")
|
||||
return 2
|
||||
benchmarks = registry.list_benchmarks(args.suite)
|
||||
if benchmark_filter:
|
||||
benchmarks = [b for b in benchmarks if b.name == benchmark_filter]
|
||||
if not benchmarks:
|
||||
console.print(
|
||||
f"[red]No registered benchmark named {benchmark_filter!r} in suite {args.suite!r}.[/red]"
|
||||
)
|
||||
return 2
|
||||
|
||||
artifacts = _collect_artifacts(config, args.suite, [b.name for b in benchmarks])
|
||||
if not artifacts:
|
||||
console.print(
|
||||
"[yellow]No run artifacts found under "
|
||||
f"{config.suite_runs_dir(args.suite)}. Run a benchmark first.[/yellow]"
|
||||
)
|
||||
return 1
|
||||
|
||||
grouped: dict[str, list[registry.RunArtifact]] = {}
|
||||
for art in artifacts:
|
||||
grouped.setdefault(art.benchmark, []).append(art)
|
||||
sections: list[registry.ReportSection] = []
|
||||
for benchmark in benchmarks:
|
||||
if benchmark.name not in grouped:
|
||||
continue
|
||||
sections.append(benchmark.report_section(grouped[benchmark.name]))
|
||||
|
||||
summary_path = write_report(
|
||||
config=config,
|
||||
suite=args.suite,
|
||||
sections=sections,
|
||||
run_timestamp=utc_iso_timestamp(),
|
||||
)
|
||||
console.print(f"[green]report OK[/green] → {summary_path}")
|
||||
return 0
|
||||
|
||||
|
||||
def _collect_artifacts(
|
||||
config: Config, suite: str, benchmark_names: list[str]
|
||||
) -> list[registry.RunArtifact]:
|
||||
"""Walk ``data/<suite>/runs/*/<benchmark>/`` for the latest artifacts.
|
||||
|
||||
Reads any ``run_artifact.json`` written by a benchmark runner. The
|
||||
runner is responsible for writing this manifest alongside its raw
|
||||
JSONL so the report writer doesn't have to know benchmark-specific
|
||||
metric shapes.
|
||||
"""
|
||||
|
||||
runs_dir = config.suite_runs_dir(suite)
|
||||
if not runs_dir.exists():
|
||||
return []
|
||||
artifacts: list[registry.RunArtifact] = []
|
||||
by_bench: dict[str, registry.RunArtifact] = {}
|
||||
for ts_dir in sorted(runs_dir.iterdir()):
|
||||
if not ts_dir.is_dir():
|
||||
continue
|
||||
for bench_name in benchmark_names:
|
||||
bench_dir = ts_dir / bench_name
|
||||
manifest = bench_dir / "run_artifact.json"
|
||||
if not manifest.exists():
|
||||
continue
|
||||
try:
|
||||
with manifest.open("r", encoding="utf-8") as fh:
|
||||
payload = json.load(fh)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
continue
|
||||
artifact = registry.RunArtifact(
|
||||
suite=suite,
|
||||
benchmark=bench_name,
|
||||
run_timestamp=ts_dir.name,
|
||||
raw_path=bench_dir / payload.get("raw_path", "raw.jsonl"),
|
||||
metrics=payload.get("metrics", {}),
|
||||
extra=payload.get("extra", {}),
|
||||
)
|
||||
# Latest run wins per benchmark.
|
||||
by_bench[bench_name] = artifact
|
||||
artifacts = list(by_bench.values())
|
||||
return artifacts
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Argparse wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="surfsense-evals",
|
||||
description="SurfSense evaluation harness — domain-agnostic core + pluggable suites.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
p_setup = sub.add_parser("setup", help="Create per-suite SearchSpace + pin LLM.")
|
||||
p_setup.add_argument("--suite", required=True)
|
||||
p_setup.add_argument(
|
||||
"--provider-model",
|
||||
required=True,
|
||||
help=(
|
||||
"OpenRouter slug for the SurfSense answer LLM (and the native arm "
|
||||
"too unless --native-arm-model is set), e.g. "
|
||||
"'anthropic/claude-sonnet-4.5'."
|
||||
),
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--agent-llm-id",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Optional override for BYOK NewLLMConfig rows.",
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--scenario",
|
||||
choices=SCENARIOS,
|
||||
default=DEFAULT_SCENARIO,
|
||||
help=(
|
||||
"head-to-head (default): both arms answer with --provider-model; "
|
||||
"symmetric-cheap: both arms use the same cheap text-only slug, "
|
||||
"SurfSense pre-extracted images at ingest with a vision LLM; "
|
||||
"cost-arbitrage: native arm uses --native-arm-model (vision), "
|
||||
"SurfSense uses --provider-model (cheap, text-only) over chunks "
|
||||
"the vision LLM already extracted at ingest."
|
||||
),
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--vision-llm",
|
||||
default=None,
|
||||
metavar="SLUG",
|
||||
help=(
|
||||
"OpenRouter slug for the vision LLM SurfSense uses at ingest "
|
||||
"when --use-vision-llm is on. If omitted in symmetric-cheap / "
|
||||
"cost-arbitrage, the strongest registered vision config is "
|
||||
"auto-picked (priority: claude-sonnet-4.5 > claude-opus-4.7 > "
|
||||
"gpt-5 > gemini-2.5-pro)."
|
||||
),
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--native-arm-model",
|
||||
default=None,
|
||||
metavar="SLUG",
|
||||
help=(
|
||||
"Required for --scenario cost-arbitrage. OpenRouter slug used "
|
||||
"by the native_pdf arm only; SurfSense answers with "
|
||||
"--provider-model. Ignored for head-to-head / symmetric-cheap."
|
||||
),
|
||||
)
|
||||
p_setup.add_argument(
|
||||
"--no-vision-llm-setup",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Skip attaching a vision LLM config to the SearchSpace even if "
|
||||
"the scenario would normally require one. Use when you want to "
|
||||
"keep whatever is already attached (e.g. a per-user config)."
|
||||
),
|
||||
)
|
||||
p_setup.set_defaults(_func=_cmd_setup, _async=True)
|
||||
|
||||
p_teardown = sub.add_parser("teardown", help="Soft-delete the suite SearchSpace + clear state slot.")
|
||||
p_teardown.add_argument("--suite", required=True)
|
||||
p_teardown.set_defaults(_func=_cmd_teardown, _async=True)
|
||||
|
||||
p_models = sub.add_parser("models", help="LLM-config discovery helpers.")
|
||||
models_sub = p_models.add_subparsers(dest="subcommand", required=True)
|
||||
p_models_list = models_sub.add_parser("list", help="List global LLM configs.")
|
||||
p_models_list.add_argument("--provider", default=None, help="Filter by provider, e.g. openrouter")
|
||||
p_models_list.add_argument("--grep", default=None, help="Substring filter on name / model_name.")
|
||||
p_models_list.set_defaults(_func=_cmd_models_list, _async=True)
|
||||
|
||||
p_suites = sub.add_parser("suites", help="List registered suites.")
|
||||
suites_sub = p_suites.add_subparsers(dest="subcommand", required=True)
|
||||
p_suites_list = suites_sub.add_parser("list", help="List suites.")
|
||||
p_suites_list.set_defaults(_func=_cmd_suites_list, _async=False)
|
||||
|
||||
p_benchmarks = sub.add_parser("benchmarks", help="List registered benchmarks.")
|
||||
bench_sub = p_benchmarks.add_subparsers(dest="subcommand", required=True)
|
||||
p_bench_list = bench_sub.add_parser("list", help="List benchmarks.")
|
||||
p_bench_list.add_argument("--suite", default=None)
|
||||
p_bench_list.set_defaults(_func=_cmd_benchmarks_list, _async=False)
|
||||
|
||||
# Dynamic ingest / run subcommands need the registry populated, so
|
||||
# discover up-front (cheap on import — modules just register).
|
||||
_discover_suites()
|
||||
|
||||
p_ingest = sub.add_parser("ingest", help="Ingest a benchmark's corpus.")
|
||||
ingest_sub = p_ingest.add_subparsers(dest="suite", required=True)
|
||||
for suite in registry.list_suites():
|
||||
suite_parser = ingest_sub.add_parser(suite, help=f"Ingest a {suite} benchmark.")
|
||||
suite_bench = suite_parser.add_subparsers(dest="benchmark", required=True)
|
||||
for benchmark in registry.list_benchmarks(suite):
|
||||
bp = suite_bench.add_parser(benchmark.name, help=getattr(benchmark, "description", benchmark.name))
|
||||
if hasattr(benchmark, "add_run_args"):
|
||||
benchmark.add_run_args(bp)
|
||||
bp.set_defaults(_func=_cmd_ingest, _async=True)
|
||||
|
||||
p_run = sub.add_parser("run", help="Run a benchmark.")
|
||||
run_sub = p_run.add_subparsers(dest="suite", required=True)
|
||||
for suite in registry.list_suites():
|
||||
suite_parser = run_sub.add_parser(suite, help=f"Run a {suite} benchmark.")
|
||||
suite_bench = suite_parser.add_subparsers(dest="benchmark", required=True)
|
||||
for benchmark in registry.list_benchmarks(suite):
|
||||
bp = suite_bench.add_parser(benchmark.name, help=getattr(benchmark, "description", benchmark.name))
|
||||
if hasattr(benchmark, "add_run_args"):
|
||||
benchmark.add_run_args(bp)
|
||||
bp.set_defaults(_func=_cmd_run, _async=True)
|
||||
|
||||
p_report = sub.add_parser("report", help="Aggregate latest run artifacts into a summary.")
|
||||
p_report.add_argument("--suite", required=True)
|
||||
p_report.add_argument("--benchmark", default=None, help="Optional: report only this benchmark.")
|
||||
p_report.set_defaults(_func=_cmd_report, _async=True)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
func = getattr(args, "_func", None)
|
||||
if func is None:
|
||||
parser.print_help()
|
||||
return 2
|
||||
is_async = getattr(args, "_async", False)
|
||||
try:
|
||||
if is_async:
|
||||
return asyncio.run(func(args))
|
||||
return func(args)
|
||||
except KeyboardInterrupt:
|
||||
console.print("[yellow]Interrupted.[/yellow]")
|
||||
return 130
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("CLI command failed")
|
||||
console.print(f"[red]Command failed: {exc}[/red]")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
sys.exit(main())
|
||||
14
surfsense_evals/src/surfsense_evals/core/clients/__init__.py
Normal file
14
surfsense_evals/src/surfsense_evals/core/clients/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""HTTP clients for the SurfSense API. All share one ``httpx.AsyncClient``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .documents import DocumentsClient
|
||||
from .new_chat import NewChatClient, StreamedAnswer
|
||||
from .search_space import SearchSpaceClient
|
||||
|
||||
__all__ = [
|
||||
"DocumentsClient",
|
||||
"NewChatClient",
|
||||
"SearchSpaceClient",
|
||||
"StreamedAnswer",
|
||||
]
|
||||
277
surfsense_evals/src/surfsense_evals/core/clients/documents.py
Normal file
277
surfsense_evals/src/surfsense_evals/core/clients/documents.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
"""Client for ``/api/v1/documents/{fileupload,status,{id}/chunks}``.
|
||||
|
||||
Verified against:
|
||||
|
||||
* ``surfsense_backend/app/routes/documents_routes.py:122-292`` (POST fileupload)
|
||||
* ``surfsense_backend/app/routes/documents_routes.py:806-871`` (GET status batch)
|
||||
* ``surfsense_backend/app/routes/documents_routes.py:1062-1128`` (GET {id}/chunks paginated)
|
||||
|
||||
Document processing is asynchronous:
|
||||
* ``POST /documents/fileupload`` returns immediately with
|
||||
``document_ids`` in ``pending``;
|
||||
* a Celery worker moves each through ``processing → ready/failed``;
|
||||
* the harness polls ``GET /documents/status?document_ids=...`` until
|
||||
every doc is ``ready`` (otherwise the retriever sees an empty corpus
|
||||
and accuracy numbers are meaningless).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileUploadResult:
|
||||
"""Mirrors the JSON returned by ``POST /documents/fileupload``."""
|
||||
|
||||
document_ids: list[int]
|
||||
duplicate_document_ids: list[int]
|
||||
total_files: int
|
||||
pending_files: int
|
||||
skipped_duplicates: int
|
||||
message: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> FileUploadResult:
|
||||
return cls(
|
||||
document_ids=[int(x) for x in payload.get("document_ids", [])],
|
||||
duplicate_document_ids=[int(x) for x in payload.get("duplicate_document_ids", [])],
|
||||
total_files=int(payload.get("total_files", 0)),
|
||||
pending_files=int(payload.get("pending_files", 0)),
|
||||
skipped_duplicates=int(payload.get("skipped_duplicates", 0)),
|
||||
message=str(payload.get("message", "")),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentStatus:
|
||||
document_id: int
|
||||
title: str
|
||||
document_type: str
|
||||
state: str
|
||||
reason: str | None = None
|
||||
|
||||
@property
|
||||
def is_ready(self) -> bool:
|
||||
return self.state == "ready"
|
||||
|
||||
@property
|
||||
def is_failed(self) -> bool:
|
||||
return self.state == "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkRow:
|
||||
id: int
|
||||
document_id: int
|
||||
content: str = ""
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DocumentProcessingFailed(RuntimeError):
|
||||
"""Raised when a polled document lands in ``failed``."""
|
||||
|
||||
def __init__(self, statuses: Sequence[DocumentStatus]) -> None:
|
||||
details = ", ".join(
|
||||
f"id={s.document_id} ({s.title!r}): {s.reason or 'unknown'}"
|
||||
for s in statuses
|
||||
)
|
||||
super().__init__(f"Document(s) failed to process: {details}")
|
||||
self.statuses = list(statuses)
|
||||
|
||||
|
||||
class DocumentProcessingTimeout(RuntimeError):
|
||||
"""Raised when polling exceeds the per-doc timeout budget."""
|
||||
|
||||
|
||||
class DocumentsClient:
|
||||
"""Document upload + status polling + chunk listing."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient, base_url: str) -> None:
|
||||
self._http = http
|
||||
self._base = base_url.rstrip("/")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# upload
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
files: Iterable[Path],
|
||||
*,
|
||||
search_space_id: int,
|
||||
should_summarize: bool = False,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
) -> FileUploadResult:
|
||||
"""Upload files to ``/api/v1/documents/fileupload``.
|
||||
|
||||
``files`` is materialised to a list because we may need to
|
||||
re-read on retry. Caller is responsible for ensuring each path
|
||||
exists and respects the per-file size cap (50 MB backend default).
|
||||
"""
|
||||
|
||||
materialised = [Path(p) for p in files]
|
||||
if not materialised:
|
||||
return FileUploadResult(
|
||||
document_ids=[],
|
||||
duplicate_document_ids=[],
|
||||
total_files=0,
|
||||
pending_files=0,
|
||||
skipped_duplicates=0,
|
||||
message="No files supplied",
|
||||
)
|
||||
|
||||
opened: list[tuple[str, Any]] = []
|
||||
try:
|
||||
for path in materialised:
|
||||
# ``open`` directly — httpx wraps it in MultipartStream.
|
||||
file_obj = path.open("rb")
|
||||
mime, _ = mimetypes.guess_type(path.name)
|
||||
opened.append(
|
||||
(
|
||||
"files",
|
||||
(path.name, file_obj, mime or "application/octet-stream"),
|
||||
)
|
||||
)
|
||||
|
||||
response = await self._http.post(
|
||||
f"{self._base}/api/v1/documents/fileupload",
|
||||
data={
|
||||
"search_space_id": str(search_space_id),
|
||||
"should_summarize": "true" if should_summarize else "false",
|
||||
"use_vision_llm": "true" if use_vision_llm else "false",
|
||||
"processing_mode": processing_mode,
|
||||
},
|
||||
files=opened,
|
||||
# Multipart uploads can be slow for big PDFs; bump per-call.
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
finally:
|
||||
for _, (_, file_obj, _) in opened:
|
||||
try:
|
||||
file_obj.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
response.raise_for_status()
|
||||
return FileUploadResult.from_payload(response.json())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# status polling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_status(
|
||||
self, *, search_space_id: int, document_ids: Sequence[int]
|
||||
) -> list[DocumentStatus]:
|
||||
if not document_ids:
|
||||
return []
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/documents/status",
|
||||
params={
|
||||
"search_space_id": search_space_id,
|
||||
"document_ids": ",".join(str(d) for d in document_ids),
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
return [
|
||||
DocumentStatus(
|
||||
document_id=int(item["id"]),
|
||||
title=str(item.get("title", "")),
|
||||
document_type=str(item.get("document_type", "")),
|
||||
state=str((item.get("status") or {}).get("state", "ready")),
|
||||
reason=(item.get("status") or {}).get("reason"),
|
||||
)
|
||||
for item in payload.get("items", [])
|
||||
]
|
||||
|
||||
async def wait_until_ready(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
document_ids: Sequence[int],
|
||||
timeout_s: float = 300.0,
|
||||
initial_poll_s: float = 1.0,
|
||||
max_poll_s: float = 10.0,
|
||||
) -> list[DocumentStatus]:
|
||||
"""Poll ``GET /documents/status`` until every doc is ``ready``.
|
||||
|
||||
Exponential backoff from ``initial_poll_s`` up to ``max_poll_s``.
|
||||
Raises ``DocumentProcessingFailed`` if any doc lands in
|
||||
``failed`` (with the offending document ids), or
|
||||
``DocumentProcessingTimeout`` if the budget is exhausted.
|
||||
"""
|
||||
|
||||
if not document_ids:
|
||||
return []
|
||||
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||
poll = initial_poll_s
|
||||
while True:
|
||||
statuses = await self.get_status(
|
||||
search_space_id=search_space_id, document_ids=document_ids
|
||||
)
|
||||
failed = [s for s in statuses if s.is_failed]
|
||||
if failed:
|
||||
raise DocumentProcessingFailed(failed)
|
||||
ready = [s for s in statuses if s.is_ready]
|
||||
if len(ready) == len(document_ids):
|
||||
return statuses
|
||||
now = asyncio.get_event_loop().time()
|
||||
if now >= deadline:
|
||||
pending = [s for s in statuses if not s.is_ready and not s.is_failed]
|
||||
pending_ids = [s.document_id for s in pending]
|
||||
raise DocumentProcessingTimeout(
|
||||
f"Timed out after {timeout_s:.0f}s waiting for documents "
|
||||
f"(still pending/processing: {pending_ids})"
|
||||
)
|
||||
await asyncio.sleep(min(poll, max(0.1, deadline - now)))
|
||||
poll = min(poll * 1.5, max_poll_s)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# chunks (chunk_id -> document_id map)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def list_chunks(
|
||||
self, document_id: int, *, page_size: int = 100
|
||||
) -> list[ChunkRow]:
|
||||
"""Walk ``GET /documents/{id}/chunks`` until ``has_more=False``.
|
||||
|
||||
Used by ingestion to materialise the ``chunk_id -> document_id``
|
||||
map needed for retrieval scoring (CUREv1).
|
||||
"""
|
||||
|
||||
rows: list[ChunkRow] = []
|
||||
page = 0
|
||||
while True:
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/documents/{document_id}/chunks",
|
||||
params={"page": page, "page_size": page_size},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
for item in payload.get("items", []):
|
||||
rows.append(
|
||||
ChunkRow(
|
||||
id=int(item["id"]),
|
||||
document_id=document_id,
|
||||
content=str(item.get("content", "")),
|
||||
raw=item,
|
||||
)
|
||||
)
|
||||
if not payload.get("has_more"):
|
||||
break
|
||||
page += 1
|
||||
return rows
|
||||
280
surfsense_evals/src/surfsense_evals/core/clients/new_chat.py
Normal file
280
surfsense_evals/src/surfsense_evals/core/clients/new_chat.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""Client for ``/api/v1/threads`` and ``/api/v1/new_chat`` (SSE).
|
||||
|
||||
Verified against:
|
||||
|
||||
* ``surfsense_backend/app/routes/new_chat_routes.py:793-848`` (POST /threads)
|
||||
* ``surfsense_backend/app/routes/new_chat_routes.py:1073-1142`` (DELETE /threads/{id})
|
||||
* ``surfsense_backend/app/routes/new_chat_routes.py:1689-1800`` (POST /new_chat SSE)
|
||||
* ``surfsense_backend/app/routes/new_chat_routes.py:191-220`` (THREAD_BUSY / TURN_CANCELLING 409)
|
||||
* ``surfsense_backend/app/services/streaming/envelope/sse.py`` (wire framing)
|
||||
* ``surfsense_backend/app/services/streaming/events/text.py`` (text-delta events)
|
||||
* ``surfsense_backend/app/schemas/new_chat.py:234-288`` (NewChatRequest body)
|
||||
|
||||
The wire format is "Vercel AI SDK"-flavoured SSE with one event per
|
||||
``data: <json>\n\n`` block (or the literal ``data: [DONE]\n\n``
|
||||
terminator). Text deltas arrive as ``{"type":"text-delta","id":...,"delta":...}``
|
||||
events; we accumulate them per ``id`` and emit the final concatenated
|
||||
text plus parsed citations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from ..parse import iter_sse_events, parse_citations
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamedAnswer:
|
||||
"""Result of a single ``/new_chat`` turn."""
|
||||
|
||||
text: str
|
||||
raw_events: list[dict[str, Any]] = field(default_factory=list)
|
||||
latency_ms: int = 0
|
||||
user_message_id: str | None = None
|
||||
assistant_message_id: str | None = None
|
||||
finished_normally: bool = False
|
||||
|
||||
@property
|
||||
def citations(self) -> list[dict[str, Any]]:
|
||||
"""Parsed citation tokens (lazy; small enough to recompute)."""
|
||||
|
||||
return [token.to_dict() for token in parse_citations(self.text)]
|
||||
|
||||
|
||||
class ThreadBusyError(RuntimeError):
|
||||
"""Raised after exhausting retries on a 409 ``THREAD_BUSY`` / ``TURN_CANCELLING``."""
|
||||
|
||||
def __init__(self, error_code: str, message: str) -> None:
|
||||
super().__init__(f"{error_code}: {message}")
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class NewChatClient:
|
||||
"""Thread create / delete / SSE ask."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient, base_url: str) -> None:
|
||||
self._http = http
|
||||
self._base = base_url.rstrip("/")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# threads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def create_thread(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
title: str = "eval",
|
||||
archived: bool = False,
|
||||
visibility: str = "PRIVATE",
|
||||
) -> int:
|
||||
response = await self._http.post(
|
||||
f"{self._base}/api/v1/threads",
|
||||
json={
|
||||
"search_space_id": search_space_id,
|
||||
"title": title,
|
||||
"archived": archived,
|
||||
"visibility": visibility,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
return int(payload["id"])
|
||||
|
||||
async def delete_thread(self, thread_id: int) -> None:
|
||||
response = await self._http.delete(
|
||||
f"{self._base}/api/v1/threads/{thread_id}",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return # idempotent
|
||||
response.raise_for_status()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /new_chat SSE
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def ask(
|
||||
self,
|
||||
*,
|
||||
thread_id: int,
|
||||
search_space_id: int,
|
||||
user_query: str,
|
||||
mentioned_document_ids: Sequence[int] | None = None,
|
||||
disabled_tools: Sequence[str] | None = None,
|
||||
max_busy_retries: int = 4,
|
||||
timeout_s: float = 600.0,
|
||||
) -> StreamedAnswer:
|
||||
"""Stream a single turn and return the accumulated answer.
|
||||
|
||||
Honours backend ``THREAD_BUSY`` / ``TURN_CANCELLING`` 409
|
||||
responses by sleeping for the ``Retry-After`` header (or the
|
||||
``retry-after-ms`` header if present) and replaying. Bounded
|
||||
by ``max_busy_retries`` so a stuck thread never blocks the
|
||||
whole run.
|
||||
"""
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"chat_id": thread_id,
|
||||
"search_space_id": search_space_id,
|
||||
"user_query": user_query,
|
||||
}
|
||||
if mentioned_document_ids:
|
||||
body["mentioned_document_ids"] = list(mentioned_document_ids)
|
||||
if disabled_tools:
|
||||
body["disabled_tools"] = list(disabled_tools)
|
||||
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return await self._stream_once(body=body, timeout_s=timeout_s)
|
||||
except ThreadBusyError as exc:
|
||||
attempt += 1
|
||||
if attempt > max_busy_retries:
|
||||
raise
|
||||
# Cap wait at 30s; backend retry hint is exponential anyway.
|
||||
wait = min(30.0, 0.5 * (2 ** attempt))
|
||||
logger.info(
|
||||
"thread_id=%s busy (%s); retry %d/%d after %.1fs",
|
||||
thread_id,
|
||||
exc.error_code,
|
||||
attempt,
|
||||
max_busy_retries,
|
||||
wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
async def _stream_once(
|
||||
self,
|
||||
*,
|
||||
body: dict[str, Any],
|
||||
timeout_s: float,
|
||||
) -> StreamedAnswer:
|
||||
# Per-call timeout — the connect should be quick, the read needs
|
||||
# to outlive the longest LLM completion.
|
||||
timeout = httpx.Timeout(timeout_s, connect=10.0)
|
||||
started = time.monotonic()
|
||||
async with self._http.stream(
|
||||
"POST",
|
||||
f"{self._base}/api/v1/new_chat",
|
||||
json=body,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status_code == 409:
|
||||
detail = await self._extract_busy_detail(response)
|
||||
raise ThreadBusyError(
|
||||
error_code=detail.get("errorCode", "THREAD_BUSY"),
|
||||
message=detail.get("message", "Thread is busy"),
|
||||
)
|
||||
response.raise_for_status()
|
||||
answer = await self._consume_sse(response)
|
||||
answer.latency_ms = int((time.monotonic() - started) * 1000)
|
||||
return answer
|
||||
|
||||
@staticmethod
|
||||
async def _extract_busy_detail(response: httpx.Response) -> dict[str, Any]:
|
||||
try:
|
||||
payload = json.loads(await response.aread())
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return {"errorCode": "THREAD_BUSY", "message": response.text}
|
||||
if isinstance(payload, dict) and isinstance(payload.get("detail"), dict):
|
||||
return payload["detail"]
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
async def _consume_sse(response: httpx.Response) -> StreamedAnswer:
|
||||
"""Walk SSE events, accumulate text-delta payloads.
|
||||
|
||||
Backend events of interest:
|
||||
|
||||
* ``{"type": "text-start", "id": ...}``
|
||||
* ``{"type": "text-delta", "id": ..., "delta": ...}``
|
||||
* ``{"type": "text-end", "id": ...}``
|
||||
* ``{"type": "start", "messageId": ...}`` (top-level message id)
|
||||
* ``{"type": "finish"}``
|
||||
* literal ``[DONE]`` sentinel
|
||||
|
||||
Multiple ``text-start`` blocks can interleave — each gets its
|
||||
own ``id`` and we concatenate them in arrival order. That
|
||||
mirrors the AI SDK client behaviour: one continuous assistant
|
||||
message visible to the user.
|
||||
"""
|
||||
|
||||
ordered_text_ids: list[str] = []
|
||||
text_buffers: dict[str, list[str]] = {}
|
||||
raw_events: list[dict[str, Any]] = []
|
||||
user_message_id: str | None = None
|
||||
assistant_message_id: str | None = None
|
||||
finished = False
|
||||
|
||||
async for event in iter_sse_events(_aiter_lines(response)):
|
||||
data = event.data
|
||||
if data == "[DONE]":
|
||||
finished = True
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(data)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("Skipping non-JSON SSE payload: %r", data[:120])
|
||||
continue
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
raw_events.append(payload)
|
||||
ev_type = payload.get("type")
|
||||
if ev_type == "text-delta":
|
||||
tid = str(payload.get("id", ""))
|
||||
delta = payload.get("delta", "")
|
||||
if not isinstance(delta, str):
|
||||
continue
|
||||
if tid not in text_buffers:
|
||||
text_buffers[tid] = []
|
||||
ordered_text_ids.append(tid)
|
||||
text_buffers[tid].append(delta)
|
||||
elif ev_type == "text-start":
|
||||
tid = str(payload.get("id", ""))
|
||||
if tid and tid not in text_buffers:
|
||||
text_buffers[tid] = []
|
||||
ordered_text_ids.append(tid)
|
||||
elif ev_type == "start":
|
||||
msg_id = payload.get("messageId")
|
||||
if isinstance(msg_id, str):
|
||||
user_message_id = user_message_id or msg_id
|
||||
elif ev_type == "data-user-message-id":
|
||||
msg_id = (payload.get("data") or {}).get("id") or payload.get("id")
|
||||
if isinstance(msg_id, str):
|
||||
user_message_id = msg_id
|
||||
elif ev_type == "data-assistant-message-id":
|
||||
msg_id = (payload.get("data") or {}).get("id") or payload.get("id")
|
||||
if isinstance(msg_id, str):
|
||||
assistant_message_id = msg_id
|
||||
elif ev_type == "finish":
|
||||
finished = True
|
||||
|
||||
text = "".join("".join(text_buffers.get(tid, [])) for tid in ordered_text_ids)
|
||||
return StreamedAnswer(
|
||||
text=text,
|
||||
raw_events=raw_events,
|
||||
user_message_id=user_message_id,
|
||||
assistant_message_id=assistant_message_id,
|
||||
finished_normally=finished,
|
||||
)
|
||||
|
||||
|
||||
async def _aiter_lines(response: httpx.Response) -> AsyncIterator[str]:
|
||||
"""Adapter so the parser can consume any line iterator (mockable in tests)."""
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
207
surfsense_evals/src/surfsense_evals/core/clients/search_space.py
Normal file
207
surfsense_evals/src/surfsense_evals/core/clients/search_space.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
"""Client for ``/api/v1/searchspaces`` and ``/api/v1/search-spaces/{id}/llm-preferences``.
|
||||
|
||||
Verified against:
|
||||
|
||||
* ``surfsense_backend/app/routes/search_spaces_routes.py:116`` (POST create)
|
||||
* ``surfsense_backend/app/routes/search_spaces_routes.py:234`` (GET by id)
|
||||
* ``surfsense_backend/app/routes/search_spaces_routes.py:422`` (DELETE soft-delete)
|
||||
* ``surfsense_backend/app/routes/search_spaces_routes.py:698-849`` (GET/PUT llm-preferences)
|
||||
* ``surfsense_backend/app/schemas/search_space.py:14`` (SearchSpaceCreate body)
|
||||
* ``surfsense_backend/app/routes/vision_llm_routes.py:60`` (GET global vision configs)
|
||||
|
||||
Note the inconsistent pluralisation in the backend: ``/searchspaces``
|
||||
(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for the
|
||||
``llm-preferences`` sub-resource. Both are mirrored verbatim here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchSpaceRow:
|
||||
"""Subset of the SearchSpace row we care about."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
description: str | None
|
||||
user_id: str
|
||||
citations_enabled: bool
|
||||
qna_custom_instructions: str | None
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> SearchSpaceRow:
|
||||
return cls(
|
||||
id=int(payload["id"]),
|
||||
name=str(payload["name"]),
|
||||
description=payload.get("description"),
|
||||
user_id=str(payload.get("user_id", "")),
|
||||
citations_enabled=bool(payload.get("citations_enabled", True)),
|
||||
qna_custom_instructions=payload.get("qna_custom_instructions"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionLlmConfigEntry:
|
||||
"""Subset of one ``GET /global-vision-llm-configs`` row.
|
||||
|
||||
The backend returns negative ids for global / OpenRouter-derived
|
||||
vision configs and positive ids for per-user BYOK rows. Either is
|
||||
accepted by ``set_llm_preferences(vision_llm_config_id=...)``.
|
||||
"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
model_name: str
|
||||
is_auto_mode: bool
|
||||
raw: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> VisionLlmConfigEntry:
|
||||
return cls(
|
||||
id=int(payload.get("id", 0)),
|
||||
name=str(payload.get("name", "")),
|
||||
provider=str(payload.get("provider", "")).upper(),
|
||||
model_name=str(payload.get("model_name", "")),
|
||||
is_auto_mode=bool(payload.get("is_auto_mode", False)),
|
||||
raw=payload,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlmPreferences:
|
||||
"""Resolved LLM preferences with the embedded full config row.
|
||||
|
||||
Mirrors ``LLMPreferencesRead`` from the backend so the lifecycle
|
||||
command can introspect ``provider`` / ``model_name`` to validate the
|
||||
OpenRouter pin.
|
||||
"""
|
||||
|
||||
agent_llm_id: int | None
|
||||
document_summary_llm_id: int | None
|
||||
image_generation_config_id: int | None
|
||||
vision_llm_config_id: int | None
|
||||
agent_llm: dict[str, Any] | None
|
||||
raw: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> LlmPreferences:
|
||||
return cls(
|
||||
agent_llm_id=payload.get("agent_llm_id"),
|
||||
document_summary_llm_id=payload.get("document_summary_llm_id"),
|
||||
image_generation_config_id=payload.get("image_generation_config_id"),
|
||||
vision_llm_config_id=payload.get("vision_llm_config_id"),
|
||||
agent_llm=payload.get("agent_llm"),
|
||||
raw=payload,
|
||||
)
|
||||
|
||||
|
||||
class SearchSpaceClient:
|
||||
"""Thin wrapper around the SearchSpace + LLM preferences endpoints."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient, base_url: str) -> None:
|
||||
self._http = http
|
||||
self._base = base_url.rstrip("/")
|
||||
|
||||
async def create(self, name: str, *, description: str | None = None) -> SearchSpaceRow:
|
||||
body: dict[str, Any] = {"name": name}
|
||||
if description is not None:
|
||||
body["description"] = description
|
||||
# citations_enabled defaults to True backend-side; keep that default.
|
||||
response = await self._http.post(
|
||||
f"{self._base}/api/v1/searchspaces",
|
||||
json=body,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return SearchSpaceRow.from_payload(response.json())
|
||||
|
||||
async def get(self, search_space_id: int) -> SearchSpaceRow:
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/searchspaces/{search_space_id}",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return SearchSpaceRow.from_payload(response.json())
|
||||
|
||||
async def delete(self, search_space_id: int) -> None:
|
||||
"""Soft-delete: backend prefixes name with ``[DELETING]`` and dispatches a Celery cascade."""
|
||||
|
||||
response = await self._http.delete(
|
||||
f"{self._base}/api/v1/searchspaces/{search_space_id}",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
# 404 means it's already gone — treat as success (idempotent teardown).
|
||||
if response.status_code == 404:
|
||||
return
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_llm_preferences(self, search_space_id: int) -> LlmPreferences:
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return LlmPreferences.from_payload(response.json())
|
||||
|
||||
async def set_llm_preferences(
|
||||
self,
|
||||
search_space_id: int,
|
||||
*,
|
||||
agent_llm_id: int | None = None,
|
||||
document_summary_llm_id: int | None = None,
|
||||
image_generation_config_id: int | None = None,
|
||||
vision_llm_config_id: int | None = None,
|
||||
) -> LlmPreferences:
|
||||
"""PUT a partial update to ``/search-spaces/{id}/llm-preferences``.
|
||||
|
||||
Backend uses ``model_dump(exclude_unset=True)`` so omitted fields
|
||||
are left unchanged.
|
||||
"""
|
||||
|
||||
body: dict[str, Any] = {}
|
||||
if agent_llm_id is not None:
|
||||
body["agent_llm_id"] = agent_llm_id
|
||||
if document_summary_llm_id is not None:
|
||||
body["document_summary_llm_id"] = document_summary_llm_id
|
||||
if image_generation_config_id is not None:
|
||||
body["image_generation_config_id"] = image_generation_config_id
|
||||
if vision_llm_config_id is not None:
|
||||
body["vision_llm_config_id"] = vision_llm_config_id
|
||||
response = await self._http.put(
|
||||
f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences",
|
||||
json=body,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return LlmPreferences.from_payload(response.json())
|
||||
|
||||
async def list_global_vision_llm_configs(self) -> list[VisionLlmConfigEntry]:
|
||||
"""List the registered global vision LLM configs.
|
||||
|
||||
Used by ``setup`` to (a) resolve an explicit ``--vision-llm <slug>``
|
||||
to a config id and (b) auto-pick the strongest registered vision
|
||||
config when the operator doesn't pass one. The ``Auto (Fastest)``
|
||||
entry (``id=0``) is filtered out — accuracy must be reproducible.
|
||||
"""
|
||||
|
||||
response = await self._http.get(
|
||||
f"{self._base}/api/v1/global-vision-llm-configs",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if not isinstance(payload, list):
|
||||
raise RuntimeError(
|
||||
f"Unexpected /global-vision-llm-configs payload: {payload!r}"
|
||||
)
|
||||
return [
|
||||
VisionLlmConfigEntry.from_payload(item)
|
||||
for item in payload
|
||||
if not bool(item.get("is_auto_mode", False))
|
||||
]
|
||||
279
surfsense_evals/src/surfsense_evals/core/config.py
Normal file
279
surfsense_evals/src/surfsense_evals/core/config.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""Environment + filesystem configuration for the harness.
|
||||
|
||||
Two responsibilities:
|
||||
|
||||
1. Load env vars (with sensible defaults) into a single immutable ``Config``
|
||||
so that every other module reads it from one place.
|
||||
2. Read / write ``data/state.json``. State is keyed by suite name so multiple
|
||||
suites can be set up in parallel and torn down independently.
|
||||
|
||||
The pinned ``search_space_id`` lives in ``state.json`` (not env) so re-runs
|
||||
are idempotent without forcing the operator to remember an integer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Resolve once at import time. ``find_dotenv`` walks up; an explicit ``.env``
|
||||
# at the package root or in CWD wins. Silent-no-op if neither exists.
|
||||
load_dotenv()
|
||||
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
"""Resolves to ``surfsense_evals/`` (the package root, not ``src/``)."""
|
||||
|
||||
|
||||
def _project_root() -> Path:
|
||||
"""Return the ``surfsense_evals/`` project root.
|
||||
|
||||
Computed from this file's path: ``src/surfsense_evals/core/config.py`` →
|
||||
walk up four levels. Kept as a function so tests can monkeypatch.
|
||||
"""
|
||||
|
||||
return _PROJECT_ROOT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config:
|
||||
"""Immutable runtime configuration."""
|
||||
|
||||
surfsense_api_base: str
|
||||
openrouter_api_key: str | None
|
||||
openrouter_base_url: str
|
||||
|
||||
# Credentials — exactly ONE mode must be supplied.
|
||||
surfsense_jwt: str | None
|
||||
surfsense_refresh_token: str | None
|
||||
surfsense_user_email: str | None
|
||||
surfsense_user_password: str | None
|
||||
|
||||
# Filesystem paths.
|
||||
data_dir: Path
|
||||
reports_dir: Path
|
||||
|
||||
@property
|
||||
def state_path(self) -> Path:
|
||||
return self.data_dir / "state.json"
|
||||
|
||||
def has_jwt_mode(self) -> bool:
|
||||
return bool(self.surfsense_jwt)
|
||||
|
||||
def has_local_mode(self) -> bool:
|
||||
return bool(self.surfsense_user_email and self.surfsense_user_password)
|
||||
|
||||
def credential_mode(self) -> str:
|
||||
"""Return ``"jwt"``, ``"local"``, or ``"none"`` (no credentials supplied)."""
|
||||
|
||||
if self.has_jwt_mode():
|
||||
return "jwt"
|
||||
if self.has_local_mode():
|
||||
return "local"
|
||||
return "none"
|
||||
|
||||
def suite_data_dir(self, suite: str) -> Path:
|
||||
return self.data_dir / suite
|
||||
|
||||
def suite_reports_dir(self, suite: str) -> Path:
|
||||
return self.reports_dir / suite
|
||||
|
||||
def suite_runs_dir(self, suite: str) -> Path:
|
||||
return self.suite_data_dir(suite) / "runs"
|
||||
|
||||
def suite_maps_dir(self, suite: str) -> Path:
|
||||
return self.suite_data_dir(suite) / "maps"
|
||||
|
||||
|
||||
def load_config() -> Config:
|
||||
"""Read the current process env into a ``Config``.
|
||||
|
||||
No validation is performed here; callers (e.g. ``auth.acquire_token``,
|
||||
``cli`` subcommands) decide which fields they require. This keeps
|
||||
``models list`` and ``suites list`` runnable without OpenRouter creds.
|
||||
"""
|
||||
|
||||
project_root = _project_root()
|
||||
data_dir = Path(os.environ.get("EVAL_DATA_DIR") or (project_root / "data")).resolve()
|
||||
reports_dir = Path(os.environ.get("EVAL_REPORTS_DIR") or (project_root / "reports")).resolve()
|
||||
return Config(
|
||||
surfsense_api_base=os.environ.get("SURFSENSE_API_BASE", "http://localhost:8000").rstrip("/"),
|
||||
openrouter_api_key=os.environ.get("OPENROUTER_API_KEY") or None,
|
||||
openrouter_base_url=os.environ.get(
|
||||
"OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"
|
||||
).rstrip("/"),
|
||||
surfsense_jwt=os.environ.get("SURFSENSE_JWT") or None,
|
||||
surfsense_refresh_token=os.environ.get("SURFSENSE_REFRESH_TOKEN") or None,
|
||||
surfsense_user_email=os.environ.get("SURFSENSE_USER_EMAIL") or None,
|
||||
surfsense_user_password=os.environ.get("SURFSENSE_USER_PASSWORD") or None,
|
||||
data_dir=data_dir,
|
||||
reports_dir=reports_dir,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# state.json — per-suite slots
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Scenario names — chosen at ``setup`` time, persisted in ``state.json``.
|
||||
#
|
||||
# * ``head-to-head`` (default, current behaviour): both arms answer with the
|
||||
# SAME slug pinned via ``--provider-model``. Vision LLM at ingest is
|
||||
# optional but recommended for image-bearing benchmarks.
|
||||
# * ``symmetric-cheap``: both arms answer with the SAME (cheap, text-only)
|
||||
# slug; SurfSense pre-extracted images at ingest with a vision LLM.
|
||||
# Measures whether vision-RAG ingestion lets a cheap downstream model
|
||||
# match a vision one. Native arm structurally loses on image questions —
|
||||
# that's the point, and the report labels it accordingly.
|
||||
# * ``cost-arbitrage``: native arm answers with an EXPENSIVE vision slug
|
||||
# (``--native-arm-model``), SurfSense answers with a CHEAP text-only slug
|
||||
# (``--provider-model``) over chunks the vision LLM already extracted at
|
||||
# ingest. Measures how close SurfSense gets to native at a fraction of
|
||||
# the per-query cost. The most compelling "shines" framing.
|
||||
SCENARIOS: tuple[str, ...] = ("head-to-head", "symmetric-cheap", "cost-arbitrage")
|
||||
DEFAULT_SCENARIO: str = "head-to-head"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuiteState:
|
||||
"""Per-suite persisted state.
|
||||
|
||||
``provider_model`` is the slug pinned to the SearchSpace's
|
||||
``agent_llm`` — what answers SurfSense queries (and what the native
|
||||
arm uses too, unless ``native_arm_model`` is set for cost-arbitrage).
|
||||
|
||||
``vision_provider_model`` is the slug of the OpenRouter vision LLM
|
||||
config attached to the SearchSpace's ``vision_llm_config_id`` — what
|
||||
SurfSense uses to extract image content at ingest time when
|
||||
``use_vision_llm=True``. ``None`` means no vision config was attached
|
||||
at setup (legacy or text-only suite).
|
||||
"""
|
||||
|
||||
search_space_id: int
|
||||
agent_llm_id: int
|
||||
provider_model: str
|
||||
created_at: str
|
||||
ingestion_maps: dict[str, str] = field(default_factory=dict)
|
||||
scenario: str = DEFAULT_SCENARIO
|
||||
vision_llm_config_id: int | None = None
|
||||
vision_provider_model: str | None = None
|
||||
native_arm_model: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"search_space_id": self.search_space_id,
|
||||
"agent_llm_id": self.agent_llm_id,
|
||||
"provider_model": self.provider_model,
|
||||
"created_at": self.created_at,
|
||||
"ingestion_maps": dict(self.ingestion_maps),
|
||||
"scenario": self.scenario,
|
||||
"vision_llm_config_id": self.vision_llm_config_id,
|
||||
"vision_provider_model": self.vision_provider_model,
|
||||
"native_arm_model": self.native_arm_model,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: Mapping[str, Any]) -> SuiteState:
|
||||
# ``scenario`` / vision / native fields default for back-compat with
|
||||
# ``state.json`` written before scenarios shipped.
|
||||
scenario = str(payload.get("scenario") or DEFAULT_SCENARIO)
|
||||
if scenario not in SCENARIOS:
|
||||
scenario = DEFAULT_SCENARIO
|
||||
raw_vision_id = payload.get("vision_llm_config_id")
|
||||
return cls(
|
||||
search_space_id=int(payload["search_space_id"]),
|
||||
agent_llm_id=int(payload["agent_llm_id"]),
|
||||
provider_model=str(payload["provider_model"]),
|
||||
created_at=str(payload.get("created_at") or ""),
|
||||
ingestion_maps=dict(payload.get("ingestion_maps") or {}),
|
||||
scenario=scenario,
|
||||
vision_llm_config_id=int(raw_vision_id) if raw_vision_id is not None else None,
|
||||
vision_provider_model=(
|
||||
str(payload["vision_provider_model"])
|
||||
if payload.get("vision_provider_model")
|
||||
else None
|
||||
),
|
||||
native_arm_model=(
|
||||
str(payload["native_arm_model"])
|
||||
if payload.get("native_arm_model")
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def effective_native_arm_model(self) -> str:
|
||||
"""Slug the native arm should use; falls back to ``provider_model``."""
|
||||
|
||||
return self.native_arm_model or self.provider_model
|
||||
|
||||
|
||||
def _load_state(config: Config) -> dict[str, Any]:
|
||||
if not config.state_path.exists():
|
||||
return {"suites": {}}
|
||||
try:
|
||||
with config.state_path.open("r", encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to read state file {config.state_path}: {exc!s}. "
|
||||
"Delete it if you want to start fresh."
|
||||
) from exc
|
||||
if not isinstance(data, dict) or "suites" not in data:
|
||||
return {"suites": {}}
|
||||
return data
|
||||
|
||||
|
||||
def _write_state(config: Config, payload: Mapping[str, Any]) -> None:
|
||||
config.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
tmp = config.state_path.with_suffix(".json.tmp")
|
||||
with tmp.open("w", encoding="utf-8") as fh:
|
||||
json.dump(dict(payload), fh, indent=2, sort_keys=True)
|
||||
fh.write("\n")
|
||||
tmp.replace(config.state_path)
|
||||
|
||||
|
||||
def get_suite_state(config: Config, suite: str) -> SuiteState | None:
|
||||
"""Return ``SuiteState`` for ``suite`` or ``None`` if not set up."""
|
||||
|
||||
state = _load_state(config)
|
||||
raw = (state.get("suites") or {}).get(suite)
|
||||
if not raw:
|
||||
return None
|
||||
return SuiteState.from_dict(raw)
|
||||
|
||||
|
||||
def set_suite_state(config: Config, suite: str, suite_state: SuiteState) -> None:
|
||||
"""Persist ``suite_state`` under the suite slot. Other suites are untouched."""
|
||||
|
||||
state = _load_state(config)
|
||||
suites = dict(state.get("suites") or {})
|
||||
suites[suite] = suite_state.to_dict()
|
||||
state["suites"] = suites
|
||||
_write_state(config, state)
|
||||
|
||||
|
||||
def clear_suite_state(config: Config, suite: str) -> bool:
|
||||
"""Remove the slot for ``suite``. Returns ``True`` if removal happened."""
|
||||
|
||||
state = _load_state(config)
|
||||
suites = dict(state.get("suites") or {})
|
||||
if suite not in suites:
|
||||
return False
|
||||
del suites[suite]
|
||||
state["suites"] = suites
|
||||
_write_state(config, state)
|
||||
return True
|
||||
|
||||
|
||||
def utc_iso_timestamp() -> str:
|
||||
"""Filesystem-safe UTC ISO timestamp, e.g. ``2026-05-11T20-30-00Z``."""
|
||||
|
||||
return datetime.now(UTC).strftime("%Y-%m-%dT%H-%M-%SZ")
|
||||
311
surfsense_evals/src/surfsense_evals/core/ingest_settings.py
Normal file
311
surfsense_evals/src/surfsense_evals/core/ingest_settings.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
"""Per-upload ingestion settings shared across every benchmark.
|
||||
|
||||
The SurfSense ``POST /api/v1/documents/fileupload`` endpoint exposes
|
||||
exactly three knobs (verified at
|
||||
``surfsense_backend/app/routes/documents_routes.py`` and
|
||||
``surfsense_backend/app/etl_pipeline/etl_document.py``):
|
||||
|
||||
* ``processing_mode`` — ``"basic"`` (default) | ``"premium"``
|
||||
* ``use_vision_llm`` — ``bool`` (run vision LLM during ingest to
|
||||
extract image content / captions / tables)
|
||||
* ``should_summarize`` — ``bool`` (generate document summary)
|
||||
|
||||
This module gives every benchmark a uniform way to:
|
||||
|
||||
1. Receive sensible per-benchmark defaults (text-only benchmarks
|
||||
default vision off; image-bearing benchmarks default vision on).
|
||||
2. Accept CLI overrides (``--use-vision-llm`` / ``--no-vision-llm``,
|
||||
``--processing-mode {basic,premium}``,
|
||||
``--should-summarize`` / ``--no-summarize``).
|
||||
3. Persist the *actual* settings used into the doc-map manifest and
|
||||
the run artifact so reports can show "vision=ON, mode=premium →
|
||||
65% accuracy" head-to-head with "vision=OFF, mode=basic → 52%".
|
||||
|
||||
A/B testing on the same corpus
|
||||
------------------------------
|
||||
|
||||
SurfSense dedupes uploads by ``(filename, search_space_id)`` — NOT by
|
||||
content hash and NOT by ingestion settings. Re-uploading the same
|
||||
filename to the same SearchSpace with a different ``use_vision_llm``
|
||||
flag will hit the duplicate branch and *not* re-process. To compare
|
||||
two settings combos head-to-head on the same corpus you must give
|
||||
each combo its own SearchSpace, which today means:
|
||||
|
||||
teardown --suite <s>
|
||||
setup --suite <s> ...
|
||||
ingest <s> <bench> --no-vision-llm # baseline run
|
||||
run <s> <bench>
|
||||
teardown --suite <s>
|
||||
setup --suite <s> ...
|
||||
ingest <s> <bench> --use-vision-llm # vision arm
|
||||
run <s> <bench>
|
||||
|
||||
The runs land in different timestamped subdirectories under
|
||||
``data/<suite>/runs/`` and ``report --suite <s>`` aggregates whichever
|
||||
manifest is currently latest per benchmark.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Keep the constant list of valid processing modes here so benchmarks
|
||||
# don't have to re-import from the backend (they don't have access to
|
||||
# the backend package anyway).
|
||||
PROCESSING_MODES: tuple[str, ...] = ("basic", "premium")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IngestSettings:
|
||||
"""Resolved per-upload knobs handed to ``DocumentsClient.upload``.
|
||||
|
||||
Use ``IngestSettings(...)`` directly to define benchmark defaults,
|
||||
or ``IngestSettings.merge(defaults, opts)`` to apply CLI overrides
|
||||
on top of those defaults.
|
||||
"""
|
||||
|
||||
use_vision_llm: bool = False
|
||||
processing_mode: str = "basic"
|
||||
should_summarize: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_vision_llm": self.use_vision_llm,
|
||||
"processing_mode": self.processing_mode,
|
||||
"should_summarize": self.should_summarize,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def merge(cls, defaults: IngestSettings, opts: Mapping[str, Any]) -> IngestSettings:
|
||||
"""Apply CLI overrides on top of ``defaults``.
|
||||
|
||||
``opts`` is the kwargs dict built by ``core.cli`` from the
|
||||
argparse namespace (see ``_cmd_ingest`` / ``_cmd_run``). Keys
|
||||
we look for: ``use_vision_llm`` (bool or None), ``processing_mode``
|
||||
(str or None), ``should_summarize`` (bool or None). Anything
|
||||
else is ignored so benchmarks can pass through their own opts.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
use_vision_llm=_coerce_bool(opts.get("use_vision_llm"), defaults.use_vision_llm),
|
||||
processing_mode=_coerce_mode(opts.get("processing_mode"), defaults.processing_mode),
|
||||
should_summarize=_coerce_bool(opts.get("should_summarize"), defaults.should_summarize),
|
||||
)
|
||||
|
||||
def render_label(self) -> str:
|
||||
"""Human-readable single-line label for reports / log lines."""
|
||||
|
||||
return (
|
||||
f"vision={'on' if self.use_vision_llm else 'off'}, "
|
||||
f"mode={self.processing_mode}, "
|
||||
f"summarize={'on' if self.should_summarize else 'off'}"
|
||||
)
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool) -> bool:
|
||||
"""Argparse with ``BooleanOptionalAction`` yields True/False/None.
|
||||
|
||||
``None`` means the operator didn't pass the flag → fall back to
|
||||
the benchmark default.
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _coerce_mode(value: Any, default: str) -> str:
|
||||
if value is None or value == "":
|
||||
return default
|
||||
val = str(value).strip().lower()
|
||||
if val not in PROCESSING_MODES:
|
||||
raise ValueError(
|
||||
f"Invalid processing_mode {val!r}; must be one of {PROCESSING_MODES}"
|
||||
)
|
||||
return val
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Argparse helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _add_bool_pair(
|
||||
parser: argparse.ArgumentParser,
|
||||
*,
|
||||
dest: str,
|
||||
on_flag: str,
|
||||
off_flag: str,
|
||||
on_help: str,
|
||||
off_help: str,
|
||||
) -> None:
|
||||
"""Add a mutually exclusive ``--foo`` / ``--no-foo`` pair.
|
||||
|
||||
We don't use ``argparse.BooleanOptionalAction`` because it would
|
||||
auto-generate ``--no-use-vision-llm`` rather than the friendlier
|
||||
``--no-vision-llm`` that operators reach for. Default is ``None``
|
||||
so ``IngestSettings.merge`` can distinguish "silent" from
|
||||
"explicit false".
|
||||
"""
|
||||
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
on_flag,
|
||||
dest=dest,
|
||||
action="store_true",
|
||||
default=None,
|
||||
help=on_help,
|
||||
)
|
||||
group.add_argument(
|
||||
off_flag,
|
||||
dest=dest,
|
||||
action="store_false",
|
||||
default=None,
|
||||
help=off_help,
|
||||
)
|
||||
|
||||
|
||||
def add_ingest_settings_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
*,
|
||||
defaults: IngestSettings,
|
||||
) -> None:
|
||||
"""Attach the three ingest-settings flag pairs to ``parser``.
|
||||
|
||||
Each bool exposes a mutually exclusive ``--foo`` / ``--no-foo``
|
||||
pair so an operator can flip either direction without restating
|
||||
every flag. Default is ``None`` so that "operator didn't pass the
|
||||
flag" is distinguishable from "operator explicitly passed false"
|
||||
— ``IngestSettings.merge`` then folds in the benchmark default
|
||||
only when the operator was silent.
|
||||
"""
|
||||
|
||||
settings_group = parser.add_argument_group(
|
||||
"ingest settings",
|
||||
f"Per-upload knobs (forwarded to /documents/fileupload). "
|
||||
f"Defaults for this benchmark: {defaults.render_label()}.",
|
||||
)
|
||||
_add_bool_pair(
|
||||
settings_group,
|
||||
dest="use_vision_llm",
|
||||
on_flag="--use-vision-llm",
|
||||
off_flag="--no-vision-llm",
|
||||
on_help=(
|
||||
"Run vision LLM during ingest to extract image content "
|
||||
f"(default for this benchmark: "
|
||||
f"{'on' if defaults.use_vision_llm else 'off'})."
|
||||
),
|
||||
off_help="Skip vision LLM during ingest (text-only ETL).",
|
||||
)
|
||||
settings_group.add_argument(
|
||||
"--processing-mode",
|
||||
dest="processing_mode",
|
||||
choices=PROCESSING_MODES,
|
||||
default=None,
|
||||
help=(
|
||||
"SurfSense ETL processing mode (premium uses a 10x page "
|
||||
f"multiplier and typically routes to a stronger ETL). "
|
||||
f"Default for this benchmark: {defaults.processing_mode!r}."
|
||||
),
|
||||
)
|
||||
_add_bool_pair(
|
||||
settings_group,
|
||||
dest="should_summarize",
|
||||
on_flag="--should-summarize",
|
||||
off_flag="--no-summarize",
|
||||
on_help=(
|
||||
"Have SurfSense generate a document summary at ingest "
|
||||
f"(default for this benchmark: "
|
||||
f"{'on' if defaults.should_summarize else 'off'})."
|
||||
),
|
||||
off_help="Skip per-document summary generation.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Doc-map manifest helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Every benchmark writes a doc-map JSONL under ``data/<suite>/maps/`` that
|
||||
# pairs source identifiers (case_id, snippet_id, doc_path, …) to the
|
||||
# SurfSense document_ids returned by the upload. To make the report
|
||||
# self-describing we also write a header line:
|
||||
#
|
||||
# {"__settings__": {"use_vision_llm": ..., "processing_mode": ..., ...}}
|
||||
#
|
||||
# These two helpers centralise that protocol so each benchmark only has to
|
||||
# call ``write_settings_header`` and ``read_settings_header``.
|
||||
|
||||
SETTINGS_HEADER_KEY = "__settings__"
|
||||
|
||||
|
||||
def settings_header_line(settings: IngestSettings) -> str:
|
||||
"""Return the JSON-serialised header line (no trailing newline)."""
|
||||
|
||||
return json.dumps({SETTINGS_HEADER_KEY: settings.to_dict()})
|
||||
|
||||
|
||||
def is_settings_header(row: Mapping[str, Any]) -> bool:
|
||||
return SETTINGS_HEADER_KEY in row
|
||||
|
||||
|
||||
def read_settings_header(map_path: Path) -> dict[str, Any]:
|
||||
"""Read the ``__settings__`` header out of a doc-map JSONL.
|
||||
|
||||
Returns ``{}`` on a missing file, an empty file, an unreadable
|
||||
file, or a file whose first non-blank line is not a settings
|
||||
header (e.g. a corpus ingested before this feature existed).
|
||||
Callers use this purely to surface settings in the report; it
|
||||
must never fail the run.
|
||||
"""
|
||||
|
||||
if not map_path.exists():
|
||||
return {}
|
||||
try:
|
||||
with map_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if isinstance(row, dict) and SETTINGS_HEADER_KEY in row:
|
||||
return dict(row[SETTINGS_HEADER_KEY])
|
||||
return {}
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
def format_ingest_settings_md(settings: Any) -> str:
|
||||
"""Render the resolved settings as a single Markdown bullet line."""
|
||||
|
||||
if not isinstance(settings, Mapping) or not settings:
|
||||
return "- SurfSense ingest settings: (not recorded — re-ingest to capture)"
|
||||
vision = "on" if settings.get("use_vision_llm") else "off"
|
||||
mode = settings.get("processing_mode") or "basic"
|
||||
summarize = "on" if settings.get("should_summarize") else "off"
|
||||
return (
|
||||
f"- SurfSense ingest settings: vision_llm=`{vision}`, "
|
||||
f"processing_mode=`{mode}`, summarize=`{summarize}`"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROCESSING_MODES",
|
||||
"SETTINGS_HEADER_KEY",
|
||||
"IngestSettings",
|
||||
"add_ingest_settings_args",
|
||||
"format_ingest_settings_md",
|
||||
"is_settings_header",
|
||||
"read_settings_header",
|
||||
"settings_header_line",
|
||||
]
|
||||
50
surfsense_evals/src/surfsense_evals/core/metrics/__init__.py
Normal file
50
surfsense_evals/src/surfsense_evals/core/metrics/__init__.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""Pure-function metric primitives. Lazy imports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .comparison import McnemarResult, bootstrap_delta_ci, mcnemar_test, paired_aggregate
|
||||
from .mc_accuracy import AccuracyResult, accuracy_with_wilson_ci, wilson_ci
|
||||
from .retrieval import RetrievalScores, mrr, ndcg_at_k, recall_at_k, score_run
|
||||
|
||||
__all__ = [
|
||||
"AccuracyResult",
|
||||
"McnemarResult",
|
||||
"RetrievalScores",
|
||||
"accuracy_with_wilson_ci",
|
||||
"bootstrap_delta_ci",
|
||||
"mcnemar_test",
|
||||
"mrr",
|
||||
"ndcg_at_k",
|
||||
"paired_aggregate",
|
||||
"recall_at_k",
|
||||
"score_run",
|
||||
"wilson_ci",
|
||||
]
|
||||
|
||||
|
||||
_MODULE_FOR = {
|
||||
"AccuracyResult": "mc_accuracy",
|
||||
"accuracy_with_wilson_ci": "mc_accuracy",
|
||||
"wilson_ci": "mc_accuracy",
|
||||
"RetrievalScores": "retrieval",
|
||||
"mrr": "retrieval",
|
||||
"ndcg_at_k": "retrieval",
|
||||
"recall_at_k": "retrieval",
|
||||
"score_run": "retrieval",
|
||||
"McnemarResult": "comparison",
|
||||
"bootstrap_delta_ci": "comparison",
|
||||
"mcnemar_test": "comparison",
|
||||
"paired_aggregate": "comparison",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in _MODULE_FOR:
|
||||
from importlib import import_module
|
||||
|
||||
mod = import_module(f".{_MODULE_FOR[name]}", __name__)
|
||||
return getattr(mod, name)
|
||||
raise AttributeError(f"module 'surfsense_evals.core.metrics' has no attribute {name!r}")
|
||||
258
surfsense_evals/src/surfsense_evals/core/metrics/comparison.py
Normal file
258
surfsense_evals/src/surfsense_evals/core/metrics/comparison.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
"""Paired comparison statistics for head-to-head benchmarks.
|
||||
|
||||
In every head-to-head benchmark (currently MedXpertQA-MM and
|
||||
MMLongBench-Doc) each question is answered by both arms (Native PDF
|
||||
and SurfSense). That makes per-question outcomes paired, so
|
||||
``McNemar's test`` on the discordant pairs is the right significance
|
||||
test for "are the two arms different?". We also expose a bootstrap
|
||||
delta CI for visualising effect size.
|
||||
|
||||
Aggregate cost / latency / token deltas are mean-based; the runner
|
||||
slices them by arm before passing them in.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import statistics
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class McnemarResult:
|
||||
"""Discordant pair counts + the test statistics."""
|
||||
|
||||
n_total: int
|
||||
b: int # native correct, surfsense wrong
|
||||
c: int # native wrong, surfsense correct
|
||||
statistic: float
|
||||
p_value: float
|
||||
method: str
|
||||
|
||||
def to_dict(self) -> dict[str, float | int | str]:
|
||||
return {
|
||||
"n_total": self.n_total,
|
||||
"b_native_correct_only": self.b,
|
||||
"c_surfsense_correct_only": self.c,
|
||||
"statistic": self.statistic,
|
||||
"p_value": self.p_value,
|
||||
"method": self.method,
|
||||
}
|
||||
|
||||
|
||||
def mcnemar_test(
|
||||
arm_a_correct: Sequence[bool],
|
||||
arm_b_correct: Sequence[bool],
|
||||
*,
|
||||
use_exact_below: int = 11,
|
||||
) -> McnemarResult:
|
||||
"""Paired McNemar's test on per-question correctness.
|
||||
|
||||
``arm_a_correct`` is treated as the reference arm (typically the
|
||||
"native" arm); ``arm_b_correct`` is the challenger (typically
|
||||
"surfsense"). The test statistic only depends on discordant pairs.
|
||||
|
||||
Default switch-over (``b + c < 11``): for very small discordant
|
||||
samples the exact binomial test is preferred; above that the
|
||||
continuity-corrected chi-square is well-behaved (Edwards 1948).
|
||||
Callers can raise ``use_exact_below`` if they prefer the more
|
||||
conservative ``b + c < 25`` rule.
|
||||
|
||||
No external statistical package is required: scipy is a heavy dep
|
||||
and we only need binomial CDFs / chi-square sf, both implementable
|
||||
in stdlib + numpy without surprises.
|
||||
"""
|
||||
|
||||
if len(arm_a_correct) != len(arm_b_correct):
|
||||
raise ValueError(
|
||||
f"Length mismatch: arm_a={len(arm_a_correct)}, arm_b={len(arm_b_correct)}"
|
||||
)
|
||||
n = len(arm_a_correct)
|
||||
b = sum(1 for a, c in zip(arm_a_correct, arm_b_correct) if a and not c)
|
||||
c = sum(1 for a, cc in zip(arm_a_correct, arm_b_correct) if (not a) and cc)
|
||||
discordant = b + c
|
||||
if discordant == 0:
|
||||
return McnemarResult(
|
||||
n_total=n, b=b, c=c, statistic=0.0, p_value=1.0, method="degenerate"
|
||||
)
|
||||
|
||||
if discordant < use_exact_below:
|
||||
# Exact binomial: under H0 each discordant pair is a Bernoulli(0.5).
|
||||
# p-value = 2 * P(X <= min(b,c) | n=discordant, p=0.5), capped at 1.
|
||||
k = min(b, c)
|
||||
cdf = sum(_binom_pmf(discordant, i) for i in range(k + 1))
|
||||
p_value = min(1.0, 2.0 * cdf)
|
||||
return McnemarResult(
|
||||
n_total=n, b=b, c=c, statistic=float(k), p_value=p_value, method="exact"
|
||||
)
|
||||
|
||||
# Chi-square with continuity correction (McNemar-Edwards).
|
||||
chi = ((abs(b - c) - 1) ** 2) / discordant
|
||||
p_value = _chi2_sf(chi, df=1)
|
||||
return McnemarResult(
|
||||
n_total=n, b=b, c=c, statistic=chi, p_value=p_value, method="chi2_cc"
|
||||
)
|
||||
|
||||
|
||||
def _binom_pmf(n: int, k: int) -> float:
|
||||
return math.comb(n, k) * (0.5 ** n)
|
||||
|
||||
|
||||
def _chi2_sf(x: float, *, df: int) -> float:
|
||||
"""Survival function (1 - CDF) of chi-square; df=1 closed form."""
|
||||
|
||||
if x <= 0:
|
||||
return 1.0
|
||||
if df == 1:
|
||||
# Chi^2(1) = N(0,1)^2; sf(x) = 2 * Phi_complement(sqrt(x))
|
||||
return math.erfc(math.sqrt(x / 2.0))
|
||||
# General fallback via regularized upper incomplete gamma.
|
||||
a = df / 2.0
|
||||
z = x / 2.0
|
||||
return _gammaincc(a, z)
|
||||
|
||||
|
||||
def _gammaincc(a: float, x: float, *, max_iter: int = 200, tol: float = 1e-12) -> float:
|
||||
"""Regularised upper incomplete gamma Q(a, x). Series + continued fraction."""
|
||||
|
||||
if x < 0 or a <= 0:
|
||||
return float("nan")
|
||||
if x == 0:
|
||||
return 1.0
|
||||
if x < a + 1.0:
|
||||
# Series for P(a, x); subtract from 1.
|
||||
p_series = _gammainc_series(a, x, max_iter=max_iter, tol=tol)
|
||||
return 1.0 - p_series
|
||||
return _gammaincc_cf(a, x, max_iter=max_iter, tol=tol)
|
||||
|
||||
|
||||
def _gammainc_series(a: float, x: float, *, max_iter: int, tol: float) -> float:
|
||||
term = 1.0 / a
|
||||
summation = term
|
||||
for n in range(1, max_iter):
|
||||
term *= x / (a + n)
|
||||
summation += term
|
||||
if abs(term) < abs(summation) * tol:
|
||||
break
|
||||
log_pre = -x + a * math.log(x) - math.lgamma(a)
|
||||
return summation * math.exp(log_pre)
|
||||
|
||||
|
||||
def _gammaincc_cf(a: float, x: float, *, max_iter: int, tol: float) -> float:
|
||||
b = x + 1.0 - a
|
||||
c_val = 1.0 / 1e-300
|
||||
d = 1.0 / b
|
||||
h = d
|
||||
for i in range(1, max_iter):
|
||||
an = -i * (i - a)
|
||||
b += 2.0
|
||||
d = an * d + b
|
||||
if abs(d) < 1e-300:
|
||||
d = 1e-300
|
||||
c_val = b + an / c_val
|
||||
if abs(c_val) < 1e-300:
|
||||
c_val = 1e-300
|
||||
d = 1.0 / d
|
||||
delta = d * c_val
|
||||
h *= delta
|
||||
if abs(delta - 1.0) < tol:
|
||||
break
|
||||
log_pre = -x + a * math.log(x) - math.lgamma(a)
|
||||
return h * math.exp(log_pre)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bootstrap delta CI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BootstrapDelta:
|
||||
delta: float
|
||||
ci_low: float
|
||||
ci_high: float
|
||||
n_resamples: int
|
||||
|
||||
def to_dict(self) -> dict[str, float | int]:
|
||||
return {
|
||||
"delta": self.delta,
|
||||
"ci_low": self.ci_low,
|
||||
"ci_high": self.ci_high,
|
||||
"n_resamples": self.n_resamples,
|
||||
}
|
||||
|
||||
|
||||
def bootstrap_delta_ci(
|
||||
arm_a_correct: Sequence[bool],
|
||||
arm_b_correct: Sequence[bool],
|
||||
*,
|
||||
n_resamples: int = 5000,
|
||||
level: float = 0.95,
|
||||
random_state: int | None = 0,
|
||||
) -> BootstrapDelta:
|
||||
"""Paired-sample bootstrap CI for ``mean(arm_b) - mean(arm_a)``.
|
||||
|
||||
Resamples *paired indices* with replacement so the dependency
|
||||
between arms is preserved.
|
||||
"""
|
||||
|
||||
if len(arm_a_correct) != len(arm_b_correct):
|
||||
raise ValueError("paired arms must have the same length")
|
||||
n = len(arm_a_correct)
|
||||
if n == 0:
|
||||
return BootstrapDelta(0.0, 0.0, 0.0, 0)
|
||||
a = np.asarray(arm_a_correct, dtype=np.int8)
|
||||
b = np.asarray(arm_b_correct, dtype=np.int8)
|
||||
delta = float(b.mean() - a.mean())
|
||||
|
||||
rng = np.random.default_rng(random_state)
|
||||
deltas = np.empty(n_resamples, dtype=np.float64)
|
||||
for i in range(n_resamples):
|
||||
idx = rng.integers(0, n, size=n)
|
||||
deltas[i] = b[idx].mean() - a[idx].mean()
|
||||
alpha = (1.0 - level) / 2.0
|
||||
ci_low, ci_high = float(np.quantile(deltas, alpha)), float(np.quantile(deltas, 1 - alpha))
|
||||
return BootstrapDelta(delta=delta, ci_low=ci_low, ci_high=ci_high, n_resamples=n_resamples)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simple aggregate helpers (cost / latency / tokens)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Aggregate:
|
||||
mean: float
|
||||
median: float
|
||||
p95: float
|
||||
n: int
|
||||
|
||||
def to_dict(self) -> dict[str, float | int]:
|
||||
return {"mean": self.mean, "median": self.median, "p95": self.p95, "n": self.n}
|
||||
|
||||
|
||||
def paired_aggregate(values: Sequence[float]) -> Aggregate:
|
||||
"""Mean / median / p95 of a list of numbers (e.g. cost-per-question)."""
|
||||
|
||||
if not values:
|
||||
return Aggregate(0.0, 0.0, 0.0, 0)
|
||||
arr = np.asarray(values, dtype=np.float64)
|
||||
return Aggregate(
|
||||
mean=float(arr.mean()),
|
||||
median=float(statistics.median(values)),
|
||||
p95=float(np.quantile(arr, 0.95)),
|
||||
n=len(values),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Aggregate",
|
||||
"BootstrapDelta",
|
||||
"McnemarResult",
|
||||
"bootstrap_delta_ci",
|
||||
"mcnemar_test",
|
||||
"paired_aggregate",
|
||||
]
|
||||
130
surfsense_evals/src/surfsense_evals/core/metrics/mc_accuracy.py
Normal file
130
surfsense_evals/src/surfsense_evals/core/metrics/mc_accuracy.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""Multiple-choice accuracy + Wilson 95% confidence intervals.
|
||||
|
||||
Wilson CI is preferred over normal-approximation because MIRAGE's
|
||||
per-task subsets can be small (PubMedQA* and BioASQ-Y/N have a few
|
||||
hundred questions each) and Wilson handles n→0 / p→{0,1} edges
|
||||
gracefully.
|
||||
|
||||
Reference for the closed form: Wilson (1927); identical to the
|
||||
``statsmodels.stats.proportion.proportion_confint(method='wilson')``
|
||||
output and what scikit-learn implements internally for its bounded
|
||||
estimators.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AccuracyResult:
|
||||
"""Per-task accuracy with Wilson CI."""
|
||||
|
||||
n_correct: int
|
||||
n_total: int
|
||||
accuracy: float
|
||||
ci_low: float
|
||||
ci_high: float
|
||||
|
||||
def to_dict(self) -> dict[str, float | int]:
|
||||
return {
|
||||
"n_correct": self.n_correct,
|
||||
"n_total": self.n_total,
|
||||
"accuracy": self.accuracy,
|
||||
"ci_low": self.ci_low,
|
||||
"ci_high": self.ci_high,
|
||||
}
|
||||
|
||||
|
||||
# Two-sided Wilson z values. 1.959964 ≈ z_{0.975}.
|
||||
_Z_FOR_LEVEL: dict[float, float] = {
|
||||
0.90: 1.6448536269514722,
|
||||
0.95: 1.959963984540054,
|
||||
0.99: 2.5758293035489004,
|
||||
}
|
||||
|
||||
|
||||
def wilson_ci(
|
||||
n_correct: int, n_total: int, *, level: float = 0.95
|
||||
) -> tuple[float, float]:
|
||||
"""Two-sided Wilson score confidence interval for a proportion.
|
||||
|
||||
Returns ``(low, high)``. ``n_total == 0`` returns ``(0.0, 1.0)`` —
|
||||
the maximally uncertain interval.
|
||||
"""
|
||||
|
||||
if n_total <= 0:
|
||||
return 0.0, 1.0
|
||||
if level not in _Z_FOR_LEVEL:
|
||||
raise ValueError(f"Unsupported confidence level {level!r}")
|
||||
z = _Z_FOR_LEVEL[level]
|
||||
p = n_correct / n_total
|
||||
n = n_total
|
||||
denom = 1.0 + (z * z) / n
|
||||
centre = (p + (z * z) / (2 * n)) / denom
|
||||
half = (z / denom) * math.sqrt((p * (1 - p) / n) + (z * z) / (4 * n * n))
|
||||
low = max(0.0, centre - half)
|
||||
high = min(1.0, centre + half)
|
||||
return low, high
|
||||
|
||||
|
||||
def accuracy_with_wilson_ci(
|
||||
n_correct: int, n_total: int, *, level: float = 0.95
|
||||
) -> AccuracyResult:
|
||||
if n_total < 0:
|
||||
raise ValueError(f"n_total must be >= 0, got {n_total}")
|
||||
if n_correct < 0 or n_correct > n_total:
|
||||
raise ValueError(
|
||||
f"n_correct must be in [0, n_total]; got n_correct={n_correct}, n_total={n_total}"
|
||||
)
|
||||
accuracy = (n_correct / n_total) if n_total > 0 else 0.0
|
||||
low, high = wilson_ci(n_correct, n_total, level=level)
|
||||
return AccuracyResult(
|
||||
n_correct=n_correct,
|
||||
n_total=n_total,
|
||||
accuracy=accuracy,
|
||||
ci_low=low,
|
||||
ci_high=high,
|
||||
)
|
||||
|
||||
|
||||
def per_task_accuracy(
|
||||
rows: Sequence[Mapping[str, object]],
|
||||
*,
|
||||
task_key: str = "task",
|
||||
correct_key: str = "is_correct",
|
||||
level: float = 0.95,
|
||||
) -> dict[str, AccuracyResult]:
|
||||
"""Group ``rows`` by ``task_key`` and compute per-task ``AccuracyResult``.
|
||||
|
||||
``rows[i][correct_key]`` must be truthy iff the answer was correct.
|
||||
"""
|
||||
|
||||
counts: dict[str, list[int]] = {}
|
||||
for row in rows:
|
||||
task = str(row.get(task_key, ""))
|
||||
bucket = counts.setdefault(task, [0, 0])
|
||||
bucket[1] += 1
|
||||
if row.get(correct_key):
|
||||
bucket[0] += 1
|
||||
return {
|
||||
task: accuracy_with_wilson_ci(c[0], c[1], level=level)
|
||||
for task, c in counts.items()
|
||||
}
|
||||
|
||||
|
||||
def macro_accuracy(per_task: Mapping[str, AccuracyResult]) -> float:
|
||||
if not per_task:
|
||||
return 0.0
|
||||
return sum(r.accuracy for r in per_task.values()) / len(per_task)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AccuracyResult",
|
||||
"accuracy_with_wilson_ci",
|
||||
"macro_accuracy",
|
||||
"per_task_accuracy",
|
||||
"wilson_ci",
|
||||
]
|
||||
132
surfsense_evals/src/surfsense_evals/core/metrics/retrieval.py
Normal file
132
surfsense_evals/src/surfsense_evals/core/metrics/retrieval.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""Retrieval metrics: Recall@k, MRR, nDCG@k.
|
||||
|
||||
Used by CUREv1's runner to score the SurfSense arm against the
|
||||
benchmark's qrels. ``corpus_id`` is the canonical CUREv1 passage id
|
||||
(string); the runner maps SurfSense ``chunk_id`` → ``document_id`` →
|
||||
``corpus_id`` before calling these.
|
||||
|
||||
Graded relevance (CUREv1 uses 0/1/2 grades) is honoured by ``ndcg_at_k``;
|
||||
``recall_at_k`` and ``mrr`` flatten anything > 0 to "relevant".
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RetrievalScores:
|
||||
"""Aggregated retrieval scores."""
|
||||
|
||||
recall_at_k: dict[int, float]
|
||||
mrr: float
|
||||
ndcg_at_10: float
|
||||
n_queries: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"recall_at_k": dict(self.recall_at_k),
|
||||
"mrr": self.mrr,
|
||||
"ndcg_at_10": self.ndcg_at_10,
|
||||
"n_queries": self.n_queries,
|
||||
}
|
||||
|
||||
|
||||
def recall_at_k(retrieved: Sequence[str], relevant: Iterable[str], k: int) -> float:
|
||||
"""Fraction of ``relevant`` documents found in ``retrieved[:k]``."""
|
||||
|
||||
if not relevant:
|
||||
return 0.0
|
||||
relevant_set = set(relevant)
|
||||
if not relevant_set:
|
||||
return 0.0
|
||||
top_k = list(retrieved)[:k]
|
||||
hits = sum(1 for doc in top_k if doc in relevant_set)
|
||||
return hits / len(relevant_set)
|
||||
|
||||
|
||||
def mrr(retrieved: Sequence[str], relevant: Iterable[str]) -> float:
|
||||
"""Reciprocal rank of the first relevant doc, 0 if none found."""
|
||||
|
||||
relevant_set = set(relevant)
|
||||
for rank, doc in enumerate(retrieved, start=1):
|
||||
if doc in relevant_set:
|
||||
return 1.0 / rank
|
||||
return 0.0
|
||||
|
||||
|
||||
def _dcg_at_k(grades: Sequence[float], k: int) -> float:
|
||||
s = 0.0
|
||||
for i, grade in enumerate(grades[:k], start=1):
|
||||
# Standard log-base-2 discount; gain = 2^grade - 1 for graded relevance.
|
||||
s += (2.0 ** grade - 1.0) / math.log2(i + 1)
|
||||
return s
|
||||
|
||||
|
||||
def ndcg_at_k(
|
||||
retrieved: Sequence[str],
|
||||
qrels: Mapping[str, float],
|
||||
k: int,
|
||||
) -> float:
|
||||
"""nDCG@k against graded ``qrels`` (``{doc_id: grade}``).
|
||||
|
||||
Unjudged documents in ``retrieved`` contribute zero gain. The
|
||||
ideal ordering is ``qrels`` sorted by grade descending.
|
||||
"""
|
||||
|
||||
if not qrels:
|
||||
return 0.0
|
||||
grades = [float(qrels.get(doc, 0.0)) for doc in retrieved]
|
||||
dcg = _dcg_at_k(grades, k)
|
||||
ideal = sorted(qrels.values(), reverse=True)
|
||||
idcg = _dcg_at_k([float(g) for g in ideal], k)
|
||||
if idcg == 0.0:
|
||||
return 0.0
|
||||
return dcg / idcg
|
||||
|
||||
|
||||
def score_run(
|
||||
*,
|
||||
per_query_retrieved: Mapping[str, Sequence[str]],
|
||||
per_query_qrels: Mapping[str, Mapping[str, float]],
|
||||
ks: Sequence[int] = (1, 5, 10, 32),
|
||||
ndcg_k: int = 10,
|
||||
) -> RetrievalScores:
|
||||
"""Aggregate Recall@k, MRR, nDCG@k across a run.
|
||||
|
||||
``per_query_retrieved`` maps ``query_id -> ordered list of doc ids``.
|
||||
``per_query_qrels`` maps ``query_id -> {doc_id: grade}`` (grade > 0
|
||||
is relevant).
|
||||
|
||||
Queries present in retrieved but not in qrels are skipped. Queries
|
||||
in qrels but missing from retrieved contribute zeros.
|
||||
"""
|
||||
|
||||
qids = set(per_query_qrels.keys()) & set(per_query_retrieved.keys())
|
||||
if not qids:
|
||||
return RetrievalScores(recall_at_k={k: 0.0 for k in ks}, mrr=0.0, ndcg_at_10=0.0, n_queries=0)
|
||||
|
||||
recall_totals = {k: 0.0 for k in ks}
|
||||
mrr_total = 0.0
|
||||
ndcg_total = 0.0
|
||||
for qid in qids:
|
||||
retrieved = list(per_query_retrieved[qid])
|
||||
qrels = per_query_qrels[qid]
|
||||
relevant_docs = [d for d, g in qrels.items() if g > 0]
|
||||
for k in ks:
|
||||
recall_totals[k] += recall_at_k(retrieved, relevant_docs, k)
|
||||
mrr_total += mrr(retrieved, relevant_docs)
|
||||
ndcg_total += ndcg_at_k(retrieved, qrels, ndcg_k)
|
||||
|
||||
n = len(qids)
|
||||
return RetrievalScores(
|
||||
recall_at_k={k: v / n for k, v in recall_totals.items()},
|
||||
mrr=mrr_total / n,
|
||||
ndcg_at_10=ndcg_total / n,
|
||||
n_queries=n,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["RetrievalScores", "mrr", "ndcg_at_k", "recall_at_k", "score_run"]
|
||||
21
surfsense_evals/src/surfsense_evals/core/parse/__init__.py
Normal file
21
surfsense_evals/src/surfsense_evals/core/parse/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""Parsers shared across suites: citations, MCQ envelopes, AI-SDK SSE."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .answer_letter import AnswerLetterResult, extract_answer_letter
|
||||
from .citations import CITATION_REGEX, CitationToken, ChunkCitation, UrlCitation, parse_citations
|
||||
from .freeform_answer import extract_freeform_answer
|
||||
from .sse import SseEvent, iter_sse_events
|
||||
|
||||
__all__ = [
|
||||
"CITATION_REGEX",
|
||||
"CitationToken",
|
||||
"ChunkCitation",
|
||||
"UrlCitation",
|
||||
"parse_citations",
|
||||
"AnswerLetterResult",
|
||||
"extract_answer_letter",
|
||||
"extract_freeform_answer",
|
||||
"SseEvent",
|
||||
"iter_sse_events",
|
||||
]
|
||||
122
surfsense_evals/src/surfsense_evals/core/parse/answer_letter.py
Normal file
122
surfsense_evals/src/surfsense_evals/core/parse/answer_letter.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Robust extractor for MCQ answer letters.
|
||||
|
||||
Handles three answer shapes seen in the wild:
|
||||
|
||||
1. **MedRAG envelope** — ``{"step_by_step_thinking": "...", "answer_choice": "A"}``
|
||||
embedded somewhere in the assistant message (often inside ```` ```json ```` /
|
||||
``` ``` ``` fences). The regex grabs the JSON object and reads the
|
||||
``answer_choice`` field.
|
||||
|
||||
2. **Final-line letter** — e.g. ``Answer: B`` or ``The correct answer is (C).``.
|
||||
Falls back to a permissive regex over the last few lines.
|
||||
|
||||
3. **Bare letter** — single uppercase letter at the end of the message.
|
||||
|
||||
The function returns the parsed letter (uppercased) plus a discriminator
|
||||
of which strategy fired so the runner / report can flag suspicious
|
||||
parses (typically zero-confidence parses indicate the model didn't
|
||||
follow the prompt).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
ParserStrategy = Literal["json_envelope", "answer_line", "bare_letter", "none"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnswerLetterResult:
|
||||
letter: str | None
|
||||
strategy: ParserStrategy
|
||||
|
||||
@property
|
||||
def found(self) -> bool:
|
||||
return self.letter is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_JSON_BLOCK = re.compile(r"\{[^{}]*\"answer_choice\"\s*:\s*\"([A-Za-z])\"[^{}]*\}", re.DOTALL)
|
||||
_FENCED_JSON = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL | re.IGNORECASE)
|
||||
_ANSWER_LINE = re.compile(
|
||||
r"(?:final\s*answer|answer\s*choice|the\s+correct\s+answer\s+is|answer)\s*[:=\-]?\s*"
|
||||
r"\(?\s*([A-Za-z])\s*[\)\.]*\s*$",
|
||||
re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
_BARE_LETTER = re.compile(r"^\s*\(?\s*([A-Za-z])\s*[\)\.]*\s*$", re.MULTILINE)
|
||||
|
||||
|
||||
def _from_json_envelope(text: str) -> str | None:
|
||||
# Try fenced code blocks first (most likely to contain the JSON).
|
||||
for fence in _FENCED_JSON.finditer(text):
|
||||
try:
|
||||
obj = json.loads(fence.group(1))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
if isinstance(obj, dict):
|
||||
choice = obj.get("answer_choice")
|
||||
if isinstance(choice, str) and choice.strip():
|
||||
return choice.strip()[:1].upper()
|
||||
|
||||
# Fall back to a tolerant regex over the whole text (handles
|
||||
# responses that drop the fences).
|
||||
match = _JSON_BLOCK.search(text)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
return None
|
||||
|
||||
|
||||
def _from_answer_line(text: str) -> str | None:
|
||||
# Walk lines bottom-up; the answer is almost always near the end.
|
||||
for match in reversed(list(_ANSWER_LINE.finditer(text))):
|
||||
letter = match.group(1).upper()
|
||||
if letter.isalpha():
|
||||
return letter
|
||||
return None
|
||||
|
||||
|
||||
def _from_bare_letter(text: str) -> str | None:
|
||||
# Inspect only the final non-empty lines (avoid grabbing in-prose
|
||||
# mentions of "A" or "I").
|
||||
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
|
||||
for ln in reversed(lines[-3:]):
|
||||
match = _BARE_LETTER.match(ln)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
return None
|
||||
|
||||
|
||||
def extract_answer_letter(text: str) -> AnswerLetterResult:
|
||||
"""Run strategies in order and return the first hit.
|
||||
|
||||
Order: JSON envelope → final-answer-line regex → bare-letter
|
||||
fallback. Empty / whitespace-only text returns
|
||||
``AnswerLetterResult(None, "none")``.
|
||||
"""
|
||||
|
||||
if not text or not text.strip():
|
||||
return AnswerLetterResult(None, "none")
|
||||
|
||||
letter = _from_json_envelope(text)
|
||||
if letter:
|
||||
return AnswerLetterResult(letter, "json_envelope")
|
||||
|
||||
letter = _from_answer_line(text)
|
||||
if letter:
|
||||
return AnswerLetterResult(letter, "answer_line")
|
||||
|
||||
letter = _from_bare_letter(text)
|
||||
if letter:
|
||||
return AnswerLetterResult(letter, "bare_letter")
|
||||
|
||||
return AnswerLetterResult(None, "none")
|
||||
|
||||
|
||||
__all__ = ["AnswerLetterResult", "ParserStrategy", "extract_answer_letter"]
|
||||
110
surfsense_evals/src/surfsense_evals/core/parse/citations.py
Normal file
110
surfsense_evals/src/surfsense_evals/core/parse/citations.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Python port of the canonical citation parser.
|
||||
|
||||
Source of truth: ``surfsense_web/lib/citations/citation-parser.ts:20-21``.
|
||||
The pattern is byte-for-byte identical to the TS export ``CITATION_REGEX``
|
||||
so a SurfSense user reading the web client and a CUREv1 retrieval scorer
|
||||
running here see the same chunk_ids extracted from the same answer.
|
||||
|
||||
The TS reference also handles a ``urlcite{N}`` placeholder produced by
|
||||
``preprocessCitationMarkdown`` — that pre-processing step is web-only
|
||||
(GFM autolink workaround), so the harness sees raw ``[citation:URL]``
|
||||
tokens and ``parse_citations`` returns them as ``UrlCitation`` directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
|
||||
# Pattern preserves the TS source verbatim:
|
||||
# /[\[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g
|
||||
#
|
||||
# Notes:
|
||||
# * Matches both ASCII ``[]`` and Chinese fullwidth ``【】`` brackets.
|
||||
# * Allows an optional ZWSP (``\u200B``) just inside each bracket.
|
||||
# * ``citation:`` then EITHER a URL (anything not ``]``, ``】``, or ZWSP),
|
||||
# OR a ``urlcite\d+`` placeholder, OR one or more comma-separated
|
||||
# chunk ids (each optionally prefixed with ``doc-`` and optionally
|
||||
# negative).
|
||||
# * URL char class deliberately excludes the closing brackets so a
|
||||
# ``[citation:https://x.com]`` doesn't swallow the ``]``.
|
||||
# The ZWSP must be the actual code-point — the original TS source uses
|
||||
# the regex literal ``\u200B`` which the JS engine interprets as the
|
||||
# character. Python's ``re`` doesn't process the ``\u`` escape inside
|
||||
# the pattern source, so we splice the literal character in via an
|
||||
# f-string. This keeps our pattern functionally identical to the TS
|
||||
# reference and lets ``"\u200B" in CITATION_REGEX.pattern`` succeed.
|
||||
_ZWSP = "\u200B"
|
||||
CITATION_REGEX = re.compile(
|
||||
rf"[\[【]{_ZWSP}?citation:\s*("
|
||||
rf"https?://[^\]】{_ZWSP}]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*"
|
||||
rf")\s*{_ZWSP}?[\]】]"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChunkCitation:
|
||||
chunk_id: int
|
||||
is_docs_chunk: bool
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"kind": "chunk",
|
||||
"chunk_id": self.chunk_id,
|
||||
"is_docs_chunk": self.is_docs_chunk,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UrlCitation:
|
||||
url: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"kind": "url", "url": self.url}
|
||||
|
||||
|
||||
CitationToken = Union[ChunkCitation, UrlCitation]
|
||||
|
||||
|
||||
def parse_citations(text: str, *, url_map: dict[str, str] | None = None) -> list[CitationToken]:
|
||||
"""Return the citation tokens found in ``text`` in document order.
|
||||
|
||||
``url_map`` is the optional ``urlciteN -> URL`` lookup that the web
|
||||
client builds in its preprocessing step. The harness ordinarily
|
||||
doesn't preprocess (we don't render the markdown, we score it), so
|
||||
the default empty map means ``urlciteN`` placeholders are dropped
|
||||
rather than mis-resolved to a missing URL.
|
||||
|
||||
Multi-id payloads like ``[citation:1, doc-2, -3]`` are flattened
|
||||
into separate ``ChunkCitation`` entries — same as the TS reference.
|
||||
"""
|
||||
|
||||
out: list[CitationToken] = []
|
||||
for match in CITATION_REGEX.finditer(text):
|
||||
captured = match.group(1)
|
||||
if captured.startswith("http://") or captured.startswith("https://"):
|
||||
out.append(UrlCitation(url=captured.strip()))
|
||||
continue
|
||||
if captured.startswith("urlcite"):
|
||||
if url_map and captured in url_map:
|
||||
out.append(UrlCitation(url=url_map[captured]))
|
||||
continue
|
||||
for raw_id in (s.strip() for s in captured.split(",")):
|
||||
is_docs_chunk = raw_id.startswith("doc-")
|
||||
number_part = raw_id[4:] if is_docs_chunk else raw_id
|
||||
try:
|
||||
chunk_id = int(number_part)
|
||||
except ValueError:
|
||||
continue
|
||||
out.append(ChunkCitation(chunk_id=chunk_id, is_docs_chunk=is_docs_chunk))
|
||||
return out
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CITATION_REGEX",
|
||||
"ChunkCitation",
|
||||
"UrlCitation",
|
||||
"CitationToken",
|
||||
"parse_citations",
|
||||
]
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Extract free-form answers from open-ended LLM responses.
|
||||
|
||||
Used by benchmarks that don't have a fixed letter set (MMLongBench-Doc,
|
||||
DocVQA-style benchmarks, future legal/finance suites). The contract:
|
||||
|
||||
* Strip leading "Answer:" / "Final answer:" markers if present.
|
||||
* Drop fenced code blocks if the model wrapped its answer in one.
|
||||
* Trim leading/trailing whitespace.
|
||||
* Return the *last* meaningful chunk — models often think out loud
|
||||
before stating the answer.
|
||||
|
||||
If the message is empty or only contains a fence, return ``""``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
_ANSWER_PREFIX = re.compile(
|
||||
r"^\s*(?:final\s*answer|the\s+answer\s+is|answer)\s*[:=\-]\s*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Marker-only regex (no capture group) used to find every "Answer:"
|
||||
# token position. We then slice from the LAST marker's end to the
|
||||
# next newline ourselves — robust to multiple inline answers because
|
||||
# we never let the engine greedy-capture across markers.
|
||||
_ANSWER_MARKER = re.compile(
|
||||
r"(?:final\s*answer|the\s+answer\s+is|answer)\s*[:=\-]\s*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_FENCED_BLOCK = re.compile(r"```[a-zA-Z0-9]*\s*([\s\S]*?)\s*```")
|
||||
|
||||
|
||||
def extract_freeform_answer(text: str) -> str:
|
||||
"""Pull the model's final answer out of a possibly-verbose response."""
|
||||
|
||||
if not text or not text.strip():
|
||||
return ""
|
||||
|
||||
# 1. Find the last line that starts with an Answer: marker. If
|
||||
# nothing matches, walk back to the last non-empty line.
|
||||
lines = [ln.rstrip() for ln in text.strip().splitlines()]
|
||||
candidate = ""
|
||||
for ln in reversed(lines):
|
||||
if not ln.strip():
|
||||
continue
|
||||
if _ANSWER_PREFIX.search(ln):
|
||||
candidate = _ANSWER_PREFIX.sub("", ln, count=1).strip()
|
||||
break
|
||||
|
||||
if not candidate:
|
||||
# 2. Inline match: find every "Answer:" marker position and
|
||||
# slice from the LAST marker's end to the next newline. Robust
|
||||
# to "preamble.Answer: 42" one-liners and multiple inline
|
||||
# markers (we always pick the final, freshest one).
|
||||
marker_matches = list(_ANSWER_MARKER.finditer(text))
|
||||
if marker_matches:
|
||||
last = marker_matches[-1]
|
||||
tail = text[last.end():]
|
||||
nl = tail.find("\n")
|
||||
if nl >= 0:
|
||||
tail = tail[:nl]
|
||||
candidate = tail.strip()
|
||||
|
||||
if not candidate:
|
||||
# 3. No "Answer:" marker — try fenced blocks.
|
||||
fences = _FENCED_BLOCK.findall(text)
|
||||
if fences:
|
||||
candidate = fences[-1].strip()
|
||||
else:
|
||||
# Last non-empty line as a fallback.
|
||||
for ln in reversed(lines):
|
||||
if ln.strip():
|
||||
candidate = ln.strip()
|
||||
break
|
||||
|
||||
# 2. Strip wrapping quotes / parens / trailing punctuation that
|
||||
# confuse the grader without changing meaning.
|
||||
candidate = candidate.strip().strip("`").strip()
|
||||
if candidate.startswith(("\"", "'")) and candidate.endswith(("\"", "'")):
|
||||
candidate = candidate[1:-1].strip()
|
||||
return candidate
|
||||
|
||||
|
||||
__all__ = ["extract_freeform_answer"]
|
||||
72
surfsense_evals/src/surfsense_evals/core/parse/sse.py
Normal file
72
surfsense_evals/src/surfsense_evals/core/parse/sse.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Minimal SSE consumer compatible with SurfSense's wire format.
|
||||
|
||||
SurfSense uses ``app/services/streaming/envelope/sse.py`` to frame events:
|
||||
|
||||
* ``data: <single-line-string>\\n\\n``
|
||||
* ``data: <json-string>\\n\\n`` (most events)
|
||||
* ``data: [DONE]\\n\\n`` (terminator)
|
||||
|
||||
There is no ``event:``, ``id:``, or ``retry:`` framing in production —
|
||||
``format_sse(payload)`` only emits the ``data:`` line. This implementation
|
||||
is therefore intentionally smaller than ``httpx-sse`` (which we still
|
||||
list as a dep so callers who want richer parsing can opt in): one event
|
||||
per ``data:`` line, separated by blank lines.
|
||||
|
||||
We accept any line iterator (an ``httpx.Response.aiter_lines`` adapter
|
||||
in production, a list in tests) so this is unit-testable without a
|
||||
network mock.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SseEvent:
|
||||
"""A parsed SSE event. Only the ``data`` field is populated.
|
||||
|
||||
Multi-line payloads (``data: a\\ndata: b``) are joined with ``\\n``
|
||||
per the SSE spec, even though SurfSense doesn't currently emit them.
|
||||
"""
|
||||
|
||||
data: str
|
||||
|
||||
|
||||
async def iter_sse_events(lines: AsyncIterator[str]) -> AsyncIterator[SseEvent]:
|
||||
"""Yield one ``SseEvent`` per blank-line-terminated frame.
|
||||
|
||||
Lines that are empty or whitespace flush the buffer. ``data:`` lines
|
||||
are accumulated into the buffer; everything else is ignored
|
||||
(matches the lenient browser EventSource behaviour).
|
||||
"""
|
||||
|
||||
buffer: list[str] = []
|
||||
async for raw in lines:
|
||||
if raw is None:
|
||||
continue
|
||||
line = raw.rstrip("\r")
|
||||
if line == "":
|
||||
if buffer:
|
||||
yield SseEvent(data="\n".join(buffer))
|
||||
buffer.clear()
|
||||
continue
|
||||
if line.startswith(":"):
|
||||
# comment / heartbeat
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
# spec: optional single space after the colon.
|
||||
payload = line[5:]
|
||||
if payload.startswith(" "):
|
||||
payload = payload[1:]
|
||||
buffer.append(payload)
|
||||
continue
|
||||
# Any other field (event:, id:, retry:) is currently unused.
|
||||
continue
|
||||
|
||||
if buffer:
|
||||
yield SseEvent(data="\n".join(buffer))
|
||||
|
||||
|
||||
__all__ = ["SseEvent", "iter_sse_events"]
|
||||
31
surfsense_evals/src/surfsense_evals/core/pdf/__init__.py
Normal file
31
surfsense_evals/src/surfsense_evals/core/pdf/__init__.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""Domain-agnostic PDF rendering helper. Lazy import."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .render import (
|
||||
PdfImage,
|
||||
render_pdf,
|
||||
render_pdf_with_images,
|
||||
render_text_files_to_pdf,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PdfImage",
|
||||
"render_pdf",
|
||||
"render_pdf_with_images",
|
||||
"render_text_files_to_pdf",
|
||||
]
|
||||
|
||||
|
||||
_LAZY = {"PdfImage", "render_pdf", "render_pdf_with_images", "render_text_files_to_pdf"}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in _LAZY:
|
||||
from . import render as _mod
|
||||
|
||||
return getattr(_mod, name)
|
||||
raise AttributeError(f"module 'surfsense_evals.core.pdf' has no attribute {name!r}")
|
||||
351
surfsense_evals/src/surfsense_evals/core/pdf/render.py
Normal file
351
surfsense_evals/src/surfsense_evals/core/pdf/render.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
"""Deterministic ``.txt`` / ``.md`` → single PDF via reportlab.
|
||||
|
||||
Used wherever a benchmark needs the same source bytes fed to both the
|
||||
native-PDF arm and the SurfSense ingestion arm. The head-to-head
|
||||
comparison is fair only if the *same* PDF is the input to both arms,
|
||||
which is why we go to lengths to make the rendering deterministic.
|
||||
|
||||
Determinism notes:
|
||||
|
||||
* We pin the PDF metadata to a fixed creation date and producer
|
||||
(``reportlab`` accepts neither directly, but ``Canvas.setAuthor`` and
|
||||
the absence of an ``info`` mutator means the bytes only differ by
|
||||
``CreationDate`` / ``ModDate``). We post-process the PDF to scrub
|
||||
those if ``deterministic=True`` is passed.
|
||||
* Page size, font, margins, and tab handling are fixed in code so the
|
||||
same input yields the same byte output across machines.
|
||||
* PDF/A is overkill for our use; basic PDF 1.4 is what every model
|
||||
expects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import re
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from reportlab.lib.pagesizes import LETTER
|
||||
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.lib.utils import ImageReader
|
||||
from reportlab.platypus import (
|
||||
Image,
|
||||
KeepTogether,
|
||||
PageBreak,
|
||||
Paragraph,
|
||||
SimpleDocTemplate,
|
||||
Spacer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderedPdf:
|
||||
path: Path
|
||||
n_pages_estimate: int
|
||||
n_chars: int
|
||||
|
||||
|
||||
_PDF_DATE_KEY = re.compile(rb"/(?:CreationDate|ModDate)\s*\(D:[^)]*\)")
|
||||
# reportlab also writes a `/ID [<hex1><hex2>]` trailer entry that
|
||||
# embeds a per-run hash. Scrub it so two renders of the same input
|
||||
# produce the same bytes.
|
||||
_PDF_ID_ARRAY = re.compile(rb"/ID\s*\[\s*<[^>]*>\s*<[^>]*>\s*\]")
|
||||
|
||||
|
||||
def _scrub_dates(pdf_bytes: bytes) -> bytes:
|
||||
"""Remove ``CreationDate`` / ``ModDate`` / trailer ``/ID`` so the
|
||||
file is byte-deterministic across runs."""
|
||||
|
||||
pdf_bytes = _PDF_DATE_KEY.sub(b"/CreationDate (D:19700101000000Z)", pdf_bytes)
|
||||
pdf_bytes = _PDF_ID_ARRAY.sub(b"/ID [<00><00>]", pdf_bytes)
|
||||
return pdf_bytes
|
||||
|
||||
|
||||
_DEFAULT_STYLES = getSampleStyleSheet()
|
||||
|
||||
|
||||
def _build_body_style() -> ParagraphStyle:
|
||||
base = _DEFAULT_STYLES["BodyText"]
|
||||
style = ParagraphStyle(
|
||||
"EvalBody",
|
||||
parent=base,
|
||||
fontName="Helvetica",
|
||||
fontSize=10.5,
|
||||
leading=14,
|
||||
spaceAfter=6,
|
||||
spaceBefore=0,
|
||||
)
|
||||
return style
|
||||
|
||||
|
||||
def _build_heading_style() -> ParagraphStyle:
|
||||
base = _DEFAULT_STYLES["Heading2"]
|
||||
style = ParagraphStyle(
|
||||
"EvalHeading",
|
||||
parent=base,
|
||||
fontName="Helvetica-Bold",
|
||||
fontSize=14,
|
||||
leading=18,
|
||||
spaceAfter=10,
|
||||
spaceBefore=8,
|
||||
)
|
||||
return style
|
||||
|
||||
|
||||
def _normalise_paragraphs(text: str) -> list[str]:
|
||||
"""Split a text blob into paragraphs while preserving blank-line structure."""
|
||||
|
||||
blocks: list[list[str]] = [[]]
|
||||
for line in text.splitlines():
|
||||
stripped = line.rstrip()
|
||||
if stripped == "":
|
||||
if blocks[-1]:
|
||||
blocks.append([])
|
||||
continue
|
||||
blocks[-1].append(stripped)
|
||||
paragraphs: list[str] = []
|
||||
for block in blocks:
|
||||
if not block:
|
||||
continue
|
||||
# Join lines within a paragraph with spaces (text-from-PDF style).
|
||||
paragraphs.append(" ".join(block))
|
||||
return paragraphs
|
||||
|
||||
|
||||
def _escape_html(text: str) -> str:
|
||||
return (
|
||||
text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
|
||||
|
||||
def render_pdf(
|
||||
*,
|
||||
title: str,
|
||||
sections: Sequence[tuple[str | None, str]],
|
||||
output_path: Path,
|
||||
deterministic: bool = True,
|
||||
) -> RenderedPdf:
|
||||
"""Render one PDF from a list of ``(section_heading, section_text)`` tuples.
|
||||
|
||||
``section_heading`` may be ``None`` for an unnamed section. Each
|
||||
section is followed by a page break so the model's PDF parser sees
|
||||
a clean structural boundary between source files.
|
||||
"""
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=LETTER,
|
||||
leftMargin=0.75 * inch,
|
||||
rightMargin=0.75 * inch,
|
||||
topMargin=0.75 * inch,
|
||||
bottomMargin=0.75 * inch,
|
||||
title=title,
|
||||
author="surfsense-evals",
|
||||
subject="Eval input",
|
||||
creator="surfsense-evals",
|
||||
)
|
||||
|
||||
body_style = _build_body_style()
|
||||
heading_style = _build_heading_style()
|
||||
title_style = ParagraphStyle(
|
||||
"EvalTitle",
|
||||
parent=_DEFAULT_STYLES["Title"],
|
||||
fontName="Helvetica-Bold",
|
||||
fontSize=18,
|
||||
leading=22,
|
||||
spaceAfter=14,
|
||||
)
|
||||
|
||||
flow: list = [Paragraph(_escape_html(title), title_style)]
|
||||
total_chars = 0
|
||||
for index, (heading, text) in enumerate(sections):
|
||||
if index > 0:
|
||||
flow.append(PageBreak())
|
||||
if heading:
|
||||
flow.append(Paragraph(_escape_html(heading), heading_style))
|
||||
for paragraph in _normalise_paragraphs(text):
|
||||
total_chars += len(paragraph)
|
||||
flow.append(Paragraph(_escape_html(paragraph), body_style))
|
||||
flow.append(Spacer(1, 4))
|
||||
|
||||
doc.build(flow)
|
||||
pdf_bytes = buffer.getvalue()
|
||||
if deterministic:
|
||||
pdf_bytes = _scrub_dates(pdf_bytes)
|
||||
output_path.write_bytes(pdf_bytes)
|
||||
|
||||
# Conservative page estimate: ~3000 chars per LETTER page at 10.5pt.
|
||||
n_pages = max(1, total_chars // 3000 + len(sections))
|
||||
return RenderedPdf(path=output_path, n_pages_estimate=n_pages, n_chars=total_chars)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PdfImage:
|
||||
"""One image to embed inside a section.
|
||||
|
||||
``caption`` is rendered below the image (italic). ``max_width_in``
|
||||
caps the rendered width in inches; height auto-scales to preserve
|
||||
aspect ratio (read with PIL).
|
||||
"""
|
||||
|
||||
path: Path
|
||||
caption: str = ""
|
||||
max_width_in: float = 5.5 # default leaves margin for LETTER 8.5"
|
||||
|
||||
|
||||
def _make_image_flowable(image: PdfImage) -> Image:
|
||||
"""Build a reportlab Image flowable scaled to fit page width."""
|
||||
|
||||
reader = ImageReader(str(image.path))
|
||||
iw, ih = reader.getSize()
|
||||
if iw <= 0 or ih <= 0:
|
||||
raise ValueError(f"Invalid image dimensions for {image.path}: {iw}x{ih}")
|
||||
target_w = image.max_width_in * inch
|
||||
target_h = target_w * (ih / iw)
|
||||
# Cap height too — some medical images are extreme portrait.
|
||||
max_h = 7.0 * inch
|
||||
if target_h > max_h:
|
||||
target_h = max_h
|
||||
target_w = target_h * (iw / ih)
|
||||
return Image(str(image.path), width=target_w, height=target_h)
|
||||
|
||||
|
||||
def render_pdf_with_images(
|
||||
*,
|
||||
title: str,
|
||||
sections: Sequence[tuple[str | None, str, Sequence[PdfImage] | None]],
|
||||
output_path: Path,
|
||||
deterministic: bool = True,
|
||||
page_break_between_sections: bool = False,
|
||||
) -> RenderedPdf:
|
||||
"""Render a PDF that mixes text and embedded images.
|
||||
|
||||
Each section is ``(heading, body_text, images)``. Images render
|
||||
inline after the body text, each followed by an italic caption.
|
||||
Set ``page_break_between_sections=True`` if you want explicit
|
||||
structural boundaries (mostly useful for multi-case PDFs); the
|
||||
default keeps everything on one page when possible (so a single
|
||||
MedXpertQA case is one PDF page with case + images + options).
|
||||
"""
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=LETTER,
|
||||
leftMargin=0.75 * inch,
|
||||
rightMargin=0.75 * inch,
|
||||
topMargin=0.75 * inch,
|
||||
bottomMargin=0.75 * inch,
|
||||
title=title,
|
||||
author="surfsense-evals",
|
||||
subject="Eval input",
|
||||
creator="surfsense-evals",
|
||||
)
|
||||
|
||||
body_style = _build_body_style()
|
||||
heading_style = _build_heading_style()
|
||||
caption_style = ParagraphStyle(
|
||||
"EvalCaption",
|
||||
parent=body_style,
|
||||
fontSize=9,
|
||||
leading=11,
|
||||
textColor="#444",
|
||||
spaceBefore=2,
|
||||
spaceAfter=10,
|
||||
)
|
||||
title_style = ParagraphStyle(
|
||||
"EvalTitle",
|
||||
parent=_DEFAULT_STYLES["Title"],
|
||||
fontName="Helvetica-Bold",
|
||||
fontSize=18,
|
||||
leading=22,
|
||||
spaceAfter=14,
|
||||
)
|
||||
|
||||
flow: list = [Paragraph(_escape_html(title), title_style)]
|
||||
total_chars = 0
|
||||
for index, (heading, text, images) in enumerate(sections):
|
||||
if index > 0 and page_break_between_sections:
|
||||
flow.append(PageBreak())
|
||||
if heading:
|
||||
flow.append(Paragraph(_escape_html(heading), heading_style))
|
||||
for paragraph in _normalise_paragraphs(text):
|
||||
total_chars += len(paragraph)
|
||||
flow.append(Paragraph(_escape_html(paragraph), body_style))
|
||||
flow.append(Spacer(1, 4))
|
||||
for image in images or []:
|
||||
try:
|
||||
img_flow = _make_image_flowable(image)
|
||||
except Exception: # noqa: BLE001 — bad image shouldn't kill PDF
|
||||
continue
|
||||
grouped = [img_flow]
|
||||
if image.caption:
|
||||
grouped.append(Paragraph(_escape_html(image.caption), caption_style))
|
||||
else:
|
||||
grouped.append(Spacer(1, 8))
|
||||
flow.append(KeepTogether(grouped))
|
||||
|
||||
doc.build(flow)
|
||||
pdf_bytes = buffer.getvalue()
|
||||
if deterministic:
|
||||
pdf_bytes = _scrub_dates(pdf_bytes)
|
||||
output_path.write_bytes(pdf_bytes)
|
||||
|
||||
n_pages = max(1, total_chars // 3000 + len(sections))
|
||||
return RenderedPdf(path=output_path, n_pages_estimate=n_pages, n_chars=total_chars)
|
||||
|
||||
|
||||
def render_text_files_to_pdf(
|
||||
*,
|
||||
title: str,
|
||||
files: Iterable[Path],
|
||||
output_path: Path,
|
||||
deterministic: bool = True,
|
||||
) -> RenderedPdf:
|
||||
"""Convenience wrapper: read a list of text files, render to one PDF.
|
||||
|
||||
The heading of each section is the file's name (no extension), so
|
||||
e.g. ``admission_note.txt`` becomes a section header ``admission_note``
|
||||
in the rendered PDF. Useful for any text-only benchmark that ships
|
||||
a corpus as separate ``.txt`` / ``.md`` shards per logical document.
|
||||
"""
|
||||
|
||||
sections: list[tuple[str | None, str]] = []
|
||||
for path in files:
|
||||
path = Path(path)
|
||||
text = path.read_text(encoding="utf-8")
|
||||
sections.append((path.stem, text))
|
||||
return render_pdf(
|
||||
title=title,
|
||||
sections=sections,
|
||||
output_path=output_path,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
|
||||
# Tiny self-check — handy when debugging.
|
||||
def _self_test() -> None: # pragma: no cover
|
||||
out = Path("./_render_self_test.pdf")
|
||||
sections = [
|
||||
("intro", "Hello world.\n\nThis is a test."),
|
||||
("body", "Line one.\nLine two."),
|
||||
]
|
||||
rendered = render_pdf(title="Self test", sections=sections, output_path=out)
|
||||
print(f"wrote {rendered.path} ({rendered.n_chars} chars)")
|
||||
|
||||
|
||||
# Importing ``datetime`` keeps the timezone helper handy if a future
|
||||
# benchmark wants to embed a real timestamp without losing determinism.
|
||||
_NOW_FROZEN = datetime(2026, 5, 11, tzinfo=UTC)
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
"""External LLM providers (used by the native arm).
|
||||
|
||||
Lazy imports so the SurfSense-only path doesn't transitively load the
|
||||
OpenRouter client until something actually constructs ``OpenRouterPdfProvider``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .openrouter_pdf import OpenRouterPdfProvider, OpenRouterResponse
|
||||
|
||||
__all__ = ["OpenRouterPdfProvider", "OpenRouterResponse"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in {"OpenRouterPdfProvider", "OpenRouterResponse"}:
|
||||
from . import openrouter_pdf as _mod
|
||||
|
||||
return getattr(_mod, name)
|
||||
raise AttributeError(f"module 'surfsense_evals.core.providers' has no attribute {name!r}")
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
"""Bare OpenRouter ``chat/completions`` provider — no PDF, no plugins.
|
||||
|
||||
Used by ``BareLlmArm`` to measure "what does the model answer with
|
||||
zero retrieval context?". Same wire shape as ``OpenRouterPdfProvider``
|
||||
minus the file-parser plugin and the ``file`` content part:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "openai/gpt-5.4-mini",
|
||||
"messages": [
|
||||
{"role": "system", "content": "<optional>"},
|
||||
{"role": "user", "content": "<prompt>"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The response shape is identical to the PDF provider's, so we re-use
|
||||
``_parse_chat_completion`` from ``openrouter_pdf`` and only specialise
|
||||
the request builder. That keeps cost-extraction, token-counting, and
|
||||
content-array handling in one place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .openrouter_pdf import (
|
||||
OpenRouterResponse,
|
||||
_DEFAULT_HEADERS,
|
||||
_parse_chat_completion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenRouterChatProvider:
|
||||
"""Stateless bare-chat client. No PDF, no file-parser plugin."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
model: str,
|
||||
timeout_s: float = 600.0,
|
||||
) -> None:
|
||||
if not api_key:
|
||||
raise ValueError("OPENROUTER_API_KEY is required for the bare-LLM arm.")
|
||||
self._api_key = api_key
|
||||
self._base = base_url.rstrip("/")
|
||||
self._model = model
|
||||
self._timeout = httpx.Timeout(timeout_s, connect=15.0)
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
system_prompt: str | None,
|
||||
max_tokens: int | None,
|
||||
) -> dict[str, Any]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
body: dict[str, Any] = {"model": self._model, "messages": messages}
|
||||
if max_tokens:
|
||||
body["max_tokens"] = max_tokens
|
||||
return body
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
http: httpx.AsyncClient | None = None,
|
||||
) -> OpenRouterResponse:
|
||||
"""Single chat completion. Errors are raised verbatim — caller decides retries."""
|
||||
|
||||
payload = self._build_payload(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
**_DEFAULT_HEADERS,
|
||||
}
|
||||
url = f"{self._base}/chat/completions"
|
||||
started = time.monotonic()
|
||||
if http is not None:
|
||||
response = await http.post(url, json=payload, headers=headers, timeout=self._timeout)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.post(
|
||||
url, json=payload, headers=headers, timeout=self._timeout
|
||||
)
|
||||
latency_ms = int((time.monotonic() - started) * 1000)
|
||||
if response.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"OpenRouter HTTP {response.status_code}: {response.text[:300]}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
return _parse_chat_completion(response.json(), latency_ms=latency_ms)
|
||||
|
||||
|
||||
__all__ = ["OpenRouterChatProvider"]
|
||||
|
|
@ -0,0 +1,231 @@
|
|||
"""Native-PDF arm provider: OpenRouter ``chat/completions`` with PDF input.
|
||||
|
||||
Per `<https://openrouter.ai/docs/features/multimodal/pdfs>`__ the wire
|
||||
shape is OpenAI-compatible with one PDF-specific extra:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "anthropic/claude-sonnet-4.5",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "file", "file": {"filename": "case.pdf",
|
||||
"file_data": "data:application/pdf;base64,<b64>"}},
|
||||
{"type": "text", "text": "<prompt>"}
|
||||
]
|
||||
}],
|
||||
"plugins": [{"id": "file-parser", "pdf": {"engine": "native"}}]
|
||||
}
|
||||
```
|
||||
|
||||
``engine: "native"`` is the only engine that doesn't pre-OCR the
|
||||
PDF — it forwards raw bytes to PDF-native models (Claude, Gemini),
|
||||
matching what a human user does when "dropping the PDF into Claude".
|
||||
``mistral-ocr`` and ``cloudflare-ai`` are exposed as enum options for
|
||||
non-native models.
|
||||
|
||||
Headers ``HTTP-Referer`` and ``X-Title`` make spend show up cleanly on
|
||||
the OpenRouter dashboard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfEngine(str, Enum):
|
||||
NATIVE = "native"
|
||||
MISTRAL_OCR = "mistral-ocr"
|
||||
CLOUDFLARE_AI = "cloudflare-ai"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenRouterResponse:
|
||||
"""Subset of the OpenRouter response we care about for scoring."""
|
||||
|
||||
text: str
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
cost_micros: int
|
||||
latency_ms: int
|
||||
finish_reason: str | None
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
_DEFAULT_HEADERS = {
|
||||
"HTTP-Referer": "https://github.com/MODSetter/SurfSense",
|
||||
"X-Title": "SurfSense-evals",
|
||||
}
|
||||
|
||||
|
||||
class OpenRouterPdfProvider:
|
||||
"""Thin httpx-based client. Stateless; safe to reuse per arm instance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
model: str,
|
||||
engine: PdfEngine = PdfEngine.NATIVE,
|
||||
timeout_s: float = 600.0,
|
||||
) -> None:
|
||||
if not api_key:
|
||||
raise ValueError("OPENROUTER_API_KEY is required for the native arm.")
|
||||
self._api_key = api_key
|
||||
self._base = base_url.rstrip("/")
|
||||
self._model = model
|
||||
self._engine = engine
|
||||
self._timeout = httpx.Timeout(timeout_s, connect=15.0)
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def engine(self) -> PdfEngine:
|
||||
return self._engine
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
pdf_path: Path,
|
||||
max_tokens: int | None,
|
||||
extra_messages: list[dict[str, Any]] | None,
|
||||
) -> dict[str, Any]:
|
||||
b64 = base64.b64encode(pdf_path.read_bytes()).decode("ascii")
|
||||
user_content: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"filename": pdf_path.name,
|
||||
"file_data": f"data:application/pdf;base64,{b64}",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
]
|
||||
messages: list[dict[str, Any]] = list(extra_messages or [])
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
body: dict[str, Any] = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"plugins": [
|
||||
{"id": "file-parser", "pdf": {"engine": self._engine.value}}
|
||||
],
|
||||
}
|
||||
if max_tokens:
|
||||
body["max_tokens"] = max_tokens
|
||||
return body
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
pdf_path: Path,
|
||||
max_tokens: int | None = None,
|
||||
extra_messages: list[dict[str, Any]] | None = None,
|
||||
http: httpx.AsyncClient | None = None,
|
||||
) -> OpenRouterResponse:
|
||||
"""Single chat completion. Errors are raised verbatim — runner decides retries."""
|
||||
|
||||
payload = self._build_payload(
|
||||
prompt=prompt,
|
||||
pdf_path=pdf_path,
|
||||
max_tokens=max_tokens,
|
||||
extra_messages=extra_messages,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
**_DEFAULT_HEADERS,
|
||||
}
|
||||
url = f"{self._base}/chat/completions"
|
||||
started = time.monotonic()
|
||||
if http is not None:
|
||||
response = await http.post(url, json=payload, headers=headers, timeout=self._timeout)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.post(
|
||||
url, json=payload, headers=headers, timeout=self._timeout
|
||||
)
|
||||
latency_ms = int((time.monotonic() - started) * 1000)
|
||||
if response.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"OpenRouter HTTP {response.status_code}: {response.text[:300]}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
data = response.json()
|
||||
return _parse_chat_completion(data, latency_ms=latency_ms)
|
||||
|
||||
|
||||
def _parse_chat_completion(payload: dict[str, Any], *, latency_ms: int) -> OpenRouterResponse:
|
||||
"""Tolerant parser for OpenRouter / OpenAI chat-completions JSON.
|
||||
|
||||
OpenRouter passes through any provider-specific extras, but the
|
||||
canonical shape is ``choices[0].message.content`` (string OR array
|
||||
of content parts) and ``usage.prompt_tokens / completion_tokens / total_tokens``.
|
||||
Cost lives at the top level (``payload["usage"]["cost"]`` or
|
||||
``payload["x-or-cost"]``) depending on routing.
|
||||
"""
|
||||
|
||||
text = ""
|
||||
finish_reason: str | None = None
|
||||
choices = payload.get("choices") or []
|
||||
if choices:
|
||||
message = (choices[0] or {}).get("message") or {}
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text = content
|
||||
elif isinstance(content, list):
|
||||
chunks: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") in {"text", "output_text"}:
|
||||
chunks.append(str(part.get("text", "")))
|
||||
text = "".join(chunks)
|
||||
finish_reason = (choices[0] or {}).get("finish_reason") or None
|
||||
|
||||
usage = payload.get("usage") or {}
|
||||
input_tokens = int(usage.get("prompt_tokens") or 0)
|
||||
output_tokens = int(usage.get("completion_tokens") or 0)
|
||||
total_tokens = int(usage.get("total_tokens") or (input_tokens + output_tokens))
|
||||
|
||||
# OpenRouter exposes cost in dollars on `usage.cost` or `cost`. We
|
||||
# convert to integer micros to avoid float-summing surprises across
|
||||
# 7,663 MIRAGE questions.
|
||||
raw_cost = usage.get("cost")
|
||||
if raw_cost is None:
|
||||
raw_cost = payload.get("cost")
|
||||
cost_micros = 0
|
||||
if raw_cost is not None:
|
||||
try:
|
||||
cost_micros = int(round(float(raw_cost) * 1_000_000))
|
||||
except (TypeError, ValueError):
|
||||
cost_micros = 0
|
||||
|
||||
return OpenRouterResponse(
|
||||
text=text,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
latency_ms=latency_ms,
|
||||
finish_reason=finish_reason,
|
||||
raw=payload,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["OpenRouterPdfProvider", "OpenRouterResponse", "PdfEngine"]
|
||||
265
surfsense_evals/src/surfsense_evals/core/registry.py
Normal file
265
surfsense_evals/src/surfsense_evals/core/registry.py
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
"""Suite + Benchmark protocols and the global registry.
|
||||
|
||||
The extensibility seam: ``core.cli`` walks ``surfsense_evals.suites`` on
|
||||
import, which auto-imports every benchmark subpackage, which calls
|
||||
``register(<benchmark>)`` at module bottom. The CLI then iterates the
|
||||
populated registry to build subcommand groups dynamically.
|
||||
|
||||
Adding a new domain = drop a folder under ``suites/<domain>/<bench>/``
|
||||
that ends in ``register(MyBenchmark())``. No edits anywhere in
|
||||
``core/`` are required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
import httpx
|
||||
|
||||
from .clients import DocumentsClient, NewChatClient, SearchSpaceClient
|
||||
from .config import Config, SuiteState
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run context — what every benchmark.ingest/run receives
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunContext:
|
||||
"""Per-invocation environment threaded into ``ingest`` and ``run``.
|
||||
|
||||
A benchmark uses this to read pinned suite state, build new HTTP
|
||||
clients on the shared ``http`` session, find the right data /
|
||||
reports paths, and discover the active OpenRouter model + key.
|
||||
|
||||
``http`` is the authenticated SurfSense client (auth event hook
|
||||
attached). It is **not** an OpenRouter client — providers create
|
||||
their own short-lived clients because OpenRouter doesn't share the
|
||||
SurfSense bearer.
|
||||
"""
|
||||
|
||||
suite: str
|
||||
benchmark: str
|
||||
config: Config
|
||||
suite_state: SuiteState
|
||||
http: httpx.AsyncClient
|
||||
|
||||
@property
|
||||
def search_space_id(self) -> int:
|
||||
return self.suite_state.search_space_id
|
||||
|
||||
@property
|
||||
def agent_llm_id(self) -> int:
|
||||
return self.suite_state.agent_llm_id
|
||||
|
||||
@property
|
||||
def provider_model(self) -> str:
|
||||
"""Slug used by the SurfSense agent (and the native arm by default).
|
||||
|
||||
For ``cost-arbitrage`` scenarios this is the *cheap, text-only*
|
||||
slug — SurfSense answers from the chunks the vision LLM already
|
||||
extracted at ingest. The native arm should use
|
||||
``native_arm_model`` instead in that scenario.
|
||||
"""
|
||||
|
||||
return self.suite_state.provider_model
|
||||
|
||||
@property
|
||||
def native_arm_model(self) -> str:
|
||||
"""Slug the native_pdf arm should use.
|
||||
|
||||
Defaults to ``provider_model`` (head-to-head / symmetric-cheap);
|
||||
for ``cost-arbitrage`` it returns the explicit
|
||||
``--native-arm-model`` so the native arm can fairly answer
|
||||
image-bearing questions.
|
||||
"""
|
||||
|
||||
return self.suite_state.effective_native_arm_model
|
||||
|
||||
@property
|
||||
def vision_provider_model(self) -> str | None:
|
||||
"""Slug of the OpenRouter vision LLM SurfSense used at ingest.
|
||||
|
||||
``None`` if no vision config was attached at setup (legacy or
|
||||
text-only suite). Used by runners purely to record what was
|
||||
actually used in ``RunArtifact.extra`` and to label reports.
|
||||
"""
|
||||
|
||||
return self.suite_state.vision_provider_model
|
||||
|
||||
@property
|
||||
def scenario(self) -> str:
|
||||
"""Scenario name pinned at setup time (see ``config.SCENARIOS``)."""
|
||||
|
||||
return self.suite_state.scenario
|
||||
|
||||
def search_space_client(self) -> SearchSpaceClient:
|
||||
return SearchSpaceClient(self.http, self.config.surfsense_api_base)
|
||||
|
||||
def documents_client(self) -> DocumentsClient:
|
||||
return DocumentsClient(self.http, self.config.surfsense_api_base)
|
||||
|
||||
def new_chat_client(self) -> NewChatClient:
|
||||
return NewChatClient(self.http, self.config.surfsense_api_base)
|
||||
|
||||
def maps_dir(self) -> Path:
|
||||
path = self.config.suite_maps_dir(self.suite)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
def runs_dir(self, *, run_timestamp: str) -> Path:
|
||||
path = self.config.suite_runs_dir(self.suite) / run_timestamp / self.benchmark
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
def benchmark_data_dir(self) -> Path:
|
||||
path = self.config.suite_data_dir(self.suite) / self.benchmark
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run artifact + report section
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunArtifact:
|
||||
"""Everything a runner persists for the report writer to consume.
|
||||
|
||||
``raw_path`` points at the JSONL of per-question ``ArmResult``
|
||||
rows. ``metrics`` is a free-form dict the benchmark fills in (e.g.
|
||||
``{"native": {...}, "surfsense": {...}, "delta": {...}}``).
|
||||
"""
|
||||
|
||||
suite: str
|
||||
benchmark: str
|
||||
run_timestamp: str
|
||||
raw_path: Path
|
||||
metrics: dict[str, Any] = field(default_factory=dict)
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReportSection:
|
||||
"""One benchmark's slice of the final summary."""
|
||||
|
||||
title: str
|
||||
headline: bool
|
||||
body_md: str
|
||||
body_json: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark protocol + registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Benchmark(Protocol):
|
||||
"""The contract every benchmark module ends with ``register(<x>)``."""
|
||||
|
||||
suite: str
|
||||
name: str
|
||||
headline: bool
|
||||
description: str
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None: # pragma: no cover - protocol
|
||||
"""Add benchmark-specific flags to ``run <suite> <benchmark>``."""
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry storage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_REGISTRY: dict[tuple[str, str], Benchmark] = {}
|
||||
|
||||
|
||||
def register(benchmark: Benchmark) -> None:
|
||||
"""Add ``benchmark`` to the registry. Last-wins on duplicate keys.
|
||||
|
||||
Duplicate registrations log a warning rather than raising so a
|
||||
benchmark module imported twice (once via auto-discovery, once via
|
||||
a test directly importing it) doesn't blow up the CLI.
|
||||
"""
|
||||
|
||||
key = (benchmark.suite, benchmark.name)
|
||||
if key in _REGISTRY:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"Benchmark %s/%s re-registered (overwriting prior)", *key
|
||||
)
|
||||
_REGISTRY[key] = benchmark
|
||||
|
||||
|
||||
def unregister(suite: str, name: str) -> None:
|
||||
"""Test helper: drop a single benchmark from the registry."""
|
||||
|
||||
_REGISTRY.pop((suite, name), None)
|
||||
|
||||
|
||||
def reset() -> None:
|
||||
"""Test helper: wipe the registry (use with monkeypatched discovery)."""
|
||||
|
||||
_REGISTRY.clear()
|
||||
|
||||
|
||||
def get(suite: str, name: str) -> Benchmark:
|
||||
try:
|
||||
return _REGISTRY[(suite, name)]
|
||||
except KeyError as exc:
|
||||
available = ", ".join(f"{s}/{n}" for s, n in sorted(_REGISTRY)) or "<none>"
|
||||
raise KeyError(
|
||||
f"Unknown benchmark '{suite}/{name}'. Registered: {available}"
|
||||
) from exc
|
||||
|
||||
|
||||
def list_suites() -> list[str]:
|
||||
return sorted({s for s, _ in _REGISTRY})
|
||||
|
||||
|
||||
def list_benchmarks(suite: str | None = None) -> list[Benchmark]:
|
||||
if suite is None:
|
||||
return [_REGISTRY[k] for k in sorted(_REGISTRY)]
|
||||
return [_REGISTRY[k] for k in sorted(_REGISTRY) if k[0] == suite]
|
||||
|
||||
|
||||
def snapshot() -> Mapping[tuple[str, str], Benchmark]:
|
||||
"""Read-only view for diagnostics (e.g. ``benchmarks list`` rendering)."""
|
||||
|
||||
return dict(_REGISTRY)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Arm",
|
||||
"Benchmark",
|
||||
"ReportSection",
|
||||
"RunArtifact",
|
||||
"RunContext",
|
||||
"get",
|
||||
"list_benchmarks",
|
||||
"list_suites",
|
||||
"register",
|
||||
"reset",
|
||||
"snapshot",
|
||||
"unregister",
|
||||
]
|
||||
|
||||
|
||||
# Re-export Arm from arms.base so suites can `from core.registry import Arm`.
|
||||
from .arms.base import Arm # noqa: E402, F401 (deliberate re-export at bottom)
|
||||
18
surfsense_evals/src/surfsense_evals/core/report/__init__.py
Normal file
18
surfsense_evals/src/surfsense_evals/core/report/__init__.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Report writer + section composition primitives. Lazy import."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .writer import write_report
|
||||
|
||||
__all__ = ["write_report"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "write_report":
|
||||
from .writer import write_report
|
||||
|
||||
return write_report
|
||||
raise AttributeError(f"module 'surfsense_evals.core.report' has no attribute {name!r}")
|
||||
89
surfsense_evals/src/surfsense_evals/core/report/writer.py
Normal file
89
surfsense_evals/src/surfsense_evals/core/report/writer.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Report writer — composes per-benchmark sections into one summary.
|
||||
|
||||
Output:
|
||||
|
||||
* ``reports/<suite>/<run-timestamp>/summary.md`` — human-readable.
|
||||
Bullet lists only (no tables) per project's coding-standards.
|
||||
* ``reports/<suite>/<run-timestamp>/summary.json`` — same content as
|
||||
structured JSON for downstream tooling (CI dashboards, regressions).
|
||||
|
||||
Headline benchmarks come first in both outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
|
||||
from ..config import Config
|
||||
from ..registry import ReportSection
|
||||
|
||||
|
||||
def write_report(
|
||||
*,
|
||||
config: Config,
|
||||
suite: str,
|
||||
sections: Iterable[ReportSection],
|
||||
run_timestamp: str,
|
||||
) -> Path:
|
||||
"""Write ``summary.md`` + ``summary.json``. Returns the path of the .md file."""
|
||||
|
||||
sections_list = list(sections)
|
||||
sections_list.sort(key=lambda s: (not s.headline, s.title.lower()))
|
||||
|
||||
out_dir = config.suite_reports_dir(suite) / run_timestamp
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
md_path = out_dir / "summary.md"
|
||||
json_path = out_dir / "summary.json"
|
||||
|
||||
md_lines: list[str] = [
|
||||
f"# SurfSense evals — suite `{suite}`",
|
||||
"",
|
||||
f"- Run timestamp: `{run_timestamp}`",
|
||||
f"- Sections: {len(sections_list)}",
|
||||
"",
|
||||
]
|
||||
headline = [s for s in sections_list if s.headline]
|
||||
secondary = [s for s in sections_list if not s.headline]
|
||||
if headline:
|
||||
md_lines.append("## Headline")
|
||||
md_lines.append("")
|
||||
for section in headline:
|
||||
md_lines.append(f"### {section.title}")
|
||||
md_lines.append("")
|
||||
md_lines.append(section.body_md.rstrip())
|
||||
md_lines.append("")
|
||||
if secondary:
|
||||
md_lines.append("## Secondary measurements")
|
||||
md_lines.append("")
|
||||
for section in secondary:
|
||||
md_lines.append(f"### {section.title}")
|
||||
md_lines.append("")
|
||||
md_lines.append(section.body_md.rstrip())
|
||||
md_lines.append("")
|
||||
|
||||
md_path.write_text("\n".join(md_lines).rstrip() + "\n", encoding="utf-8")
|
||||
|
||||
json_payload = {
|
||||
"suite": suite,
|
||||
"run_timestamp": run_timestamp,
|
||||
"sections": [
|
||||
{
|
||||
"title": s.title,
|
||||
"headline": s.headline,
|
||||
"body_md": s.body_md,
|
||||
"body_json": s.body_json,
|
||||
}
|
||||
for s in sections_list
|
||||
],
|
||||
}
|
||||
json_path.write_text(
|
||||
json.dumps(json_payload, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return md_path
|
||||
|
||||
|
||||
__all__ = ["ReportSection", "write_report"]
|
||||
58
surfsense_evals/src/surfsense_evals/core/scenarios.py
Normal file
58
surfsense_evals/src/surfsense_evals/core/scenarios.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Shared scenario formatting helpers for head-to-head benchmark reports.
|
||||
|
||||
The scenario chosen at ``setup`` time (``head-to-head``, ``symmetric-cheap``,
|
||||
``cost-arbitrage``) materially changes how a head-to-head report should be
|
||||
read. This module produces the one-bullet summary every head-to-head
|
||||
runner stamps near the top of its ``report_section`` body so reviewers
|
||||
immediately see the framing — no need to dig into ``run_artifact.json``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
|
||||
def format_scenario_md(extra: Mapping[str, Any] | None) -> str:
|
||||
"""Render a scenario-aware bullet for a benchmark report.
|
||||
|
||||
Reads ``extra["scenario"]`` plus the runtime LLM slugs the runner
|
||||
recorded. Falls back to a sensible "head-to-head" line if the artifact
|
||||
pre-dates scenarios so old runs still render cleanly.
|
||||
"""
|
||||
|
||||
extra = dict(extra or {})
|
||||
scenario = str(extra.get("scenario") or "head-to-head")
|
||||
surf_slug = str(extra.get("provider_model") or "?")
|
||||
native_slug = str(extra.get("native_arm_model") or surf_slug)
|
||||
vision_slug = extra.get("vision_provider_model")
|
||||
|
||||
if scenario == "cost-arbitrage":
|
||||
body = (
|
||||
f"- Scenario: **cost-arbitrage** — native arm answers with "
|
||||
f"`{native_slug}` (vision); SurfSense answers with `{surf_slug}` "
|
||||
f"over chunks vision-extracted at ingest"
|
||||
f"{f' by `{vision_slug}`' if vision_slug else ''}. "
|
||||
"Measures how close SurfSense gets to native at a fraction of "
|
||||
"the per-query cost."
|
||||
)
|
||||
elif scenario == "symmetric-cheap":
|
||||
body = (
|
||||
f"- Scenario: **symmetric-cheap** — both arms answer with "
|
||||
f"`{surf_slug}`; SurfSense pre-extracted images at ingest"
|
||||
f"{f' via `{vision_slug}`' if vision_slug else ''}. "
|
||||
"Native arm structurally loses on image-bearing questions "
|
||||
"(text-only model can't see images) — that's the point."
|
||||
)
|
||||
else:
|
||||
body = (
|
||||
f"- Scenario: head-to-head — both arms answer with `{surf_slug}` "
|
||||
"via OpenRouter."
|
||||
)
|
||||
if vision_slug:
|
||||
body += f" SurfSense ingest VLM: `{vision_slug}`."
|
||||
|
||||
return body
|
||||
|
||||
|
||||
__all__ = ["format_scenario_md"]
|
||||
127
surfsense_evals/src/surfsense_evals/core/vision_llm.py
Normal file
127
surfsense_evals/src/surfsense_evals/core/vision_llm.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Vision LLM resolution + auto-pick logic for the harness's ``setup`` command.
|
||||
|
||||
Two responsibilities:
|
||||
|
||||
1. Resolve an explicit ``--vision-llm <slug>`` to a global OpenRouter
|
||||
vision LLM config id that ``set_llm_preferences(vision_llm_config_id=...)``
|
||||
can accept.
|
||||
2. Auto-pick the strongest registered vision config when the operator
|
||||
doesn't pass ``--vision-llm`` but the scenario / benchmark needs one.
|
||||
|
||||
The priority list mirrors the recommended slugs in the README so the
|
||||
auto-pick is deterministic and reviewable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .clients.search_space import VisionLlmConfigEntry
|
||||
|
||||
# Order matters — first match wins when auto-picking. Keep these in sync
|
||||
# with the "Recommended vision slugs" table in the README so the
|
||||
# auto-pick story is the same one users read about.
|
||||
RECOMMENDED_VISION_PRIORITY: tuple[str, ...] = (
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"anthropic/claude-opus-4.7",
|
||||
"openai/gpt-5",
|
||||
"google/gemini-2.5-pro",
|
||||
)
|
||||
|
||||
|
||||
class VisionConfigError(RuntimeError):
|
||||
"""Raised when no vision config can be resolved (explicit or auto)."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedVisionConfig:
|
||||
"""Result of ``resolve_vision_llm`` — what to attach + a label for logs."""
|
||||
|
||||
config_id: int
|
||||
provider_model: str
|
||||
selected_via: str # "explicit" | "auto-priority" | "auto-fallback"
|
||||
|
||||
|
||||
def _openrouter_only(entries: Iterable[VisionLlmConfigEntry]) -> list[VisionLlmConfigEntry]:
|
||||
return [e for e in entries if e.provider == "OPENROUTER" and not e.is_auto_mode]
|
||||
|
||||
|
||||
def resolve_vision_llm(
|
||||
candidates: list[VisionLlmConfigEntry],
|
||||
*,
|
||||
explicit_slug: str | None,
|
||||
) -> ResolvedVisionConfig:
|
||||
"""Resolve a vision LLM config id from a slug or by auto-picking.
|
||||
|
||||
* If ``explicit_slug`` is given: must match exactly one OpenRouter
|
||||
vision config's ``model_name``. Raises ``VisionConfigError`` with a
|
||||
friendly listing if zero / many match.
|
||||
* Otherwise: walk ``RECOMMENDED_VISION_PRIORITY`` in order and return
|
||||
the first registered one. If none of the recommended slugs are
|
||||
registered, fall back to the first OpenRouter vision config in the
|
||||
list (deterministic by listing order). Raises ``VisionConfigError``
|
||||
if zero are registered at all.
|
||||
"""
|
||||
|
||||
or_vision = _openrouter_only(candidates)
|
||||
|
||||
if explicit_slug is not None:
|
||||
matches = [e for e in or_vision if e.model_name == explicit_slug]
|
||||
if not matches:
|
||||
sample = ", ".join(e.model_name for e in or_vision[:8]) or "<none>"
|
||||
raise VisionConfigError(
|
||||
f"No OpenRouter vision config found for slug '{explicit_slug}'. "
|
||||
"Make sure `openrouter_integration.vision_enabled: true` in "
|
||||
"global_llm_config.yaml and that the Celery worker has finished "
|
||||
"its first refresh. "
|
||||
f"Available OpenRouter vision slugs (sample): {sample}."
|
||||
)
|
||||
if len(matches) > 1:
|
||||
listing = "\n".join(f" id={e.id} name={e.name!r}" for e in matches)
|
||||
raise VisionConfigError(
|
||||
f"Multiple OpenRouter vision configs match '{explicit_slug}':\n{listing}"
|
||||
)
|
||||
only = matches[0]
|
||||
return ResolvedVisionConfig(
|
||||
config_id=only.id,
|
||||
provider_model=only.model_name,
|
||||
selected_via="explicit",
|
||||
)
|
||||
|
||||
if not or_vision:
|
||||
raise VisionConfigError(
|
||||
"No OpenRouter vision LLM configs are registered with this "
|
||||
"SurfSense backend. Either pass `--no-vision-llm` to the ingest "
|
||||
"step (text-only ingestion), or enable "
|
||||
"`openrouter_integration.vision_enabled: true` in "
|
||||
"global_llm_config.yaml so the Celery worker syncs vision-capable "
|
||||
"OpenRouter models on next refresh."
|
||||
)
|
||||
|
||||
by_slug = {e.model_name: e for e in or_vision}
|
||||
for preferred in RECOMMENDED_VISION_PRIORITY:
|
||||
match = by_slug.get(preferred)
|
||||
if match is not None:
|
||||
return ResolvedVisionConfig(
|
||||
config_id=match.id,
|
||||
provider_model=match.model_name,
|
||||
selected_via="auto-priority",
|
||||
)
|
||||
|
||||
# Fallback: first registered OpenRouter vision config. Deterministic
|
||||
# because the backend returns them in a stable order.
|
||||
fallback = or_vision[0]
|
||||
return ResolvedVisionConfig(
|
||||
config_id=fallback.id,
|
||||
provider_model=fallback.model_name,
|
||||
selected_via="auto-fallback",
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RECOMMENDED_VISION_PRIORITY",
|
||||
"ResolvedVisionConfig",
|
||||
"VisionConfigError",
|
||||
"resolve_vision_llm",
|
||||
]
|
||||
66
surfsense_evals/src/surfsense_evals/suites/__init__.py
Normal file
66
surfsense_evals/src/surfsense_evals/suites/__init__.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Suite registry auto-discovery.
|
||||
|
||||
Importing ``surfsense_evals.suites`` walks every subpackage one level deep
|
||||
(domain like ``medical``) AND its benchmark subpackages
|
||||
(``medical/medxpertqa``, ``medical/mirage``, ``medical/cure``). Each
|
||||
benchmark's ``__init__.py`` is expected to call
|
||||
``core.registry.register(<Benchmark>)`` at module bottom; merely importing
|
||||
the module is enough to populate the registry.
|
||||
|
||||
Adding a new domain is therefore: drop a folder under ``suites/`` with the
|
||||
right structure. No edits anywhere else.
|
||||
|
||||
Subpackages whose name starts with ``_`` are skipped — that's reserved for
|
||||
test fixtures (e.g. ``suites/_demo/``) so they don't accidentally show up
|
||||
in ``benchmarks list``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
from typing import Iterable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _iter_subpackages(package) -> Iterable[str]:
|
||||
"""Yield fully-qualified subpackage names one level deep, skipping ``_*``."""
|
||||
|
||||
for module_info in pkgutil.iter_modules(package.__path__, prefix=f"{package.__name__}."):
|
||||
if not module_info.ispkg:
|
||||
continue
|
||||
leaf = module_info.name.rsplit(".", 1)[-1]
|
||||
if leaf.startswith("_"):
|
||||
continue
|
||||
yield module_info.name
|
||||
|
||||
|
||||
def discover_suites() -> list[str]:
|
||||
"""Import every domain + benchmark subpackage so registrations fire.
|
||||
|
||||
Returns the list of fully-qualified benchmark module names that were
|
||||
successfully imported. Failures are logged (not raised) so a single
|
||||
broken benchmark doesn't take down the whole CLI — the operator still
|
||||
sees the working benchmarks via ``benchmarks list``.
|
||||
"""
|
||||
|
||||
import surfsense_evals.suites as _suites # self-import for __path__
|
||||
|
||||
imported: list[str] = []
|
||||
for domain_name in _iter_subpackages(_suites):
|
||||
try:
|
||||
domain_pkg = importlib.import_module(domain_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to import suite domain %s: %s", domain_name, exc)
|
||||
continue
|
||||
for benchmark_name in _iter_subpackages(domain_pkg):
|
||||
try:
|
||||
importlib.import_module(benchmark_name)
|
||||
imported.append(benchmark_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to import benchmark %s: %s", benchmark_name, exc
|
||||
)
|
||||
return imported
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
"""Test fixture suite — skipped by the auto-discovery walker (name starts with ``_``).
|
||||
|
||||
Imported explicitly by ``tests/core/test_registry.py`` to prove the
|
||||
register-on-import contract works without polluting the production
|
||||
benchmark list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
"""Demo benchmark — registers on import, used only by the registry tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from typing import Any
|
||||
|
||||
from ....core.registry import (
|
||||
Benchmark,
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
register,
|
||||
)
|
||||
|
||||
|
||||
class HelloBenchmark:
|
||||
suite: str = "_demo"
|
||||
name: str = "hello"
|
||||
headline: bool = False
|
||||
description: str = "Demo benchmark used by the registry test."
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("--echo", default="hi")
|
||||
|
||||
async def ingest(self, ctx: RunContext, **_opts: Any) -> None: # pragma: no cover
|
||||
return None
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact: # pragma: no cover
|
||||
return RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp="0",
|
||||
raw_path=ctx.benchmark_data_dir() / "raw.jsonl",
|
||||
metrics={"echo": opts.get("echo")},
|
||||
)
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
return ReportSection(
|
||||
title="Hello demo",
|
||||
headline=False,
|
||||
body_md="- runs: " + str(len(artifacts)),
|
||||
)
|
||||
|
||||
|
||||
register(HelloBenchmark())
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""Medical RAG benchmarks (MedXpertQA-MM headline + MIRAGE/CUREv1 secondary).
|
||||
|
||||
Subpackages register themselves with ``core.registry`` on import. The
|
||||
``suites/__init__.py`` discovery walker imports them automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""CUREv1 — secondary single-arm SurfSense retrieval measurement.
|
||||
|
||||
Source: https://huggingface.co/datasets/clinia/CUREv1
|
||||
Paper: https://arxiv.org/html/2412.06954v4
|
||||
|
||||
Pure retrieval benchmark — 10 medical disciplines, English/French/Spanish
|
||||
queries, expert-curated qrels (graded 0/1/2). The harness ingests the
|
||||
corpus, runs each query via SurfSense's ``/api/v1/new_chat``, parses
|
||||
chunk citations, maps them back to CUREv1 ``corpus-id``, and scores
|
||||
Recall@k / MRR / nDCG@10 against qrels.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .runner import CureBenchmark
|
||||
from ....core import registry as _registry
|
||||
|
||||
_registry.register(CureBenchmark())
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
"""CUREv1 ingestion.
|
||||
|
||||
For each (lang, discipline) requested, downloads the corpus split via
|
||||
``datasets.load_dataset(path="clinia/CUREv1", name="corpus", split=<discipline>)``,
|
||||
batches passages into ~5 MB markdown bundles, uploads them to
|
||||
SurfSense, polls until ``ready``, and persists the
|
||||
``corpus_id -> document_id`` map under
|
||||
``data/medical/maps/cure_corpus_map_<discipline>.jsonl``. A union map
|
||||
``cure_corpus_map.jsonl`` is also written so the runner can resolve
|
||||
citations across disciplines without juggling per-file paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_BATCH_SIZE_BYTES = 5 * 1024 * 1024
|
||||
|
||||
# 10 disciplines covered by the dataset card. We exhaustively list
|
||||
# them so a smoke test can default to one.
|
||||
DISCIPLINES = (
|
||||
"anesthesiology",
|
||||
"cardiology",
|
||||
"dermatology",
|
||||
"endocrinology",
|
||||
"gastroenterology",
|
||||
"hematology",
|
||||
"nephrology",
|
||||
"neurology",
|
||||
"obstetrics_gynecology",
|
||||
"psychiatry",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorpusPassage:
|
||||
corpus_id: str
|
||||
title: str
|
||||
text: str
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
title = (self.title or "").strip() or "Untitled"
|
||||
body = (self.text or "").strip()
|
||||
return f"# {title}\n\n_id: `{self.corpus_id}`_\n\n{body}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PassageBatch:
|
||||
path: Path
|
||||
corpus_ids: list[str]
|
||||
|
||||
|
||||
def _stream_corpus(discipline: str) -> Iterable[CorpusPassage]:
|
||||
"""Stream corpus rows for one discipline via the ``datasets`` library."""
|
||||
|
||||
from datasets import load_dataset # noqa: PLC0415
|
||||
|
||||
logger.info("Loading CUREv1 corpus for discipline=%s", discipline)
|
||||
ds = load_dataset(path="clinia/CUREv1", name="corpus", split=discipline)
|
||||
for row in ds:
|
||||
cid = str(row.get("_id") or "")
|
||||
if not cid:
|
||||
continue
|
||||
yield CorpusPassage(
|
||||
corpus_id=cid,
|
||||
title=str(row.get("title") or ""),
|
||||
text=str(row.get("text") or ""),
|
||||
)
|
||||
|
||||
|
||||
def _write_batches(
|
||||
passages: Iterable[CorpusPassage],
|
||||
*,
|
||||
out_dir: Path,
|
||||
discipline: str,
|
||||
batch_bytes: int = _BATCH_SIZE_BYTES,
|
||||
) -> list[PassageBatch]:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
batches: list[PassageBatch] = []
|
||||
current_buffer = io.StringIO()
|
||||
current_ids: list[str] = []
|
||||
current_bytes = 0
|
||||
batch_idx = 0
|
||||
|
||||
def _flush() -> None:
|
||||
nonlocal current_buffer, current_ids, current_bytes, batch_idx
|
||||
if not current_ids:
|
||||
return
|
||||
path = out_dir / f"cure_{discipline}_{batch_idx:04d}.md"
|
||||
path.write_text(current_buffer.getvalue(), encoding="utf-8")
|
||||
batches.append(PassageBatch(path=path, corpus_ids=current_ids))
|
||||
batch_idx += 1
|
||||
current_buffer = io.StringIO()
|
||||
current_ids = []
|
||||
current_bytes = 0
|
||||
|
||||
for passage in passages:
|
||||
chunk = passage.to_markdown() + "\n---\n\n"
|
||||
chunk_bytes = len(chunk.encode("utf-8"))
|
||||
if current_bytes + chunk_bytes > batch_bytes and current_ids:
|
||||
_flush()
|
||||
current_buffer.write(chunk)
|
||||
current_ids.append(passage.corpus_id)
|
||||
current_bytes += chunk_bytes
|
||||
_flush()
|
||||
return batches
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
disciplines: list[str] | None = None,
|
||||
max_per_discipline: int | None = None,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
disciplines = disciplines or list(DISCIPLINES)
|
||||
settings = settings or IngestSettings(use_vision_llm=False, processing_mode="basic")
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
batches_root = bench_dir / "batches"
|
||||
batches_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
docs_client = ctx.documents_client()
|
||||
union_map_path = ctx.maps_dir() / "cure_corpus_map.jsonl"
|
||||
union_map_fh = union_map_path.open("w", encoding="utf-8")
|
||||
# Header row records the ingest-time settings so the runner can
|
||||
# surface them in the report (see core/ingest_settings.py).
|
||||
union_map_fh.write(settings_header_line(settings) + "\n")
|
||||
try:
|
||||
for discipline in disciplines:
|
||||
try:
|
||||
passages_iter = _stream_corpus(discipline)
|
||||
if max_per_discipline is not None:
|
||||
passages_iter = _take(passages_iter, max_per_discipline)
|
||||
batches = _write_batches(
|
||||
passages_iter,
|
||||
out_dir=batches_root / discipline,
|
||||
discipline=discipline,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Skipping discipline %s: %s", discipline, exc)
|
||||
continue
|
||||
if not batches:
|
||||
logger.warning("Discipline %s produced 0 batches; skipping upload", discipline)
|
||||
continue
|
||||
logger.info(
|
||||
"Uploading %d batches for discipline %s", len(batches), discipline
|
||||
)
|
||||
upload_result = await docs_client.upload(
|
||||
files=[b.path for b in batches],
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
new_doc_ids = list(upload_result.document_ids)
|
||||
if new_doc_ids:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=new_doc_ids,
|
||||
timeout_s=3600.0,
|
||||
max_poll_s=15.0,
|
||||
)
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=new_doc_ids + upload_result.duplicate_document_ids,
|
||||
)
|
||||
title_to_doc = {s.title: s.document_id for s in statuses}
|
||||
|
||||
per_discipline_path = (
|
||||
ctx.maps_dir() / f"cure_corpus_map_{discipline}.jsonl"
|
||||
)
|
||||
with per_discipline_path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for batch in batches:
|
||||
doc_id = title_to_doc.get(batch.path.name)
|
||||
if doc_id is None:
|
||||
logger.warning("No document_id for batch %s", batch.path.name)
|
||||
continue
|
||||
for cid in batch.corpus_ids:
|
||||
record = {
|
||||
"corpus_id": cid,
|
||||
"document_id": doc_id,
|
||||
"discipline": discipline,
|
||||
}
|
||||
fh.write(json.dumps(record) + "\n")
|
||||
union_map_fh.write(json.dumps(record) + "\n")
|
||||
|
||||
chunks_map_path = ctx.maps_dir() / f"cure_chunk_map_{discipline}.jsonl"
|
||||
with chunks_map_path.open("w", encoding="utf-8") as fh:
|
||||
for doc_id in {title_to_doc.get(b.path.name) for b in batches} - {None}:
|
||||
try:
|
||||
chunks = await docs_client.list_chunks(int(doc_id))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to list chunks for doc_id=%s: %s", doc_id, exc
|
||||
)
|
||||
continue
|
||||
for chunk in chunks:
|
||||
fh.write(
|
||||
json.dumps(
|
||||
{
|
||||
"chunk_id": chunk.id,
|
||||
"document_id": doc_id,
|
||||
"discipline": discipline,
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
finally:
|
||||
union_map_fh.close()
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["cure"] = str(union_map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
logger.info("CUREv1 ingestion complete; union map at %s", union_map_path)
|
||||
|
||||
|
||||
def _take(it: Iterable, n: int) -> Iterable:
|
||||
yielded = 0
|
||||
for x in it:
|
||||
if yielded >= n:
|
||||
return
|
||||
yield x
|
||||
yielded += 1
|
||||
|
||||
|
||||
__all__ = ["DISCIPLINES", "CorpusPassage", "PassageBatch", "run_ingest"]
|
||||
|
|
@ -0,0 +1,397 @@
|
|||
"""CUREv1 runner — single-arm SurfSense retrieval scoring.
|
||||
|
||||
For each query we ask SurfSense via ``/api/v1/new_chat`` (no
|
||||
``mentioned_document_ids``) and parse chunk citations from the
|
||||
streamed answer. Cited ``chunk_id`` → ``document_id`` (chunk map) →
|
||||
``corpus_id`` (corpus map). The resulting ranked list is scored
|
||||
against the dataset's qrels.
|
||||
|
||||
The prompt nudges the model to surface its supporting passages via
|
||||
SurfSense's standard ``[citation:CHUNK_ID]`` format (already required
|
||||
by the agent system prompt), so we recover retrieval ordering from
|
||||
the answer text without needing a separate retrieval API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
is_settings_header,
|
||||
read_settings_header,
|
||||
)
|
||||
from ....core.metrics.retrieval import score_run
|
||||
from ....core.registry import (
|
||||
Benchmark,
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_PROMPT = """\
|
||||
You are a medical literature retrieval assistant for the question
|
||||
below. Identify the top passages from the knowledge base that best
|
||||
answer it and cite each one in the standard format
|
||||
[citation:CHUNK_ID]. List as many citations as are useful, ordered
|
||||
from most to least relevant. Provide a one-sentence justification
|
||||
for each citation.
|
||||
|
||||
Query: {query}
|
||||
"""
|
||||
|
||||
|
||||
_DESCRIPTION = "CUREv1 retrieval (single-arm SurfSense): Recall@k / MRR / nDCG@10."
|
||||
|
||||
# CUREv1 corpus is text-only markdown bundles; vision LLM at ingest
|
||||
# is wasted by default but the operator can flip it via CLI for an
|
||||
# A/B comparison.
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CureQuery:
|
||||
qid: str
|
||||
text: str
|
||||
discipline: str
|
||||
|
||||
|
||||
def _load_chunk_map(maps_dir: Path) -> dict[int, int]:
|
||||
"""Union all ``cure_chunk_map_<discipline>.jsonl`` into one dict."""
|
||||
|
||||
out: dict[int, int] = {}
|
||||
for path in sorted(maps_dir.glob("cure_chunk_map_*.jsonl")):
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
continue
|
||||
try:
|
||||
out[int(row["chunk_id"])] = int(row["document_id"])
|
||||
except (KeyError, TypeError, ValueError):
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
def _load_doc_to_corpus(maps_dir: Path) -> dict[int, list[str]]:
|
||||
"""Map ``document_id -> [corpus_id, ...]`` from the union map.
|
||||
|
||||
Multiple corpus passages may live in one batched markdown
|
||||
document, so each doc_id maps to a list. Citation ordering of the
|
||||
first occurrence is preserved.
|
||||
"""
|
||||
|
||||
out: dict[int, list[str]] = defaultdict(list)
|
||||
union_path = maps_dir / "cure_corpus_map.jsonl"
|
||||
if not union_path.exists():
|
||||
return out
|
||||
with union_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
continue
|
||||
try:
|
||||
out[int(row["document_id"])].append(str(row["corpus_id"]))
|
||||
except (KeyError, TypeError, ValueError):
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
def _load_queries(*, lang: str, disciplines: list[str], sample_n: int | None) -> list[CureQuery]:
|
||||
from datasets import load_dataset # noqa: PLC0415
|
||||
|
||||
out: list[CureQuery] = []
|
||||
for discipline in disciplines:
|
||||
try:
|
||||
ds = load_dataset(path="clinia/CUREv1", name=f"queries-{lang}", split=discipline)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Skipping queries for %s/%s: %s", lang, discipline, exc)
|
||||
continue
|
||||
for row in ds:
|
||||
qid = str(row.get("_id") or "")
|
||||
text = str(row.get("text") or "")
|
||||
if not qid or not text:
|
||||
continue
|
||||
out.append(CureQuery(qid=qid, text=text, discipline=discipline))
|
||||
out.sort(key=lambda q: (q.discipline, q.qid))
|
||||
if sample_n is not None and sample_n > 0:
|
||||
# Stratified-by-discipline slice.
|
||||
per_d = max(1, sample_n // max(1, len(disciplines)))
|
||||
sliced: list[CureQuery] = []
|
||||
counter: dict[str, int] = defaultdict(int)
|
||||
for q in out:
|
||||
if counter[q.discipline] >= per_d:
|
||||
continue
|
||||
sliced.append(q)
|
||||
counter[q.discipline] += 1
|
||||
if len(sliced) >= sample_n:
|
||||
break
|
||||
out = sliced
|
||||
return out
|
||||
|
||||
|
||||
def _load_qrels(*, disciplines: list[str]) -> dict[str, dict[str, float]]:
|
||||
from datasets import load_dataset # noqa: PLC0415
|
||||
|
||||
out: dict[str, dict[str, float]] = defaultdict(dict)
|
||||
for discipline in disciplines:
|
||||
try:
|
||||
ds = load_dataset(path="clinia/CUREv1", name="qrels", split=discipline)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Skipping qrels for %s: %s", discipline, exc)
|
||||
continue
|
||||
for row in ds:
|
||||
qid = str(row.get("query-id") or row.get("query_id") or "")
|
||||
cid = str(row.get("corpus-id") or row.get("corpus_id") or "")
|
||||
score = row.get("score")
|
||||
if not qid or not cid or score is None:
|
||||
continue
|
||||
try:
|
||||
out[qid][cid] = float(score)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
async def _gather_with_limit(coros, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(c):
|
||||
async with sem:
|
||||
return await c
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
class CureBenchmark:
|
||||
suite: str = "medical"
|
||||
name: str = "cure"
|
||||
headline: bool = False
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("--lang", default="en", choices=("en", "es", "fr"))
|
||||
parser.add_argument("--discipline", default=None,
|
||||
help="Restrict to one discipline (default: all ingested).")
|
||||
parser.add_argument("--n", dest="sample_n", type=int, default=None)
|
||||
parser.add_argument("--concurrency", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--max-passages-per-discipline", type=int, default=None,
|
||||
help="(ingest only) cap corpus rows per discipline for smoke testing.",
|
||||
)
|
||||
# Per-upload knobs forwarded to /documents/fileupload at ingest;
|
||||
# ignored at run-time (runner reads resolved settings from the
|
||||
# union-map header).
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import DISCIPLINES, run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
disciplines=list(DISCIPLINES),
|
||||
max_per_discipline=opts.get("max_passages_per_discipline"),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
lang = opts.get("lang") or "en"
|
||||
discipline_filter = opts.get("discipline")
|
||||
sample_n = opts.get("sample_n")
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
|
||||
maps_dir = ctx.maps_dir()
|
||||
chunk_to_doc = _load_chunk_map(maps_dir)
|
||||
doc_to_corpus = _load_doc_to_corpus(maps_dir)
|
||||
ingest_settings = read_settings_header(maps_dir / "cure_corpus_map.jsonl")
|
||||
if not chunk_to_doc or not doc_to_corpus:
|
||||
raise RuntimeError(
|
||||
"CUREv1 not ingested for this suite. Run "
|
||||
"`python -m surfsense_evals ingest medical cure` first."
|
||||
)
|
||||
|
||||
# Disciplines to query are determined by the per-discipline maps
|
||||
# actually present (either user-filtered or whatever was ingested).
|
||||
ingested_disciplines = sorted({
|
||||
row_disc
|
||||
for path in maps_dir.glob("cure_corpus_map_*.jsonl")
|
||||
for row_disc in [path.stem[len("cure_corpus_map_"):]]
|
||||
})
|
||||
if discipline_filter:
|
||||
disciplines = [discipline_filter]
|
||||
else:
|
||||
disciplines = ingested_disciplines or ["dermatology"]
|
||||
|
||||
queries = _load_queries(lang=lang, disciplines=disciplines, sample_n=sample_n)
|
||||
if not queries:
|
||||
raise RuntimeError(
|
||||
f"No CUREv1 queries matched lang={lang!r} disciplines={disciplines!r}."
|
||||
)
|
||||
qrels = _load_qrels(disciplines=disciplines)
|
||||
logger.info(
|
||||
"CUREv1: %d queries / %d qrels across disciplines %s",
|
||||
len(queries),
|
||||
len(qrels),
|
||||
disciplines,
|
||||
)
|
||||
|
||||
arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
async def _ask(q: CureQuery) -> ArmResult:
|
||||
return await arm.answer(
|
||||
ArmRequest(
|
||||
question_id=f"{q.discipline}::{q.qid}",
|
||||
prompt=_PROMPT.format(query=q.text.strip()),
|
||||
)
|
||||
)
|
||||
|
||||
results: list[ArmResult] = await _gather_with_limit(
|
||||
(_ask(q) for q in queries), concurrency=concurrency
|
||||
)
|
||||
|
||||
per_query_retrieved: dict[str, list[str]] = {}
|
||||
for q, res in zip(queries, results):
|
||||
chunk_ids: list[int] = []
|
||||
seen: set[int] = set()
|
||||
for citation in res.citations:
|
||||
if citation.get("kind") != "chunk":
|
||||
continue
|
||||
cid = int(citation.get("chunk_id"))
|
||||
if cid in seen:
|
||||
continue
|
||||
chunk_ids.append(cid)
|
||||
seen.add(cid)
|
||||
corpus_ids: list[str] = []
|
||||
seen_corpus: set[str] = set()
|
||||
for cid in chunk_ids:
|
||||
doc_id = chunk_to_doc.get(cid)
|
||||
if doc_id is None:
|
||||
continue
|
||||
for corpus_id in doc_to_corpus.get(doc_id, []):
|
||||
if corpus_id in seen_corpus:
|
||||
continue
|
||||
corpus_ids.append(corpus_id)
|
||||
seen_corpus.add(corpus_id)
|
||||
per_query_retrieved[q.qid] = corpus_ids
|
||||
|
||||
scores = score_run(
|
||||
per_query_retrieved=per_query_retrieved,
|
||||
per_query_qrels=qrels,
|
||||
ks=(1, 5, 10, 32),
|
||||
ndcg_k=10,
|
||||
)
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, res in zip(queries, results):
|
||||
fh.write(
|
||||
json.dumps(
|
||||
{
|
||||
"discipline": q.discipline,
|
||||
"qid": q.qid,
|
||||
"lang": lang,
|
||||
"retrieved_corpus_ids": per_query_retrieved.get(q.qid, []),
|
||||
**res.to_jsonl(),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
metrics = scores.to_dict()
|
||||
metrics["lang"] = lang
|
||||
metrics["disciplines"] = disciplines
|
||||
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_queries": len(queries),
|
||||
"lang": lang,
|
||||
"disciplines": disciplines,
|
||||
"concurrency": concurrency,
|
||||
"provider_model": ctx.provider_model,
|
||||
"ingest_settings": ingest_settings,
|
||||
},
|
||||
)
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="CUREv1 — single-arm SurfSense retrieval",
|
||||
headline=False,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
m = latest.metrics
|
||||
recall = m.get("recall_at_k", {})
|
||||
lines: list[str] = [
|
||||
format_ingest_settings_md(latest.extra.get("ingest_settings")),
|
||||
f"- Language: {m.get('lang', '?')}",
|
||||
f"- Disciplines: {', '.join(m.get('disciplines', []) or ['?'])}",
|
||||
f"- n_queries (after qrels intersection): {m.get('n_queries', 0)}",
|
||||
]
|
||||
for k in (1, 5, 10, 32):
|
||||
v = recall.get(str(k), recall.get(k))
|
||||
if v is not None:
|
||||
lines.append(f"- Recall@{k}: {float(v):.3f}")
|
||||
lines.append(f"- MRR: {float(m.get('mrr', 0.0)):.3f}")
|
||||
lines.append(f"- nDCG@10: {float(m.get('ndcg_at_10', 0.0)):.3f}")
|
||||
return ReportSection(
|
||||
title="CUREv1 — single-arm SurfSense retrieval",
|
||||
headline=False,
|
||||
body_md="\n".join(lines),
|
||||
body_json=m,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["CureBenchmark", "CureQuery"]
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
"""MedXpertQA-MM — multimodal medical exam head-to-head (medical suite headline).
|
||||
|
||||
Source: https://huggingface.co/datasets/TsinghuaC3I/MedXpertQA
|
||||
Paper: https://arxiv.org/abs/2501.18362 (ICML 2025)
|
||||
|
||||
* MM subset: ~2,000 expert-level exam questions with diverse medical
|
||||
images (radiology, dermatology, pathology, ECGs, gross specimens,
|
||||
fundus photos) and structured patient information embedded in the
|
||||
question stem.
|
||||
* 5 answer choices per MM question (A–E).
|
||||
* USMLE / COMLEX / 17 specialty board sources; rigorously filtered
|
||||
and reviewed by physicians.
|
||||
|
||||
Real diagnostic images carry signal that text-only patient charts
|
||||
cannot (e.g. CT scans, dermoscopy), so this benchmark exercises the
|
||||
full vision RAG pipeline end-to-end against a vision-capable model
|
||||
fed the same PDF natively.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import registry as _registry
|
||||
from .runner import MedXpertQAMMBenchmark
|
||||
|
||||
_registry.register(MedXpertQAMMBenchmark())
|
||||
|
|
@ -0,0 +1,394 @@
|
|||
"""MedXpertQA-MM ingestion.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Pull ``MM/test.jsonl`` (and optionally ``MM/dev.jsonl``) plus
|
||||
``images.zip`` from
|
||||
``hf://datasets/TsinghuaC3I/MedXpertQA``. Cache under
|
||||
``<data_dir>/medical/medxpertqa/``.
|
||||
2. Extract ``images.zip`` once into ``<data_dir>/medical/medxpertqa/images/``.
|
||||
3. Render one PDF per MM question (text question + structured patient
|
||||
info embedded in the question stem + each image flowable + answer
|
||||
options). Output: ``<data_dir>/medical/medxpertqa/pdfs/<id>.pdf``.
|
||||
4. Upload each PDF to SurfSense with ``use_vision_llm=True``; persist
|
||||
``id -> document_id`` in
|
||||
``<data_dir>/medical/maps/medxpertqa_doc_map.jsonl``.
|
||||
|
||||
Both arms then receive byte-identical PDFs. The native arm sends the
|
||||
PDF directly to OpenRouter; SurfSense ingests via its own vision
|
||||
pipeline and the runner queries with ``mentioned_document_ids=[...]``
|
||||
to scope retrieval to the question's PDF.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.pdf import PdfImage, render_pdf_with_images
|
||||
from ....core.registry import RunContext
|
||||
from .prompt import format_options
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HF_REPO_ID = "TsinghuaC3I/MedXpertQA"
|
||||
HF_REPO_TYPE = "dataset"
|
||||
|
||||
|
||||
def _hf_hub_download(*args, **kwargs):
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
return hf_hub_download(*args, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MedXpertQuestion:
|
||||
qid: str # e.g. "MM-26"
|
||||
question: str # full question text (case + ask)
|
||||
options: dict[str, str] # A-E
|
||||
label: str # "A".."E"
|
||||
image_files: list[str] # filenames inside images.zip
|
||||
medical_task: str
|
||||
body_system: str
|
||||
question_type: str
|
||||
split: str # "test" or "dev"
|
||||
|
||||
|
||||
def _load_jsonl(path: Path, *, split: str) -> list[MedXpertQuestion]:
|
||||
out: list[MedXpertQuestion] = []
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
for raw_line in fh:
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
qid = str(row.get("id") or "").strip()
|
||||
question = str(row.get("question") or "").strip()
|
||||
options = row.get("options") or {}
|
||||
label = str(row.get("label") or "").strip().upper()
|
||||
if not qid or not question or not isinstance(options, dict) or not label:
|
||||
continue
|
||||
opts = {str(k).strip().upper(): str(v).strip() for k, v in options.items()}
|
||||
images = row.get("images") or []
|
||||
if not isinstance(images, list):
|
||||
images = []
|
||||
out.append(MedXpertQuestion(
|
||||
qid=qid,
|
||||
question=question,
|
||||
options=opts,
|
||||
label=label,
|
||||
image_files=[str(x).strip() for x in images if str(x).strip()],
|
||||
medical_task=str(row.get("medical_task") or "").strip(),
|
||||
body_system=str(row.get("body_system") or "").strip(),
|
||||
question_type=str(row.get("question_type") or "").strip(),
|
||||
split=split,
|
||||
))
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image archive helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_images_extracted(images_zip: Path, images_dir: Path) -> None:
|
||||
"""Extract images.zip once, tolerantly handle re-runs."""
|
||||
|
||||
marker = images_dir / ".extracted_ok"
|
||||
if marker.exists():
|
||||
return
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Extracting MedXpertQA images.zip -> %s", images_dir)
|
||||
with zipfile.ZipFile(images_zip) as zf:
|
||||
zf.extractall(images_dir)
|
||||
marker.write_text("ok\n", encoding="utf-8")
|
||||
|
||||
|
||||
def _resolve_image_path(image_filename: str, images_dir: Path) -> Path | None:
|
||||
"""Find a question's image in the (possibly nested) extract directory.
|
||||
|
||||
The zip layout sometimes nests under ``images/`` and sometimes
|
||||
flat — handle both.
|
||||
"""
|
||||
|
||||
direct = images_dir / image_filename
|
||||
if direct.exists():
|
||||
return direct
|
||||
nested = images_dir / "images" / image_filename
|
||||
if nested.exists():
|
||||
return nested
|
||||
# Last-ditch: glob recursively (slow but correct for unusual layouts).
|
||||
matches = list(images_dir.rglob(image_filename))
|
||||
return matches[0] if matches else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PDF rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _render_question_pdf(
|
||||
q: MedXpertQuestion,
|
||||
*,
|
||||
images_dir: Path,
|
||||
pdfs_dir: Path,
|
||||
) -> tuple[Path, list[str]]:
|
||||
"""Render one MedXpertQA question into a PDF.
|
||||
|
||||
Layout:
|
||||
Title: MedXpertQA — <qid> (medical_task / body_system)
|
||||
Section 1 (case): full question text
|
||||
Section 1 images: each image flowable + caption
|
||||
Section 2 (options): A) ... B) ... C) ... D) ... E) ...
|
||||
|
||||
Returns (pdf_path, missing_images) so the caller can warn on
|
||||
questions where some image files weren't found.
|
||||
"""
|
||||
|
||||
out_path = pdfs_dir / f"{q.qid}.pdf"
|
||||
images: list[PdfImage] = []
|
||||
missing: list[str] = []
|
||||
for fname in q.image_files:
|
||||
resolved = _resolve_image_path(fname, images_dir)
|
||||
if resolved is None:
|
||||
missing.append(fname)
|
||||
continue
|
||||
images.append(PdfImage(path=resolved, caption=f"Image: {fname}", max_width_in=5.5))
|
||||
|
||||
title_meta_parts = []
|
||||
if q.medical_task:
|
||||
title_meta_parts.append(q.medical_task)
|
||||
if q.body_system:
|
||||
title_meta_parts.append(q.body_system)
|
||||
if q.question_type:
|
||||
title_meta_parts.append(q.question_type)
|
||||
title_suffix = f" ({' / '.join(title_meta_parts)})" if title_meta_parts else ""
|
||||
|
||||
sections = [
|
||||
("Clinical case", q.question, images),
|
||||
("Answer choices", format_options(q.options), None),
|
||||
]
|
||||
render_pdf_with_images(
|
||||
title=f"MedXpertQA-MM {q.qid}{title_suffix}",
|
||||
sections=sections,
|
||||
output_path=out_path,
|
||||
)
|
||||
return out_path, missing
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _upload_pdfs(
|
||||
ctx: RunContext,
|
||||
pdf_paths: Iterable[Path],
|
||||
*,
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
) -> dict[str, int]:
|
||||
docs_client = ctx.documents_client()
|
||||
name_to_id: dict[str, int] = {}
|
||||
pdf_list = list(pdf_paths)
|
||||
for batch_start in range(0, len(pdf_list), batch_size):
|
||||
batch = pdf_list[batch_start:batch_start + batch_size]
|
||||
result = await docs_client.upload(
|
||||
files=batch,
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
|
||||
if all_ids:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=result.document_ids,
|
||||
timeout_s=1800.0,
|
||||
)
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=all_ids,
|
||||
)
|
||||
for s in statuses:
|
||||
name_to_id[s.title] = s.document_id
|
||||
logger.info(
|
||||
"Uploaded MedXpertQA batch %d-%d: %d new, %d duplicate",
|
||||
batch_start, batch_start + len(batch),
|
||||
len(result.document_ids), len(result.duplicate_document_ids),
|
||||
)
|
||||
return name_to_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
split: str = "test",
|
||||
max_questions: int | None = None,
|
||||
upload_batch_size: int = 8,
|
||||
skip_upload: bool = False,
|
||||
include_dev: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
"""Ingest MedXpertQA-MM into the medical suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
split : 'test' (default), 'dev', or 'both'
|
||||
Which subset to render + upload.
|
||||
max_questions : int | None
|
||||
Cap on number of questions ingested (handy for fast iteration).
|
||||
upload_batch_size : int
|
||||
PDFs per ``fileupload`` call.
|
||||
skip_upload : bool
|
||||
Render PDFs locally but don't push to SurfSense.
|
||||
include_dev : bool
|
||||
Convenience: equivalent to ``split='both'``.
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(use_vision_llm=True, processing_mode="basic")
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
images_zip_local = bench_dir / "images.zip"
|
||||
images_dir = bench_dir / "images"
|
||||
pdfs_dir = bench_dir / "pdfs"
|
||||
pdfs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_cache = bench_dir / ".hf_cache"
|
||||
hf_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Step 1: download jsonl(s)
|
||||
splits_to_load: list[str] = []
|
||||
if split == "both" or include_dev:
|
||||
splits_to_load = ["dev", "test"]
|
||||
elif split in {"dev", "test"}:
|
||||
splits_to_load = [split]
|
||||
else:
|
||||
raise ValueError(f"Unknown split {split!r}; use 'test' / 'dev' / 'both'")
|
||||
|
||||
questions: list[MedXpertQuestion] = []
|
||||
for sp in splits_to_load:
|
||||
rel = f"MM/{sp}.jsonl"
|
||||
local = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=rel,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(hf_cache),
|
||||
)
|
||||
loaded = _load_jsonl(Path(local), split=sp)
|
||||
questions.extend(loaded)
|
||||
logger.info("Loaded %d MedXpertQA-MM questions from %s split", len(loaded), sp)
|
||||
|
||||
if max_questions is not None and max_questions > 0:
|
||||
questions = questions[:max_questions]
|
||||
if not questions:
|
||||
raise RuntimeError("No MedXpertQA-MM questions loaded; check the split argument.")
|
||||
|
||||
# Step 2: download images.zip + extract once
|
||||
if not images_zip_local.exists():
|
||||
local_zip = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename="images.zip",
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(hf_cache),
|
||||
)
|
||||
# Materialise into bench_dir so the path is stable.
|
||||
try:
|
||||
from os import link as _link
|
||||
_link(local_zip, images_zip_local)
|
||||
except OSError:
|
||||
from shutil import copy2
|
||||
copy2(local_zip, images_zip_local)
|
||||
_ensure_images_extracted(images_zip_local, images_dir)
|
||||
|
||||
# Step 3: render PDFs
|
||||
pdf_paths: dict[str, Path] = {}
|
||||
missing_image_count = 0
|
||||
for i, q in enumerate(questions, start=1):
|
||||
try:
|
||||
pdf, missing = _render_question_pdf(q, images_dir=images_dir, pdfs_dir=pdfs_dir)
|
||||
pdf_paths[q.qid] = pdf
|
||||
if missing:
|
||||
missing_image_count += len(missing)
|
||||
logger.debug("qid=%s missing %d images: %s", q.qid, len(missing), missing)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to render MedXpertQA PDF for %s: %s", q.qid, exc)
|
||||
if i % 50 == 0:
|
||||
logger.info(" ... rendered %d / %d PDFs", i, len(questions))
|
||||
if missing_image_count:
|
||||
logger.warning(
|
||||
"MedXpertQA: %d image references could not be resolved on disk "
|
||||
"(rendered PDFs may be missing some images).",
|
||||
missing_image_count,
|
||||
)
|
||||
|
||||
# Step 4: upload
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("MedXpertQA: --skip-upload set; skipping SurfSense ingestion")
|
||||
else:
|
||||
logger.info("MedXpertQA upload settings: %s", settings.render_label())
|
||||
name_to_id = await _upload_pdfs(
|
||||
ctx,
|
||||
pdf_paths.values(),
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Step 5: persist manifest + questions
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
with questions_jsonl.open("w", encoding="utf-8") as fh:
|
||||
for q in questions:
|
||||
fh.write(json.dumps({
|
||||
"qid": q.qid,
|
||||
"question": q.question,
|
||||
"options": q.options,
|
||||
"label": q.label,
|
||||
"image_files": q.image_files,
|
||||
"medical_task": q.medical_task,
|
||||
"body_system": q.body_system,
|
||||
"question_type": q.question_type,
|
||||
"split": q.split,
|
||||
}) + "\n")
|
||||
logger.info("Wrote %d MedXpertQA questions to %s", len(questions), questions_jsonl)
|
||||
|
||||
map_path = ctx.maps_dir() / "medxpertqa_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
# Header line records the resolved ingest settings
|
||||
# (see core/ingest_settings.py).
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for q in questions:
|
||||
local = pdf_paths.get(q.qid)
|
||||
if local is None:
|
||||
continue
|
||||
fh.write(json.dumps({
|
||||
"qid": q.qid,
|
||||
"document_id": name_to_id.get(local.name),
|
||||
"pdf_path": str(local),
|
||||
"n_images": len(q.image_files),
|
||||
"split": q.split,
|
||||
}) + "\n")
|
||||
logger.info("Wrote MedXpertQA doc map to %s", map_path)
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["medxpertqa"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
|
||||
__all__ = ["MedXpertQuestion", "run_ingest"]
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""MedXpertQA-MM prompt.
|
||||
|
||||
Mirrors the upstream paper's evaluation prompt (Zuo et al., ICML 2025
|
||||
§3.4): present case + 5 options A-E, ask for a single letter answer.
|
||||
We also instruct the model to use the embedded images explicitly,
|
||||
since the whole point of the MM subset is that the answer depends on
|
||||
visual evidence (radiology / dermoscopy / pathology / ECG, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
ANSWER_LETTERS = ("A", "B", "C", "D", "E")
|
||||
|
||||
|
||||
_PROMPT = """\
|
||||
You are a board-certified physician. The following exam question
|
||||
includes a clinical case and one or more medical images (radiology,
|
||||
dermatology, pathology, ECG, etc.). Use BOTH the text and the images
|
||||
to choose the best answer. Do not rely on memorisation of the case;
|
||||
read the images carefully — they often determine the correct answer.
|
||||
|
||||
Case + question:
|
||||
{question}
|
||||
|
||||
Answer choices:
|
||||
{options_block}
|
||||
|
||||
Respond on a single line in the format `Answer: X` where X is one of
|
||||
A, B, C, D, or E.
|
||||
"""
|
||||
|
||||
|
||||
def format_options(options: Mapping[str, str]) -> str:
|
||||
"""Render the ``A) ... E) ...`` options block."""
|
||||
|
||||
parts: list[str] = []
|
||||
for letter in ANSWER_LETTERS:
|
||||
text = options.get(letter)
|
||||
if text is None or str(text).strip() == "":
|
||||
continue
|
||||
parts.append(f"{letter}) {str(text).strip()}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def build_prompt(question: str, options: Mapping[str, str]) -> str:
|
||||
return _PROMPT.format(
|
||||
question=question.strip(),
|
||||
options_block=format_options(options),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ANSWER_LETTERS", "build_prompt", "format_options"]
|
||||
|
|
@ -0,0 +1,681 @@
|
|||
"""MedXpertQA-MM runner — Native PDF (vision) vs SurfSense (vision RAG).
|
||||
|
||||
Headline benchmark for the medical suite.
|
||||
|
||||
* Native arm reads the rendered PDF (case + images + options) via
|
||||
OpenRouter ``chat/completions`` + the file-parser plugin.
|
||||
* SurfSense arm queries ``POST /api/v1/new_chat`` scoped via
|
||||
``mentioned_document_ids=[doc_id]`` to the same per-question PDF.
|
||||
|
||||
Operational notes:
|
||||
|
||||
* PDFs contain real images (radiology, dermoscopy, pathology, ECGs).
|
||||
Operator must pin a vision-capable model via
|
||||
``setup --provider-model anthropic/claude-sonnet-4.5`` (or similar);
|
||||
the runner emits a warning if a known text-only slug is pinned.
|
||||
* MedXpertQA tags ``medical_task`` (Diagnosis / Treatment / Basic
|
||||
Medicine) and ``body_system`` (Cardiovascular / Lymphatic / …)
|
||||
directly on every row; we slice the report by both.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, NativePdfArm, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
is_settings_header,
|
||||
)
|
||||
from ....core.metrics.comparison import (
|
||||
bootstrap_delta_ci,
|
||||
mcnemar_test,
|
||||
paired_aggregate,
|
||||
)
|
||||
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci
|
||||
from ....core.providers.openrouter_pdf import OpenRouterPdfProvider, PdfEngine
|
||||
from ....core.registry import (
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
)
|
||||
from ....core.scenarios import format_scenario_md
|
||||
from .prompt import ANSWER_LETTERS, build_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_TEXT_ONLY_HINTS = ("gpt-5.4-mini", "gpt-3.5", "text-only", "instruct-")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MXQuestion:
|
||||
qid: str
|
||||
question: str
|
||||
options: dict[str, str]
|
||||
label: str
|
||||
medical_task: str
|
||||
body_system: str
|
||||
question_type: str
|
||||
split: str
|
||||
n_images: int
|
||||
pdf_path: Path
|
||||
document_id: int | None
|
||||
|
||||
|
||||
def _load_doc_map(map_path: Path) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
|
||||
"""Read the doc map JSONL.
|
||||
|
||||
Returns ``(rows, settings)`` where ``settings`` is the
|
||||
``__settings__`` header blob (or ``{}`` for legacy maps).
|
||||
"""
|
||||
|
||||
rows: dict[str, dict[str, Any]] = {}
|
||||
settings: dict[str, Any] = {}
|
||||
with map_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
settings = dict(row["__settings__"])
|
||||
continue
|
||||
rows[str(row["qid"])] = row
|
||||
return rows, settings
|
||||
|
||||
|
||||
def _load_questions(
|
||||
questions_jsonl: Path,
|
||||
doc_map: dict[str, dict[str, Any]],
|
||||
*,
|
||||
split_filter: str | None,
|
||||
task_filter: str | None,
|
||||
body_filter: str | None,
|
||||
require_images: bool,
|
||||
sample_n: int | None,
|
||||
) -> list[MXQuestion]:
|
||||
out: list[MXQuestion] = []
|
||||
with questions_jsonl.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
qid = str(row.get("qid") or "").strip()
|
||||
if not qid:
|
||||
continue
|
||||
if split_filter and split_filter != "all" and row.get("split") != split_filter:
|
||||
continue
|
||||
if task_filter and task_filter != "all" and row.get("medical_task") != task_filter:
|
||||
continue
|
||||
if body_filter and body_filter != "all" and row.get("body_system") != body_filter:
|
||||
continue
|
||||
map_row = doc_map.get(qid)
|
||||
if map_row is None:
|
||||
logger.debug("No doc-map entry for %s; skipping", qid)
|
||||
continue
|
||||
n_images = int(map_row.get("n_images", 0))
|
||||
if require_images and n_images <= 0:
|
||||
continue
|
||||
out.append(MXQuestion(
|
||||
qid=qid,
|
||||
question=str(row.get("question") or ""),
|
||||
options={str(k).upper(): str(v) for k, v in (row.get("options") or {}).items()},
|
||||
label=str(row.get("label") or "").strip().upper(),
|
||||
medical_task=str(row.get("medical_task") or "").strip(),
|
||||
body_system=str(row.get("body_system") or "").strip(),
|
||||
question_type=str(row.get("question_type") or "").strip(),
|
||||
split=str(row.get("split") or ""),
|
||||
n_images=n_images,
|
||||
pdf_path=Path(map_row["pdf_path"]),
|
||||
document_id=map_row.get("document_id"),
|
||||
))
|
||||
out.sort(key=lambda q: (q.split, q.qid))
|
||||
if sample_n is not None and sample_n > 0:
|
||||
out = out[:sample_n]
|
||||
return out
|
||||
|
||||
|
||||
async def _gather_with_limit(coros: Iterable, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(coro):
|
||||
async with sem:
|
||||
return await coro
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
_DESCRIPTION = (
|
||||
"MedXpertQA-MM (~2,000 multimodal medical exam questions, 5 options, with images) — "
|
||||
"Native PDF (vision) vs SurfSense (vision RAG) head-to-head."
|
||||
)
|
||||
|
||||
# MedXpertQA-MM PDFs embed clinical images; vision LLM at ingest is
|
||||
# the whole point. Operators can flip ``--no-vision-llm`` to measure
|
||||
# how much we degrade without it (likely material).
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=True,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
class MedXpertQAMMBenchmark:
|
||||
"""Multimodal medical exam head-to-head."""
|
||||
|
||||
suite: str = "medical"
|
||||
name: str = "medxpertqa"
|
||||
headline: bool = True # The medical suite headline.
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--split", default="test", choices=["test", "dev", "all"],
|
||||
help="Which MedXpertQA-MM split to run (default: test).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", default="all",
|
||||
help="Filter by medical_task value (e.g. Diagnosis, Treatment, Basic Medicine).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--body-system", dest="body_filter", default="all",
|
||||
help="Filter by body_system value (e.g. Cardiovascular, Lymphatic).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--require-images", dest="require_images", action="store_true",
|
||||
help="Skip rare MM rows that ended up with zero resolvable images.",
|
||||
)
|
||||
parser.add_argument("--n", dest="sample_n", type=int, default=None,
|
||||
help="Run only the first N questions after filters apply.")
|
||||
parser.add_argument("--concurrency", type=int, default=4,
|
||||
help="Parallel question workers per arm.")
|
||||
parser.add_argument("--no-mentions", dest="no_mentions", action="store_true",
|
||||
help="SurfSense arm: skip mentioned_document_ids (unscoped retrieval).")
|
||||
parser.add_argument(
|
||||
"--pdf-engine", default="native",
|
||||
choices=[e.value for e in PdfEngine],
|
||||
help="OpenRouter file-parser engine for the native arm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-output-tokens", type=int, default=512,
|
||||
help="Cap on completion length for both arms.",
|
||||
)
|
||||
# Ingest-only knobs (forwarded by the CLI to ingest.run_ingest).
|
||||
parser.add_argument(
|
||||
"--max-questions", dest="max_questions", type=int, default=None,
|
||||
help="(ingest only) cap on number of MM questions to render + upload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-batch-size", dest="upload_batch_size", type=int, default=8,
|
||||
help="(ingest only) PDFs per fileupload call.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-upload", dest="skip_upload", action="store_true",
|
||||
help="(ingest only) render PDFs locally but don't push to SurfSense.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-dev", dest="include_dev", action="store_true",
|
||||
help="(ingest only) shorthand for --split all.",
|
||||
)
|
||||
# Per-upload knobs forwarded to /documents/fileupload at ingest;
|
||||
# ignored at run-time (runner reads the resolved settings out of
|
||||
# the doc-map manifest header).
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
split=opts.get("split") or "test",
|
||||
max_questions=opts.get("max_questions"),
|
||||
upload_batch_size=int(opts.get("upload_batch_size") or 8),
|
||||
skip_upload=bool(opts.get("skip_upload", False)),
|
||||
include_dev=bool(opts.get("include_dev", False)),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
split_filter = opts.get("split") or "test"
|
||||
task_filter = opts.get("task") or "all"
|
||||
body_filter = opts.get("body_filter") or "all"
|
||||
require_images = bool(opts.get("require_images"))
|
||||
sample_n = opts.get("sample_n")
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
no_mentions = bool(opts.get("no_mentions"))
|
||||
pdf_engine_name = opts.get("pdf_engine") or "native"
|
||||
max_output_tokens = int(opts.get("max_output_tokens") or 512)
|
||||
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
map_path = ctx.maps_dir() / "medxpertqa_doc_map.jsonl"
|
||||
if not questions_jsonl.exists() or not map_path.exists():
|
||||
raise RuntimeError(
|
||||
"MedXpertQA-MM not ingested for this suite. Run "
|
||||
"`python -m surfsense_evals ingest medical medxpertqa` first."
|
||||
)
|
||||
|
||||
doc_map, ingest_settings = _load_doc_map(map_path)
|
||||
questions = _load_questions(
|
||||
questions_jsonl, doc_map,
|
||||
split_filter=split_filter,
|
||||
task_filter=task_filter if task_filter != "all" else None,
|
||||
body_filter=body_filter if body_filter != "all" else None,
|
||||
require_images=require_images,
|
||||
sample_n=sample_n,
|
||||
)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"No MedXpertQA-MM questions matched the filters; broaden --split/--task/--body-system/--n."
|
||||
)
|
||||
logger.info("MedXpertQA-MM: scheduled %d questions", len(questions))
|
||||
|
||||
api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENROUTER_API_KEY env var is required for the native arm.")
|
||||
|
||||
# Native arm slug differs from SurfSense slug only in cost-arbitrage
|
||||
# scenario; otherwise both arms answer with provider_model.
|
||||
native_arm_model = ctx.native_arm_model
|
||||
if any(hint in native_arm_model.lower() for hint in _TEXT_ONLY_HINTS):
|
||||
if ctx.scenario == "symmetric-cheap":
|
||||
logger.info(
|
||||
"symmetric-cheap: native arm pinned to text-only %r as "
|
||||
"intended; expect it to lose on image-bearing questions "
|
||||
"(SurfSense answers from vision-extracted chunks).",
|
||||
native_arm_model,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Native arm slug %r looks text-only; image content in "
|
||||
"MedXpertQA PDFs will be ignored. Re-pin via "
|
||||
"`setup --provider-model anthropic/claude-sonnet-4.5` "
|
||||
"(or pass --native-arm-model and --scenario cost-arbitrage "
|
||||
"to make this asymmetry explicit).",
|
||||
native_arm_model,
|
||||
)
|
||||
|
||||
provider = OpenRouterPdfProvider(
|
||||
api_key=api_key,
|
||||
base_url=ctx.config.openrouter_base_url,
|
||||
model=native_arm_model,
|
||||
engine=PdfEngine(pdf_engine_name),
|
||||
)
|
||||
native_arm = NativePdfArm(provider=provider, max_output_tokens=max_output_tokens)
|
||||
surf_arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
|
||||
async def _native_one(q: MXQuestion) -> ArmResult:
|
||||
return await native_arm.answer(_make_native_request(q, max_output_tokens))
|
||||
|
||||
async def _surf_one(q: MXQuestion) -> ArmResult:
|
||||
return await surf_arm.answer(_make_surfsense_request(q, no_mentions=no_mentions))
|
||||
|
||||
native_results, surf_results = await asyncio.gather(
|
||||
_gather_with_limit((_native_one(q) for q in questions), concurrency=concurrency),
|
||||
_gather_with_limit((_surf_one(q) for q in questions), concurrency=concurrency),
|
||||
)
|
||||
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, n_res, s_res in zip(questions, native_results, surf_results, strict=False):
|
||||
meta = {
|
||||
"qid": q.qid,
|
||||
"split": q.split,
|
||||
"medical_task": q.medical_task,
|
||||
"body_system": q.body_system,
|
||||
"question_type": q.question_type,
|
||||
"n_images": q.n_images,
|
||||
"correct": q.label,
|
||||
"document_id": q.document_id,
|
||||
}
|
||||
fh.write(json.dumps({**meta, **n_res.to_jsonl()}) + "\n")
|
||||
fh.write(json.dumps({**meta, **s_res.to_jsonl()}) + "\n")
|
||||
|
||||
metrics = _compute_metrics(questions, native_results, surf_results)
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_questions": len(questions),
|
||||
"concurrency": concurrency,
|
||||
"split_filter": split_filter,
|
||||
"task_filter": task_filter,
|
||||
"body_filter": body_filter,
|
||||
"require_images": require_images,
|
||||
"no_mentions": no_mentions,
|
||||
"pdf_engine": pdf_engine_name,
|
||||
"scenario": ctx.scenario,
|
||||
"provider_model": ctx.provider_model,
|
||||
"native_arm_model": native_arm_model,
|
||||
"vision_provider_model": ctx.vision_provider_model,
|
||||
"agent_llm_id": ctx.agent_llm_id,
|
||||
"ingest_settings": ingest_settings,
|
||||
},
|
||||
)
|
||||
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps({
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
}, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="MedXpertQA-MM — Native PDF (vision) vs SurfSense (vision RAG)",
|
||||
headline=False,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
m = latest.metrics
|
||||
native = m.get("native", {})
|
||||
surf = m.get("surfsense", {})
|
||||
delta = m.get("delta", {})
|
||||
per_task = m.get("per_task", {})
|
||||
per_body = m.get("per_body_system", {})
|
||||
extra = latest.extra
|
||||
|
||||
body_lines: list[str] = []
|
||||
body_lines.append(
|
||||
f"- Sample size: {extra.get('n_questions', '?')} questions "
|
||||
f"(split: `{extra.get('split_filter', 'test')}`, "
|
||||
f"task: `{extra.get('task_filter', 'all')}`, "
|
||||
f"body: `{extra.get('body_filter', 'all')}`, "
|
||||
f"engine: `{extra.get('pdf_engine', 'native')}`)."
|
||||
)
|
||||
body_lines.append(format_scenario_md(extra))
|
||||
body_lines.append(format_ingest_settings_md(extra.get("ingest_settings")))
|
||||
body_lines.append(
|
||||
"- Native arm (OpenRouter `chat/completions` + file plugin, "
|
||||
f"`{extra.get('native_arm_model') or extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(native, indent=" "))
|
||||
body_lines.append(
|
||||
"- SurfSense arm (`POST /api/v1/new_chat`, vision RAG over chunks, "
|
||||
f"`{extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(surf, indent=" "))
|
||||
body_lines.append("- Delta (paired):")
|
||||
body_lines.append(
|
||||
f" - Accuracy: SurfSense {_pp(delta.get('accuracy_pp'))} pp "
|
||||
f"(McNemar p={_fmt(delta.get('mcnemar_p_value'), 4)}, "
|
||||
f"method={delta.get('mcnemar_method')})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Bootstrap 95% CI on delta: "
|
||||
f"[{_pp(delta.get('bootstrap_ci_low'))}pp, {_pp(delta.get('bootstrap_ci_high'))}pp]"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Cost / question: native ${_dollars(native.get('cost_micros_mean'))}, "
|
||||
f"surfsense ${_dollars(surf.get('cost_micros_mean'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('cost_micros_pct'))})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Latency p50: native {_ms_to_s(native.get('latency_ms_median'))}, "
|
||||
f"surfsense {_ms_to_s(surf.get('latency_ms_median'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('latency_ms_pct'))})"
|
||||
)
|
||||
if per_task:
|
||||
body_lines.append("- Per-medical_task split:")
|
||||
for task_name, vals in sorted(per_task.items()):
|
||||
body_lines.append(
|
||||
f" - {task_name}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
|
||||
f"(n={vals.get('n')})"
|
||||
)
|
||||
if per_body:
|
||||
body_lines.append("- Per-body_system split (top 5 by sample size):")
|
||||
top = sorted(per_body.items(), key=lambda kv: -kv[1].get("n", 0))[:5]
|
||||
for body_name, vals in top:
|
||||
body_lines.append(
|
||||
f" - {body_name}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
|
||||
f"(n={vals.get('n')})"
|
||||
)
|
||||
|
||||
return ReportSection(
|
||||
title="MedXpertQA-MM — Native PDF (vision) vs SurfSense (vision RAG)",
|
||||
headline=False,
|
||||
body_md="\n".join(body_lines),
|
||||
body_json=m,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-question helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_native_request(q: MXQuestion, max_tokens: int) -> ArmRequest:
|
||||
prompt = build_prompt(q.question, q.options)
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=prompt,
|
||||
pdf_paths=[q.pdf_path],
|
||||
options={"max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
|
||||
def _make_surfsense_request(q: MXQuestion, *, no_mentions: bool) -> ArmRequest:
|
||||
prompt = build_prompt(q.question, q.options)
|
||||
mentions: list[int] | None = None
|
||||
if not no_mentions and q.document_id is not None:
|
||||
mentions = [int(q.document_id)]
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=prompt,
|
||||
mentioned_document_ids=mentions,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_metrics(
|
||||
questions: list[MXQuestion],
|
||||
native_results: list[ArmResult],
|
||||
surf_results: list[ArmResult],
|
||||
) -> dict[str, Any]:
|
||||
native_correct: list[bool] = []
|
||||
surf_correct: list[bool] = []
|
||||
for q, n_res, s_res in zip(questions, native_results, surf_results, strict=False):
|
||||
gold = q.label
|
||||
n_ok = (n_res.answer_letter or "").upper() == gold and gold in ANSWER_LETTERS
|
||||
s_ok = (s_res.answer_letter or "").upper() == gold and gold in ANSWER_LETTERS
|
||||
native_correct.append(n_ok)
|
||||
surf_correct.append(s_ok)
|
||||
|
||||
native_costs = [float(r.cost_micros) for r in native_results]
|
||||
surf_costs = [float(r.cost_micros) for r in surf_results]
|
||||
native_lats = [float(r.latency_ms) for r in native_results]
|
||||
surf_lats = [float(r.latency_ms) for r in surf_results]
|
||||
native_in = [float(r.input_tokens) for r in native_results]
|
||||
native_out = [float(r.output_tokens) for r in native_results]
|
||||
|
||||
native_acc = accuracy_with_wilson_ci(sum(native_correct), len(native_correct))
|
||||
surf_acc = accuracy_with_wilson_ci(sum(surf_correct), len(surf_correct))
|
||||
mc = mcnemar_test(native_correct, surf_correct)
|
||||
boot = bootstrap_delta_ci(native_correct, surf_correct, n_resamples=2000)
|
||||
|
||||
native_cost_agg = paired_aggregate(native_costs)
|
||||
surf_cost_agg = paired_aggregate(surf_costs)
|
||||
native_lat_agg = paired_aggregate(native_lats)
|
||||
surf_lat_agg = paired_aggregate(surf_lats)
|
||||
|
||||
cost_pct = _safe_pct(surf_cost_agg.mean, native_cost_agg.mean)
|
||||
lat_pct = _safe_pct(surf_lat_agg.median, native_lat_agg.median)
|
||||
|
||||
per_task = _per_field(questions, native_correct, surf_correct, key=lambda q: q.medical_task or "unknown")
|
||||
per_body = _per_field(questions, native_correct, surf_correct, key=lambda q: q.body_system or "unknown")
|
||||
|
||||
return {
|
||||
"native": {
|
||||
**native_acc.to_dict(),
|
||||
"cost_micros_mean": native_cost_agg.mean,
|
||||
"cost_micros_median": native_cost_agg.median,
|
||||
"latency_ms_mean": native_lat_agg.mean,
|
||||
"latency_ms_median": native_lat_agg.median,
|
||||
"latency_ms_p95": native_lat_agg.p95,
|
||||
"input_tokens_mean": (sum(native_in) / len(native_in)) if native_in else 0.0,
|
||||
"output_tokens_mean": (sum(native_out) / len(native_out)) if native_out else 0.0,
|
||||
},
|
||||
"surfsense": {
|
||||
**surf_acc.to_dict(),
|
||||
"cost_micros_mean": surf_cost_agg.mean,
|
||||
"cost_micros_median": surf_cost_agg.median,
|
||||
"latency_ms_mean": surf_lat_agg.mean,
|
||||
"latency_ms_median": surf_lat_agg.median,
|
||||
"latency_ms_p95": surf_lat_agg.p95,
|
||||
},
|
||||
"delta": {
|
||||
"accuracy_pp": 100.0 * (surf_acc.accuracy - native_acc.accuracy),
|
||||
"mcnemar_p_value": mc.p_value,
|
||||
"mcnemar_method": mc.method,
|
||||
"mcnemar_b_native_only": mc.b,
|
||||
"mcnemar_c_surfsense_only": mc.c,
|
||||
"bootstrap_ci_low": 100.0 * boot.ci_low,
|
||||
"bootstrap_ci_high": 100.0 * boot.ci_high,
|
||||
"cost_micros_pct": cost_pct,
|
||||
"latency_ms_pct": lat_pct,
|
||||
},
|
||||
"per_task": per_task,
|
||||
"per_body_system": per_body,
|
||||
}
|
||||
|
||||
|
||||
def _per_field(
|
||||
questions: list[MXQuestion],
|
||||
native_correct: list[bool],
|
||||
surf_correct: list[bool],
|
||||
*,
|
||||
key,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
bucket: dict[str, list[tuple[bool, bool]]] = {}
|
||||
for q, n_ok, s_ok in zip(questions, native_correct, surf_correct, strict=False):
|
||||
bucket.setdefault(key(q), []).append((n_ok, s_ok))
|
||||
out: dict[str, dict[str, Any]] = {}
|
||||
for k, pairs in bucket.items():
|
||||
n_correct = [a for a, _ in pairs]
|
||||
s_correct = [b for _, b in pairs]
|
||||
out[k] = {
|
||||
"n": len(pairs),
|
||||
"native_accuracy": (sum(n_correct) / len(pairs)) if pairs else 0.0,
|
||||
"surfsense_accuracy": (sum(s_correct) / len(pairs)) if pairs else 0.0,
|
||||
"delta_accuracy_pp": (
|
||||
100.0 * (sum(s_correct) - sum(n_correct)) / len(pairs)
|
||||
if pairs else 0.0
|
||||
),
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def _safe_pct(numerator: float, denominator: float) -> float | None:
|
||||
if denominator == 0:
|
||||
return None
|
||||
return 100.0 * (numerator - denominator) / denominator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Formatters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _arm_summary_lines(d: dict[str, Any], *, indent: str) -> str:
|
||||
if not d:
|
||||
return f"{indent}(no data)"
|
||||
acc = d.get("accuracy", 0.0)
|
||||
low = d.get("ci_low", 0.0)
|
||||
high = d.get("ci_high", 0.0)
|
||||
lines = [
|
||||
f"{indent}- Accuracy: {acc * 100:.1f}% (Wilson 95% CI: {low * 100:.1f}% – {high * 100:.1f}%)",
|
||||
f"{indent}- Cost / question: ${_dollars(d.get('cost_micros_mean'))} (mean), "
|
||||
f"${_dollars(d.get('cost_micros_median'))} (median)",
|
||||
f"{indent}- Latency: p50 {_ms_to_s(d.get('latency_ms_median'))}, "
|
||||
f"p95 {_ms_to_s(d.get('latency_ms_p95'))}",
|
||||
]
|
||||
if "input_tokens_mean" in d:
|
||||
lines.append(
|
||||
f"{indent}- Mean tokens / question: in {d.get('input_tokens_mean', 0):.0f}, "
|
||||
f"out {d.get('output_tokens_mean', 0):.0f}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _dollars(micros: Any) -> str:
|
||||
if micros is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{(float(micros) / 1_000_000):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _ms_to_s(ms: Any) -> str:
|
||||
if ms is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(ms) / 1000:.1f}s"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pp(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.1f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pct_change(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.0f}%"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _fmt(value: Any, ndigits: int) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):.{ndigits}f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
__all__ = ["MedXpertQAMMBenchmark", "MXQuestion"]
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
"""MIRAGE — secondary single-arm SurfSense MCQ measurement.
|
||||
|
||||
Source: https://github.com/Teddy-XiongGZ/MIRAGE, paper
|
||||
https://aclanthology.org/2024.findings-acl.372/. 7,663 questions
|
||||
across MMLU-Med, MedQA-US, MedMCQA, PubMedQA*, BioASQ-Y/N.
|
||||
|
||||
This is a SurfSense-only measurement (not a head-to-head); native
|
||||
PDF-in-LLM doesn't apply because there is no per-question discrete
|
||||
document — the corpus is millions of biomedical snippets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .runner import MirageBenchmark
|
||||
from ....core import registry as _registry
|
||||
|
||||
_registry.register(MirageBenchmark())
|
||||
|
|
@ -0,0 +1,548 @@
|
|||
"""MIRAGE ingestion.
|
||||
|
||||
Downloads:
|
||||
|
||||
* ``benchmark.json`` (≈ 4 MB; questions for the 5 sub-tasks).
|
||||
* ``retrieved_snippets_10k.zip`` (the union of top-10k snippet ids
|
||||
retrieved by every retriever in the MedRAG paper, per task — a
|
||||
recall ceiling that avoids needing the full 23.9M-doc PubMed mirror).
|
||||
|
||||
Snippet *content* lives in the MedRAG HF mirrors
|
||||
(``MedRAG/textbooks``, ``MedRAG/pubmed``, ``MedRAG/statpearls``,
|
||||
``MedRAG/wikipedia``). We default to ``MedRAG/textbooks`` (212 MB,
|
||||
125k snippets) which is the smallest and covers the majority of
|
||||
``MedQA-US`` and the medical examination subsets. Operators can
|
||||
opt into larger corpora with ``--corpus``.
|
||||
|
||||
Each snippet is written as one markdown file then batched into
|
||||
``~5 MB`` markdown bundles for SurfSense's file upload (smaller
|
||||
than backend default ``MAX_FILE_SIZE_BYTES`` and avoids the per-call
|
||||
overhead of one HTTP request per snippet).
|
||||
|
||||
The ingestion produces two maps under ``data/medical/maps/``:
|
||||
|
||||
* ``mirage_snippet_map.jsonl`` — ``{snippet_id, document_id, batch_path}``
|
||||
* ``mirage_chunk_map.jsonl`` — ``{chunk_id, document_id, snippet_id?}``
|
||||
(best-effort; chunk text is heuristically attributed to the
|
||||
snippet it overlaps when the SurfSense chunker splits a batched
|
||||
markdown).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MIRAGE_BENCHMARK_URL = (
|
||||
"https://raw.githubusercontent.com/Teddy-XiongGZ/MIRAGE/main/benchmark.json"
|
||||
)
|
||||
# Upstream only ships ONE zip — top-10k retrievals across 5 retrievers,
|
||||
# ~16 GB. We default to skipping it (see `--skip-snippet-filter`) and
|
||||
# ingesting the chosen corpus in full; this URL is only fetched when
|
||||
# the operator explicitly opts in.
|
||||
MIRAGE_SNIPPETS_ZIP_URL = (
|
||||
"https://virginia.box.com/shared/static/cxq17th6eisl2pn04vp0x723zczlvlzc.zip"
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_CORPUS = "MedRAG/textbooks"
|
||||
_BATCH_SIZE_BYTES = 5 * 1024 * 1024
|
||||
# 2 GB safety cap. Anything larger requires --allow-large-download.
|
||||
# Set high enough that ``benchmark.json`` and small zips pass through
|
||||
# untouched but the 16 GB MIRAGE retrievals zip trips the guard.
|
||||
_LARGE_DOWNLOAD_BYTES = 2 * 1024 * 1024 * 1024
|
||||
_DOWNLOAD_RETRIES = 5
|
||||
_RETRYABLE_NET_EXC: tuple[type[BaseException], ...] = (
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnippetRow:
|
||||
snippet_id: str
|
||||
title: str
|
||||
content: str
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
title = (self.title or "").strip() or "Untitled"
|
||||
body = (self.content or "").strip()
|
||||
return f"# {title}\n\n_id: `{self.snippet_id}`_\n\n{body}\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _fetch_to_path(
|
||||
url: str,
|
||||
*,
|
||||
dest: Path,
|
||||
label: str,
|
||||
timeout_s: float = 600.0,
|
||||
allow_large_download: bool = False,
|
||||
expect_zip: bool = False,
|
||||
) -> Path:
|
||||
"""Download ``url`` to ``dest`` with retry, atomic-rename, and
|
||||
HTTP ``Range`` resume.
|
||||
|
||||
Operational properties:
|
||||
|
||||
* If ``dest`` already exists *and* (when ``expect_zip`` is True) the
|
||||
cached file is a valid ZIP, returns it immediately. A corrupt ZIP
|
||||
is removed and re-downloaded — this is the safety net for the
|
||||
`box.com truncated 16 GB zip` failure mode where the previous
|
||||
run wrote a half-completed file then exited with an exception.
|
||||
* Bytes are written to ``<dest>.partial`` and renamed only after the
|
||||
stream completes cleanly (and, for zips, only after a quick
|
||||
central-directory check). A failure mid-download leaves the
|
||||
``.partial`` file in place so the next attempt can resume from
|
||||
where it stopped via an HTTP ``Range`` header.
|
||||
* Retries on transient network errors (``RemoteProtocolError``,
|
||||
``ReadError``, ``ReadTimeout``, ``ConnectError``,
|
||||
``ConnectTimeout``) with exponential backoff, up to
|
||||
``_DOWNLOAD_RETRIES``.
|
||||
* Aborts before downloading if the ``Content-Length`` (or already-
|
||||
downloaded ``.partial`` size) is over ``_LARGE_DOWNLOAD_BYTES``
|
||||
and ``allow_large_download`` is False, to keep an operator from
|
||||
surprise-grabbing 16 GB on a slow link.
|
||||
"""
|
||||
|
||||
if dest.exists():
|
||||
if expect_zip and not _is_valid_zip(dest):
|
||||
logger.warning(
|
||||
"Cached %s at %s failed ZIP validation (size=%d B); deleting "
|
||||
"and re-downloading.",
|
||||
label,
|
||||
dest,
|
||||
dest.stat().st_size,
|
||||
)
|
||||
dest.unlink(missing_ok=True)
|
||||
else:
|
||||
logger.info("Using cached %s at %s", label, dest)
|
||||
return dest
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
partial = dest.with_suffix(dest.suffix + ".partial")
|
||||
last_exc: BaseException | None = None
|
||||
|
||||
for attempt in range(1, _DOWNLOAD_RETRIES + 1):
|
||||
existing_bytes = partial.stat().st_size if partial.exists() else 0
|
||||
headers: dict[str, str] = {}
|
||||
if existing_bytes:
|
||||
headers["Range"] = f"bytes={existing_bytes}-"
|
||||
logger.info(
|
||||
"Resuming %s from byte %d (attempt %d/%d)",
|
||||
label,
|
||||
existing_bytes,
|
||||
attempt,
|
||||
_DOWNLOAD_RETRIES,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Downloading %s from %s (attempt %d/%d)",
|
||||
label,
|
||||
url,
|
||||
attempt,
|
||||
_DOWNLOAD_RETRIES,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout_s, connect=20.0),
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
async with client.stream("GET", url, headers=headers) as response:
|
||||
if existing_bytes and response.status_code == 200:
|
||||
logger.warning(
|
||||
"Server ignored Range header for %s; restarting from 0.",
|
||||
label,
|
||||
)
|
||||
partial.unlink(missing_ok=True)
|
||||
existing_bytes = 0
|
||||
elif response.status_code == 416:
|
||||
# Range not satisfiable — the .partial is at or
|
||||
# past the end. Treat as "already downloaded";
|
||||
# validate by closing and re-opening for atomic
|
||||
# rename below.
|
||||
logger.info(
|
||||
"Server reports %s already complete (HTTP 416).",
|
||||
label,
|
||||
)
|
||||
elif response.status_code not in (200, 206):
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = _planned_total_size(response, existing_bytes)
|
||||
if (
|
||||
total_size is not None
|
||||
and total_size > _LARGE_DOWNLOAD_BYTES
|
||||
and not allow_large_download
|
||||
):
|
||||
raise _LargeDownloadAbort(label, total_size)
|
||||
|
||||
mode = "ab" if existing_bytes else "wb"
|
||||
with partial.open(mode) as fh:
|
||||
async for chunk in response.aiter_bytes(chunk_size=1 << 18):
|
||||
fh.write(chunk)
|
||||
# Optional content sanity check before promoting to dest.
|
||||
if expect_zip and not _is_valid_zip(partial):
|
||||
raise zipfile.BadZipFile(
|
||||
f"{label} downloaded to {partial} but failed central-"
|
||||
"directory check; will retry."
|
||||
)
|
||||
partial.replace(dest)
|
||||
return dest
|
||||
except _LargeDownloadAbort:
|
||||
raise
|
||||
except _RETRYABLE_NET_EXC as exc:
|
||||
last_exc = exc
|
||||
wait = min(60.0, 2.0 ** attempt)
|
||||
logger.warning(
|
||||
"Network error fetching %s (%s: %s); retrying in %.0fs.",
|
||||
label,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
except zipfile.BadZipFile as exc:
|
||||
last_exc = exc
|
||||
# Truncated body — drop the partial and retry from scratch.
|
||||
partial.unlink(missing_ok=True)
|
||||
wait = min(60.0, 2.0 ** attempt)
|
||||
logger.warning(
|
||||
"Truncated ZIP for %s; restarting from byte 0 in %.0fs.",
|
||||
label,
|
||||
wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to download {label} after {_DOWNLOAD_RETRIES} attempts: {last_exc!s}"
|
||||
)
|
||||
|
||||
|
||||
def _planned_total_size(response: httpx.Response, existing_bytes: int) -> int | None:
|
||||
"""Best-effort total size including any already-buffered .partial bytes."""
|
||||
|
||||
cl = response.headers.get("Content-Length")
|
||||
if not cl:
|
||||
return None
|
||||
try:
|
||||
remaining = int(cl)
|
||||
except ValueError:
|
||||
return None
|
||||
return existing_bytes + remaining
|
||||
|
||||
|
||||
def _is_valid_zip(path: Path) -> bool:
|
||||
"""Cheap ZIP validity check via central-directory parse."""
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(path) as zf:
|
||||
# ``namelist`` forces the central directory to be parsed.
|
||||
zf.namelist()
|
||||
return True
|
||||
except (zipfile.BadZipFile, OSError):
|
||||
return False
|
||||
|
||||
|
||||
class _LargeDownloadAbort(RuntimeError):
|
||||
"""Raised when a download exceeds the safety threshold without opt-in."""
|
||||
|
||||
def __init__(self, label: str, size_bytes: int) -> None:
|
||||
gb = size_bytes / (1024 ** 3)
|
||||
super().__init__(
|
||||
f"{label} would download ~{gb:.1f} GB, above the {_LARGE_DOWNLOAD_BYTES / (1024 ** 3):.0f} GB safety cap. "
|
||||
"Re-run with `--allow-large-download` to acknowledge, or use "
|
||||
"`--skip-snippet-filter` to bypass this download entirely and "
|
||||
"ingest the full corpus instead."
|
||||
)
|
||||
|
||||
|
||||
def _read_snippet_ids(zip_path: Path, *, tasks: list[str]) -> dict[str, set[str]]:
|
||||
"""Walk the ZIP for files whose path contains any task name.
|
||||
|
||||
Each MedRAG retriever produces one JSON file per task in the zip;
|
||||
we union all retrievers' top-K ids. The exact directory layout has
|
||||
historically been ``<retriever>/<task>.json`` mapping
|
||||
``question_id -> [snippet_id, ...]``.
|
||||
"""
|
||||
|
||||
out: dict[str, set[str]] = {t: set() for t in tasks}
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
for member in zf.namelist():
|
||||
if not member.lower().endswith(".json"):
|
||||
continue
|
||||
stem = Path(member).stem.lower()
|
||||
for task in tasks:
|
||||
if task.lower() in stem:
|
||||
try:
|
||||
with zf.open(member) as fh:
|
||||
payload = json.loads(fh.read().decode("utf-8"))
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
for ids in payload.values():
|
||||
if isinstance(ids, list):
|
||||
for sid in ids:
|
||||
if isinstance(sid, str):
|
||||
out[task].add(sid)
|
||||
elif isinstance(sid, dict) and "id" in sid:
|
||||
out[task].add(str(sid["id"]))
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def _load_corpus(
|
||||
corpus_name: str, snippet_ids: set[str] | None
|
||||
) -> Iterable[SnippetRow]:
|
||||
"""Stream rows from a MedRAG HF corpus.
|
||||
|
||||
* ``snippet_ids=None`` → yield every row (full-corpus ingestion path).
|
||||
* ``snippet_ids={...}`` → filter to the requested ids.
|
||||
|
||||
Imported lazily — ``datasets`` is a heavyweight dep.
|
||||
"""
|
||||
|
||||
if snippet_ids is not None and not snippet_ids:
|
||||
return iter(())
|
||||
from datasets import load_dataset # noqa: PLC0415
|
||||
|
||||
logger.info("Loading corpus %s (this may take a while)", corpus_name)
|
||||
ds = load_dataset(corpus_name, split="train", streaming=True)
|
||||
for row in ds:
|
||||
sid = str(row.get("id") or "")
|
||||
if snippet_ids is not None and sid not in snippet_ids:
|
||||
continue
|
||||
yield SnippetRow(
|
||||
snippet_id=sid,
|
||||
title=str(row.get("title") or ""),
|
||||
content=str(row.get("content") or row.get("contents") or ""),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batching + upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnippetBatch:
|
||||
path: Path
|
||||
snippet_ids: list[str]
|
||||
|
||||
|
||||
def _write_batches(
|
||||
snippets: Iterable[SnippetRow],
|
||||
*,
|
||||
out_dir: Path,
|
||||
batch_bytes: int = _BATCH_SIZE_BYTES,
|
||||
prefix: str = "mirage",
|
||||
) -> list[SnippetBatch]:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
batches: list[SnippetBatch] = []
|
||||
current_buffer = io.StringIO()
|
||||
current_ids: list[str] = []
|
||||
current_bytes = 0
|
||||
batch_idx = 0
|
||||
|
||||
def _flush() -> None:
|
||||
nonlocal current_buffer, current_ids, current_bytes, batch_idx
|
||||
if not current_ids:
|
||||
return
|
||||
path = out_dir / f"{prefix}_{batch_idx:04d}.md"
|
||||
path.write_text(current_buffer.getvalue(), encoding="utf-8")
|
||||
batches.append(SnippetBatch(path=path, snippet_ids=current_ids))
|
||||
batch_idx += 1
|
||||
current_buffer = io.StringIO()
|
||||
current_ids = []
|
||||
current_bytes = 0
|
||||
|
||||
for snippet in snippets:
|
||||
chunk = snippet.to_markdown() + "\n---\n\n"
|
||||
chunk_bytes = len(chunk.encode("utf-8"))
|
||||
if current_bytes + chunk_bytes > batch_bytes and current_ids:
|
||||
_flush()
|
||||
current_buffer.write(chunk)
|
||||
current_ids.append(snippet.snippet_id)
|
||||
current_bytes += chunk_bytes
|
||||
_flush()
|
||||
return batches
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
tasks: list[str] | None = None,
|
||||
corpus: str = _DEFAULT_CORPUS,
|
||||
max_snippets_per_task: int | None = None,
|
||||
skip_snippet_filter: bool = True,
|
||||
allow_large_download: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
"""Ingest a MedRAG corpus into the suite SearchSpace.
|
||||
|
||||
By default (``skip_snippet_filter=True``) we ingest the **entire**
|
||||
chosen corpus and let SurfSense's own retriever do the work. The
|
||||
upstream MIRAGE retrieval zip is ~16 GB and only useful when you
|
||||
want to pre-filter the corpus to the set of snippets some other
|
||||
retriever surfaced; for ``MedRAG/textbooks`` (212 MB / 125k snippets)
|
||||
that pre-filter is unnecessary overhead and routinely fails to
|
||||
download (box.com truncates the stream). Set
|
||||
``skip_snippet_filter=False`` (CLI: ``--use-snippet-filter``) only
|
||||
if you specifically want the upstream filter — and budget the
|
||||
16 GB zip transfer.
|
||||
"""
|
||||
|
||||
tasks = tasks or ["mmlu", "medqa", "medmcqa", "pubmedqa", "bioasq"]
|
||||
settings = settings or IngestSettings(use_vision_llm=False, processing_mode="basic")
|
||||
|
||||
bench_path = ctx.benchmark_data_dir() / "benchmark.json"
|
||||
await _fetch_to_path(MIRAGE_BENCHMARK_URL, dest=bench_path, label="MIRAGE benchmark.json")
|
||||
|
||||
if skip_snippet_filter:
|
||||
logger.info(
|
||||
"Skipping retrieved_snippets_10k.zip (skip_snippet_filter=True); "
|
||||
"ingesting entire corpus %s.",
|
||||
corpus,
|
||||
)
|
||||
snippets = list(_load_corpus(corpus, snippet_ids=None))
|
||||
else:
|
||||
zip_path = ctx.benchmark_data_dir() / "retrieved_snippets_10k.zip"
|
||||
await _fetch_to_path(
|
||||
MIRAGE_SNIPPETS_ZIP_URL,
|
||||
dest=zip_path,
|
||||
label="MIRAGE retrieved_snippets_10k.zip",
|
||||
allow_large_download=allow_large_download,
|
||||
expect_zip=True,
|
||||
)
|
||||
|
||||
by_task = _read_snippet_ids(zip_path, tasks=tasks)
|
||||
if max_snippets_per_task is not None:
|
||||
by_task = {k: set(list(v)[:max_snippets_per_task]) for k, v in by_task.items()}
|
||||
|
||||
union_ids = set().union(*by_task.values())
|
||||
logger.info(
|
||||
"MIRAGE: tasks=%s, snippet ids per task: %s, union=%d",
|
||||
tasks,
|
||||
{k: len(v) for k, v in by_task.items()},
|
||||
len(union_ids),
|
||||
)
|
||||
if not union_ids:
|
||||
raise RuntimeError(
|
||||
f"No snippet ids parsed for tasks {tasks!r} from {zip_path}. "
|
||||
"Check the zip layout (the upstream archive may have changed)."
|
||||
)
|
||||
|
||||
snippets = list(_load_corpus(corpus, snippet_ids=union_ids))
|
||||
logger.info(
|
||||
"Loaded %d / %d requested snippets from corpus %s",
|
||||
len(snippets),
|
||||
len(union_ids),
|
||||
corpus,
|
||||
)
|
||||
if not snippets:
|
||||
raise RuntimeError(
|
||||
f"Corpus {corpus} returned 0 matching rows. Either the snippet "
|
||||
"ids reference a different corpus (e.g. PubMed) or the HF mirror "
|
||||
"is unavailable. Pass --corpus to override."
|
||||
)
|
||||
|
||||
batches_dir = ctx.benchmark_data_dir() / "batches"
|
||||
batches = _write_batches(snippets, out_dir=batches_dir)
|
||||
logger.info("Wrote %d snippet batches to %s", len(batches), batches_dir)
|
||||
|
||||
docs_client = ctx.documents_client()
|
||||
upload_result = await docs_client.upload(
|
||||
files=[b.path for b in batches],
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
logger.info("MIRAGE upload settings: %s", settings.render_label())
|
||||
new_doc_ids = list(upload_result.document_ids)
|
||||
if new_doc_ids:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=new_doc_ids,
|
||||
timeout_s=3600.0,
|
||||
max_poll_s=15.0,
|
||||
)
|
||||
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=new_doc_ids + upload_result.duplicate_document_ids,
|
||||
)
|
||||
title_to_doc = {s.title: s.document_id for s in statuses}
|
||||
|
||||
snippet_map_path = ctx.maps_dir() / "mirage_snippet_map.jsonl"
|
||||
chunk_map_path = ctx.maps_dir() / "mirage_chunk_map.jsonl"
|
||||
with snippet_map_path.open("w", encoding="utf-8") as fh:
|
||||
# Header line records the ingest-time settings (see
|
||||
# core/ingest_settings.py for the protocol).
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for batch in batches:
|
||||
doc_id = title_to_doc.get(batch.path.name)
|
||||
if doc_id is None:
|
||||
logger.warning("No document_id for batch %s", batch.path.name)
|
||||
continue
|
||||
for sid in batch.snippet_ids:
|
||||
fh.write(
|
||||
json.dumps(
|
||||
{
|
||||
"snippet_id": sid,
|
||||
"document_id": doc_id,
|
||||
"batch_path": str(batch.path),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
# Best-effort chunk map. SurfSense doesn't expose snippet attribution
|
||||
# per chunk, so we just record (chunk_id -> document_id) here; the
|
||||
# MIRAGE runner only needs document_id for accuracy scoring.
|
||||
with chunk_map_path.open("w", encoding="utf-8") as fh:
|
||||
for doc_id in {b.path.name and title_to_doc.get(b.path.name) for b in batches} - {None}:
|
||||
try:
|
||||
chunks = await docs_client.list_chunks(int(doc_id))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to list chunks for doc_id=%s: %s", doc_id, exc)
|
||||
continue
|
||||
for chunk in chunks:
|
||||
fh.write(
|
||||
json.dumps({"chunk_id": chunk.id, "document_id": doc_id})
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["mirage"] = str(snippet_map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
logger.info("Wrote MIRAGE maps to %s and %s", snippet_map_path, chunk_map_path)
|
||||
|
||||
|
||||
__all__ = ["run_ingest", "SnippetRow", "SnippetBatch"]
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
"""MedRAG ``{step_by_step_thinking, answer_choice}`` MCQ prompt.
|
||||
|
||||
Mirrors the MedRAG paper's prompt format so accuracy numbers are
|
||||
comparable to the published MIRAGE leaderboard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """\
|
||||
You are a helpful medical expert. Answer the following multiple-choice
|
||||
question using the relevant medical knowledge available to you (and any
|
||||
retrieved context, if provided).
|
||||
|
||||
Respond with a JSON object on a single line:
|
||||
{{"step_by_step_thinking": "<your reasoning>", "answer_choice": "<letter>"}}
|
||||
|
||||
Question: {question}
|
||||
|
||||
Options:
|
||||
{options_block}
|
||||
"""
|
||||
|
||||
|
||||
def _options_block(options: Mapping[str, str]) -> str:
|
||||
parts: list[str] = []
|
||||
for letter in sorted(options.keys()):
|
||||
text = options.get(letter)
|
||||
if text is None or text == "":
|
||||
continue
|
||||
parts.append(f"{letter}) {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def build_prompt(question: str, options: Mapping[str, str]) -> str:
|
||||
return _PROMPT_TEMPLATE.format(
|
||||
question=question.strip(),
|
||||
options_block=_options_block(options),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["build_prompt"]
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
"""MIRAGE runner: SurfSense-only per-task accuracy.
|
||||
|
||||
The benchmark file format is one top-level dict per task (``mmlu``,
|
||||
``medqa``, ``medmcqa``, ``pubmedqa``, ``bioasq``); each task value is
|
||||
``{question_id: {question, options, answer}}``.
|
||||
|
||||
We restrict retrieval to the suite SearchSpace's full corpus (no
|
||||
``mentioned_document_ids`` — MIRAGE has no per-question ground-truth
|
||||
document; retrieval *is* the test). Accuracy is paired against the
|
||||
``answer`` letter from the dataset.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
read_settings_header,
|
||||
)
|
||||
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci, macro_accuracy
|
||||
from ....core.registry import (
|
||||
Benchmark,
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
)
|
||||
from .prompt import build_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_TASKS = ("mmlu", "medqa", "medmcqa", "pubmedqa", "bioasq")
|
||||
_DESCRIPTION = "MIRAGE (7,663 medical MCQs) — single-arm SurfSense per-task accuracy."
|
||||
|
||||
# MIRAGE corpus is text-only (textbook + abstract markdown). Vision
|
||||
# LLM at ingest is wasted compute by default; flip ``--use-vision-llm``
|
||||
# to measure cost.
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MirageQuestion:
|
||||
task: str
|
||||
qid: str
|
||||
question: str
|
||||
options: dict[str, str]
|
||||
correct: str
|
||||
|
||||
@property
|
||||
def question_id(self) -> str:
|
||||
return f"{self.task}::{self.qid}"
|
||||
|
||||
|
||||
def _load_questions(
|
||||
benchmark: dict[str, Any],
|
||||
*,
|
||||
tasks: list[str],
|
||||
sample_n: int | None,
|
||||
) -> list[MirageQuestion]:
|
||||
out: list[MirageQuestion] = []
|
||||
for task in tasks:
|
||||
rows = benchmark.get(task) or {}
|
||||
if not isinstance(rows, dict):
|
||||
continue
|
||||
for qid, raw in rows.items():
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
options = raw.get("options") or {}
|
||||
if not isinstance(options, dict):
|
||||
continue
|
||||
answer_raw = str(raw.get("answer") or "").strip()
|
||||
if not answer_raw:
|
||||
continue
|
||||
answer_letter = answer_raw[:1].upper()
|
||||
out.append(
|
||||
MirageQuestion(
|
||||
task=task,
|
||||
qid=str(qid),
|
||||
question=str(raw.get("question", "")),
|
||||
options={str(k): str(v) for k, v in options.items() if v},
|
||||
correct=answer_letter,
|
||||
)
|
||||
)
|
||||
out.sort(key=lambda q: (q.task, q.qid))
|
||||
if sample_n is not None and sample_n > 0:
|
||||
# Stratified-by-task slice so smoke runs cover every task.
|
||||
per_task = max(1, sample_n // max(1, len(tasks)))
|
||||
sliced: list[MirageQuestion] = []
|
||||
per_task_counter: dict[str, int] = {}
|
||||
for q in out:
|
||||
n = per_task_counter.get(q.task, 0)
|
||||
if n >= per_task:
|
||||
continue
|
||||
sliced.append(q)
|
||||
per_task_counter[q.task] = n + 1
|
||||
if len(sliced) >= sample_n:
|
||||
break
|
||||
out = sliced
|
||||
return out
|
||||
|
||||
|
||||
async def _gather_with_limit(coros, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(c):
|
||||
async with sem:
|
||||
return await c
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
class MirageBenchmark:
|
||||
suite: str = "medical"
|
||||
name: str = "mirage"
|
||||
headline: bool = False
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
default="all",
|
||||
choices=("all", *_TASKS),
|
||||
help="Run a single task or all (default: all).",
|
||||
)
|
||||
parser.add_argument("--n", dest="sample_n", type=int, default=None,
|
||||
help="Stratified sample size across tasks.")
|
||||
parser.add_argument("--concurrency", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--corpus", default="MedRAG/textbooks",
|
||||
help="HF MedRAG corpus to ingest from (default: MedRAG/textbooks).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-snippets-per-task", type=int, default=None,
|
||||
help="Cap the per-task ingestion to N snippets (smoke).",
|
||||
)
|
||||
# Mutually exclusive: by default we skip the upstream 16 GB
|
||||
# retrievals zip and ingest the entire corpus. Operators who
|
||||
# want the upstream pre-filter pass --use-snippet-filter (and,
|
||||
# if their corpus mismatch warrants the 16 GB transfer,
|
||||
# --allow-large-download).
|
||||
snippet_group = parser.add_mutually_exclusive_group()
|
||||
snippet_group.add_argument(
|
||||
"--use-snippet-filter", dest="use_snippet_filter", action="store_true",
|
||||
default=False,
|
||||
help="Download retrieved_snippets_10k.zip (~16 GB) and "
|
||||
"filter the corpus to those ids before ingest. "
|
||||
"Default: skip and ingest entire corpus.",
|
||||
)
|
||||
snippet_group.add_argument(
|
||||
"--skip-snippet-filter", dest="use_snippet_filter", action="store_false",
|
||||
help="(Default) Skip the 16 GB upstream zip; ingest entire corpus.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow-large-download", action="store_true", default=False,
|
||||
help="Permit downloads larger than 2 GB (e.g. retrieved_snippets_10k.zip).",
|
||||
)
|
||||
# Per-upload knobs; ignored at run-time (runner reads the
|
||||
# resolved settings out of the snippet-map manifest header).
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
corpus=str(opts.get("corpus") or "MedRAG/textbooks"),
|
||||
max_snippets_per_task=opts.get("max_snippets_per_task"),
|
||||
skip_snippet_filter=not bool(opts.get("use_snippet_filter")),
|
||||
allow_large_download=bool(opts.get("allow_large_download")),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
task_filter = opts.get("task") or "all"
|
||||
tasks = list(_TASKS) if task_filter == "all" else [task_filter]
|
||||
sample_n = opts.get("sample_n")
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
|
||||
bench_path = ctx.benchmark_data_dir() / "benchmark.json"
|
||||
if not bench_path.exists():
|
||||
raise RuntimeError(
|
||||
"MIRAGE benchmark.json missing. Run "
|
||||
"`python -m surfsense_evals ingest medical mirage` first."
|
||||
)
|
||||
benchmark = json.loads(bench_path.read_text(encoding="utf-8"))
|
||||
ingest_settings = read_settings_header(
|
||||
ctx.maps_dir() / "mirage_snippet_map.jsonl"
|
||||
)
|
||||
questions = _load_questions(benchmark, tasks=tasks, sample_n=sample_n)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
f"No MIRAGE questions matched task={task_filter!r} sample_n={sample_n!r}."
|
||||
)
|
||||
logger.info("MIRAGE: scheduled %d questions across tasks %s",
|
||||
len(questions), tasks)
|
||||
|
||||
arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
async def _ask(q: MirageQuestion) -> ArmResult:
|
||||
request = ArmRequest(
|
||||
question_id=q.question_id,
|
||||
prompt=build_prompt(q.question, q.options),
|
||||
)
|
||||
return await arm.answer(request)
|
||||
|
||||
results: list[ArmResult] = await _gather_with_limit(
|
||||
(_ask(q) for q in questions), concurrency=concurrency
|
||||
)
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, res in zip(questions, results):
|
||||
fh.write(
|
||||
json.dumps(
|
||||
{
|
||||
"task": q.task,
|
||||
"qid": q.qid,
|
||||
"correct": q.correct,
|
||||
**res.to_jsonl(),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
per_task_acc: dict[str, dict[str, Any]] = {}
|
||||
for task in tasks:
|
||||
n_correct = 0
|
||||
n_total = 0
|
||||
for q, res in zip(questions, results):
|
||||
if q.task != task:
|
||||
continue
|
||||
n_total += 1
|
||||
if (res.answer_letter or "").upper() == q.correct:
|
||||
n_correct += 1
|
||||
acc = accuracy_with_wilson_ci(n_correct, n_total)
|
||||
per_task_acc[task] = acc.to_dict()
|
||||
|
||||
macro = macro_accuracy(
|
||||
{t: accuracy_with_wilson_ci(d["n_correct"], d["n_total"]) for t, d in per_task_acc.items()}
|
||||
)
|
||||
metrics = {"per_task": per_task_acc, "macro_accuracy": macro}
|
||||
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_questions": len(questions),
|
||||
"task_filter": task_filter,
|
||||
"concurrency": concurrency,
|
||||
"provider_model": ctx.provider_model,
|
||||
"ingest_settings": ingest_settings,
|
||||
},
|
||||
)
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="MIRAGE — single-arm SurfSense per-task accuracy",
|
||||
headline=False,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
per_task = latest.metrics.get("per_task", {})
|
||||
macro = latest.metrics.get("macro_accuracy", 0.0)
|
||||
lines: list[str] = []
|
||||
lines.append(format_ingest_settings_md(latest.extra.get("ingest_settings")))
|
||||
for task in _TASKS:
|
||||
row = per_task.get(task)
|
||||
if not row:
|
||||
continue
|
||||
acc = row.get("accuracy", 0.0)
|
||||
low = row.get("ci_low", 0.0)
|
||||
high = row.get("ci_high", 0.0)
|
||||
lines.append(
|
||||
f"- {task}: {acc * 100:.1f}% "
|
||||
f"(Wilson 95% CI: {low * 100:.1f}% – {high * 100:.1f}%, "
|
||||
f"n={row.get('n_total', '?')})"
|
||||
)
|
||||
if not lines:
|
||||
lines.append("- (no per-task results)")
|
||||
lines.append(f"- Macro accuracy: {macro * 100:.2f}%")
|
||||
return ReportSection(
|
||||
title="MIRAGE — single-arm SurfSense per-task accuracy",
|
||||
headline=False,
|
||||
body_md="\n".join(lines),
|
||||
body_json=latest.metrics,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["MirageBenchmark", "MirageQuestion"]
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
"""Multimodal long-document benchmarks (PDFs with embedded images/charts/tables).
|
||||
|
||||
Distinct from the medical suite because these documents are domain-mixed
|
||||
(research reports, financials, manuals, government, brochures, papers).
|
||||
The hypothesis being tested here is *general*: does SurfSense's
|
||||
chunking-based vision RAG preserve information that lives in pixels —
|
||||
across long PDFs, across pages — versus feeding the same PDF directly
|
||||
to a vision-capable model?
|
||||
|
||||
Subpackages register themselves with ``core.registry`` on import. The
|
||||
``suites/__init__.py`` discovery walker imports them automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""MMLongBench-Doc — head-to-head Native PDF (vision) vs SurfSense (vision RAG).
|
||||
|
||||
Source: https://huggingface.co/datasets/yubo2333/MMLongBench-Doc
|
||||
Paper: https://arxiv.org/abs/2407.01523 (NeurIPS 2024 D&B Track)
|
||||
|
||||
* 135 long PDFs (avg 47 pages, multi-modal: text, images, charts, tables)
|
||||
* 1,091 expert-annotated questions
|
||||
* 33% require evidence from multiple pages
|
||||
* ~22% intentionally unanswerable (tests hallucination resistance)
|
||||
* 7 document types: research report, tutorial/workshop, academic paper,
|
||||
financial report, brochure, government, manuals
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import registry as _registry
|
||||
from .runner import MMLongBenchDocBenchmark
|
||||
|
||||
_registry.register(MMLongBenchDocBenchmark())
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
"""Format-aware grader for MMLongBench-Doc answers.
|
||||
|
||||
The dataset ships with five ``answer_format`` values per question:
|
||||
|
||||
* ``Str`` — short factoid string
|
||||
* ``Int`` — integer count / year
|
||||
* ``Float`` — decimal number (often with units stripped)
|
||||
* ``List`` — comma- or semicolon-separated bag of items
|
||||
* ``None`` — gold answer is literally "Not answerable" (hallucination probe)
|
||||
|
||||
The official MMLongBench-Doc paper grades with GPT-4 as judge. We
|
||||
implement a *deterministic* rule-based grader as the default (so two
|
||||
researchers running the same harness get the same number); an
|
||||
LLM-judge mode is exposed via ``--judge gpt5`` and routed through the
|
||||
same OpenRouter key the arms use, but is opt-in to keep cost down.
|
||||
|
||||
Returned by every grading call:
|
||||
|
||||
* ``correct: bool`` — final pass/fail used for accuracy + McNemar
|
||||
* ``f1: float`` — token-level F1 (continuous credit, useful when
|
||||
comparing arms that get *most* of a list right)
|
||||
* ``method: str`` — which path graded the row (one of
|
||||
``str_norm`` / ``int_eq`` / ``float_tol`` / ``list_set`` /
|
||||
``none_match`` / ``llm_judge``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import string
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradeResult:
|
||||
correct: bool
|
||||
f1: float
|
||||
method: str
|
||||
normalised_pred: str = ""
|
||||
normalised_gold: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Normalisation helpers (shared)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
|
||||
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.IGNORECASE)
|
||||
_WS = re.compile(r"\s+")
|
||||
_NOT_ANSWERABLE_TOKENS = {
|
||||
"not answerable",
|
||||
"cannot be answered",
|
||||
"cannot answer",
|
||||
"no answer",
|
||||
"unknown",
|
||||
"none",
|
||||
"not specified",
|
||||
"not mentioned",
|
||||
"not provided",
|
||||
"the answer is not in the document",
|
||||
}
|
||||
|
||||
# Abbreviations that should be matched literally on the lowercased
|
||||
# prediction (because normalisation strips their punctuation and
|
||||
# leaves them too short to be safe as substring tokens).
|
||||
_NOT_ANSWERABLE_LITERAL = {"n/a", "na/", "n.a.", "n a"}
|
||||
|
||||
|
||||
def _normalise_text(s: str) -> str:
|
||||
"""SQuAD-style normalisation: lowercase, drop punctuation/articles, squash whitespace."""
|
||||
|
||||
s = s.lower()
|
||||
s = s.translate(_PUNCT_TABLE)
|
||||
s = _ARTICLES.sub(" ", s)
|
||||
s = _WS.sub(" ", s).strip()
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-format graders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _grade_str(pred: str, gold: str) -> GradeResult:
|
||||
p = _normalise_text(pred)
|
||||
g = _normalise_text(gold)
|
||||
if not p:
|
||||
return GradeResult(False, 0.0, "str_norm", p, g)
|
||||
if p == g:
|
||||
return GradeResult(True, 1.0, "str_norm", p, g)
|
||||
# Substring match in either direction = correct (handles the common
|
||||
# "model emits a fuller sentence containing the gold" case).
|
||||
if g and (g in p or p in g):
|
||||
return GradeResult(True, _f1_tokens(p, g), "str_norm", p, g)
|
||||
return GradeResult(False, _f1_tokens(p, g), "str_norm", p, g)
|
||||
|
||||
|
||||
_INT_RE = re.compile(r"-?\d[\d,]*")
|
||||
|
||||
|
||||
def _grade_int(pred: str, gold: str) -> GradeResult:
|
||||
g_match = _INT_RE.search(gold)
|
||||
if g_match is None:
|
||||
return _grade_str(pred, gold)
|
||||
g_val = int(g_match.group(0).replace(",", ""))
|
||||
p_match = _INT_RE.search(pred)
|
||||
if p_match is None:
|
||||
return GradeResult(False, 0.0, "int_eq", str(p_match), str(g_val))
|
||||
p_val = int(p_match.group(0).replace(",", ""))
|
||||
return GradeResult(p_val == g_val, 1.0 if p_val == g_val else 0.0,
|
||||
"int_eq", str(p_val), str(g_val))
|
||||
|
||||
|
||||
_FLOAT_RE = re.compile(r"-?\d+(?:[.,]\d+)?")
|
||||
|
||||
|
||||
def _grade_float(pred: str, gold: str, *, rel_tol: float = 1e-2) -> GradeResult:
|
||||
g_match = _FLOAT_RE.search(gold)
|
||||
if g_match is None:
|
||||
return _grade_str(pred, gold)
|
||||
g_val = float(g_match.group(0).replace(",", "."))
|
||||
p_match = _FLOAT_RE.search(pred)
|
||||
if p_match is None:
|
||||
return GradeResult(False, 0.0, "float_tol", "", str(g_val))
|
||||
p_val = float(p_match.group(0).replace(",", "."))
|
||||
# Tolerance: 1% relative or 0.01 absolute, whichever is looser.
|
||||
abs_diff = abs(p_val - g_val)
|
||||
tol = max(abs(g_val) * rel_tol, 0.01)
|
||||
ok = abs_diff <= tol
|
||||
return GradeResult(ok, 1.0 if ok else 0.0, "float_tol", str(p_val), str(g_val))
|
||||
|
||||
|
||||
_LIST_SPLIT = re.compile(r"[;,\n]")
|
||||
|
||||
|
||||
def _grade_list(pred: str, gold: str) -> GradeResult:
|
||||
g_items = {_normalise_text(x) for x in _LIST_SPLIT.split(gold) if x.strip()}
|
||||
p_items = {_normalise_text(x) for x in _LIST_SPLIT.split(pred) if x.strip()}
|
||||
if not g_items:
|
||||
return _grade_str(pred, gold)
|
||||
inter = g_items & p_items
|
||||
if not inter:
|
||||
return GradeResult(False, 0.0, "list_set",
|
||||
", ".join(sorted(p_items)),
|
||||
", ".join(sorted(g_items)))
|
||||
precision = len(inter) / len(p_items) if p_items else 0.0
|
||||
recall = len(inter) / len(g_items)
|
||||
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return GradeResult(f1 >= 0.999, f1, "list_set",
|
||||
", ".join(sorted(p_items)),
|
||||
", ".join(sorted(g_items)))
|
||||
|
||||
|
||||
def _grade_none(pred: str, gold: str) -> GradeResult:
|
||||
"""Gold == 'Not answerable'. The arm earns credit if its prediction
|
||||
expresses inability to answer.
|
||||
|
||||
Two passes:
|
||||
|
||||
1. Literal-substring check on the lowercased+stripped pred for
|
||||
ambiguous abbreviations like ``n/a`` (since normalisation
|
||||
strips the punctuation and would over-match).
|
||||
2. Word-boundary substring check on the normalised pred for the
|
||||
multi-word phrases (``cannot answer``, ``not specified`` etc.).
|
||||
"""
|
||||
|
||||
raw_lower = (pred or "").strip().lower()
|
||||
p = _normalise_text(pred)
|
||||
expressed_unknown = False
|
||||
|
||||
# Pass 1: literal abbreviation hits on the raw lowercased text.
|
||||
if any(lit in raw_lower for lit in _NOT_ANSWERABLE_LITERAL):
|
||||
expressed_unknown = True
|
||||
|
||||
# Pass 2: word-boundary check on normalised tokens.
|
||||
if not expressed_unknown:
|
||||
p_padded = f" {p} "
|
||||
for tok_raw in _NOT_ANSWERABLE_TOKENS:
|
||||
tok = _normalise_text(tok_raw)
|
||||
if not tok or len(tok) < 3:
|
||||
continue
|
||||
if f" {tok} " in p_padded:
|
||||
expressed_unknown = True
|
||||
break
|
||||
return GradeResult(
|
||||
expressed_unknown, 1.0 if expressed_unknown else 0.0,
|
||||
"none_match", p, _normalise_text(gold),
|
||||
)
|
||||
|
||||
|
||||
def _f1_tokens(pred: str, gold: str) -> float:
|
||||
p_tok = pred.split()
|
||||
g_tok = gold.split()
|
||||
if not p_tok or not g_tok:
|
||||
return 0.0
|
||||
common = Counter(p_tok) & Counter(g_tok)
|
||||
overlap = sum(common.values())
|
||||
if overlap == 0:
|
||||
return 0.0
|
||||
precision = overlap / len(p_tok)
|
||||
recall = overlap / len(g_tok)
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_FORMAT_DISPATCH = {
|
||||
"str": _grade_str,
|
||||
"int": _grade_int,
|
||||
"float": _grade_float,
|
||||
"list": _grade_list,
|
||||
"none": _grade_none,
|
||||
}
|
||||
|
||||
|
||||
def grade(*, pred: str, gold: str, answer_format: str) -> GradeResult:
|
||||
"""Grade a single (prediction, gold) pair.
|
||||
|
||||
``answer_format`` is the dataset's ``answer_format`` column value.
|
||||
Unknown / blank values fall through to string grading.
|
||||
"""
|
||||
|
||||
fmt = (answer_format or "").strip().lower()
|
||||
fn = _FORMAT_DISPATCH.get(fmt, _grade_str)
|
||||
return fn(pred or "", gold or "")
|
||||
|
||||
|
||||
__all__ = ["GradeResult", "grade"]
|
||||
|
|
@ -0,0 +1,365 @@
|
|||
"""MMLongBench-Doc ingestion.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Pull the questions parquet from
|
||||
``hf://datasets/yubo2333/MMLongBench-Doc/data/`` and cache locally.
|
||||
2. Resolve the unique set of ``doc_id`` referenced by questions, and
|
||||
download each PDF from
|
||||
``hf://datasets/yubo2333/MMLongBench-Doc/documents/<doc_id>``.
|
||||
``huggingface_hub.hf_hub_download`` is resumable + content-hash
|
||||
verifying; we cache PDFs under ``<data_dir>/multimodal_doc/mmlongbench/pdfs/``.
|
||||
3. Upload every PDF to SurfSense via ``DocumentsClient.upload`` with
|
||||
``use_vision_llm=True`` so SurfSense's Pillow + LiteLLM vision
|
||||
pipeline extracts captions / OCR for embedded images, charts, and
|
||||
tables.
|
||||
4. Wait for ``processed`` status and persist
|
||||
``doc_id -> document_id`` in
|
||||
``<data_dir>/multimodal_doc/maps/mmlongbench_doc_map.jsonl``.
|
||||
|
||||
By default we ingest **all** 135 PDFs (~660 MB, totally manageable).
|
||||
Operators can scope to a subset with ``--max-docs N`` if iterating on
|
||||
a slow vision pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HF_REPO_ID = "yubo2333/MMLongBench-Doc"
|
||||
HF_REPO_TYPE = "dataset"
|
||||
|
||||
# Lazy import: huggingface_hub + pyarrow are heavyweight; keep the
|
||||
# benchmark module importable on machines that have only the core
|
||||
# install (e.g. CI lint jobs).
|
||||
def _hf_hub_download(*args, **kwargs):
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
return hf_hub_download(*args, **kwargs)
|
||||
|
||||
|
||||
def _list_repo_files() -> list[str]:
|
||||
from huggingface_hub import list_repo_files
|
||||
|
||||
return list_repo_files(repo_id=HF_REPO_ID, repo_type=HF_REPO_TYPE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question parquet -> Python rows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMLongBenchQuestion:
|
||||
doc_id: str # filename inside the documents/ folder
|
||||
doc_type: str
|
||||
question: str
|
||||
answer: str
|
||||
answer_format: str # Str / Int / Float / List / None
|
||||
evidence_pages: list[int]
|
||||
evidence_sources: list[str]
|
||||
|
||||
|
||||
def _load_questions_from_parquet(parquet_path: Path) -> list[MMLongBenchQuestion]:
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
table = pq.read_table(parquet_path)
|
||||
rows = table.to_pylist()
|
||||
out: list[MMLongBenchQuestion] = []
|
||||
for row in rows:
|
||||
doc_id = str(row.get("doc_id") or "").strip()
|
||||
if not doc_id:
|
||||
continue
|
||||
question = str(row.get("question") or "").strip()
|
||||
if not question:
|
||||
continue
|
||||
out.append(
|
||||
MMLongBenchQuestion(
|
||||
doc_id=doc_id,
|
||||
doc_type=str(row.get("doc_type") or "").strip(),
|
||||
question=question,
|
||||
answer=str(row.get("answer") or "").strip(),
|
||||
answer_format=str(row.get("answer_format") or "").strip(),
|
||||
evidence_pages=_parse_int_list(row.get("evidence_pages")),
|
||||
evidence_sources=_parse_str_list(row.get("evidence_sources")),
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _parse_int_list(raw) -> list[int]:
|
||||
if raw is None:
|
||||
return []
|
||||
if isinstance(raw, list):
|
||||
out = []
|
||||
for x in raw:
|
||||
try:
|
||||
out.append(int(x))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return out
|
||||
text = str(raw).strip().strip("[]")
|
||||
if not text:
|
||||
return []
|
||||
out: list[int] = []
|
||||
for tok in text.split(","):
|
||||
tok = tok.strip().strip("'\"")
|
||||
if tok.isdigit():
|
||||
out.append(int(tok))
|
||||
return out
|
||||
|
||||
|
||||
def _parse_str_list(raw) -> list[str]:
|
||||
if raw is None:
|
||||
return []
|
||||
if isinstance(raw, list):
|
||||
return [str(x).strip().strip("'\"") for x in raw if str(x).strip()]
|
||||
text = str(raw).strip().strip("[]")
|
||||
if not text:
|
||||
return []
|
||||
return [tok.strip().strip("'\"") for tok in text.split(",") if tok.strip()]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _download_questions_parquet(cache_dir: Path) -> Path:
|
||||
"""Download every parquet under ``data/`` and concatenate.
|
||||
|
||||
The HF dataset usually publishes a single ``train`` split, but we
|
||||
enumerate to be robust to repo restructuring.
|
||||
"""
|
||||
|
||||
parquet_paths: list[Path] = []
|
||||
files = _list_repo_files()
|
||||
data_files = [f for f in files if f.startswith("data/") and f.endswith(".parquet")]
|
||||
if not data_files:
|
||||
raise RuntimeError(
|
||||
f"No parquet files found under data/ in {HF_REPO_ID}; "
|
||||
f"upstream repo may have been restructured."
|
||||
)
|
||||
for rel in sorted(data_files):
|
||||
local = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=rel,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(cache_dir),
|
||||
)
|
||||
parquet_paths.append(Path(local))
|
||||
logger.info("Cached MMLongBench parquet shard %s -> %s", rel, local)
|
||||
return parquet_paths[0] if len(parquet_paths) == 1 else _merge_parquets(parquet_paths, cache_dir)
|
||||
|
||||
|
||||
def _merge_parquets(paths: list[Path], cache_dir: Path) -> Path:
|
||||
"""Combine multiple parquet shards into one (rare branch, but correct)."""
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
tables = [pq.read_table(p) for p in paths]
|
||||
merged = pa.concat_tables(tables, promote_options="default")
|
||||
out = cache_dir / "merged_questions.parquet"
|
||||
pq.write_table(merged, out)
|
||||
return out
|
||||
|
||||
|
||||
def _download_pdf(doc_id: str, cache_dir: Path, pdfs_dir: Path) -> Path:
|
||||
"""Download a single PDF (resumable via huggingface_hub cache)."""
|
||||
|
||||
rel = f"documents/{doc_id}"
|
||||
local = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=rel,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(cache_dir),
|
||||
)
|
||||
# Materialise to a stable path inside our data/ tree so the runner
|
||||
# has a deterministic location regardless of HF cache internals.
|
||||
dest = pdfs_dir / doc_id
|
||||
if not dest.exists() or dest.stat().st_size != Path(local).stat().st_size:
|
||||
# Use a hardlink when possible (cheap), fall back to copy.
|
||||
try:
|
||||
if dest.exists():
|
||||
dest.unlink()
|
||||
os.link(local, dest)
|
||||
except OSError:
|
||||
from shutil import copy2
|
||||
|
||||
copy2(local, dest)
|
||||
return dest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _upload_pdfs(
|
||||
ctx: RunContext,
|
||||
pdf_paths: Iterable[Path],
|
||||
*,
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
) -> dict[str, int]:
|
||||
"""Upload PDFs in batches, return ``filename -> document_id`` map."""
|
||||
|
||||
docs_client = ctx.documents_client()
|
||||
name_to_id: dict[str, int] = {}
|
||||
pdf_list = list(pdf_paths)
|
||||
for batch_start in range(0, len(pdf_list), batch_size):
|
||||
batch = pdf_list[batch_start:batch_start + batch_size]
|
||||
result = await docs_client.upload(
|
||||
files=batch,
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
|
||||
if all_ids:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=result.document_ids, # only newly added need polling
|
||||
timeout_s=1800.0, # vision pipeline is slow on long PDFs
|
||||
)
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=all_ids,
|
||||
)
|
||||
for s in statuses:
|
||||
name_to_id[s.title] = s.document_id
|
||||
logger.info(
|
||||
"Uploaded MMLongBench batch %d-%d: %d new, %d duplicate",
|
||||
batch_start, batch_start + len(batch),
|
||||
len(result.document_ids), len(result.duplicate_document_ids),
|
||||
)
|
||||
return name_to_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
max_docs: int | None = None,
|
||||
upload_batch_size: int = 8,
|
||||
skip_upload: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
"""Ingest MMLongBench-Doc into the multimodal_doc suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_docs : int | None
|
||||
Cap the number of PDFs to download + upload. ``None`` = all 135.
|
||||
Useful when iterating on the runner without paying for the full
|
||||
vision pipeline pass each time.
|
||||
upload_batch_size : int
|
||||
How many PDFs to send per ``fileupload`` call. Smaller batches
|
||||
recover faster from individual failures; larger batches reduce
|
||||
round-trip overhead.
|
||||
skip_upload : bool
|
||||
Download + cache PDFs locally but skip SurfSense ingestion.
|
||||
Useful for testing the native arm in isolation.
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(use_vision_llm=True, processing_mode="basic")
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
pdfs_dir = bench_dir / "pdfs"
|
||||
pdfs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_cache = bench_dir / ".hf_cache"
|
||||
hf_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Step 1: questions
|
||||
parquet_path = _download_questions_parquet(hf_cache)
|
||||
questions = _load_questions_from_parquet(parquet_path)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"MMLongBench-Doc parquet contains no parseable questions. "
|
||||
"Upstream may have changed schema."
|
||||
)
|
||||
|
||||
# Persist a copy alongside the PDFs so the runner has one place to read.
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
with questions_jsonl.open("w", encoding="utf-8") as fh:
|
||||
for q in questions:
|
||||
fh.write(json.dumps({
|
||||
"doc_id": q.doc_id,
|
||||
"doc_type": q.doc_type,
|
||||
"question": q.question,
|
||||
"answer": q.answer,
|
||||
"answer_format": q.answer_format,
|
||||
"evidence_pages": q.evidence_pages,
|
||||
"evidence_sources": q.evidence_sources,
|
||||
}) + "\n")
|
||||
logger.info("Wrote %d MMLongBench questions to %s", len(questions), questions_jsonl)
|
||||
|
||||
# Step 2: download unique PDFs
|
||||
unique_doc_ids = sorted({q.doc_id for q in questions})
|
||||
if max_docs is not None and max_docs > 0:
|
||||
unique_doc_ids = unique_doc_ids[:max_docs]
|
||||
logger.info("MMLongBench: downloading %d unique PDFs", len(unique_doc_ids))
|
||||
|
||||
pdf_paths: dict[str, Path] = {}
|
||||
for i, doc_id in enumerate(unique_doc_ids, start=1):
|
||||
try:
|
||||
pdf_paths[doc_id] = _download_pdf(doc_id, hf_cache, pdfs_dir)
|
||||
if i % 10 == 0:
|
||||
logger.info(" ... %d / %d PDFs cached", i, len(unique_doc_ids))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to download MMLongBench PDF %s: %s", doc_id, exc)
|
||||
|
||||
# Step 3: upload to SurfSense
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("MMLongBench: --skip-upload set; skipping SurfSense ingestion")
|
||||
else:
|
||||
logger.info("MMLongBench upload settings: %s", settings.render_label())
|
||||
name_to_id = await _upload_pdfs(
|
||||
ctx,
|
||||
pdf_paths.values(),
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Step 4: persist doc_id -> document_id manifest
|
||||
map_path = ctx.maps_dir() / "mmlongbench_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
# Header line records the resolved ingest settings
|
||||
# (see core/ingest_settings.py).
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for doc_id in unique_doc_ids:
|
||||
local = pdf_paths.get(doc_id)
|
||||
if local is None:
|
||||
continue
|
||||
fh.write(json.dumps({
|
||||
"doc_id": doc_id,
|
||||
"document_id": name_to_id.get(local.name),
|
||||
"pdf_path": str(local),
|
||||
"n_questions": sum(1 for q in questions if q.doc_id == doc_id),
|
||||
}) + "\n")
|
||||
logger.info("Wrote MMLongBench doc map to %s", map_path)
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["mmlongbench"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
|
||||
__all__ = ["MMLongBenchQuestion", "run_ingest"]
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
"""MMLongBench-Doc prompt template.
|
||||
|
||||
Both arms get the same prompt — only the document delivery channel
|
||||
differs (native PDF embedded in the OpenRouter request vs SurfSense
|
||||
RAG retrieval). The format hint in the prompt mirrors what the
|
||||
upstream paper uses so the grader's regex can reliably extract the
|
||||
answer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-format hint blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FORMAT_HINTS: dict[str, str] = {
|
||||
"str": (
|
||||
"Respond with the answer as a short phrase, no full sentence. "
|
||||
"Format your final line as `Answer: <text>`."
|
||||
),
|
||||
"int": (
|
||||
"Respond with a single integer only. "
|
||||
"Format your final line as `Answer: <integer>`."
|
||||
),
|
||||
"float": (
|
||||
"Respond with a single decimal number only (no units). "
|
||||
"Format your final line as `Answer: <number>`."
|
||||
),
|
||||
"list": (
|
||||
"Respond with a comma-separated list of items, no extra text. "
|
||||
"Format your final line as `Answer: item1, item2, item3`."
|
||||
),
|
||||
"none": (
|
||||
"If the answer cannot be determined from the document, say so explicitly. "
|
||||
"Format your final line as `Answer: Not answerable`."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
_PROMPT = """\
|
||||
You are a document-understanding assistant. Use ONLY the provided
|
||||
document to answer the question. The document may contain text,
|
||||
tables, charts, figures, and images. If the answer is in a chart or
|
||||
image, read it carefully. Do not use external knowledge.
|
||||
|
||||
Question: {question}
|
||||
|
||||
{format_hint}
|
||||
"""
|
||||
|
||||
|
||||
def build_prompt(question: str, *, answer_format: str) -> str:
|
||||
"""Assemble the full prompt for one MMLongBench question."""
|
||||
|
||||
fmt = (answer_format or "str").strip().lower()
|
||||
hint = _FORMAT_HINTS.get(fmt, _FORMAT_HINTS["str"])
|
||||
return _PROMPT.format(question=question.strip(), format_hint=hint)
|
||||
|
||||
|
||||
__all__ = ["build_prompt"]
|
||||
|
|
@ -0,0 +1,704 @@
|
|||
"""MMLongBench-Doc runner — head-to-head Native PDF (vision) vs SurfSense (vision RAG).
|
||||
|
||||
Differences from a typical MCQ head-to-head:
|
||||
|
||||
* Open-ended answers (Str / Int / Float / List / Not-answerable) — uses
|
||||
``extract_freeform_answer`` instead of ``extract_answer_letter``.
|
||||
* Format-aware grader (see ``.grader``) returns both binary correctness
|
||||
(for accuracy / McNemar) and continuous F1 (for nuanced reporting).
|
||||
* Native arm requires a vision-capable model — we don't enforce this
|
||||
in code (operator's choice via ``setup --provider-model``) but we
|
||||
emit a warning if the pinned slug looks text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, NativePdfArm, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
is_settings_header,
|
||||
)
|
||||
from ....core.metrics.comparison import (
|
||||
bootstrap_delta_ci,
|
||||
mcnemar_test,
|
||||
paired_aggregate,
|
||||
)
|
||||
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci
|
||||
from ....core.parse.freeform_answer import extract_freeform_answer
|
||||
from ....core.providers.openrouter_pdf import OpenRouterPdfProvider, PdfEngine
|
||||
from ....core.registry import (
|
||||
ReportSection,
|
||||
RunArtifact,
|
||||
RunContext,
|
||||
)
|
||||
from ....core.scenarios import format_scenario_md
|
||||
from .grader import GradeResult, grade
|
||||
from .prompt import build_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question + map row shapes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMLBQuestion:
|
||||
qid: str # synthesised from doc_id + index
|
||||
doc_id: str # filename inside the documents/ folder
|
||||
doc_type: str
|
||||
question: str
|
||||
gold_answer: str
|
||||
answer_format: str
|
||||
evidence_pages: list[int]
|
||||
evidence_sources: list[str]
|
||||
pdf_path: Path
|
||||
document_id: int | None # SurfSense doc id (None if upload skipped)
|
||||
|
||||
|
||||
def _load_doc_map(map_path: Path) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
|
||||
"""Read the doc map JSONL.
|
||||
|
||||
Returns ``(rows, settings)`` where ``settings`` is the
|
||||
``__settings__`` header blob (or ``{}`` for legacy maps).
|
||||
"""
|
||||
|
||||
rows: dict[str, dict[str, Any]] = {}
|
||||
settings: dict[str, Any] = {}
|
||||
with map_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
settings = dict(row["__settings__"])
|
||||
continue
|
||||
rows[str(row["doc_id"])] = row
|
||||
return rows, settings
|
||||
|
||||
|
||||
def _load_questions(
|
||||
questions_jsonl: Path,
|
||||
doc_map: dict[str, dict[str, Any]],
|
||||
*,
|
||||
doc_filter: list[str] | None,
|
||||
format_filter: str | None,
|
||||
sample_n: int | None,
|
||||
skip_unanswerable: bool,
|
||||
) -> list[MMLBQuestion]:
|
||||
out: list[MMLBQuestion] = []
|
||||
per_doc_counter: dict[str, int] = {}
|
||||
with questions_jsonl.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
doc_id = str(row.get("doc_id") or "").strip()
|
||||
if not doc_id:
|
||||
continue
|
||||
if doc_filter and doc_id not in doc_filter:
|
||||
continue
|
||||
map_row = doc_map.get(doc_id)
|
||||
if map_row is None:
|
||||
logger.debug("No doc-map entry for %s; skipping", doc_id)
|
||||
continue
|
||||
answer_format = str(row.get("answer_format") or "").strip().lower()
|
||||
if format_filter and format_filter != "all" and format_filter != answer_format:
|
||||
continue
|
||||
gold = str(row.get("answer") or "").strip()
|
||||
if skip_unanswerable and answer_format == "none":
|
||||
continue
|
||||
idx = per_doc_counter.get(doc_id, 0)
|
||||
per_doc_counter[doc_id] = idx + 1
|
||||
out.append(MMLBQuestion(
|
||||
qid=f"{doc_id}::Q{idx:03d}",
|
||||
doc_id=doc_id,
|
||||
doc_type=str(row.get("doc_type") or "").strip(),
|
||||
question=str(row.get("question") or "").strip(),
|
||||
gold_answer=gold,
|
||||
answer_format=answer_format,
|
||||
evidence_pages=list(row.get("evidence_pages") or []),
|
||||
evidence_sources=list(row.get("evidence_sources") or []),
|
||||
pdf_path=Path(map_row["pdf_path"]),
|
||||
document_id=map_row.get("document_id"),
|
||||
))
|
||||
out.sort(key=lambda q: (q.doc_id, q.qid))
|
||||
if sample_n is not None and sample_n > 0:
|
||||
out = out[:sample_n]
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bounded concurrency helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _gather_with_limit(coros: Iterable, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(coro):
|
||||
async with sem:
|
||||
return await coro
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_DESCRIPTION = (
|
||||
"MMLongBench-Doc (135 long PDFs, 1,091 multimodal questions) — "
|
||||
"Native PDF (vision) vs SurfSense (vision RAG) head-to-head."
|
||||
)
|
||||
|
||||
|
||||
_TEXT_ONLY_HINTS = ("gpt-5.4-mini", "gpt-3.5", "text-only", "instruct-")
|
||||
|
||||
# MMLongBench-Doc PDFs are long documents with figures, charts, and
|
||||
# tables. Vision LLM at ingest is the whole point; flip --no-vision-llm
|
||||
# to measure how much SurfSense degrades on real document images.
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=True,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
class MMLongBenchDocBenchmark:
|
||||
"""Long-document multimodal RAG vs native vision."""
|
||||
|
||||
suite: str = "multimodal_doc"
|
||||
name: str = "mmlongbench"
|
||||
headline: bool = True
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--docs",
|
||||
default=None,
|
||||
help="Comma-separated doc_ids (filenames) to run (default: all).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
default="all",
|
||||
choices=["all", "str", "int", "float", "list", "none"],
|
||||
help="Filter to one answer format. 'none' = unanswerable probes only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", dest="sample_n", type=int, default=None,
|
||||
help="Run only the first N questions after filters apply.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-unanswerable", dest="skip_unanswerable", action="store_true",
|
||||
help="Drop ~22%% unanswerable questions (use to compare against baselines that don't include them).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrency", type=int, default=4,
|
||||
help="Parallel question workers per arm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-mentions", dest="no_mentions", action="store_true",
|
||||
help="SurfSense arm: skip mentioned_document_ids (unscoped retrieval).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdf-engine", default="native",
|
||||
choices=[e.value for e in PdfEngine],
|
||||
help="OpenRouter file-parser engine for the native arm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-output-tokens", type=int, default=512,
|
||||
help="Cap on completion length for both arms.",
|
||||
)
|
||||
# Ingest-only knobs (forwarded by the CLI to ingest.run_ingest).
|
||||
parser.add_argument(
|
||||
"--max-docs", dest="max_docs", type=int, default=None,
|
||||
help="(ingest only) cap on number of unique PDFs to download + upload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-batch-size", dest="upload_batch_size", type=int, default=8,
|
||||
help="(ingest only) PDFs per fileupload call.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-upload", dest="skip_upload", action="store_true",
|
||||
help="(ingest only) cache PDFs locally but don't push to SurfSense.",
|
||||
)
|
||||
# Per-upload knobs forwarded to /documents/fileupload at ingest;
|
||||
# ignored at run-time (runner reads the resolved settings out of
|
||||
# the doc-map manifest header).
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
max_docs=opts.get("max_docs"),
|
||||
upload_batch_size=int(opts.get("upload_batch_size") or 8),
|
||||
skip_upload=bool(opts.get("skip_upload", False)),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
docs_raw: str | None = opts.get("docs")
|
||||
doc_filter = [d.strip() for d in docs_raw.split(",")] if docs_raw else None
|
||||
format_filter = opts.get("format") or "all"
|
||||
sample_n = opts.get("sample_n")
|
||||
skip_unanswerable = bool(opts.get("skip_unanswerable"))
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
no_mentions = bool(opts.get("no_mentions"))
|
||||
pdf_engine_name = opts.get("pdf_engine") or "native"
|
||||
max_output_tokens = int(opts.get("max_output_tokens") or 512)
|
||||
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
map_path = ctx.maps_dir() / "mmlongbench_doc_map.jsonl"
|
||||
if not questions_jsonl.exists() or not map_path.exists():
|
||||
raise RuntimeError(
|
||||
"MMLongBench-Doc not ingested for this suite. Run "
|
||||
"`python -m surfsense_evals ingest multimodal_doc mmlongbench` first."
|
||||
)
|
||||
|
||||
doc_map, ingest_settings = _load_doc_map(map_path)
|
||||
questions = _load_questions(
|
||||
questions_jsonl, doc_map,
|
||||
doc_filter=doc_filter,
|
||||
format_filter=None if format_filter == "all" else format_filter,
|
||||
sample_n=sample_n,
|
||||
skip_unanswerable=skip_unanswerable,
|
||||
)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"No MMLongBench questions matched the filters; broaden --docs/--format/--n."
|
||||
)
|
||||
logger.info("MMLongBench-Doc: scheduled %d questions", len(questions))
|
||||
|
||||
api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError(
|
||||
"OPENROUTER_API_KEY env var is required for the native arm."
|
||||
)
|
||||
|
||||
# Native arm slug differs from SurfSense slug only in cost-arbitrage
|
||||
# scenario; otherwise both arms answer with provider_model.
|
||||
native_arm_model = ctx.native_arm_model
|
||||
if any(hint in native_arm_model.lower() for hint in _TEXT_ONLY_HINTS):
|
||||
if ctx.scenario == "symmetric-cheap":
|
||||
logger.info(
|
||||
"symmetric-cheap: native arm pinned to text-only %r as "
|
||||
"intended; expect it to lose on image-bearing pages "
|
||||
"(SurfSense answers from vision-extracted chunks).",
|
||||
native_arm_model,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Native arm slug %r looks text-only; image content in "
|
||||
"PDFs will be ignored. Re-pin via "
|
||||
"`setup --provider-model anthropic/claude-sonnet-4.5` "
|
||||
"(or pass --native-arm-model and --scenario cost-arbitrage "
|
||||
"to make this asymmetry explicit).",
|
||||
native_arm_model,
|
||||
)
|
||||
|
||||
provider = OpenRouterPdfProvider(
|
||||
api_key=api_key,
|
||||
base_url=ctx.config.openrouter_base_url,
|
||||
model=native_arm_model,
|
||||
engine=PdfEngine(pdf_engine_name),
|
||||
)
|
||||
native_arm = NativePdfArm(provider=provider, max_output_tokens=max_output_tokens)
|
||||
surf_arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
|
||||
async def _native_one(q: MMLBQuestion) -> ArmResult:
|
||||
return await native_arm.answer(_make_native_request(q, max_output_tokens))
|
||||
|
||||
async def _surf_one(q: MMLBQuestion) -> ArmResult:
|
||||
return await surf_arm.answer(_make_surfsense_request(q, no_mentions=no_mentions))
|
||||
|
||||
native_results, surf_results = await asyncio.gather(
|
||||
_gather_with_limit((_native_one(q) for q in questions), concurrency=concurrency),
|
||||
_gather_with_limit((_surf_one(q) for q in questions), concurrency=concurrency),
|
||||
)
|
||||
|
||||
native_grades = [_grade_one(q, r) for q, r in zip(questions, native_results, strict=False)]
|
||||
surf_grades = [_grade_one(q, r) for q, r in zip(questions, surf_results, strict=False)]
|
||||
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, n_res, s_res, n_g, s_g in zip(
|
||||
questions, native_results, surf_results, native_grades, surf_grades, strict=False
|
||||
):
|
||||
meta = {
|
||||
"qid": q.qid,
|
||||
"doc_id": q.doc_id,
|
||||
"doc_type": q.doc_type,
|
||||
"answer_format": q.answer_format,
|
||||
"gold": q.gold_answer,
|
||||
"evidence_pages": q.evidence_pages,
|
||||
"evidence_sources": q.evidence_sources,
|
||||
"document_id": q.document_id,
|
||||
}
|
||||
fh.write(json.dumps({
|
||||
**meta,
|
||||
**n_res.to_jsonl(),
|
||||
"graded": _grade_to_jsonl(n_g),
|
||||
}) + "\n")
|
||||
fh.write(json.dumps({
|
||||
**meta,
|
||||
**s_res.to_jsonl(),
|
||||
"graded": _grade_to_jsonl(s_g),
|
||||
}) + "\n")
|
||||
|
||||
metrics = _compute_metrics(questions, native_results, surf_results, native_grades, surf_grades)
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_questions": len(questions),
|
||||
"concurrency": concurrency,
|
||||
"format_filter": format_filter,
|
||||
"skip_unanswerable": skip_unanswerable,
|
||||
"no_mentions": no_mentions,
|
||||
"pdf_engine": pdf_engine_name,
|
||||
"scenario": ctx.scenario,
|
||||
"provider_model": ctx.provider_model,
|
||||
"native_arm_model": native_arm_model,
|
||||
"vision_provider_model": ctx.vision_provider_model,
|
||||
"agent_llm_id": ctx.agent_llm_id,
|
||||
"ingest_settings": ingest_settings,
|
||||
},
|
||||
)
|
||||
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps({
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
}, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="MMLongBench-Doc — Native PDF (vision) vs SurfSense (vision RAG)",
|
||||
headline=True,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
m = latest.metrics
|
||||
native = m.get("native", {})
|
||||
surf = m.get("surfsense", {})
|
||||
delta = m.get("delta", {})
|
||||
per_format = m.get("per_format", {})
|
||||
extra = latest.extra
|
||||
|
||||
body_lines: list[str] = []
|
||||
body_lines.append(
|
||||
f"- Sample size: {extra.get('n_questions', '?')} questions "
|
||||
f"(format filter: `{extra.get('format_filter', 'all')}`, "
|
||||
f"skip-unanswerable: `{extra.get('skip_unanswerable', False)}`, "
|
||||
f"engine: `{extra.get('pdf_engine', 'native')}`)."
|
||||
)
|
||||
body_lines.append(format_scenario_md(extra))
|
||||
body_lines.append(format_ingest_settings_md(extra.get("ingest_settings")))
|
||||
body_lines.append(
|
||||
"- Native arm (OpenRouter `chat/completions` + file plugin, "
|
||||
f"`{extra.get('native_arm_model') or extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(native, indent=" "))
|
||||
body_lines.append(
|
||||
"- SurfSense arm (`POST /api/v1/new_chat`, vision RAG over chunks, "
|
||||
f"`{extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(surf, indent=" "))
|
||||
body_lines.append("- Delta (paired):")
|
||||
body_lines.append(
|
||||
f" - Accuracy: SurfSense {_pp(delta.get('accuracy_pp'))} pp "
|
||||
f"(McNemar p={_fmt(delta.get('mcnemar_p_value'), 4)}, "
|
||||
f"method={delta.get('mcnemar_method')})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - F1 (mean): SurfSense {_pp(delta.get('f1_pp'))} pp"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Bootstrap 95% CI on accuracy delta: "
|
||||
f"[{_pp(delta.get('bootstrap_ci_low'))}pp, {_pp(delta.get('bootstrap_ci_high'))}pp]"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Cost / question: native ${_dollars(native.get('cost_micros_mean'))}, "
|
||||
f"surfsense ${_dollars(surf.get('cost_micros_mean'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('cost_micros_pct'))})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Latency p50: native {_ms_to_s(native.get('latency_ms_median'))}, "
|
||||
f"surfsense {_ms_to_s(surf.get('latency_ms_median'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('latency_ms_pct'))})"
|
||||
)
|
||||
if per_format:
|
||||
body_lines.append("- Per-format split (accuracy delta in pp):")
|
||||
for fmt, vals in sorted(per_format.items()):
|
||||
body_lines.append(
|
||||
f" - {fmt}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
|
||||
f"(n={vals.get('n')}, native acc={vals.get('native_accuracy', 0)*100:.1f}%, "
|
||||
f"surf acc={vals.get('surfsense_accuracy', 0)*100:.1f}%)"
|
||||
)
|
||||
|
||||
return ReportSection(
|
||||
title="MMLongBench-Doc — Native PDF (vision) vs SurfSense (vision RAG)",
|
||||
headline=True,
|
||||
body_md="\n".join(body_lines),
|
||||
body_json=m,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-question helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_native_request(q: MMLBQuestion, max_tokens: int) -> ArmRequest:
|
||||
prompt = build_prompt(q.question, answer_format=q.answer_format)
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=prompt,
|
||||
pdf_paths=[q.pdf_path],
|
||||
options={"max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
|
||||
def _make_surfsense_request(q: MMLBQuestion, *, no_mentions: bool) -> ArmRequest:
|
||||
prompt = build_prompt(q.question, answer_format=q.answer_format)
|
||||
mentions: list[int] | None = None
|
||||
if not no_mentions and q.document_id is not None:
|
||||
mentions = [int(q.document_id)]
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=prompt,
|
||||
mentioned_document_ids=mentions,
|
||||
)
|
||||
|
||||
|
||||
def _grade_one(q: MMLBQuestion, result: ArmResult) -> GradeResult:
|
||||
pred_text = extract_freeform_answer(result.raw_text or "")
|
||||
return grade(pred=pred_text, gold=q.gold_answer, answer_format=q.answer_format)
|
||||
|
||||
|
||||
def _grade_to_jsonl(g: GradeResult) -> dict[str, Any]:
|
||||
return {
|
||||
"correct": g.correct,
|
||||
"f1": g.f1,
|
||||
"method": g.method,
|
||||
"normalised_pred": g.normalised_pred,
|
||||
"normalised_gold": g.normalised_gold,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics aggregation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_metrics(
|
||||
questions: list[MMLBQuestion],
|
||||
native_results: list[ArmResult],
|
||||
surf_results: list[ArmResult],
|
||||
native_grades: list[GradeResult],
|
||||
surf_grades: list[GradeResult],
|
||||
) -> dict[str, Any]:
|
||||
native_correct = [g.correct for g in native_grades]
|
||||
surf_correct = [g.correct for g in surf_grades]
|
||||
native_f1 = [g.f1 for g in native_grades]
|
||||
surf_f1 = [g.f1 for g in surf_grades]
|
||||
|
||||
native_costs = [float(r.cost_micros) for r in native_results]
|
||||
surf_costs = [float(r.cost_micros) for r in surf_results]
|
||||
native_latencies = [float(r.latency_ms) for r in native_results]
|
||||
surf_latencies = [float(r.latency_ms) for r in surf_results]
|
||||
native_in_tokens = [float(r.input_tokens) for r in native_results]
|
||||
native_out_tokens = [float(r.output_tokens) for r in native_results]
|
||||
|
||||
native_acc = accuracy_with_wilson_ci(sum(native_correct), len(native_correct))
|
||||
surf_acc = accuracy_with_wilson_ci(sum(surf_correct), len(surf_correct))
|
||||
mc = mcnemar_test(native_correct, surf_correct)
|
||||
boot = bootstrap_delta_ci(native_correct, surf_correct, n_resamples=2000)
|
||||
|
||||
native_cost_agg = paired_aggregate(native_costs)
|
||||
surf_cost_agg = paired_aggregate(surf_costs)
|
||||
native_latency_agg = paired_aggregate(native_latencies)
|
||||
surf_latency_agg = paired_aggregate(surf_latencies)
|
||||
|
||||
cost_pct = _safe_pct(surf_cost_agg.mean, native_cost_agg.mean)
|
||||
latency_pct = _safe_pct(surf_latency_agg.median, native_latency_agg.median)
|
||||
|
||||
per_format_pairs: dict[str, list[tuple[bool, bool]]] = {}
|
||||
for q, n_ok, s_ok in zip(questions, native_correct, surf_correct, strict=False):
|
||||
per_format_pairs.setdefault(q.answer_format or "unknown", []).append((n_ok, s_ok))
|
||||
|
||||
per_format: dict[str, dict[str, Any]] = {}
|
||||
for fmt, pairs in per_format_pairs.items():
|
||||
n_correct = [a for a, _ in pairs]
|
||||
s_correct = [b for _, b in pairs]
|
||||
per_format[fmt] = {
|
||||
"n": len(pairs),
|
||||
"native_accuracy": (sum(n_correct) / len(pairs)) if pairs else 0.0,
|
||||
"surfsense_accuracy": (sum(s_correct) / len(pairs)) if pairs else 0.0,
|
||||
"delta_accuracy_pp": (
|
||||
100.0 * (sum(s_correct) - sum(n_correct)) / len(pairs)
|
||||
if pairs else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
native_f1_mean = sum(native_f1) / len(native_f1) if native_f1 else 0.0
|
||||
surf_f1_mean = sum(surf_f1) / len(surf_f1) if surf_f1 else 0.0
|
||||
|
||||
return {
|
||||
"native": {
|
||||
**native_acc.to_dict(),
|
||||
"f1_mean": native_f1_mean,
|
||||
"cost_micros_mean": native_cost_agg.mean,
|
||||
"cost_micros_median": native_cost_agg.median,
|
||||
"latency_ms_mean": native_latency_agg.mean,
|
||||
"latency_ms_median": native_latency_agg.median,
|
||||
"latency_ms_p95": native_latency_agg.p95,
|
||||
"input_tokens_mean": (sum(native_in_tokens) / len(native_in_tokens)) if native_in_tokens else 0.0,
|
||||
"output_tokens_mean": (sum(native_out_tokens) / len(native_out_tokens)) if native_out_tokens else 0.0,
|
||||
},
|
||||
"surfsense": {
|
||||
**surf_acc.to_dict(),
|
||||
"f1_mean": surf_f1_mean,
|
||||
"cost_micros_mean": surf_cost_agg.mean,
|
||||
"cost_micros_median": surf_cost_agg.median,
|
||||
"latency_ms_mean": surf_latency_agg.mean,
|
||||
"latency_ms_median": surf_latency_agg.median,
|
||||
"latency_ms_p95": surf_latency_agg.p95,
|
||||
},
|
||||
"delta": {
|
||||
"accuracy_pp": 100.0 * (surf_acc.accuracy - native_acc.accuracy),
|
||||
"f1_pp": 100.0 * (surf_f1_mean - native_f1_mean),
|
||||
"mcnemar_p_value": mc.p_value,
|
||||
"mcnemar_method": mc.method,
|
||||
"mcnemar_b_native_only": mc.b,
|
||||
"mcnemar_c_surfsense_only": mc.c,
|
||||
"bootstrap_ci_low": 100.0 * boot.ci_low,
|
||||
"bootstrap_ci_high": 100.0 * boot.ci_high,
|
||||
"cost_micros_pct": cost_pct,
|
||||
"latency_ms_pct": latency_pct,
|
||||
},
|
||||
"per_format": per_format,
|
||||
}
|
||||
|
||||
|
||||
def _safe_pct(numerator: float, denominator: float) -> float | None:
|
||||
if denominator == 0:
|
||||
return None
|
||||
return 100.0 * (numerator - denominator) / denominator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tiny formatting helpers used by report_section
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _arm_summary_lines(d: dict[str, Any], *, indent: str) -> str:
|
||||
if not d:
|
||||
return f"{indent}(no data)"
|
||||
acc = d.get("accuracy", 0.0)
|
||||
low = d.get("ci_low", 0.0)
|
||||
high = d.get("ci_high", 0.0)
|
||||
f1 = d.get("f1_mean", 0.0)
|
||||
lines = [
|
||||
f"{indent}- Accuracy: {acc * 100:.1f}% (Wilson 95% CI: {low * 100:.1f}% – {high * 100:.1f}%)",
|
||||
f"{indent}- F1 (token-level mean): {f1 * 100:.1f}%",
|
||||
f"{indent}- Cost / question: ${_dollars(d.get('cost_micros_mean'))} (mean), "
|
||||
f"${_dollars(d.get('cost_micros_median'))} (median)",
|
||||
f"{indent}- Latency: p50 {_ms_to_s(d.get('latency_ms_median'))}, "
|
||||
f"p95 {_ms_to_s(d.get('latency_ms_p95'))}",
|
||||
]
|
||||
if "input_tokens_mean" in d:
|
||||
lines.append(
|
||||
f"{indent}- Mean tokens / question: in {d.get('input_tokens_mean', 0):.0f}, "
|
||||
f"out {d.get('output_tokens_mean', 0):.0f}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _dollars(micros: Any) -> str:
|
||||
if micros is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{(float(micros) / 1_000_000):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _ms_to_s(ms: Any) -> str:
|
||||
if ms is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(ms) / 1000:.1f}s"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pp(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.1f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pct_change(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.0f}%"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _fmt(value: Any, ndigits: int) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):.{ndigits}f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
__all__ = ["MMLBQuestion", "MMLongBenchDocBenchmark"]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""Research / multi-document RAG benchmarks.
|
||||
|
||||
Distinct from ``multimodal_doc`` (PDF-bound) and ``medical`` (one
|
||||
question = one source PDF). Benchmarks here put *retrieval and
|
||||
reasoning across many documents* in the critical path — the regime
|
||||
where SurfSense's chunk-level RAG should shine versus "pour the
|
||||
entire document into the LLM" or "ask the LLM cold".
|
||||
|
||||
* ``frames`` (google/frames-benchmark) — 824 multi-hop Wikipedia
|
||||
questions; tests bare-LLM vs SurfSense over a shared ~330-doc
|
||||
corpus.
|
||||
* ``crag`` (facebookresearch/CRAG, KDD Cup 2024) — 2,706 web QA
|
||||
pairs with 5 pre-retrieved HTML pages each; tests bare-LLM vs
|
||||
long-context-stuffed LLM vs SurfSense over the question's 5
|
||||
scoped pages — the closest comparison to a competing RAG product.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
"""CRAG — Comprehensive RAG Benchmark (Yang et al., Meta, KDD Cup 2024).
|
||||
|
||||
Source: https://github.com/facebookresearch/CRAG (Tasks 1, 2, and 3)
|
||||
Paper: https://arxiv.org/abs/2406.04744
|
||||
|
||||
This package registers two siblings:
|
||||
|
||||
* ``crag`` — Tasks 1 & 2: 5 candidate pages per question.
|
||||
* ``crag_t3`` — Task 3: 50 candidate pages per question. The
|
||||
long-context arm is capped to the top-5 (the realistic "naive
|
||||
RAG = pick top-K results" baseline); SurfSense retrieves over
|
||||
all 50, where its rerank becomes the entire contribution.
|
||||
|
||||
Both share the grader, prompt, runner, and report code; only the
|
||||
ingest path differs (single bz2 vs 4-part tar.bz2 streamed).
|
||||
|
||||
CRAG ships ~2,706 factual QA pairs, each paired with **5 full HTML
|
||||
pages** retrieved as the top-5 of a real web search at ``query_time``
|
||||
(50 in Task 3).
|
||||
The benchmark spans 5 domains (finance, music, movie, sports, open)
|
||||
and 8 question types (simple, comparison, aggregation, set, multi-hop,
|
||||
post-processing, false_premise, simple_w_condition) — heads/torsos/
|
||||
tails of entity popularity — and an explicit static→real-time
|
||||
freshness axis.
|
||||
|
||||
Why CRAG demonstrates SurfSense more clearly than FRAMES
|
||||
--------------------------------------------------------
|
||||
FRAMES tested SurfSense vs. *no retrieval at all* — a fair "naive
|
||||
prompting" baseline (the published 40.8% number) but not a competing
|
||||
RAG product. CRAG enables a three-way comparison:
|
||||
|
||||
* ``bare_llm`` — chat completion with the question only. CRAG
|
||||
paper: ≤34% accuracy ("LLM cold").
|
||||
* ``long_context`` — stuff all 5 extracted page texts straight into
|
||||
the prompt (the "naive RAG" / "straightforward RAG" arm in the
|
||||
paper). Published baseline: ~44%.
|
||||
* ``surfsense`` — POST ``/api/v1/new_chat`` with retrieval scoped
|
||||
to the question's 5 ingested pages (``mentioned_document_ids``).
|
||||
|
||||
So the headline becomes "SurfSense vs. context-stuffed long-context
|
||||
LLM, both fed the same 5 pages" — a head-to-head against the simplest
|
||||
realistic RAG strategy, not against an unarmed model.
|
||||
|
||||
Scoring follows the CRAG paper: each prediction is graded as
|
||||
**correct** (+1), **missing/I-don't-know** (0), or **incorrect** (-1),
|
||||
and the headline metric is the *Truthfulness Score*:
|
||||
``(#correct - #incorrect) / total`` — penalising hallucinations
|
||||
relative to refusals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import registry as _registry
|
||||
from .runner import CragBenchmark, CragTask3Benchmark
|
||||
|
||||
_registry.register(CragBenchmark())
|
||||
_registry.register(CragTask3Benchmark())
|
||||
|
|
@ -0,0 +1,335 @@
|
|||
"""CRAG dataset loader — download ``crag_task_1_and_2_dev_v4.jsonl.bz2`` and parse.
|
||||
|
||||
The CRAG repo (``facebookresearch/CRAG``) ships Tasks 1 & 2 as a
|
||||
single bzip2-compressed JSONL on GitHub raw. Each row carries:
|
||||
|
||||
* ``interaction_id`` — opaque per-question id (we keep verbatim)
|
||||
* ``query_time`` — wall clock of the original web search
|
||||
* ``domain`` — finance | music | movie | sports | open
|
||||
* ``question_type`` — simple | comparison | aggregation | set |
|
||||
multi-hop | post-processing | false_premise |
|
||||
simple_w_condition
|
||||
* ``static_or_dynamic`` — static | slow-changing | fast-changing | real-time
|
||||
* ``query`` — the question
|
||||
* ``answer`` — gold short answer
|
||||
* ``alt_ans`` — list[str] of alternative valid answers
|
||||
(paraphrases / synonyms / unit variants)
|
||||
* ``split`` — 0 = validation, 1 = public test
|
||||
* ``popularity`` — head | torso | tail (KG questions); empty for web
|
||||
* ``search_results`` — list of up to 5 ``{page_name, page_url,
|
||||
page_snippet, page_result, page_last_modified}``;
|
||||
``page_result`` is full HTML.
|
||||
|
||||
We materialise this into ``CragQuestion`` objects keeping ``pages`` as
|
||||
a list of ``CragPage`` so downstream ingest can save each as its own
|
||||
file and SurfSense can dedupe on filename.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bz2
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import urllib.request
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tasks 1 & 2 share the same JSONL on the public CRAG repo.
|
||||
CRAG_TASK_1_2_URL = (
|
||||
"https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/"
|
||||
"crag_task_1_and_2_dev_v4.jsonl.bz2"
|
||||
)
|
||||
CRAG_TASK_1_2_FILENAME = "crag_task_1_and_2_dev_v4.jsonl.bz2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question / page dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragPage:
|
||||
"""One of the up-to-5 pre-retrieved web pages for a CRAG question."""
|
||||
|
||||
page_name: str
|
||||
page_url: str
|
||||
page_snippet: str
|
||||
page_html: str
|
||||
page_last_modified: str | None = None
|
||||
|
||||
@property
|
||||
def url_hash(self) -> str:
|
||||
"""Stable 12-hex digest of the page URL for filename keys.
|
||||
|
||||
We can't use the raw URL as a filename (slashes, query strings,
|
||||
unicode), and we *do* want collision-safety across the whole
|
||||
ingest sample. ``sha1[:12]`` gives us 48 bits of namespace
|
||||
which is overkill for a corpus capped at a few thousand pages.
|
||||
"""
|
||||
|
||||
return hashlib.sha1(self.page_url.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragQuestion:
|
||||
"""One row of CRAG (Tasks 1 & 2)."""
|
||||
|
||||
qid: str # synthesised "C00000".."C02705"
|
||||
interaction_id: str
|
||||
query_time: str
|
||||
query: str
|
||||
gold_answer: str
|
||||
alt_answers: list[str]
|
||||
domain: str
|
||||
question_type: str
|
||||
static_or_dynamic: str
|
||||
popularity: str # may be "" for web-sourced questions
|
||||
split: int # 0=validation, 1=public_test
|
||||
raw_index: int # row index in the source JSONL
|
||||
pages: list[CragPage] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"qid": self.qid,
|
||||
"interaction_id": self.interaction_id,
|
||||
"query_time": self.query_time,
|
||||
"query": self.query,
|
||||
"gold_answer": self.gold_answer,
|
||||
"alt_answers": list(self.alt_answers),
|
||||
"domain": self.domain,
|
||||
"question_type": self.question_type,
|
||||
"static_or_dynamic": self.static_or_dynamic,
|
||||
"popularity": self.popularity,
|
||||
"split": self.split,
|
||||
"raw_index": self.raw_index,
|
||||
"n_pages": len(self.pages),
|
||||
"page_urls": [p.page_url for p in self.pages],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download + decompress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def download_task_1_2(cache_dir: Path) -> Path:
|
||||
"""Download the bz2 archive into ``cache_dir`` (skip if cached).
|
||||
|
||||
Returns the path to the local ``.jsonl.bz2``. We use stdlib
|
||||
``urllib`` rather than ``httpx`` to keep the download synchronous
|
||||
and trivially resumable (re-running the function is a no-op once
|
||||
the file is on disk and non-empty).
|
||||
"""
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = cache_dir / CRAG_TASK_1_2_FILENAME
|
||||
if dest.exists() and dest.stat().st_size > 0:
|
||||
logger.debug("CRAG bz2 already cached at %s", dest)
|
||||
return dest
|
||||
|
||||
logger.info("Downloading CRAG (Tasks 1 & 2) from %s ...", CRAG_TASK_1_2_URL)
|
||||
tmp = dest.with_suffix(dest.suffix + ".part")
|
||||
req = urllib.request.Request(
|
||||
CRAG_TASK_1_2_URL,
|
||||
headers={"User-Agent": "SurfSense-Evals/0.1 (CRAG dataset fetch)"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=600) as response, tmp.open("wb") as fh:
|
||||
chunk = response.read(1 << 20)
|
||||
while chunk:
|
||||
fh.write(chunk)
|
||||
chunk = response.read(1 << 20)
|
||||
tmp.replace(dest)
|
||||
logger.info("CRAG bz2 downloaded: %s (%.1f MiB)", dest, dest.stat().st_size / 1024 / 1024)
|
||||
return dest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parse
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_pages(raw_search_results: Any) -> list[CragPage]:
|
||||
if not isinstance(raw_search_results, list):
|
||||
return []
|
||||
pages: list[CragPage] = []
|
||||
for entry in raw_search_results:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
url = str(entry.get("page_url") or "").strip()
|
||||
html = str(entry.get("page_result") or "")
|
||||
if not url or not html.strip():
|
||||
# No URL or empty HTML => useless for retrieval.
|
||||
continue
|
||||
pages.append(CragPage(
|
||||
page_name=str(entry.get("page_name") or "").strip(),
|
||||
page_url=url,
|
||||
page_snippet=str(entry.get("page_snippet") or "").strip(),
|
||||
page_html=html,
|
||||
page_last_modified=(
|
||||
str(entry.get("page_last_modified")).strip()
|
||||
if entry.get("page_last_modified") else None
|
||||
),
|
||||
))
|
||||
return pages
|
||||
|
||||
|
||||
def _parse_alt_answers(raw: Any) -> list[str]:
|
||||
if isinstance(raw, list):
|
||||
return [str(x).strip() for x in raw if str(x).strip()]
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return [raw.strip()]
|
||||
return []
|
||||
|
||||
|
||||
def iter_questions(jsonl_bz2_path: Path) -> list[CragQuestion]:
|
||||
"""Stream-decompress + parse the CRAG JSONL into ``CragQuestion`` objects.
|
||||
|
||||
The bz2 expansion ratio is ~10x and the decompressed file is
|
||||
multi-GB; we therefore decompress *line by line* via
|
||||
``bz2.open(..., "rt")``. Each row is a single (potentially very
|
||||
large, due to embedded HTML) JSON object. We keep the entire row
|
||||
in memory because we materialise the pages to disk immediately
|
||||
after parsing in the ingest pipeline — the runner never holds
|
||||
more than the current sample's worth of HTML.
|
||||
"""
|
||||
|
||||
out: list[CragQuestion] = []
|
||||
with bz2.open(jsonl_bz2_path, mode="rt", encoding="utf-8") as fh:
|
||||
for raw_idx, line in enumerate(fh):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning("Skipping malformed CRAG row %d: %s", raw_idx, exc)
|
||||
continue
|
||||
query = str(row.get("query") or "").strip()
|
||||
answer = str(row.get("answer") or "").strip()
|
||||
if not query or not answer:
|
||||
logger.debug("Skipping CRAG row %d with missing query/answer", raw_idx)
|
||||
continue
|
||||
interaction_id = str(row.get("interaction_id") or "").strip()
|
||||
pages = _parse_pages(row.get("search_results"))
|
||||
out.append(CragQuestion(
|
||||
qid=f"C{raw_idx:05d}",
|
||||
interaction_id=interaction_id,
|
||||
query_time=str(row.get("query_time") or "").strip(),
|
||||
query=query,
|
||||
gold_answer=answer,
|
||||
alt_answers=_parse_alt_answers(row.get("alt_ans")),
|
||||
domain=str(row.get("domain") or "").strip().lower(),
|
||||
question_type=str(row.get("question_type") or "").strip().lower(),
|
||||
static_or_dynamic=str(row.get("static_or_dynamic") or "").strip().lower(),
|
||||
popularity=str(row.get("popularity") or "").strip().lower(),
|
||||
split=int(row.get("split") or 0),
|
||||
raw_index=raw_idx,
|
||||
pages=pages,
|
||||
))
|
||||
return out
|
||||
|
||||
|
||||
def stratified_sample(
|
||||
questions: list[CragQuestion],
|
||||
*,
|
||||
n: int,
|
||||
seed: int = 17,
|
||||
) -> list[CragQuestion]:
|
||||
"""Take ``n`` questions that roughly preserve the domain × question-type mix.
|
||||
|
||||
CRAG is only ~2.7k rows so naive head-of-list sampling badly
|
||||
over-weights ``finance`` (because the dataset isn't shuffled by
|
||||
domain). We bucket on ``(domain, question_type)`` and round-robin
|
||||
pick from each bucket until we hit ``n`` — this gives every
|
||||
bucket a fair shot and keeps the sample composition stable across
|
||||
re-runs (deterministic via the seeded shuffle inside each bucket).
|
||||
"""
|
||||
|
||||
if n <= 0 or n >= len(questions):
|
||||
return list(questions)
|
||||
import random
|
||||
|
||||
rng = random.Random(seed)
|
||||
buckets: dict[tuple[str, str], list[CragQuestion]] = {}
|
||||
for q in questions:
|
||||
buckets.setdefault((q.domain, q.question_type), []).append(q)
|
||||
for items in buckets.values():
|
||||
rng.shuffle(items)
|
||||
|
||||
keys = sorted(buckets.keys())
|
||||
chosen: list[CragQuestion] = []
|
||||
cursor = 0
|
||||
while len(chosen) < n and any(buckets[k] for k in keys):
|
||||
key = keys[cursor % len(keys)]
|
||||
cursor += 1
|
||||
if buckets[key]:
|
||||
chosen.append(buckets[key].pop())
|
||||
chosen.sort(key=lambda q: q.raw_index)
|
||||
return chosen
|
||||
|
||||
|
||||
def write_questions_jsonl(questions: list[CragQuestion], dest: Path) -> None:
|
||||
"""Persist a parsed copy (without page HTML) under the benchmark data dir."""
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
with dest.open("w", encoding="utf-8") as fh:
|
||||
for q in questions:
|
||||
fh.write(json.dumps(q.to_dict()) + "\n")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reading the lightweight questions.jsonl back
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_questions_jsonl(path: Path) -> list[dict[str, Any]]:
|
||||
"""Re-load the lightweight (no-HTML) questions JSONL from disk."""
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
if not path.exists():
|
||||
return out
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
out.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience: decompress a snippet to memory for tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def decompress_to_memory(jsonl_bz2_path: Path) -> io.StringIO:
|
||||
"""For tests / one-off scripts: read the whole bz2 into a StringIO.
|
||||
|
||||
Avoids leaking gigabytes; use ``iter_questions`` in production.
|
||||
"""
|
||||
|
||||
with bz2.open(jsonl_bz2_path, mode="rb") as fh:
|
||||
return io.StringIO(fh.read().decode("utf-8"))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CRAG_TASK_1_2_FILENAME",
|
||||
"CRAG_TASK_1_2_URL",
|
||||
"CragPage",
|
||||
"CragQuestion",
|
||||
"decompress_to_memory",
|
||||
"download_task_1_2",
|
||||
"iter_questions",
|
||||
"load_questions_jsonl",
|
||||
"stratified_sample",
|
||||
"write_questions_jsonl",
|
||||
]
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
"""CRAG Task 3 dataset loader — 4-part tar.bz2 → streaming JSONL.
|
||||
|
||||
Task 3 ships ~7 GB of compressed data split into 4 parts on GitHub:
|
||||
|
||||
crag_task_3_dev_v4.tar.bz2.part1 (≈2 GB)
|
||||
crag_task_3_dev_v4.tar.bz2.part2 (≈2 GB)
|
||||
crag_task_3_dev_v4.tar.bz2.part3 (≈2 GB)
|
||||
crag_task_3_dev_v4.tar.bz2.part4 (≈1.3 GB)
|
||||
|
||||
Concatenated, they form a tar archive containing a single JSONL file.
|
||||
Decompressed, that JSONL is on the order of 30-50 GB because each row
|
||||
embeds 50 full HTML pages (vs 5 in Tasks 1 & 2).
|
||||
|
||||
Materialising the JSONL would blow the disk budget (we have ~50 GB
|
||||
free at the time of writing), so we stream the whole thing instead:
|
||||
|
||||
1. Download parts (idempotent; ``scripts/download_crag_task3.py``).
|
||||
2. Concat them into a virtual file via ``_MultiPartReader``.
|
||||
3. Wrap in ``bz2.BZ2File`` for on-the-fly decompression.
|
||||
4. Wrap in ``tarfile.open(fileobj=..., mode="r|")`` for streaming
|
||||
tar member iteration.
|
||||
5. For the JSONL member inside, ``tar.extractfile()`` returns a
|
||||
binary file-like; we iterate lines and yield parsed dicts.
|
||||
|
||||
The caller can ``break`` out as soon as they have enough samples —
|
||||
nothing past the consumed point is decompressed.
|
||||
|
||||
Schema is identical to Tasks 1 & 2 (see ``dataset.py``); only
|
||||
``search_results`` is bigger (50 entries instead of 5).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bz2
|
||||
import json
|
||||
import logging
|
||||
import tarfile
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import IO
|
||||
|
||||
from .dataset import (
|
||||
CragPage,
|
||||
CragQuestion,
|
||||
_parse_alt_answers,
|
||||
_parse_pages,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CRAG_TASK_3_PART_URLS: tuple[str, ...] = tuple(
|
||||
"https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/"
|
||||
f"crag_task_3_dev_v4.tar.bz2.part{i}"
|
||||
for i in (1, 2, 3, 4)
|
||||
)
|
||||
CRAG_TASK_3_PART_NAMES: tuple[str, ...] = tuple(
|
||||
f"crag_task_3_dev_v4.tar.bz2.part{i}" for i in (1, 2, 3, 4)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-part virtual file (concatenates N files transparently)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _MultiPartReader:
|
||||
"""Read N files end-to-end as if they were one big file.
|
||||
|
||||
Implements just enough of the file protocol for ``bz2.BZ2File``
|
||||
to consume it: ``read(n)``, ``readable()``, ``close()``.
|
||||
Doesn't implement ``seek`` — the bz2 + tarfile streaming path
|
||||
is forward-only, which is what we want here.
|
||||
"""
|
||||
|
||||
def __init__(self, paths: list[Path]) -> None:
|
||||
if not paths:
|
||||
raise ValueError("_MultiPartReader needs at least one path")
|
||||
for p in paths:
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(p)
|
||||
self._paths = list(paths)
|
||||
self._idx = 0
|
||||
self._fh: IO[bytes] | None = self._paths[0].open("rb")
|
||||
self._closed = False
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
if self._closed:
|
||||
raise ValueError("read of closed _MultiPartReader")
|
||||
if n is None or n < 0:
|
||||
chunks: list[bytes] = []
|
||||
while self._fh is not None:
|
||||
chunks.append(self._fh.read())
|
||||
self._advance()
|
||||
return b"".join(chunks)
|
||||
out: list[bytes] = []
|
||||
remaining = n
|
||||
while remaining > 0 and self._fh is not None:
|
||||
chunk = self._fh.read(remaining)
|
||||
if not chunk:
|
||||
self._advance()
|
||||
continue
|
||||
out.append(chunk)
|
||||
remaining -= len(chunk)
|
||||
return b"".join(out)
|
||||
|
||||
def _advance(self) -> None:
|
||||
if self._fh is not None:
|
||||
self._fh.close()
|
||||
self._fh = None
|
||||
self._idx += 1
|
||||
if self._idx < len(self._paths):
|
||||
self._fh = self._paths[self._idx].open("rb")
|
||||
|
||||
def readable(self) -> bool:
|
||||
return not self._closed
|
||||
|
||||
def close(self) -> None:
|
||||
if self._fh is not None:
|
||||
self._fh.close()
|
||||
self._fh = None
|
||||
self._closed = True
|
||||
|
||||
def __enter__(self) -> _MultiPartReader:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[no-untyped-def]
|
||||
self.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stream the JSONL inside the tar.bz2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _is_jsonl_member(name: str) -> bool:
|
||||
return name.endswith(".jsonl") or name.endswith(".jsonl.txt")
|
||||
|
||||
|
||||
def iter_questions_task3(
|
||||
parts_dir: Path,
|
||||
*,
|
||||
max_questions: int | None = None,
|
||||
) -> list[CragQuestion]:
|
||||
"""Stream-parse Task 3 rows into ``CragQuestion`` objects.
|
||||
|
||||
The Task 3 archive ships its 2,706 questions sharded across
|
||||
multiple JSONL files inside the tar (e.g.
|
||||
``crag_task_3_dev_v4_0.jsonl``, ``..._1.jsonl``, …). We iterate
|
||||
members in-stream, parse every JSONL one we encounter, and stop
|
||||
as soon as ``max_questions`` is reached — at which point we
|
||||
don't decompress any further members.
|
||||
|
||||
For a typical n=50 sample at ~3 MB per row we touch ~150 MB of
|
||||
decompressed JSONL — almost always inside the first shard.
|
||||
"""
|
||||
|
||||
parts = [parts_dir / name for name in CRAG_TASK_3_PART_NAMES]
|
||||
multi = _MultiPartReader(parts)
|
||||
bz = bz2.BZ2File(multi, mode="rb")
|
||||
tar = tarfile.open(fileobj=bz, mode="r|")
|
||||
out: list[CragQuestion] = []
|
||||
raw_idx = 0
|
||||
found_jsonl = False
|
||||
try:
|
||||
for member in tar:
|
||||
if not member.isfile() or not _is_jsonl_member(member.name):
|
||||
continue
|
||||
found_jsonl = True
|
||||
logger.info(
|
||||
"CRAG Task 3: streaming JSONL shard %s (size: %d bytes)",
|
||||
member.name, member.size,
|
||||
)
|
||||
fh = tar.extractfile(member)
|
||||
if fh is None:
|
||||
logger.warning("tar.extractfile returned None for %s; skipping", member.name)
|
||||
continue
|
||||
try:
|
||||
for raw_line in fh:
|
||||
line = raw_line.decode("utf-8", errors="replace").strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning(
|
||||
"Skipping malformed CRAG Task 3 row %d in %s: %s",
|
||||
raw_idx, member.name, exc,
|
||||
)
|
||||
raw_idx += 1
|
||||
continue
|
||||
query = str(row.get("query") or "").strip()
|
||||
answer = str(row.get("answer") or "").strip()
|
||||
if not query or not answer:
|
||||
raw_idx += 1
|
||||
continue
|
||||
out.append(CragQuestion(
|
||||
qid=f"T3_{raw_idx:05d}",
|
||||
interaction_id=str(row.get("interaction_id") or "").strip(),
|
||||
query_time=str(row.get("query_time") or "").strip(),
|
||||
query=query,
|
||||
gold_answer=answer,
|
||||
alt_answers=_parse_alt_answers(row.get("alt_ans")),
|
||||
domain=str(row.get("domain") or "").strip().lower(),
|
||||
question_type=str(row.get("question_type") or "").strip().lower(),
|
||||
static_or_dynamic=str(row.get("static_or_dynamic") or "").strip().lower(),
|
||||
popularity=str(row.get("popularity") or "").strip().lower(),
|
||||
split=int(row.get("split") or 0),
|
||||
raw_index=raw_idx,
|
||||
pages=_parse_pages(row.get("search_results")),
|
||||
))
|
||||
raw_idx += 1
|
||||
if max_questions is not None and len(out) >= max_questions:
|
||||
return out
|
||||
finally:
|
||||
try:
|
||||
fh.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
if not found_jsonl:
|
||||
raise RuntimeError(
|
||||
"No JSONL member found inside Task 3 tar.bz2 archive; "
|
||||
"schema may have changed upstream."
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
tar.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
try:
|
||||
bz.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
try:
|
||||
multi.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
return out
|
||||
|
||||
|
||||
def parts_present(parts_dir: Path) -> bool:
|
||||
"""``True`` iff all 4 parts exist on disk and are non-empty."""
|
||||
|
||||
for name in CRAG_TASK_3_PART_NAMES:
|
||||
p = parts_dir / name
|
||||
if not p.exists() or p.stat().st_size == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Re-exports for convenience
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CRAG_TASK_3_PART_NAMES",
|
||||
"CRAG_TASK_3_PART_URLS",
|
||||
"CragPage",
|
||||
"CragQuestion",
|
||||
"iter_questions_task3",
|
||||
"parts_present",
|
||||
]
|
||||
|
|
@ -0,0 +1,540 @@
|
|||
"""CRAG 3-class grader: ``correct`` (+1) / ``missing`` (0) / ``incorrect`` (-1).
|
||||
|
||||
The CRAG paper's headline metric is the **Truthfulness Score**:
|
||||
|
||||
score = (#correct - #incorrect) / total
|
||||
|
||||
which rewards calibrated abstention — refusing to answer is neutral
|
||||
(0), guessing wrong is negative (-1). Grading is therefore a 3-class
|
||||
problem rather than the 2-class accuracy used for FRAMES.
|
||||
|
||||
Pipeline per (pred, gold, alt_ans, question_type):
|
||||
|
||||
1. Detect refusal first (``Answer: I don't know`` / "I don't know" /
|
||||
"no information") → ``missing`` (deterministic, never billed).
|
||||
2. ``false_premise`` questions: gold is canonically "the question
|
||||
contains a false premise" — reward any answer that flags the
|
||||
false premise (substring "false premise" / "incorrect premise" /
|
||||
"no such") as correct.
|
||||
3. Run the FRAMES-style deterministic shortcut (exact / numeric /
|
||||
substring) on ``pred`` against ``gold ∪ alt_ans``. Hit → correct.
|
||||
4. Fall through to the LLM judge (if configured), which returns one
|
||||
of ``{correct, missing, incorrect}`` — verbatim CRAG protocol.
|
||||
5. No judge configured → record ``incorrect`` (pessimistic but at
|
||||
least monotone with the deterministic grader).
|
||||
|
||||
The judge is throttled by an asyncio.Semaphore so it doesn't outrun
|
||||
the OpenRouter rate limit; the pre-judge deterministic pass keeps
|
||||
the bill bounded (most easy "Beyoncé"-vs-"Beyoncé Knowles" cases
|
||||
short-circuit before we burn judge tokens).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import string
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from ....core.providers.openrouter_chat import OpenRouterChatProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GradeClass = Literal["correct", "missing", "incorrect"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragGradeResult:
|
||||
"""One graded (pred, gold) pair under CRAG's 3-class rubric."""
|
||||
|
||||
grade: GradeClass
|
||||
score: int # +1 / 0 / -1
|
||||
method: str # exact, numeric, substring, refusal,
|
||||
# false_premise_correct, false_premise_miss,
|
||||
# llm_judge, lexical_miss, ...
|
||||
normalised_pred: str = ""
|
||||
normalised_gold: str = ""
|
||||
judge_rationale: str = ""
|
||||
|
||||
@property
|
||||
def correct(self) -> bool:
|
||||
return self.grade == "correct"
|
||||
|
||||
@property
|
||||
def missing(self) -> bool:
|
||||
return self.grade == "missing"
|
||||
|
||||
@property
|
||||
def incorrect(self) -> bool:
|
||||
return self.grade == "incorrect"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"grade": self.grade,
|
||||
"score": self.score,
|
||||
"method": self.method,
|
||||
"normalised_pred": self.normalised_pred,
|
||||
"normalised_gold": self.normalised_gold,
|
||||
"judge_rationale": self.judge_rationale,
|
||||
}
|
||||
|
||||
|
||||
def _grade_to_score(grade: GradeClass) -> int:
|
||||
return {"correct": 1, "missing": 0, "incorrect": -1}[grade]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
|
||||
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.IGNORECASE)
|
||||
_WS = re.compile(r"\s+")
|
||||
|
||||
|
||||
def _normalise(s: str) -> str:
|
||||
s = (s or "").lower()
|
||||
s = s.translate(_PUNCT_TABLE)
|
||||
s = _ARTICLES.sub(" ", s)
|
||||
s = _WS.sub(" ", s).strip()
|
||||
return s
|
||||
|
||||
|
||||
_WORD_NUMBERS = {
|
||||
"zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
|
||||
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11,
|
||||
"twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15, "sixteen": 16,
|
||||
"seventeen": 17, "eighteen": 18, "nineteen": 19, "twenty": 20,
|
||||
}
|
||||
|
||||
_NUMERIC_RE = re.compile(r"-?\d+(?:[.,]\d+)?")
|
||||
|
||||
|
||||
def _maybe_number(s: str) -> float | None:
|
||||
"""Extract a single numeric value from raw lowercased text."""
|
||||
|
||||
raw = (s or "").strip().lower()
|
||||
if not raw:
|
||||
return None
|
||||
match = _NUMERIC_RE.search(raw)
|
||||
if match:
|
||||
try:
|
||||
return float(match.group(0).replace(",", ""))
|
||||
except ValueError:
|
||||
pass
|
||||
for tok in _normalise(s).split():
|
||||
if tok in _WORD_NUMBERS:
|
||||
return float(_WORD_NUMBERS[tok])
|
||||
return None
|
||||
|
||||
|
||||
def _whole_word_substring(haystack: str, needle: str) -> bool:
|
||||
if not needle:
|
||||
return False
|
||||
return f" {needle} " in f" {haystack} "
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Refusal detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_REFUSAL_PATTERNS = [
|
||||
re.compile(r"\bi\s+don'?t\s+know\b", re.IGNORECASE),
|
||||
re.compile(r"\bi\s+do\s+not\s+know\b", re.IGNORECASE),
|
||||
re.compile(r"\bnot\s+enough\s+information\b", re.IGNORECASE),
|
||||
re.compile(r"\binsufficient\s+information\b", re.IGNORECASE),
|
||||
re.compile(r"\bcannot\s+(?:be\s+)?(?:answered|determined)\b", re.IGNORECASE),
|
||||
re.compile(r"\bunable\s+to\s+(?:answer|determine)\b", re.IGNORECASE),
|
||||
re.compile(r"\bno\s+(?:information|data|evidence)\b", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _is_refusal(pred: str) -> bool:
|
||||
"""Cheap deterministic check for "I don't know" -shaped responses."""
|
||||
|
||||
if not pred or not pred.strip():
|
||||
return True # empty answer is a de facto refusal
|
||||
return any(p.search(pred) for p in _REFUSAL_PATTERNS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# False-premise handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_FALSE_PREMISE_PATTERNS = [
|
||||
re.compile(r"false\s+premise", re.IGNORECASE),
|
||||
re.compile(r"incorrect\s+premise", re.IGNORECASE),
|
||||
re.compile(r"premise\s+(?:is|of)\s+the\s+question", re.IGNORECASE),
|
||||
re.compile(r"\bno\s+such\b", re.IGNORECASE),
|
||||
re.compile(r"never\s+(?:happened|occurred|existed)", re.IGNORECASE),
|
||||
re.compile(r"\bdid\s+not\s+(?:happen|occur|exist)\b", re.IGNORECASE),
|
||||
re.compile(r"\bdoes\s+not\s+exist\b", re.IGNORECASE),
|
||||
re.compile(r"is\s+not\s+(?:true|correct|accurate)", re.IGNORECASE),
|
||||
re.compile(r"\bisn'?t\s+(?:true|correct|accurate)\b", re.IGNORECASE),
|
||||
re.compile(r"\binvalid\s+(?:premise|question|assumption)\b", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _flags_false_premise(pred: str) -> bool:
|
||||
return any(p.search(pred) for p in _FALSE_PREMISE_PATTERNS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deterministic grader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def grade_deterministic(
|
||||
*,
|
||||
pred: str,
|
||||
gold: str,
|
||||
alt_answers: Sequence[str] = (),
|
||||
question_type: str = "",
|
||||
) -> CragGradeResult:
|
||||
"""Try to grade without the LLM judge. Returns a final result.
|
||||
|
||||
Always returns *some* result — the caller checks ``method`` to
|
||||
decide whether the LLM judge should overturn it. ``lexical_miss``
|
||||
and ``false_premise_unclear`` are the two methods that trigger the
|
||||
judge fallback.
|
||||
"""
|
||||
|
||||
qtype = (question_type or "").lower()
|
||||
n_pred = _normalise(pred)
|
||||
n_gold = _normalise(gold)
|
||||
|
||||
if _is_refusal(pred):
|
||||
# CRAG protocol: refusal is *missing* (0), even on false-premise
|
||||
# questions where one might argue refusal == correct. We
|
||||
# follow the paper's grading literally.
|
||||
return CragGradeResult(
|
||||
grade="missing",
|
||||
score=0,
|
||||
method="refusal",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
|
||||
# Empty-gold guard (shouldn't happen, but defensively):
|
||||
if not n_gold:
|
||||
return CragGradeResult(
|
||||
grade="incorrect",
|
||||
score=-1,
|
||||
method="empty_gold",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
|
||||
# False-premise questions: gold is typically "the question contains
|
||||
# a false premise" / "no such X" / similar. Any answer that
|
||||
# explicitly flags the false premise is correct.
|
||||
if qtype == "false_premise":
|
||||
if _flags_false_premise(pred):
|
||||
return CragGradeResult(
|
||||
grade="correct",
|
||||
score=1,
|
||||
method="false_premise_flagged",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
# If the model commits to *any* concrete answer on a false-
|
||||
# premise question without flagging the premise, it is wrong.
|
||||
# But we don't classify ourselves — let the judge decide on
|
||||
# the off chance the gold itself is e.g. "no" and the pred
|
||||
# is "no" without explicit "false premise" wording.
|
||||
return CragGradeResult(
|
||||
grade="incorrect",
|
||||
score=-1,
|
||||
method="false_premise_unclear",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
|
||||
# All non-false-premise questions: try the standard chain against
|
||||
# gold and each alt answer. First match wins.
|
||||
candidates = [gold, *list(alt_answers)]
|
||||
for candidate in candidates:
|
||||
if not candidate or not str(candidate).strip():
|
||||
continue
|
||||
cand_norm = _normalise(candidate)
|
||||
if not cand_norm:
|
||||
continue
|
||||
if n_pred == cand_norm:
|
||||
return CragGradeResult(
|
||||
grade="correct", score=1, method="exact",
|
||||
normalised_pred=n_pred, normalised_gold=cand_norm,
|
||||
)
|
||||
p_num = _maybe_number(pred)
|
||||
c_num = _maybe_number(candidate)
|
||||
if p_num is not None and c_num is not None:
|
||||
# Pure 1% relative tolerance for CRAG (currency, counts,
|
||||
# ratios). Unlike FRAMES (which uses a 0.5 absolute floor
|
||||
# for year-shaped answers), CRAG's numeric questions are
|
||||
# often small-value (stock prices, percentages) where a
|
||||
# 0.5 floor would let "$2.05" match "$2.17". The judge is
|
||||
# the safety net for borderline rounding cases.
|
||||
tol = abs(c_num) * 0.01
|
||||
if abs(p_num - c_num) <= tol:
|
||||
return CragGradeResult(
|
||||
grade="correct", score=1, method="numeric",
|
||||
normalised_pred=n_pred, normalised_gold=cand_norm,
|
||||
)
|
||||
# Numeric question with different numbers — keep looking
|
||||
# at other candidates rather than declaring miss now;
|
||||
# alt answers may include word forms that pass.
|
||||
if _whole_word_substring(n_pred, cand_norm):
|
||||
return CragGradeResult(
|
||||
grade="correct", score=1, method="substring",
|
||||
normalised_pred=n_pred, normalised_gold=cand_norm,
|
||||
)
|
||||
if _whole_word_substring(cand_norm, n_pred) and len(n_pred) >= 3:
|
||||
return CragGradeResult(
|
||||
grade="correct", score=1, method="substring_reverse",
|
||||
normalised_pred=n_pred, normalised_gold=cand_norm,
|
||||
)
|
||||
|
||||
return CragGradeResult(
|
||||
grade="incorrect",
|
||||
score=-1,
|
||||
method="lexical_miss",
|
||||
normalised_pred=n_pred,
|
||||
normalised_gold=n_gold,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM-as-judge (3-class)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_JUDGE_SYSTEM = (
|
||||
"You are an impartial grader for short-answer factual questions, "
|
||||
"following the CRAG benchmark rubric. Given a question, the gold "
|
||||
"answer (and any alternative valid answers), and a model's "
|
||||
"prediction, classify the prediction into exactly one of three "
|
||||
"categories:\n\n"
|
||||
"* \"correct\" — the prediction expresses the same factual "
|
||||
"content as the gold answer (paraphrasing OK; numbers as words "
|
||||
"OK; partial-but-correct names OK; non-contradictory extra "
|
||||
"detail OK).\n"
|
||||
"* \"missing\" — the prediction explicitly refuses, says \"I "
|
||||
"don't know\", says there is insufficient information, or hedges "
|
||||
"without committing.\n"
|
||||
"* \"incorrect\" — the prediction commits to a fact that is "
|
||||
"different from the gold answer, or fails to flag a false "
|
||||
"premise when the question contains one.\n\n"
|
||||
"Special case: if the question contains a false premise and the "
|
||||
"gold answer says so, then a prediction that flags the false "
|
||||
"premise is \"correct\".\n\n"
|
||||
"Respond with ONLY a JSON object on a single line:\n"
|
||||
'{\"grade\": \"correct\"|\"missing\"|\"incorrect\", \"rationale\": \"<one short sentence>\"}'
|
||||
)
|
||||
|
||||
|
||||
_JUDGE_TEMPLATE = """\
|
||||
Question: {question}
|
||||
Question type: {question_type}
|
||||
Gold answer: {gold}
|
||||
{alt_block}Model prediction: {pred}
|
||||
|
||||
Decide whether the prediction is correct, missing, or incorrect.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragJudgeConfig:
|
||||
api_key: str
|
||||
model: str = "anthropic/claude-sonnet-4.5"
|
||||
base_url: str = "https://openrouter.ai/api/v1"
|
||||
max_tokens: int = 200
|
||||
concurrency: int = 4
|
||||
|
||||
|
||||
class CragLlmJudge:
|
||||
"""Async LLM judge over OpenRouter chat completions, 3-class output."""
|
||||
|
||||
def __init__(self, *, config: CragJudgeConfig) -> None:
|
||||
self._config = config
|
||||
self._provider = OpenRouterChatProvider(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
model=config.model,
|
||||
)
|
||||
self._sem = asyncio.Semaphore(max(1, config.concurrency))
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._config.model
|
||||
|
||||
async def judge(
|
||||
self,
|
||||
*,
|
||||
question: str,
|
||||
gold: str,
|
||||
alt_answers: Sequence[str],
|
||||
pred: str,
|
||||
question_type: str = "",
|
||||
) -> tuple[GradeClass, str]:
|
||||
"""Return ``(grade, rationale)``. Errors return incorrect + reason."""
|
||||
|
||||
alt_block = ""
|
||||
if alt_answers:
|
||||
alt_lines = "\n".join(f" - {a}" for a in alt_answers if a)
|
||||
if alt_lines:
|
||||
alt_block = f"Alternative valid answers:\n{alt_lines}\n"
|
||||
prompt = _JUDGE_TEMPLATE.format(
|
||||
question=question,
|
||||
question_type=question_type or "unknown",
|
||||
gold=gold,
|
||||
alt_block=alt_block,
|
||||
pred=pred,
|
||||
)
|
||||
try:
|
||||
async with self._sem:
|
||||
response = await self._provider.complete(
|
||||
prompt=prompt,
|
||||
system_prompt=_JUDGE_SYSTEM,
|
||||
max_tokens=self._config.max_tokens,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return "incorrect", f"judge_error: {type(exc).__name__}: {exc}"
|
||||
return _parse_judge_response(response.text)
|
||||
|
||||
|
||||
def _parse_judge_response(text: str) -> tuple[GradeClass, str]:
|
||||
"""Parse the judge reply into a 3-class label + rationale."""
|
||||
|
||||
if not text or not text.strip():
|
||||
return "incorrect", "judge_returned_empty"
|
||||
match = re.search(r"\{[^{}]*\}", text, flags=re.DOTALL)
|
||||
candidate = match.group(0) if match else text
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
lowered = text.strip().lower()
|
||||
if "correct" in lowered and "incorrect" not in lowered:
|
||||
return "correct", "yes (parser_fallback)"
|
||||
if "missing" in lowered or "i don" in lowered:
|
||||
return "missing", "missing (parser_fallback)"
|
||||
return "incorrect", f"unparseable_judge_response: {text[:200]}"
|
||||
raw_grade = str(data.get("grade") or "").strip().lower()
|
||||
rationale = str(data.get("rationale", "")).strip()[:280]
|
||||
if raw_grade in {"correct", "missing", "incorrect"}:
|
||||
return raw_grade, rationale # type: ignore[return-value]
|
||||
return "incorrect", f"unknown_grade={raw_grade!r}; {rationale}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Combined grader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Methods that should *not* trigger the LLM judge — the deterministic
|
||||
# verdict is conclusive (refusal, exact match, numeric mismatch, etc.).
|
||||
_TERMINAL_METHODS = frozenset({
|
||||
"refusal",
|
||||
"exact",
|
||||
"numeric",
|
||||
"substring",
|
||||
"substring_reverse",
|
||||
"false_premise_flagged",
|
||||
"empty_gold",
|
||||
})
|
||||
|
||||
|
||||
async def grade_with_judge(
|
||||
*,
|
||||
pred: str,
|
||||
gold: str,
|
||||
alt_answers: Sequence[str],
|
||||
question: str,
|
||||
question_type: str,
|
||||
judge: CragLlmJudge | None,
|
||||
) -> CragGradeResult:
|
||||
"""One row → deterministic shortcut → optional LLM judge fallback."""
|
||||
|
||||
det = grade_deterministic(
|
||||
pred=pred,
|
||||
gold=gold,
|
||||
alt_answers=alt_answers,
|
||||
question_type=question_type,
|
||||
)
|
||||
if det.method in _TERMINAL_METHODS:
|
||||
return det
|
||||
if judge is None:
|
||||
return det # ``lexical_miss`` / ``false_premise_unclear`` → keep as-is
|
||||
grade, rationale = await judge.judge(
|
||||
question=question,
|
||||
gold=gold,
|
||||
alt_answers=alt_answers,
|
||||
pred=pred,
|
||||
question_type=question_type,
|
||||
)
|
||||
return CragGradeResult(
|
||||
grade=grade,
|
||||
score=_grade_to_score(grade),
|
||||
method="llm_judge",
|
||||
normalised_pred=det.normalised_pred,
|
||||
normalised_gold=det.normalised_gold,
|
||||
judge_rationale=rationale,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CragGradeRow:
|
||||
"""One row to grade. Mirrors the FRAMES grader's tuple but typed."""
|
||||
|
||||
qid: str
|
||||
question: str
|
||||
gold: str
|
||||
alt_answers: list[str]
|
||||
pred: str
|
||||
question_type: str = ""
|
||||
|
||||
|
||||
async def grade_many(
|
||||
*,
|
||||
rows: Sequence[CragGradeRow],
|
||||
judge: CragLlmJudge | None,
|
||||
) -> list[CragGradeResult]:
|
||||
"""Grade every row concurrently. Judge enforces its own concurrency cap."""
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
coros = [
|
||||
grade_with_judge(
|
||||
pred=r.pred,
|
||||
gold=r.gold,
|
||||
alt_answers=r.alt_answers,
|
||||
question=r.question,
|
||||
question_type=r.question_type,
|
||||
judge=judge,
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
return list(await asyncio.gather(*coros))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CragGradeResult",
|
||||
"CragGradeRow",
|
||||
"CragJudgeConfig",
|
||||
"CragLlmJudge",
|
||||
"GradeClass",
|
||||
"grade_deterministic",
|
||||
"grade_many",
|
||||
"grade_with_judge",
|
||||
]
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
"""HTML → markdown for CRAG pages, with boilerplate removal.
|
||||
|
||||
Each CRAG page is a *full* HTML document (nav, ads, recommended-for-
|
||||
you, footer, ...). Without removing that boilerplate, retrieval over
|
||||
the chunks would surface menu items and "subscribe to our newsletter"
|
||||
boxes instead of the actual page content. We use ``trafilatura``,
|
||||
which is purpose-built for main-content extraction (the same library
|
||||
Common Crawl downstream pipelines use). It outputs clean prose with
|
||||
section headers, lists, and tables preserved.
|
||||
|
||||
Extraction policy:
|
||||
1. ``trafilatura.extract`` with ``output_format="markdown"`` — main
|
||||
content only, headers preserved, tables kept.
|
||||
2. If extraction fails or returns < 200 chars (paywalled / JS-only
|
||||
page / extraction confused), fall back to a plain stdlib
|
||||
``HTMLParser`` that strips tags and collapses whitespace. Some
|
||||
text is better than no text — SurfSense's chunker handles noisy
|
||||
prose.
|
||||
|
||||
We *intentionally* keep the page name and URL as visible H1 / link
|
||||
metadata so the SurfSense chunker preserves doc identity at the top of
|
||||
the first chunk (mirrors what we do for FRAMES Wikipedia pages).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from html.parser import HTMLParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_MIN_TRAFILATURA_LENGTH = 200
|
||||
_MAX_OUTPUT_CHARS = 200_000 # cap to keep upload payloads sane
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""Outcome of converting one HTML blob to plain markdown."""
|
||||
|
||||
text: str
|
||||
method: str # "trafilatura" | "fallback_strip" | "empty"
|
||||
n_chars: int
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return self.n_chars > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trafilatura wrapper (lazy import so tests / small scripts don't pay)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _trafilatura_extract(html_text: str, *, url: str) -> str | None:
|
||||
try:
|
||||
import trafilatura
|
||||
except ImportError: # pragma: no cover - dependency is required
|
||||
logger.warning("trafilatura not installed; falling back to strip-tags only")
|
||||
return None
|
||||
try:
|
||||
text = trafilatura.extract(
|
||||
html_text,
|
||||
url=url or None,
|
||||
output_format="markdown",
|
||||
include_links=False,
|
||||
include_images=False,
|
||||
include_tables=True,
|
||||
favor_recall=True,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - trafilatura raises a zoo
|
||||
logger.debug("trafilatura.extract crashed for %s: %s", url, exc)
|
||||
return None
|
||||
if not text:
|
||||
return None
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stdlib fallback: strip HTML tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StripHTMLParser(HTMLParser):
|
||||
"""Collect text content, treating block tags as paragraph breaks.
|
||||
|
||||
We deliberately drop ``<script>``, ``<style>``, ``<nav>``,
|
||||
``<header>``, ``<footer>``, and ``<aside>`` content — these are
|
||||
almost always boilerplate and they are the dominant source of
|
||||
noise SurfSense ends up retrieving against if not removed.
|
||||
"""
|
||||
|
||||
_SKIP_TAGS = frozenset({"script", "style", "nav", "header", "footer", "aside", "svg"})
|
||||
_BLOCK_TAGS = frozenset({
|
||||
"p", "div", "section", "article", "li", "ul", "ol",
|
||||
"h1", "h2", "h3", "h4", "h5", "h6", "br", "tr",
|
||||
"td", "th", "table", "blockquote", "pre",
|
||||
})
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(convert_charrefs=True)
|
||||
self._buffer: list[str] = []
|
||||
self._skip_depth: int = 0
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list) -> None: # noqa: ARG002
|
||||
if tag in self._SKIP_TAGS:
|
||||
self._skip_depth += 1
|
||||
if tag in self._BLOCK_TAGS:
|
||||
self._buffer.append("\n")
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag in self._SKIP_TAGS and self._skip_depth > 0:
|
||||
self._skip_depth -= 1
|
||||
if tag in self._BLOCK_TAGS:
|
||||
self._buffer.append("\n")
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._skip_depth:
|
||||
return
|
||||
self._buffer.append(data)
|
||||
|
||||
def get_text(self) -> str:
|
||||
text = "".join(self._buffer)
|
||||
# Decode any leftover entities and collapse whitespace.
|
||||
text = html.unescape(text)
|
||||
text = re.sub(r"[ \t]+", " ", text)
|
||||
text = re.sub(r"\n[ \t]+", "\n", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _strip_tags(html_text: str) -> str:
|
||||
parser = _StripHTMLParser()
|
||||
try:
|
||||
parser.feed(html_text)
|
||||
except Exception as exc: # noqa: BLE001 - HTMLParser is fragile on garbage input
|
||||
logger.debug("HTMLParser failed; using regex strip: %s", exc)
|
||||
no_tags = re.sub(r"<[^>]+>", " ", html_text)
|
||||
return re.sub(r"\s+", " ", html.unescape(no_tags)).strip()
|
||||
return parser.get_text()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_main_content(
|
||||
html_text: str,
|
||||
*,
|
||||
url: str = "",
|
||||
page_name: str = "",
|
||||
last_modified: str | None = None,
|
||||
) -> ExtractionResult:
|
||||
"""Convert one HTML blob into clean markdown for ingest.
|
||||
|
||||
The returned ``text`` is prefixed with a small metadata header
|
||||
(``# {page_name}\\n\\nSource: {url}\\n``) so that:
|
||||
|
||||
* SurfSense's chunker has a stable doc-identity anchor at the top
|
||||
of the first chunk (matches what we do for FRAMES Wikipedia).
|
||||
* The retrieval-augmented arm sees the URL inline, which the LLM
|
||||
can surface as a citation if the prompt asks for one.
|
||||
"""
|
||||
|
||||
body = ""
|
||||
method = "empty"
|
||||
if html_text and html_text.strip():
|
||||
body = _trafilatura_extract(html_text, url=url) or ""
|
||||
if body and len(body) >= _MIN_TRAFILATURA_LENGTH:
|
||||
method = "trafilatura"
|
||||
else:
|
||||
stripped = _strip_tags(html_text)
|
||||
# Prefer trafilatura output even if short, but only if it
|
||||
# contained any prose at all — empty trafilatura fall-through
|
||||
# to the stripped form.
|
||||
if body and stripped and len(stripped) > len(body) * 1.5:
|
||||
body = stripped
|
||||
method = "fallback_strip"
|
||||
elif not body and stripped:
|
||||
body = stripped
|
||||
method = "fallback_strip"
|
||||
elif body:
|
||||
method = "trafilatura"
|
||||
|
||||
body = body.strip()
|
||||
if len(body) > _MAX_OUTPUT_CHARS:
|
||||
body = body[:_MAX_OUTPUT_CHARS] + "\n\n[...truncated...]"
|
||||
|
||||
if not body:
|
||||
return ExtractionResult(text="", method="empty", n_chars=0)
|
||||
|
||||
title_line = (page_name or url or "Untitled").strip()
|
||||
header_lines = [f"# {title_line}"]
|
||||
if url:
|
||||
header_lines.append(f"Source: {url}")
|
||||
if last_modified:
|
||||
header_lines.append(f"Last modified: {last_modified}")
|
||||
final = "\n".join(header_lines) + "\n\n" + body + "\n"
|
||||
return ExtractionResult(text=final, method=method, n_chars=len(final))
|
||||
|
||||
|
||||
__all__ = ["ExtractionResult", "extract_main_content"]
|
||||
|
|
@ -0,0 +1,447 @@
|
|||
"""CRAG ingestion: download → extract → upload → per-question doc map.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Download ``crag_task_1_and_2_dev_v4.jsonl.bz2`` from
|
||||
``facebookresearch/CRAG`` (skip if cached).
|
||||
2. Stream-parse into ``CragQuestion`` objects.
|
||||
3. Optionally cap to ``--n-questions N`` (and *stratified* sample
|
||||
across ``(domain, question_type)`` so the smoke / partial run
|
||||
isn't dominated by ``finance`` or ``simple``).
|
||||
4. For each question, extract the 5 web pages to clean markdown via
|
||||
``trafilatura`` and write them to
|
||||
``<bench_dir>/pages/<qid>__<page_idx>__<url_hash>.md``. The
|
||||
filename is unique across the whole sample (so SurfSense's
|
||||
``(filename, search_space)`` dedup never collides between
|
||||
questions) and round-trippable (the ``<qid>__`` prefix lets the
|
||||
ingest infer doc-membership at the title level even before we
|
||||
land on a stable status response).
|
||||
5. Upload all extracted pages to SurfSense in batches with text-only
|
||||
ETL (``use_vision_llm=False, processing_mode="basic"``) — these
|
||||
are extracted plaintext, no images involved.
|
||||
6. Persist a doc map at
|
||||
``<suite_data>/maps/crag_doc_map.jsonl`` with one row per question:
|
||||
|
||||
{"qid": "C00042",
|
||||
"interaction_id": "<uuid>",
|
||||
"question": "<text>",
|
||||
"gold_answer": "<text>",
|
||||
"alt_answers": [...],
|
||||
"domain": "...", "question_type": "...",
|
||||
"static_or_dynamic": "...", "popularity": "...",
|
||||
"query_time": "...",
|
||||
"page_filenames": ["C00042__0__abc123.md", ...],
|
||||
"document_ids": [42101, 42102, ...],
|
||||
"missing_pages": [...] # filenames whose upload failed
|
||||
}
|
||||
|
||||
The runner uses ``document_ids`` to scope SurfSense retrieval to
|
||||
exactly the 5 pages of the question (matches CRAG protocol — the
|
||||
benchmark explicitly hands over its own retrieved pages).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.clients.documents import (
|
||||
DocumentProcessingFailed,
|
||||
DocumentProcessingTimeout,
|
||||
)
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
from .dataset import (
|
||||
CragPage,
|
||||
CragQuestion,
|
||||
download_task_1_2,
|
||||
iter_questions,
|
||||
stratified_sample,
|
||||
write_questions_jsonl,
|
||||
)
|
||||
from .html_extract import extract_main_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_FILENAME_SAFE = re.compile(r"[^A-Za-z0-9._\-]+")
|
||||
|
||||
|
||||
def _page_filename(qid: str, page_idx: int, page: CragPage) -> str:
|
||||
"""Filesystem-safe, globally unique markdown filename for a CRAG page.
|
||||
|
||||
Format: ``<qid>__<idx>__<url_hash>.md``. Both the qid (``C00042``)
|
||||
and the URL-hash (``[:12]``) are alphanumeric so we don't need to
|
||||
sanitise them, but we strip anything else just in case.
|
||||
"""
|
||||
|
||||
qid_safe = _FILENAME_SAFE.sub("_", qid)
|
||||
return f"{qid_safe}__{page_idx:02d}__{page.url_hash}.md"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _IngestStats:
|
||||
n_questions: int
|
||||
n_pages_total: int
|
||||
n_pages_extracted: int
|
||||
n_pages_empty: int
|
||||
n_uploaded: int
|
||||
n_existing: int
|
||||
bench_dir: Path
|
||||
map_path: Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Page extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _materialise_pages(
|
||||
questions: list[CragQuestion],
|
||||
*,
|
||||
pages_dir: Path,
|
||||
overwrite: bool = False,
|
||||
) -> tuple[dict[str, list[str]], dict[str, str]]:
|
||||
"""Extract every page in every question to ``pages_dir`` as markdown.
|
||||
|
||||
Returns:
|
||||
* ``qid -> [filename, filename, ...]`` (in page order, only
|
||||
successful extractions)
|
||||
* ``filename -> source_url`` for diagnostics
|
||||
|
||||
Empty extractions (paywall / JS / parse-fail with no fallback
|
||||
output) are skipped — better to retrieve from 4 pages than feed
|
||||
SurfSense's chunker an empty file.
|
||||
"""
|
||||
|
||||
pages_dir.mkdir(parents=True, exist_ok=True)
|
||||
qid_to_files: dict[str, list[str]] = {}
|
||||
file_to_url: dict[str, str] = {}
|
||||
method_counts: dict[str, int] = {}
|
||||
n_empty = 0
|
||||
|
||||
for q in questions:
|
||||
names: list[str] = []
|
||||
for idx, page in enumerate(q.pages):
|
||||
filename = _page_filename(q.qid, idx, page)
|
||||
dest = pages_dir / filename
|
||||
if dest.exists() and dest.stat().st_size > 0 and not overwrite:
|
||||
method_counts["cache_hit"] = method_counts.get("cache_hit", 0) + 1
|
||||
names.append(filename)
|
||||
file_to_url[filename] = page.page_url
|
||||
continue
|
||||
result = extract_main_content(
|
||||
page.page_html,
|
||||
url=page.page_url,
|
||||
page_name=page.page_name,
|
||||
last_modified=page.page_last_modified,
|
||||
)
|
||||
method_counts[result.method] = method_counts.get(result.method, 0) + 1
|
||||
if not result.ok:
|
||||
n_empty += 1
|
||||
continue
|
||||
dest.write_text(result.text, encoding="utf-8")
|
||||
names.append(filename)
|
||||
file_to_url[filename] = page.page_url
|
||||
qid_to_files[q.qid] = names
|
||||
|
||||
logger.info(
|
||||
"CRAG page extraction: %s; empty=%d, total_files=%d across %d questions",
|
||||
method_counts, n_empty, len(file_to_url), len(qid_to_files),
|
||||
)
|
||||
return qid_to_files, file_to_url
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _upload_pages(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
pages_dir: Path,
|
||||
filenames: list[str],
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
) -> dict[str, int]:
|
||||
"""Upload ``filenames`` (already on disk under ``pages_dir``) and return name → doc_id."""
|
||||
|
||||
if not filenames:
|
||||
return {}
|
||||
docs_client = ctx.documents_client()
|
||||
name_to_id: dict[str, int] = {}
|
||||
paths = [pages_dir / fn for fn in filenames if (pages_dir / fn).exists()]
|
||||
|
||||
for batch_start in range(0, len(paths), batch_size):
|
||||
batch = paths[batch_start : batch_start + batch_size]
|
||||
result = await docs_client.upload(
|
||||
files=batch,
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
|
||||
if result.document_ids:
|
||||
try:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=result.document_ids,
|
||||
timeout_s=900.0,
|
||||
)
|
||||
except (DocumentProcessingFailed, DocumentProcessingTimeout) as exc:
|
||||
logger.warning("CRAG batch processing issue: %s", exc)
|
||||
if all_ids:
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=all_ids,
|
||||
)
|
||||
for s in statuses:
|
||||
stem = Path(s.title).stem if s.title.endswith(".md") else s.title
|
||||
name_to_id[stem] = s.document_id
|
||||
name_to_id[s.title] = s.document_id
|
||||
if not s.title.endswith(".md"):
|
||||
name_to_id[f"{s.title}.md"] = s.document_id
|
||||
logger.info(
|
||||
"CRAG upload batch %d-%d: %d new, %d duplicate",
|
||||
batch_start, batch_start + len(batch),
|
||||
len(result.document_ids), len(result.duplicate_document_ids),
|
||||
)
|
||||
return name_to_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Doc map writer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_question_doc_ids(
|
||||
questions: list[CragQuestion],
|
||||
qid_to_files: dict[str, list[str]],
|
||||
name_to_id: dict[str, int],
|
||||
) -> list[dict[str, Any]]:
|
||||
rows: list[dict[str, Any]] = []
|
||||
for q in questions:
|
||||
filenames = qid_to_files.get(q.qid, [])
|
||||
doc_ids: list[int] = []
|
||||
missing: list[str] = []
|
||||
for fn in filenames:
|
||||
stem = Path(fn).stem
|
||||
doc_id = name_to_id.get(stem) or name_to_id.get(fn)
|
||||
if doc_id is not None and doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
else:
|
||||
missing.append(fn)
|
||||
rows.append({
|
||||
"qid": q.qid,
|
||||
"interaction_id": q.interaction_id,
|
||||
"raw_index": q.raw_index,
|
||||
"question": q.query,
|
||||
"gold_answer": q.gold_answer,
|
||||
"alt_answers": list(q.alt_answers),
|
||||
"domain": q.domain,
|
||||
"question_type": q.question_type,
|
||||
"static_or_dynamic": q.static_or_dynamic,
|
||||
"popularity": q.popularity,
|
||||
"query_time": q.query_time,
|
||||
"split": q.split,
|
||||
"page_filenames": filenames,
|
||||
"document_ids": doc_ids,
|
||||
"missing_pages": missing,
|
||||
"n_pages": len(filenames),
|
||||
})
|
||||
return rows
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
n_questions: int | None = None,
|
||||
upload_batch_size: int = 16,
|
||||
skip_upload: bool = False,
|
||||
overwrite_extract: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
sample_seed: int = 17,
|
||||
) -> None:
|
||||
"""Ingest the CRAG benchmark (Tasks 1 & 2) into the research suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_questions
|
||||
Cap on the number of CRAG questions to materialise.
|
||||
``None`` = all 2,706 (~13,500 pages — large; smoke runs
|
||||
should pass 10-20 and full runs ~200).
|
||||
upload_batch_size
|
||||
Markdown files per ``/documents/fileupload`` call.
|
||||
skip_upload
|
||||
Extract + cache markdown locally but don't push to SurfSense
|
||||
(useful for debugging the extraction step).
|
||||
overwrite_extract
|
||||
Re-run trafilatura even when a cached markdown file exists.
|
||||
Default False so re-running ingest is idempotent.
|
||||
settings
|
||||
Override per-upload knobs. CRAG defaults to text-only basic
|
||||
ETL — these are *extracted* plaintext, no images.
|
||||
sample_seed
|
||||
RNG seed for ``stratified_sample``. Pin this for reproducibility.
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
pages_dir = bench_dir / "pages"
|
||||
raw_cache = bench_dir / ".raw_cache"
|
||||
raw_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
bz2_path = download_task_1_2(raw_cache)
|
||||
logger.info("CRAG: parsing %s ...", bz2_path.name)
|
||||
all_questions = iter_questions(bz2_path)
|
||||
if not all_questions:
|
||||
raise RuntimeError(
|
||||
"CRAG JSONL contained no parseable rows; upstream may have changed schema."
|
||||
)
|
||||
logger.info("CRAG: parsed %d total questions", len(all_questions))
|
||||
|
||||
if n_questions is not None and n_questions > 0:
|
||||
questions = stratified_sample(all_questions, n=n_questions, seed=sample_seed)
|
||||
logger.info(
|
||||
"CRAG: stratified sample of %d questions across %d (domain, qtype) buckets",
|
||||
len(questions),
|
||||
len({(q.domain, q.question_type) for q in questions}),
|
||||
)
|
||||
else:
|
||||
questions = all_questions
|
||||
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
write_questions_jsonl(questions, questions_jsonl)
|
||||
|
||||
n_pages_total = sum(len(q.pages) for q in questions)
|
||||
logger.info(
|
||||
"CRAG: extracting up to %d pages across %d questions ...",
|
||||
n_pages_total, len(questions),
|
||||
)
|
||||
qid_to_files, file_to_url = _materialise_pages(
|
||||
questions, pages_dir=pages_dir, overwrite=overwrite_extract,
|
||||
)
|
||||
n_pages_extracted = sum(len(v) for v in qid_to_files.values())
|
||||
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("CRAG: --skip-upload; skipping SurfSense ingestion")
|
||||
else:
|
||||
all_filenames = sorted({fn for fns in qid_to_files.values() for fn in fns})
|
||||
logger.info("CRAG: uploading %d unique pages ...", len(all_filenames))
|
||||
name_to_id = await _upload_pages(
|
||||
ctx,
|
||||
pages_dir=pages_dir,
|
||||
filenames=all_filenames,
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
doc_rows = _resolve_question_doc_ids(questions, qid_to_files, name_to_id)
|
||||
map_path = ctx.maps_dir() / "crag_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for row in doc_rows:
|
||||
fh.write(json.dumps(row) + "\n")
|
||||
logger.info("Wrote CRAG doc map to %s (%d rows)", map_path, len(doc_rows))
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["crag"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
stats = _IngestStats(
|
||||
n_questions=len(questions),
|
||||
n_pages_total=n_pages_total,
|
||||
n_pages_extracted=n_pages_extracted,
|
||||
n_pages_empty=n_pages_total - n_pages_extracted,
|
||||
n_uploaded=len(name_to_id),
|
||||
n_existing=0,
|
||||
bench_dir=bench_dir,
|
||||
map_path=map_path,
|
||||
)
|
||||
logger.info("CRAG ingest done: %s", stats)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# For runner: read extracted page text back from disk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_page_markdown(bench_dir: Path, filename: str) -> str | None:
|
||||
"""Return the on-disk markdown body for a previously-extracted page.
|
||||
|
||||
Used by the long-context runner arm to assemble the prompt at
|
||||
inference time — we don't keep all 5×N pages in memory between
|
||||
ingest and run.
|
||||
"""
|
||||
|
||||
path = bench_dir / "pages" / filename
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
async def _retry_upload_idempotent( # noqa: D401 - hidden helper
|
||||
ctx: RunContext,
|
||||
*,
|
||||
pages_dir: Path,
|
||||
filenames: list[str],
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
max_attempts: int = 2,
|
||||
) -> dict[str, int]:
|
||||
"""Future-proofing hook (unused today): retry the ingest upload pass."""
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return await _upload_pages(
|
||||
ctx,
|
||||
pages_dir=pages_dir,
|
||||
filenames=filenames,
|
||||
batch_size=batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_exc = exc
|
||||
logger.warning("CRAG upload attempt %d failed: %s", attempt + 1, exc)
|
||||
await asyncio.sleep(2.0 * (attempt + 1))
|
||||
if last_exc is not None:
|
||||
raise last_exc
|
||||
return {}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_IngestStats",
|
||||
"_materialise_pages",
|
||||
"_page_filename",
|
||||
"_resolve_question_doc_ids",
|
||||
"_upload_pages",
|
||||
"read_page_markdown",
|
||||
"run_ingest",
|
||||
]
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
"""CRAG Task 3 ingestion: 4-part download → streaming JSONL → upload.
|
||||
|
||||
Same flow as ``ingest.run_ingest`` for Tasks 1 & 2 (extract HTML →
|
||||
upload markdown → resolve doc_ids → write doc map), but:
|
||||
|
||||
* Source: 4 .tar.bz2 parts streamed via ``dataset_task3``.
|
||||
* Page count: 50 per question instead of 5 — the whole point of
|
||||
Task 3 (the long-context arm now structurally has to choose what
|
||||
to keep, while SurfSense's retrieval becomes mandatory).
|
||||
* Stratified sampling re-uses the Task 1 helper since the question
|
||||
schema is identical.
|
||||
|
||||
Doc map lands at ``<suite_data>/maps/crag_t3_doc_map.jsonl`` with the
|
||||
same row shape as Task 1's map (so the runner only needs to know
|
||||
which file to load; everything else is shared).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
from .dataset import stratified_sample, write_questions_jsonl
|
||||
from .dataset_task3 import (
|
||||
CRAG_TASK_3_PART_NAMES,
|
||||
iter_questions_task3,
|
||||
parts_present,
|
||||
)
|
||||
from .ingest import (
|
||||
_IngestStats,
|
||||
_materialise_pages,
|
||||
_resolve_question_doc_ids,
|
||||
_upload_pages,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_INSTRUCTIONS_TO_DOWNLOAD = (
|
||||
"Run `python scripts/download_crag_task3.py` first to fetch the "
|
||||
"4 tar.bz2 parts (~7 GB total) into "
|
||||
"data/research/crag_t3/.raw_cache/. The downloader is idempotent "
|
||||
"and parallel."
|
||||
)
|
||||
|
||||
|
||||
async def run_ingest_task3(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
n_questions: int | None = None,
|
||||
upload_batch_size: int = 16,
|
||||
skip_upload: bool = False,
|
||||
overwrite_extract: bool = False,
|
||||
settings: IngestSettings | None = None,
|
||||
sample_seed: int = 17,
|
||||
parse_cap: int | None = None,
|
||||
) -> None:
|
||||
"""Ingest CRAG Task 3 (50 pages per question) into the research suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_questions
|
||||
Cap on the post-stratified-sample question count. ``None`` =
|
||||
"use whatever ``parse_cap`` produced". For real runs aim for
|
||||
50 (~2,500 pages) — n=200 (10k pages) is doable but slow.
|
||||
parse_cap
|
||||
Hard cap on how many rows we *parse* from the streaming
|
||||
archive before stratified sampling. Defaults to
|
||||
``max(400, 6*n_questions)`` — enough to cover all (domain,
|
||||
question_type) buckets ~5x but small enough to fit in the
|
||||
first shard or two (each shard is ≈5 GB decompressed and
|
||||
holds ~300 rows; bz2 throughput is ~50 MB/s). Lowering this
|
||||
is the only knob that bounds streaming cost since we can
|
||||
``break`` out of the JSONL stream early without decompressing
|
||||
the rest of the ~50 GB archive body.
|
||||
upload_batch_size
|
||||
Markdown files per ``/documents/fileupload`` call.
|
||||
skip_upload
|
||||
Extract markdown locally, don't push to SurfSense.
|
||||
overwrite_extract
|
||||
Re-run trafilatura even when a cached markdown is present.
|
||||
settings
|
||||
Per-upload knobs override (default: text-only basic ETL).
|
||||
sample_seed
|
||||
RNG seed for stratified sampling (deterministic).
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
pages_dir = bench_dir / "pages"
|
||||
raw_cache = bench_dir / ".raw_cache"
|
||||
raw_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not parts_present(raw_cache):
|
||||
missing = [
|
||||
n for n in CRAG_TASK_3_PART_NAMES
|
||||
if not (raw_cache / n).exists()
|
||||
]
|
||||
raise RuntimeError(
|
||||
f"CRAG Task 3 parts missing from {raw_cache}: {missing}. "
|
||||
f"{_INSTRUCTIONS_TO_DOWNLOAD}"
|
||||
)
|
||||
|
||||
# 1. Stream-parse (capped). For n=50 we don't need the full 2,706
|
||||
# rows — just enough that the stratified sampler can balance.
|
||||
# Each tar shard ~5 GB / ~300 rows / ~2 min decompress, so
|
||||
# 400-500 rows = shard 0 + a slice of shard 1 ≈ 3-4 min.
|
||||
parse_cap = parse_cap or (
|
||||
max(400, 6 * (n_questions or 50)) if n_questions else None
|
||||
)
|
||||
logger.info(
|
||||
"CRAG Task 3: streaming JSONL (parse_cap=%s) ...",
|
||||
parse_cap if parse_cap else "no-cap",
|
||||
)
|
||||
all_questions = iter_questions_task3(raw_cache, max_questions=parse_cap)
|
||||
logger.info("CRAG Task 3: parsed %d rows", len(all_questions))
|
||||
|
||||
if not all_questions:
|
||||
raise RuntimeError("CRAG Task 3 streaming returned 0 rows; check archive integrity.")
|
||||
|
||||
if n_questions is not None and n_questions > 0:
|
||||
questions = stratified_sample(all_questions, n=n_questions, seed=sample_seed)
|
||||
logger.info(
|
||||
"CRAG Task 3: stratified sample of %d questions across %d (domain, qtype) buckets",
|
||||
len(questions),
|
||||
len({(q.domain, q.question_type) for q in questions}),
|
||||
)
|
||||
else:
|
||||
questions = all_questions
|
||||
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
write_questions_jsonl(questions, questions_jsonl)
|
||||
|
||||
n_pages_total = sum(len(q.pages) for q in questions)
|
||||
logger.info(
|
||||
"CRAG Task 3: extracting up to %d pages across %d questions ...",
|
||||
n_pages_total, len(questions),
|
||||
)
|
||||
qid_to_files, _file_to_url = _materialise_pages(
|
||||
questions, pages_dir=pages_dir, overwrite=overwrite_extract,
|
||||
)
|
||||
n_pages_extracted = sum(len(v) for v in qid_to_files.values())
|
||||
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("CRAG Task 3: --skip-upload; skipping SurfSense ingestion")
|
||||
else:
|
||||
all_filenames = sorted({fn for fns in qid_to_files.values() for fn in fns})
|
||||
logger.info("CRAG Task 3: uploading %d unique pages ...", len(all_filenames))
|
||||
name_to_id = await _upload_pages(
|
||||
ctx,
|
||||
pages_dir=pages_dir,
|
||||
filenames=all_filenames,
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
doc_rows = _resolve_question_doc_ids(questions, qid_to_files, name_to_id)
|
||||
map_path = ctx.maps_dir() / "crag_t3_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for row in doc_rows:
|
||||
fh.write(json.dumps(row) + "\n")
|
||||
logger.info("Wrote CRAG Task 3 doc map to %s (%d rows)", map_path, len(doc_rows))
|
||||
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["crag_t3"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
stats = _IngestStats(
|
||||
n_questions=len(questions),
|
||||
n_pages_total=n_pages_total,
|
||||
n_pages_extracted=n_pages_extracted,
|
||||
n_pages_empty=n_pages_total - n_pages_extracted,
|
||||
n_uploaded=len(name_to_id),
|
||||
n_existing=0,
|
||||
bench_dir=bench_dir,
|
||||
map_path=map_path,
|
||||
)
|
||||
logger.info("CRAG Task 3 ingest done: %s", stats)
|
||||
|
||||
|
||||
__all__ = ["run_ingest_task3"]
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
"""CRAG prompt templates for the three competing arms.
|
||||
|
||||
The CRAG paper grades each prediction as one of:
|
||||
|
||||
* **correct** — answer matches gold (with paraphrasing tolerance)
|
||||
* **missing** — model refuses or says "I don't know"
|
||||
* **incorrect** — model commits to a wrong answer (hallucination)
|
||||
|
||||
The truthfulness score `(correct - incorrect) / total` rewards
|
||||
calibrated abstention, so the prompts below explicitly *invite* the
|
||||
model to refuse when it isn't confident — otherwise the bare-LLM arm
|
||||
gets penalised twice (no docs *and* a no-refusal prompt) and the
|
||||
comparison stops being fair to the LLM-only baseline.
|
||||
|
||||
Three templates, byte-identical instructions:
|
||||
|
||||
* ``build_bare_prompt(q)`` — question-only.
|
||||
* ``build_long_context_prompt(q, contexts)`` — question + concatenated
|
||||
page extracts, all stuffed into the user message. Mirrors the
|
||||
paper's "straightforward RAG" baseline.
|
||||
* ``build_surfsense_prompt(q)`` — question + a hint that retrieval
|
||||
over the question's 5 ingested pages is available; the SurfSense
|
||||
agent itself owns the retrieval step.
|
||||
|
||||
The ``Answer:`` line at the end is parsed by ``extract_freeform_answer``
|
||||
in the runner, so the format is mandatory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
_BASE_INSTRUCTIONS = (
|
||||
"You are a careful question-answering assistant. The question is a "
|
||||
"real-world factual question that may be about finance, music, "
|
||||
"movies, sports, or any other domain.\n\n"
|
||||
"Important rules:\n"
|
||||
"1. If the question contains a false premise (an assumption that "
|
||||
"is factually wrong), say so explicitly in your final answer "
|
||||
"rather than answering as if the premise were true.\n"
|
||||
"2. If you are not confident in an answer, prefer saying \"I don't "
|
||||
"know\" over guessing. A wrong commit is penalised more than a "
|
||||
"refusal.\n"
|
||||
"3. Keep the final answer short — a name, a number, a date, a "
|
||||
"phrase. Do not repeat the question.\n\n"
|
||||
"Format your final line EXACTLY as:\n"
|
||||
"Answer: <short answer>\n"
|
||||
"If you don't know, write `Answer: I don't know`."
|
||||
)
|
||||
|
||||
|
||||
_BARE_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
Question: {question}
|
||||
Question time: {query_time}
|
||||
"""
|
||||
|
||||
|
||||
_SURFSENSE_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
You have access to a search index of up to 5 web pages that were
|
||||
retrieved for this question. Use the retrieval tool to look up any
|
||||
facts you are not confident about. The pages may be partially or fully
|
||||
relevant; some may contradict each other (prefer the more authoritative
|
||||
or more recent source).
|
||||
|
||||
Question: {question}
|
||||
Question time: {query_time}
|
||||
"""
|
||||
|
||||
|
||||
_LONG_CONTEXT_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
You are given the full text of {n_contexts} web pages that were
|
||||
retrieved for this question. Read all of them, then answer. The
|
||||
pages may be partially or fully relevant; some may contradict each
|
||||
other (prefer the more authoritative or more recent source).
|
||||
|
||||
{contexts}
|
||||
|
||||
Question: {question}
|
||||
Question time: {query_time}
|
||||
"""
|
||||
|
||||
|
||||
def build_bare_prompt(question: str, *, query_time: str = "") -> str:
|
||||
"""Prompt for the no-retrieval baseline arm."""
|
||||
|
||||
return _BARE_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
question=question.strip(),
|
||||
query_time=query_time.strip() or "unknown",
|
||||
)
|
||||
|
||||
|
||||
def build_surfsense_prompt(question: str, *, query_time: str = "") -> str:
|
||||
"""Prompt for the SurfSense arm (agent does retrieval itself)."""
|
||||
|
||||
return _SURFSENSE_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
question=question.strip(),
|
||||
query_time=query_time.strip() or "unknown",
|
||||
)
|
||||
|
||||
|
||||
def build_long_context_prompt(
|
||||
question: str,
|
||||
*,
|
||||
contexts: list[tuple[str, str]],
|
||||
query_time: str = "",
|
||||
per_page_char_cap: int = 12_000,
|
||||
) -> str:
|
||||
"""Prompt for the "stuff all pages into the prompt" baseline.
|
||||
|
||||
``contexts`` is a list of ``(page_title_or_url, page_text)`` pairs.
|
||||
Each page is truncated at ``per_page_char_cap`` (default 12k chars
|
||||
≈ 3k tokens) so a 5-page CRAG question fits well under any
|
||||
modern long-context window with room for the question + reasoning.
|
||||
"""
|
||||
|
||||
blocks: list[str] = []
|
||||
for idx, (title, text) in enumerate(contexts, start=1):
|
||||
body = (text or "").strip()
|
||||
if len(body) > per_page_char_cap:
|
||||
body = body[:per_page_char_cap].rstrip() + "\n[...truncated...]"
|
||||
title_clean = (title or f"page_{idx}").strip().replace("\n", " ")
|
||||
blocks.append(
|
||||
f"--- PAGE {idx}: {title_clean} ---\n{body}\n"
|
||||
)
|
||||
contexts_block = "\n".join(blocks) if blocks else "(no pages retrieved)"
|
||||
return _LONG_CONTEXT_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
n_contexts=len(contexts),
|
||||
contexts=contexts_block,
|
||||
question=question.strip(),
|
||||
query_time=query_time.strip() or "unknown",
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_bare_prompt",
|
||||
"build_long_context_prompt",
|
||||
"build_surfsense_prompt",
|
||||
]
|
||||
1053
surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py
Normal file
1053
surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,29 @@
|
|||
"""FRAMES — multi-hop Wikipedia retrieval & reasoning (google/frames-benchmark).
|
||||
|
||||
Source: https://huggingface.co/datasets/google/frames-benchmark
|
||||
Paper: https://arxiv.org/abs/2409.12941 (Krishna et al., 2024)
|
||||
|
||||
* 824 multi-hop questions, each requiring 2-15 Wikipedia articles
|
||||
* 5 reasoning types: numerical, tabular, multiple constraints,
|
||||
temporal, post-processing
|
||||
* Published Gemini-Pro-1.5 baselines:
|
||||
- Naive prompting (no retrieval): 40.8%
|
||||
- BM25, top-4: 47.4%
|
||||
- Multi-step retrieval & reasoning: 66.0%
|
||||
- Oracle retrieval (gold articles): 72.9%
|
||||
|
||||
This is the benchmark that *finally* puts SurfSense's strongest claim
|
||||
on trial: cross-document iterative retrieval. The harness ingests
|
||||
every Wikipedia article referenced by any question in the run sample
|
||||
into a single SearchSpace; SurfSense answers without
|
||||
``mentioned_document_ids`` so its agent has to actually retrieve.
|
||||
The bare-LLM arm answers from the prompt only (the published 40.8%
|
||||
baseline number).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import registry as _registry
|
||||
from .runner import FramesBenchmark
|
||||
|
||||
_registry.register(FramesBenchmark())
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
"""FRAMES dataset loader — download ``test.tsv`` from HF and parse rows.
|
||||
|
||||
The HF repo (``google/frames-benchmark``) ships a single tab-separated
|
||||
file at ``test.tsv`` (824 rows). Columns of interest for us:
|
||||
|
||||
* unnamed first column → row index (``id`` we synthesise as ``Q000``..)
|
||||
* ``Prompt`` → the question (free-text, often multi-clause)
|
||||
* ``Answer`` → gold answer (short string: name, number, year, ...)
|
||||
* ``wikipedia_link_1`` ... ``wikipedia_link_11+`` → sparse per-question
|
||||
link cells (we ignore in favour of the consolidated column below).
|
||||
* ``reasoning_types`` → pipe-separated tags (``"Numerical reasoning |
|
||||
Tabular reasoning | Multiple constraints"``)
|
||||
* ``wiki_links`` → Python-list literal of every URL the question relies
|
||||
on, e.g. ``"['https://en.wikipedia.org/wiki/...', '...']"``
|
||||
|
||||
We use ``wiki_links`` (already deduplicated per row) and
|
||||
``ast.literal_eval`` to materialise it. The legacy
|
||||
``wikipedia_link_*`` columns are kept around only so a curious
|
||||
operator can compare cell-vs-list if upstream ever drift apart.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HF_REPO_ID = "google/frames-benchmark"
|
||||
HF_REPO_TYPE = "dataset"
|
||||
HF_TEST_FILE = "test.tsv"
|
||||
|
||||
|
||||
def _hf_hub_download(*args: Any, **kwargs: Any) -> str:
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
return hf_hub_download(*args, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class FramesQuestion:
|
||||
"""One row of FRAMES (post-parse)."""
|
||||
|
||||
qid: str # synthesised "Q000" .. "Q823"
|
||||
question: str
|
||||
gold_answer: str
|
||||
wiki_urls: list[str] # deduped, in original order
|
||||
reasoning_types: list[str] # split on "|"
|
||||
raw_index: int # row index from the TSV (for debugging)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"qid": self.qid,
|
||||
"question": self.question,
|
||||
"gold_answer": self.gold_answer,
|
||||
"wiki_urls": list(self.wiki_urls),
|
||||
"reasoning_types": list(self.reasoning_types),
|
||||
"raw_index": self.raw_index,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download + parse
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def download_test_tsv(cache_dir: Path) -> Path:
|
||||
"""Resumable download of ``test.tsv`` via ``huggingface_hub``."""
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
local = _hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=HF_TEST_FILE,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
cache_dir=str(cache_dir),
|
||||
)
|
||||
return Path(local)
|
||||
|
||||
|
||||
def _parse_wiki_links(raw: Any) -> list[str]:
|
||||
"""Convert the ``wiki_links`` cell (Python list literal) to ``list[str]``."""
|
||||
|
||||
if not raw:
|
||||
return []
|
||||
if isinstance(raw, list):
|
||||
return [str(x).strip() for x in raw if str(x).strip()]
|
||||
text = str(raw).strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
parsed = ast.literal_eval(text)
|
||||
except (SyntaxError, ValueError):
|
||||
# Fall back: maybe it's a comma-separated string with no quotes.
|
||||
return [tok.strip() for tok in text.strip("[]").split(",") if tok.strip()]
|
||||
if isinstance(parsed, (list, tuple)):
|
||||
return [str(x).strip() for x in parsed if str(x).strip()]
|
||||
return [str(parsed).strip()]
|
||||
|
||||
|
||||
def _parse_reasoning_types(raw: Any) -> list[str]:
|
||||
if not raw:
|
||||
return []
|
||||
text = str(raw).strip()
|
||||
if not text:
|
||||
return []
|
||||
return [tok.strip() for tok in text.split("|") if tok.strip()]
|
||||
|
||||
|
||||
def load_questions(tsv_path: Path) -> list[FramesQuestion]:
|
||||
"""Read FRAMES rows from disk into ``FramesQuestion`` objects.
|
||||
|
||||
Uses pandas for robust TSV parsing (tabs inside quoted strings are
|
||||
rare in this dataset but pandas handles them; the stdlib ``csv``
|
||||
module is fine too if pandas ever becomes a problem). We pin
|
||||
``index_col=0`` because the upstream TSV uses the first unnamed
|
||||
column as the row index.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_csv(tsv_path, sep="\t", index_col=0, keep_default_na=False)
|
||||
out: list[FramesQuestion] = []
|
||||
for raw_idx, row in df.iterrows():
|
||||
prompt = str(row.get("Prompt") or "").strip()
|
||||
answer = str(row.get("Answer") or "").strip()
|
||||
if not prompt or not answer:
|
||||
logger.debug("Skipping FRAMES row %s with missing Prompt/Answer", raw_idx)
|
||||
continue
|
||||
urls = _parse_wiki_links(row.get("wiki_links"))
|
||||
if not urls:
|
||||
# Fall back to the per-cell ``wikipedia_link_*`` columns.
|
||||
urls = []
|
||||
for col in row.index:
|
||||
if col.startswith("wikipedia_link"):
|
||||
val = str(row.get(col) or "").strip()
|
||||
if val and val not in urls:
|
||||
urls.append(val)
|
||||
reasoning = _parse_reasoning_types(row.get("reasoning_types"))
|
||||
out.append(FramesQuestion(
|
||||
qid=f"Q{int(raw_idx):03d}",
|
||||
question=prompt,
|
||||
gold_answer=answer,
|
||||
wiki_urls=urls,
|
||||
reasoning_types=reasoning,
|
||||
raw_index=int(raw_idx),
|
||||
))
|
||||
return out
|
||||
|
||||
|
||||
def write_questions_jsonl(questions: list[FramesQuestion], dest: Path) -> None:
|
||||
"""Persist a parsed copy under the benchmark data dir."""
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
with dest.open("w", encoding="utf-8") as fh:
|
||||
for q in questions:
|
||||
fh.write(json.dumps(q.to_dict()) + "\n")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FramesQuestion",
|
||||
"download_test_tsv",
|
||||
"load_questions",
|
||||
"write_questions_jsonl",
|
||||
]
|
||||
|
|
@ -0,0 +1,341 @@
|
|||
"""FRAMES grader: deterministic shortcut + LLM-as-judge fallback.
|
||||
|
||||
FRAMES gold answers are short factoids (a name, a year, an ordinal,
|
||||
a count). The published paper uses an LLM judge for grading, citing
|
||||
the long tail of paraphrasing ("Jane Ballou" vs "Mrs. Ballou (Jane)";
|
||||
"5" vs "five"; "London, England" vs "London"). We replicate that
|
||||
faithfully *but* avoid burning judge tokens on the obvious cases.
|
||||
|
||||
Pipeline per (pred, gold):
|
||||
|
||||
1. Normalise both sides (SQuAD-style).
|
||||
2. If normalised pred == normalised gold → CORRECT (``method=exact``).
|
||||
3. Numeric path: if both extract to a single number and the values
|
||||
match within 1% relative tolerance → CORRECT (``method=numeric``).
|
||||
4. Substring path: if normalised gold appears as a *whole-word phrase*
|
||||
inside normalised pred (or vice versa) → CORRECT
|
||||
(``method=substring``).
|
||||
5. Otherwise → call the LLM judge if a judge is wired; the judge
|
||||
returns yes/no with a one-line rationale.
|
||||
6. If no judge is configured, fall through to ``False``
|
||||
(``method=lexical_miss``).
|
||||
|
||||
The judge is called *concurrently* across the run via a semaphore (so
|
||||
it doesn't outrun the upstream rate limit). Cached on
|
||||
``(arm, qid)`` so re-running ``report`` doesn't re-judge.
|
||||
|
||||
Returned shape mirrors ``mmlongbench.grader.GradeResult`` to keep
|
||||
report writers uniform across benchmarks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import string
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from ....core.providers.openrouter_chat import OpenRouterChatProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradeResult:
|
||||
"""Shape mirrors mmlongbench.grader.GradeResult for report uniformity."""
|
||||
|
||||
correct: bool
|
||||
f1: float
|
||||
method: str
|
||||
normalised_pred: str = ""
|
||||
normalised_gold: str = ""
|
||||
judge_rationale: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"correct": self.correct,
|
||||
"f1": self.f1,
|
||||
"method": self.method,
|
||||
"normalised_pred": self.normalised_pred,
|
||||
"normalised_gold": self.normalised_gold,
|
||||
"judge_rationale": self.judge_rationale,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
|
||||
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.IGNORECASE)
|
||||
_WS = re.compile(r"\s+")
|
||||
|
||||
|
||||
def _normalise(s: str) -> str:
|
||||
s = (s or "").lower()
|
||||
s = s.translate(_PUNCT_TABLE)
|
||||
s = _ARTICLES.sub(" ", s)
|
||||
s = _WS.sub(" ", s).strip()
|
||||
return s
|
||||
|
||||
|
||||
_WORD_NUMBERS = {
|
||||
"zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
|
||||
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11,
|
||||
"twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15, "sixteen": 16,
|
||||
"seventeen": 17, "eighteen": 18, "nineteen": 19, "twenty": 20,
|
||||
}
|
||||
|
||||
_NUMERIC_RE = re.compile(r"-?\d+(?:[.,]\d+)?")
|
||||
|
||||
|
||||
def _maybe_number(s: str) -> float | None:
|
||||
"""Extract a single numeric value, recognising digit and word forms.
|
||||
|
||||
Operates on the lowercased *raw* text (rather than the
|
||||
punctuation-stripped normalisation) so that thousands separators
|
||||
like ``1,234`` are preserved through the regex and parsed
|
||||
correctly. We only fall back to ``_normalise`` for the word-number
|
||||
pass, which doesn't care about punctuation.
|
||||
"""
|
||||
|
||||
raw = (s or "").strip().lower()
|
||||
if not raw:
|
||||
return None
|
||||
match = _NUMERIC_RE.search(raw)
|
||||
if match:
|
||||
try:
|
||||
return float(match.group(0).replace(",", ""))
|
||||
except ValueError:
|
||||
pass
|
||||
for tok in _normalise(s).split():
|
||||
if tok in _WORD_NUMBERS:
|
||||
return float(_WORD_NUMBERS[tok])
|
||||
return None
|
||||
|
||||
|
||||
def _whole_word_substring(haystack: str, needle: str) -> bool:
|
||||
"""Is ``needle`` present as a whole-word phrase in ``haystack``?"""
|
||||
|
||||
if not needle:
|
||||
return False
|
||||
pad_h = f" {haystack} "
|
||||
pad_n = f" {needle} "
|
||||
return pad_n in pad_h
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deterministic shortcut
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def grade_deterministic(*, pred: str, gold: str) -> GradeResult:
|
||||
"""Try to grade without the LLM judge. Returns a final-result object.
|
||||
|
||||
A ``False`` result with ``method == "lexical_miss"`` is the signal
|
||||
to the caller that the LLM judge should be consulted (if available).
|
||||
"""
|
||||
|
||||
if not (pred or "").strip():
|
||||
return GradeResult(False, 0.0, "empty_pred", "", _normalise(gold))
|
||||
|
||||
p = _normalise(pred)
|
||||
g = _normalise(gold)
|
||||
if not g:
|
||||
# Defensively: gold should never be empty; if it is, we can't grade.
|
||||
return GradeResult(False, 0.0, "empty_gold", p, g)
|
||||
|
||||
if p == g:
|
||||
return GradeResult(True, 1.0, "exact", p, g)
|
||||
|
||||
p_num = _maybe_number(pred)
|
||||
g_num = _maybe_number(gold)
|
||||
if p_num is not None and g_num is not None:
|
||||
# 1% relative tolerance, 0.5 absolute floor (handles year-ish answers).
|
||||
tol = max(abs(g_num) * 0.01, 0.5)
|
||||
if abs(p_num - g_num) <= tol:
|
||||
return GradeResult(True, 1.0, "numeric", p, g)
|
||||
return GradeResult(False, 0.0, "numeric_miss", p, g)
|
||||
|
||||
if _whole_word_substring(p, g):
|
||||
return GradeResult(True, 1.0, "substring", p, g)
|
||||
if _whole_word_substring(g, p) and len(p) >= 3:
|
||||
# Be conservative the other direction — only credit if pred is
|
||||
# at least 3 normalised chars (avoids "John" matching gold
|
||||
# "John F. Kennedy" as correct).
|
||||
return GradeResult(True, 1.0, "substring_reverse", p, g)
|
||||
|
||||
return GradeResult(False, 0.0, "lexical_miss", p, g)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM-as-judge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_JUDGE_SYSTEM = (
|
||||
"You are an impartial grader for short-answer factual questions. "
|
||||
"Given a question, the gold answer, and a model's prediction, "
|
||||
"decide whether the prediction is correct. The prediction is "
|
||||
"correct if it expresses the same factual content as the gold "
|
||||
"answer, allowing for paraphrasing, surface-level differences "
|
||||
"(numbers as words, names with/without titles), and additional "
|
||||
"non-contradictory detail. The prediction is incorrect if it "
|
||||
"expresses a different fact, omits the central answer, or hedges "
|
||||
"without committing.\n\n"
|
||||
"Respond with ONLY a JSON object on a single line:\n"
|
||||
'{\"correct\": true|false, \"rationale\": \"<one short sentence>\"}'
|
||||
)
|
||||
|
||||
|
||||
_JUDGE_TEMPLATE = """\
|
||||
Question: {question}
|
||||
Gold answer: {gold}
|
||||
Model prediction: {pred}
|
||||
|
||||
Decide whether the prediction is correct.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class JudgeConfig:
|
||||
"""Configuration handed to ``LlmJudge`` at construction time."""
|
||||
|
||||
api_key: str
|
||||
model: str = "anthropic/claude-sonnet-4.5"
|
||||
base_url: str = "https://openrouter.ai/api/v1"
|
||||
max_tokens: int = 200
|
||||
concurrency: int = 4
|
||||
|
||||
|
||||
class LlmJudge:
|
||||
"""Async LLM judge over OpenRouter chat completions."""
|
||||
|
||||
def __init__(self, *, config: JudgeConfig) -> None:
|
||||
self._config = config
|
||||
self._provider = OpenRouterChatProvider(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
model=config.model,
|
||||
)
|
||||
self._sem = asyncio.Semaphore(max(1, config.concurrency))
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._config.model
|
||||
|
||||
async def judge(
|
||||
self,
|
||||
*,
|
||||
question: str,
|
||||
gold: str,
|
||||
pred: str,
|
||||
) -> tuple[bool, str]:
|
||||
"""Return ``(is_correct, rationale)``. Errors return False + reason."""
|
||||
|
||||
prompt = _JUDGE_TEMPLATE.format(question=question, gold=gold, pred=pred)
|
||||
try:
|
||||
async with self._sem:
|
||||
response = await self._provider.complete(
|
||||
prompt=prompt,
|
||||
system_prompt=_JUDGE_SYSTEM,
|
||||
max_tokens=self._config.max_tokens,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return False, f"judge_error: {type(exc).__name__}: {exc}"
|
||||
return _parse_judge_response(response.text)
|
||||
|
||||
|
||||
def _parse_judge_response(text: str) -> tuple[bool, str]:
|
||||
"""Pull ``correct`` + ``rationale`` out of the judge's reply."""
|
||||
|
||||
if not text or not text.strip():
|
||||
return False, "judge_returned_empty"
|
||||
# Accept JSON anywhere in the message; some models prepend prose.
|
||||
match = re.search(r"\{[^{}]*\}", text, flags=re.DOTALL)
|
||||
candidate = match.group(0) if match else text
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Fallback: yes/no parsing.
|
||||
lowered = text.strip().lower()
|
||||
if lowered.startswith("yes") or "correct: yes" in lowered or '"correct": true' in lowered:
|
||||
return True, "yes (parser_fallback)"
|
||||
if lowered.startswith("no") or "correct: no" in lowered or '"correct": false' in lowered:
|
||||
return False, "no (parser_fallback)"
|
||||
return False, f"unparseable_judge_response: {text[:200]}"
|
||||
correct = bool(data.get("correct"))
|
||||
rationale = str(data.get("rationale", "")).strip()[:280]
|
||||
return correct, rationale
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Combined grader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def grade_with_judge(
|
||||
*,
|
||||
pred: str,
|
||||
gold: str,
|
||||
question: str,
|
||||
judge: LlmJudge | None,
|
||||
) -> GradeResult:
|
||||
"""Grade one row: deterministic shortcut → optional LLM judge fallback."""
|
||||
|
||||
det = grade_deterministic(pred=pred, gold=gold)
|
||||
if det.correct or det.method != "lexical_miss":
|
||||
return det
|
||||
if judge is None:
|
||||
return det
|
||||
is_correct, rationale = await judge.judge(question=question, gold=gold, pred=pred)
|
||||
return GradeResult(
|
||||
correct=is_correct,
|
||||
f1=1.0 if is_correct else 0.0,
|
||||
method="llm_judge",
|
||||
normalised_pred=det.normalised_pred,
|
||||
normalised_gold=det.normalised_gold,
|
||||
judge_rationale=rationale,
|
||||
)
|
||||
|
||||
|
||||
async def grade_many(
|
||||
*,
|
||||
rows: Sequence[tuple[str, str, str, str]],
|
||||
judge: LlmJudge | None,
|
||||
) -> list[GradeResult]:
|
||||
"""Grade ``[(qid, question, gold, pred), ...]`` concurrently.
|
||||
|
||||
The judge already enforces its own concurrency cap; this just
|
||||
schedules everything via ``asyncio.gather``. ``qid`` is unused
|
||||
inside the grader but threaded through so callers can correlate
|
||||
results back to their rows.
|
||||
"""
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
coros = [
|
||||
grade_with_judge(pred=p, gold=g, question=q, judge=judge)
|
||||
for _qid, q, g, p in rows
|
||||
]
|
||||
return list(await asyncio.gather(*coros))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GradeResult",
|
||||
"JudgeConfig",
|
||||
"LlmJudge",
|
||||
"grade_deterministic",
|
||||
"grade_many",
|
||||
"grade_with_judge",
|
||||
]
|
||||
|
|
@ -0,0 +1,341 @@
|
|||
"""FRAMES ingestion: download → fetch Wikipedia → upload markdown.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Download ``test.tsv`` from ``hf://datasets/google/frames-benchmark``.
|
||||
2. Parse rows into ``FramesQuestion`` objects.
|
||||
3. Optionally cap to the first ``--max-questions N`` so a smoke run
|
||||
doesn't trigger a 1k-article fetch.
|
||||
4. Build the **deduplicated** set of Wikipedia URLs across the chosen
|
||||
sample (questions share many articles — Q1 and Q42 might both
|
||||
reference ``James_A._Garfield``).
|
||||
5. Fetch each unique article via ``WikiFetcher`` (polite 2 RPS) into
|
||||
``<bench_dir>/wiki/<title>.md``.
|
||||
6. Upload the resulting markdown files to SurfSense in batches with
|
||||
``use_vision_llm=False, processing_mode="basic"`` (text-only — no
|
||||
reason to pay vision LLM costs on Wikipedia plaintext).
|
||||
7. Persist a doc map at
|
||||
``<suite_data>/maps/frames_doc_map.jsonl`` with one row per question
|
||||
listing its ``document_ids`` (so the runner *could* scope retrieval
|
||||
if requested, though by default we don't — see ``runner.py``).
|
||||
|
||||
The doc map row shape:
|
||||
|
||||
{"qid": "Q000",
|
||||
"wiki_titles": ["President of the United States", "James Buchanan", ...],
|
||||
"document_ids": [123, 124, ...],
|
||||
"missing_titles": []}
|
||||
|
||||
We resolve titles → SurfSense document_ids via the post-upload
|
||||
``DocumentStatus.title`` field. SurfSense's title is the uploaded
|
||||
filename (without extension), so we round-trip via
|
||||
``cache_filename_for_title`` to match.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.clients.documents import (
|
||||
DocumentProcessingFailed,
|
||||
DocumentProcessingTimeout,
|
||||
)
|
||||
from ....core.config import set_suite_state
|
||||
from ....core.ingest_settings import IngestSettings, settings_header_line
|
||||
from ....core.registry import RunContext
|
||||
from .dataset import (
|
||||
download_test_tsv,
|
||||
load_questions,
|
||||
write_questions_jsonl,
|
||||
)
|
||||
from .wiki_fetch import (
|
||||
WikiArticle,
|
||||
WikiFetcher,
|
||||
cache_filename_for_title,
|
||||
title_from_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _IngestStats:
|
||||
n_questions: int
|
||||
n_unique_urls: int
|
||||
n_fetched: int
|
||||
n_cached_hits: int
|
||||
n_missing: int
|
||||
n_uploaded: int
|
||||
n_existing: int
|
||||
bench_dir: Path
|
||||
map_path: Path
|
||||
|
||||
|
||||
async def _fetch_articles(
|
||||
fetcher: WikiFetcher,
|
||||
urls: list[str],
|
||||
) -> tuple[dict[str, WikiArticle], list[str]]:
|
||||
"""Fetch each URL serially (the WikiFetcher's rate-limiter serialises anyway).
|
||||
|
||||
Returns ``(url -> WikiArticle, missing_urls)``. Missing means
|
||||
Wikipedia reported the title doesn't exist, the URL was non-wiki,
|
||||
or the API returned an empty extract.
|
||||
"""
|
||||
|
||||
fetched: dict[str, WikiArticle] = {}
|
||||
missing: list[str] = []
|
||||
n_total = len(urls)
|
||||
for i, url in enumerate(urls, start=1):
|
||||
try:
|
||||
article = await fetcher.fetch(url)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FRAMES wiki fetch %s failed: %s", url, exc)
|
||||
missing.append(url)
|
||||
continue
|
||||
if article is None:
|
||||
missing.append(url)
|
||||
continue
|
||||
fetched[url] = article
|
||||
if i % 25 == 0 or i == n_total:
|
||||
logger.info(" ... fetched %d / %d Wikipedia articles", i, n_total)
|
||||
return fetched, missing
|
||||
|
||||
|
||||
async def _upload_markdowns(
|
||||
ctx: RunContext,
|
||||
articles: list[WikiArticle],
|
||||
*,
|
||||
batch_size: int,
|
||||
settings: IngestSettings,
|
||||
) -> dict[str, int]:
|
||||
"""Upload deduplicated markdown files. Returns ``filename -> document_id``.
|
||||
|
||||
SurfSense dedupes uploads on ``(filename, search_space_id)``, so
|
||||
re-running ingest after a crash is idempotent — duplicates land in
|
||||
``duplicate_document_ids`` and we still recover their ids via the
|
||||
status endpoint.
|
||||
"""
|
||||
|
||||
if not articles:
|
||||
return {}
|
||||
docs_client = ctx.documents_client()
|
||||
name_to_id: dict[str, int] = {}
|
||||
paths = [a.markdown_path for a in articles]
|
||||
for batch_start in range(0, len(paths), batch_size):
|
||||
batch = paths[batch_start : batch_start + batch_size]
|
||||
result = await docs_client.upload(
|
||||
files=batch,
|
||||
search_space_id=ctx.search_space_id,
|
||||
should_summarize=settings.should_summarize,
|
||||
use_vision_llm=settings.use_vision_llm,
|
||||
processing_mode=settings.processing_mode,
|
||||
)
|
||||
all_ids = list(result.document_ids) + list(result.duplicate_document_ids)
|
||||
if result.document_ids:
|
||||
try:
|
||||
await docs_client.wait_until_ready(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=result.document_ids,
|
||||
timeout_s=900.0,
|
||||
)
|
||||
except (DocumentProcessingFailed, DocumentProcessingTimeout) as exc:
|
||||
logger.warning("FRAMES batch processing issue: %s", exc)
|
||||
if all_ids:
|
||||
statuses = await docs_client.get_status(
|
||||
search_space_id=ctx.search_space_id,
|
||||
document_ids=all_ids,
|
||||
)
|
||||
for s in statuses:
|
||||
# SurfSense stores the uploaded filename as ``title`` (no extension).
|
||||
stem = Path(s.title).stem if s.title.endswith(".md") else s.title
|
||||
name_to_id[stem] = s.document_id
|
||||
name_to_id[s.title] = s.document_id
|
||||
logger.info(
|
||||
"FRAMES upload batch %d-%d: %d new, %d duplicate",
|
||||
batch_start, batch_start + len(batch),
|
||||
len(result.document_ids), len(result.duplicate_document_ids),
|
||||
)
|
||||
return name_to_id
|
||||
|
||||
|
||||
def _resolve_question_doc_ids(
|
||||
questions: list[Any],
|
||||
fetched: dict[str, WikiArticle],
|
||||
name_to_id: dict[str, int],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""For each question, list the document_ids of its (fetched) wiki articles."""
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for q in questions:
|
||||
doc_ids: list[int] = []
|
||||
titles: list[str] = []
|
||||
missing: list[str] = []
|
||||
for url in q.wiki_urls:
|
||||
article = fetched.get(url)
|
||||
if article is None:
|
||||
missing.append(url)
|
||||
continue
|
||||
titles.append(article.title)
|
||||
stem = Path(cache_filename_for_title(article.title)).stem
|
||||
doc_id = name_to_id.get(stem) or name_to_id.get(article.markdown_path.name)
|
||||
if doc_id is not None and doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
rows.append({
|
||||
"qid": q.qid,
|
||||
"raw_index": q.raw_index,
|
||||
"n_wiki_urls": len(q.wiki_urls),
|
||||
"wiki_titles": titles,
|
||||
"document_ids": doc_ids,
|
||||
"missing_urls": missing,
|
||||
})
|
||||
return rows
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_ingest(
|
||||
ctx: RunContext,
|
||||
*,
|
||||
max_questions: int | None = None,
|
||||
upload_batch_size: int = 16,
|
||||
skip_upload: bool = False,
|
||||
fetch_rate_limit_rps: float = 2.0,
|
||||
settings: IngestSettings | None = None,
|
||||
) -> None:
|
||||
"""Ingest the FRAMES benchmark into the research suite.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_questions : int | None
|
||||
Cap on the number of FRAMES questions to materialise. ``None`` =
|
||||
all 824 (≈300+ unique articles). Smoke runs should pass 5-10.
|
||||
upload_batch_size : int
|
||||
Markdown files per ``/documents/fileupload`` call. Larger
|
||||
batches reduce round-trip overhead; smaller batches recover
|
||||
faster from individual processing failures.
|
||||
skip_upload : bool
|
||||
Fetch + cache Wikipedia articles locally but don't push to
|
||||
SurfSense. Useful for debugging the fetcher in isolation.
|
||||
fetch_rate_limit_rps : float
|
||||
Maximum requests-per-second to the Wikipedia API. Default 2.0
|
||||
is a polite ceiling; raise cautiously.
|
||||
settings : IngestSettings | None
|
||||
Override per-upload knobs. FRAMES defaults to text-only
|
||||
(no vision LLM, basic mode) — the corpus is plain wikitext.
|
||||
"""
|
||||
|
||||
settings = settings or IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
wiki_cache = bench_dir / "wiki"
|
||||
wiki_cache.mkdir(parents=True, exist_ok=True)
|
||||
hf_cache = bench_dir / ".hf_cache"
|
||||
hf_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. Download + parse questions.
|
||||
tsv_path = download_test_tsv(hf_cache)
|
||||
questions = load_questions(tsv_path)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"FRAMES test.tsv contained no parseable rows; upstream may "
|
||||
"have changed schema."
|
||||
)
|
||||
logger.info("FRAMES: parsed %d questions from %s", len(questions), tsv_path.name)
|
||||
if max_questions is not None and max_questions > 0:
|
||||
questions = questions[:max_questions]
|
||||
logger.info("FRAMES: capped to first %d questions", len(questions))
|
||||
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
write_questions_jsonl(questions, questions_jsonl)
|
||||
|
||||
# 2. Build deduplicated URL set (preserving first-seen order).
|
||||
seen_urls: dict[str, None] = {}
|
||||
for q in questions:
|
||||
for url in q.wiki_urls:
|
||||
seen_urls.setdefault(url, None)
|
||||
unique_urls = list(seen_urls.keys())
|
||||
logger.info(
|
||||
"FRAMES: %d unique Wikipedia URLs across %d questions",
|
||||
len(unique_urls), len(questions),
|
||||
)
|
||||
|
||||
# 3. Fetch (with cache).
|
||||
fetcher = WikiFetcher(cache_dir=wiki_cache, rate_limit_rps=fetch_rate_limit_rps)
|
||||
n_cached = sum(
|
||||
1 for url in unique_urls
|
||||
if (wiki_cache / cache_filename_for_title(_safe_title(url))).exists()
|
||||
)
|
||||
fetched, missing_urls = await _fetch_articles(fetcher, unique_urls)
|
||||
logger.info(
|
||||
"FRAMES: fetched=%d, cache_hits=%d, missing=%d",
|
||||
len(fetched), n_cached, len(missing_urls),
|
||||
)
|
||||
|
||||
# 4. Upload to SurfSense (deduped by filename).
|
||||
name_to_id: dict[str, int] = {}
|
||||
if skip_upload:
|
||||
logger.info("FRAMES: --skip-upload; skipping SurfSense ingestion")
|
||||
else:
|
||||
unique_articles = list({a.markdown_path: a for a in fetched.values()}.values())
|
||||
name_to_id = await _upload_markdowns(
|
||||
ctx,
|
||||
unique_articles,
|
||||
batch_size=upload_batch_size,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# 5. Persist per-question doc map.
|
||||
doc_rows = _resolve_question_doc_ids(questions, fetched, name_to_id)
|
||||
|
||||
map_path = ctx.maps_dir() / "frames_doc_map.jsonl"
|
||||
with map_path.open("w", encoding="utf-8") as fh:
|
||||
fh.write(settings_header_line(settings) + "\n")
|
||||
for row in doc_rows:
|
||||
fh.write(json.dumps(row) + "\n")
|
||||
logger.info("Wrote FRAMES doc map to %s (%d rows)", map_path, len(doc_rows))
|
||||
|
||||
# 6. Update suite state.
|
||||
new_state = ctx.suite_state
|
||||
new_state.ingestion_maps["frames"] = str(map_path)
|
||||
set_suite_state(ctx.config, ctx.suite, new_state)
|
||||
|
||||
stats = _IngestStats(
|
||||
n_questions=len(questions),
|
||||
n_unique_urls=len(unique_urls),
|
||||
n_fetched=len(fetched),
|
||||
n_cached_hits=n_cached,
|
||||
n_missing=len(missing_urls),
|
||||
n_uploaded=len(name_to_id),
|
||||
n_existing=0,
|
||||
bench_dir=bench_dir,
|
||||
map_path=map_path,
|
||||
)
|
||||
logger.info("FRAMES ingest done: %s", stats)
|
||||
|
||||
|
||||
def _safe_title(url: str) -> str:
|
||||
"""Pre-cache title resolution; returns ``""`` on bad URL."""
|
||||
|
||||
try:
|
||||
return title_from_url(url)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
__all__ = ["run_ingest"]
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
"""FRAMES prompt templates.
|
||||
|
||||
Two templates: one for the bare-LLM arm (no retrieval), one for
|
||||
SurfSense (the agent retrieves; we mostly just instruct it on
|
||||
output format). Both arms must use byte-identical *content* for the
|
||||
question itself so the head-to-head is fair — the wrappers diverge
|
||||
only in framing.
|
||||
|
||||
Format expectations (mirrors the FRAMES paper, section 4):
|
||||
|
||||
* Short factual answer — names, dates, numbers, ordinals
|
||||
* No extra explanation in the final line; we anchor on
|
||||
``Answer: <text>`` for deterministic extraction
|
||||
* Free-text reasoning is *allowed* before the final ``Answer:`` line —
|
||||
multi-hop questions often benefit from it. We just don't grade it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
_BASE_INSTRUCTIONS = (
|
||||
"You are a careful question-answering assistant. The question may "
|
||||
"require combining facts from multiple sources, doing arithmetic, "
|
||||
"or reasoning about dates. Think step by step if needed, then give "
|
||||
"the final answer.\n\n"
|
||||
"Format your final line EXACTLY as:\n"
|
||||
"Answer: <short answer>\n\n"
|
||||
"The answer should be as short as possible — a name, a number, a "
|
||||
"date, a single phrase. Do not repeat the question. Do not include "
|
||||
"punctuation at the end unless it is part of the answer."
|
||||
)
|
||||
|
||||
|
||||
_BARE_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
Question: {question}
|
||||
"""
|
||||
|
||||
|
||||
_SURFSENSE_TEMPLATE = """\
|
||||
{instructions}
|
||||
|
||||
You have access to a Wikipedia knowledge base via retrieval. Use it
|
||||
to look up any facts you are not confident about. The corpus contains
|
||||
the Wikipedia articles needed to answer this question, but you must
|
||||
retrieve them yourself — they are not pre-selected.
|
||||
|
||||
Question: {question}
|
||||
"""
|
||||
|
||||
|
||||
def build_bare_prompt(question: str) -> str:
|
||||
"""Prompt for the no-retrieval baseline arm."""
|
||||
|
||||
return _BARE_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
question=question.strip(),
|
||||
)
|
||||
|
||||
|
||||
def build_surfsense_prompt(question: str) -> str:
|
||||
"""Prompt for the SurfSense arm (retrieval-augmented)."""
|
||||
|
||||
return _SURFSENSE_TEMPLATE.format(
|
||||
instructions=_BASE_INSTRUCTIONS,
|
||||
question=question.strip(),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["build_bare_prompt", "build_surfsense_prompt"]
|
||||
|
|
@ -0,0 +1,686 @@
|
|||
"""FRAMES runner — Bare LLM (no retrieval) vs SurfSense (multi-hop RAG).
|
||||
|
||||
Two arms run paired on every question in the sample:
|
||||
|
||||
1. ``BareLlmArm`` — OpenRouter chat completion with the question only.
|
||||
Reproduces the published "naive prompting" baseline (40.8% on
|
||||
Gemini-Pro-1.5).
|
||||
2. ``SurfSenseArm`` — POST ``/api/v1/new_chat`` with **no**
|
||||
``mentioned_document_ids`` so the agent retrieves over the entire
|
||||
ingested Wikipedia corpus. This is the "multi-step retrieval &
|
||||
reasoning" cell in the FRAMES paper.
|
||||
|
||||
Open-ended grading: deterministic shortcut + optional LLM-as-judge
|
||||
(``--no-judge`` to disable). Cost / latency / token aggregates are
|
||||
collected per arm. Paired stats (McNemar, bootstrap CI) for the
|
||||
accuracy delta. Per-reasoning-type breakdown to surface where one
|
||||
arm beats the other (numerical vs temporal vs multi-constraint, ...).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ....core.arms import ArmRequest, ArmResult, BareLlmArm, SurfSenseArm
|
||||
from ....core.config import utc_iso_timestamp
|
||||
from ....core.ingest_settings import (
|
||||
IngestSettings,
|
||||
add_ingest_settings_args,
|
||||
format_ingest_settings_md,
|
||||
is_settings_header,
|
||||
)
|
||||
from ....core.metrics.comparison import (
|
||||
bootstrap_delta_ci,
|
||||
mcnemar_test,
|
||||
paired_aggregate,
|
||||
)
|
||||
from ....core.metrics.mc_accuracy import accuracy_with_wilson_ci
|
||||
from ....core.parse.freeform_answer import extract_freeform_answer
|
||||
from ....core.providers.openrouter_chat import OpenRouterChatProvider
|
||||
from ....core.registry import ReportSection, RunArtifact, RunContext
|
||||
from ....core.scenarios import format_scenario_md
|
||||
from .grader import GradeResult, JudgeConfig, LlmJudge, grade_many
|
||||
from .prompt import build_bare_prompt, build_surfsense_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Question shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class FramesRunnerQuestion:
|
||||
qid: str
|
||||
raw_index: int
|
||||
question: str
|
||||
gold_answer: str
|
||||
reasoning_types: list[str]
|
||||
document_ids: list[int] # subset of corpus relevant to this Q (may be empty)
|
||||
n_wiki_urls: int
|
||||
missing_urls: list[str]
|
||||
|
||||
|
||||
def _load_doc_map(map_path: Path) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
|
||||
rows: dict[str, dict[str, Any]] = {}
|
||||
settings: dict[str, Any] = {}
|
||||
with map_path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if is_settings_header(row):
|
||||
settings = dict(row["__settings__"])
|
||||
continue
|
||||
rows[str(row["qid"])] = row
|
||||
return rows, settings
|
||||
|
||||
|
||||
def _load_questions(
|
||||
questions_jsonl: Path,
|
||||
doc_map: dict[str, dict[str, Any]],
|
||||
*,
|
||||
sample_n: int | None,
|
||||
reasoning_filter: str | None,
|
||||
) -> list[FramesRunnerQuestion]:
|
||||
out: list[FramesRunnerQuestion] = []
|
||||
with questions_jsonl.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
qid = str(row.get("qid") or "").strip()
|
||||
if not qid:
|
||||
continue
|
||||
map_row = doc_map.get(qid, {})
|
||||
reasoning = list(row.get("reasoning_types") or [])
|
||||
if reasoning_filter and reasoning_filter not in [r.lower() for r in reasoning]:
|
||||
continue
|
||||
out.append(FramesRunnerQuestion(
|
||||
qid=qid,
|
||||
raw_index=int(row.get("raw_index") or 0),
|
||||
question=str(row.get("question") or "").strip(),
|
||||
gold_answer=str(row.get("gold_answer") or "").strip(),
|
||||
reasoning_types=reasoning,
|
||||
document_ids=list(map_row.get("document_ids") or []),
|
||||
n_wiki_urls=int(map_row.get("n_wiki_urls") or 0),
|
||||
missing_urls=list(map_row.get("missing_urls") or []),
|
||||
))
|
||||
out.sort(key=lambda q: q.raw_index)
|
||||
if sample_n is not None and sample_n > 0:
|
||||
out = out[:sample_n]
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bounded concurrency helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _gather_with_limit(coros: Iterable, *, concurrency: int) -> list[Any]:
|
||||
sem = asyncio.Semaphore(max(1, concurrency))
|
||||
|
||||
async def _wrap(coro):
|
||||
async with sem:
|
||||
return await coro
|
||||
|
||||
return await asyncio.gather(*(_wrap(c) for c in coros))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_DESCRIPTION = (
|
||||
"FRAMES (824 multi-hop Wikipedia questions, 5 reasoning types) — "
|
||||
"Bare LLM (no retrieval) vs SurfSense (multi-step RAG over the "
|
||||
"Wikipedia corpus). Tests cross-document retrieval + reasoning."
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_INGEST_SETTINGS = IngestSettings(
|
||||
use_vision_llm=False,
|
||||
processing_mode="basic",
|
||||
should_summarize=False,
|
||||
)
|
||||
|
||||
|
||||
class FramesBenchmark:
|
||||
"""Multi-hop Wikipedia RAG vs naive prompting."""
|
||||
|
||||
suite: str = "research"
|
||||
name: str = "frames"
|
||||
headline: bool = True
|
||||
description: str = _DESCRIPTION
|
||||
|
||||
def add_run_args(self, parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--n", dest="sample_n", type=int, default=None,
|
||||
help="Run only the first N questions after filters (default: all 824).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reasoning",
|
||||
dest="reasoning_filter",
|
||||
default=None,
|
||||
help=(
|
||||
"Filter to questions tagged with this reasoning type "
|
||||
"(e.g. 'numerical reasoning', 'temporal reasoning'). "
|
||||
"Case-insensitive substring against the upstream tags."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrency", type=int, default=4,
|
||||
help="Parallel question workers per arm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scope-mentions", dest="scope_mentions", action="store_true",
|
||||
help=(
|
||||
"SurfSense arm: scope retrieval to the per-question "
|
||||
"document_ids (oracle-retrieval upper bound). Default "
|
||||
"is full-corpus retrieval (the realistic FRAMES setting)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-output-tokens", type=int, default=512,
|
||||
help="Cap on completion length for both arms.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-judge", dest="no_judge", action="store_true",
|
||||
help=(
|
||||
"Disable LLM-as-judge fallback grading; use only the "
|
||||
"deterministic grader (faster but more pessimistic)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--judge-model",
|
||||
dest="judge_model",
|
||||
default="anthropic/claude-sonnet-4.5",
|
||||
help="OpenRouter slug for the LLM judge (default: claude-sonnet-4.5).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--judge-concurrency",
|
||||
dest="judge_concurrency",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Parallel judge calls (default: 4).",
|
||||
)
|
||||
# Ingest-only knobs.
|
||||
parser.add_argument(
|
||||
"--max-questions", dest="max_questions", type=int, default=None,
|
||||
help="(ingest only) cap on number of questions to materialise + ingest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-batch-size", dest="upload_batch_size", type=int, default=16,
|
||||
help="(ingest only) markdown files per fileupload call.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-upload", dest="skip_upload", action="store_true",
|
||||
help="(ingest only) cache wiki articles locally but don't push to SurfSense.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fetch-rps", dest="fetch_rate_limit_rps", type=float, default=2.0,
|
||||
help="(ingest only) max requests/second to the Wikipedia API.",
|
||||
)
|
||||
add_ingest_settings_args(parser, defaults=_DEFAULT_INGEST_SETTINGS)
|
||||
|
||||
async def ingest(self, ctx: RunContext, **opts: Any) -> None:
|
||||
from .ingest import run_ingest
|
||||
|
||||
settings = IngestSettings.merge(_DEFAULT_INGEST_SETTINGS, opts)
|
||||
await run_ingest(
|
||||
ctx,
|
||||
max_questions=opts.get("max_questions"),
|
||||
upload_batch_size=int(opts.get("upload_batch_size") or 16),
|
||||
skip_upload=bool(opts.get("skip_upload", False)),
|
||||
fetch_rate_limit_rps=float(opts.get("fetch_rate_limit_rps") or 2.0),
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
async def run(self, ctx: RunContext, **opts: Any) -> RunArtifact:
|
||||
sample_n = opts.get("sample_n")
|
||||
reasoning_filter = opts.get("reasoning_filter")
|
||||
if reasoning_filter:
|
||||
reasoning_filter = reasoning_filter.strip().lower() or None
|
||||
concurrency = int(opts.get("concurrency") or 4)
|
||||
scope_mentions = bool(opts.get("scope_mentions"))
|
||||
max_output_tokens = int(opts.get("max_output_tokens") or 512)
|
||||
no_judge = bool(opts.get("no_judge"))
|
||||
judge_model = str(opts.get("judge_model") or "anthropic/claude-sonnet-4.5")
|
||||
judge_concurrency = int(opts.get("judge_concurrency") or 4)
|
||||
|
||||
bench_dir = ctx.benchmark_data_dir()
|
||||
questions_jsonl = bench_dir / "questions.jsonl"
|
||||
map_path = ctx.maps_dir() / "frames_doc_map.jsonl"
|
||||
if not questions_jsonl.exists() or not map_path.exists():
|
||||
raise RuntimeError(
|
||||
"FRAMES not ingested for this suite. Run "
|
||||
"`python -m surfsense_evals ingest research frames` first."
|
||||
)
|
||||
|
||||
doc_map, ingest_settings = _load_doc_map(map_path)
|
||||
questions = _load_questions(
|
||||
questions_jsonl, doc_map,
|
||||
sample_n=sample_n,
|
||||
reasoning_filter=reasoning_filter,
|
||||
)
|
||||
if not questions:
|
||||
raise RuntimeError(
|
||||
"No FRAMES questions matched the filters; broaden --reasoning/--n."
|
||||
)
|
||||
logger.info("FRAMES: scheduled %d questions", len(questions))
|
||||
|
||||
api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError(
|
||||
"OPENROUTER_API_KEY env var is required for the bare-LLM arm."
|
||||
)
|
||||
|
||||
bare_provider = OpenRouterChatProvider(
|
||||
api_key=api_key,
|
||||
base_url=ctx.config.openrouter_base_url,
|
||||
model=ctx.native_arm_model,
|
||||
)
|
||||
bare_arm = BareLlmArm(
|
||||
provider=bare_provider,
|
||||
max_output_tokens=max_output_tokens,
|
||||
)
|
||||
surf_arm = SurfSenseArm(
|
||||
client=ctx.new_chat_client(),
|
||||
search_space_id=ctx.search_space_id,
|
||||
ephemeral_threads=True,
|
||||
)
|
||||
|
||||
judge: LlmJudge | None = None
|
||||
if not no_judge:
|
||||
judge = LlmJudge(config=JudgeConfig(
|
||||
api_key=api_key,
|
||||
model=judge_model,
|
||||
base_url=ctx.config.openrouter_base_url,
|
||||
concurrency=judge_concurrency,
|
||||
))
|
||||
|
||||
run_timestamp = utc_iso_timestamp()
|
||||
run_dir = ctx.runs_dir(run_timestamp=run_timestamp)
|
||||
raw_path = run_dir / "raw.jsonl"
|
||||
|
||||
async def _bare_one(q: FramesRunnerQuestion) -> ArmResult:
|
||||
return await bare_arm.answer(_make_bare_request(q, max_output_tokens))
|
||||
|
||||
async def _surf_one(q: FramesRunnerQuestion) -> ArmResult:
|
||||
return await surf_arm.answer(
|
||||
_make_surfsense_request(q, scope_mentions=scope_mentions)
|
||||
)
|
||||
|
||||
bare_results, surf_results = await asyncio.gather(
|
||||
_gather_with_limit((_bare_one(q) for q in questions), concurrency=concurrency),
|
||||
_gather_with_limit((_surf_one(q) for q in questions), concurrency=concurrency),
|
||||
)
|
||||
|
||||
bare_grades = await _grade_results(questions, bare_results, judge=judge)
|
||||
surf_grades = await _grade_results(questions, surf_results, judge=judge)
|
||||
|
||||
with raw_path.open("w", encoding="utf-8") as fh:
|
||||
for q, b_res, s_res, b_g, s_g in zip(
|
||||
questions, bare_results, surf_results, bare_grades, surf_grades, strict=False
|
||||
):
|
||||
meta = {
|
||||
"qid": q.qid,
|
||||
"raw_index": q.raw_index,
|
||||
"reasoning_types": q.reasoning_types,
|
||||
"n_wiki_urls": q.n_wiki_urls,
|
||||
"n_resolved_doc_ids": len(q.document_ids),
|
||||
"n_missing_urls": len(q.missing_urls),
|
||||
"gold": q.gold_answer,
|
||||
}
|
||||
fh.write(json.dumps({
|
||||
**meta,
|
||||
**b_res.to_jsonl(),
|
||||
"graded": b_g.to_dict(),
|
||||
}) + "\n")
|
||||
fh.write(json.dumps({
|
||||
**meta,
|
||||
**s_res.to_jsonl(),
|
||||
"graded": s_g.to_dict(),
|
||||
}) + "\n")
|
||||
|
||||
metrics = _compute_metrics(questions, bare_results, surf_results, bare_grades, surf_grades)
|
||||
artifact = RunArtifact(
|
||||
suite=self.suite,
|
||||
benchmark=self.name,
|
||||
run_timestamp=run_timestamp,
|
||||
raw_path=raw_path,
|
||||
metrics=metrics,
|
||||
extra={
|
||||
"n_questions": len(questions),
|
||||
"concurrency": concurrency,
|
||||
"reasoning_filter": reasoning_filter,
|
||||
"scope_mentions": scope_mentions,
|
||||
"no_judge": no_judge,
|
||||
"judge_model": judge_model if not no_judge else None,
|
||||
"scenario": ctx.scenario,
|
||||
"provider_model": ctx.provider_model,
|
||||
"native_arm_model": ctx.native_arm_model,
|
||||
"vision_provider_model": ctx.vision_provider_model,
|
||||
"agent_llm_id": ctx.agent_llm_id,
|
||||
"ingest_settings": ingest_settings,
|
||||
"bare_arm_label": "bare_llm",
|
||||
},
|
||||
)
|
||||
|
||||
manifest_path = run_dir / "run_artifact.json"
|
||||
manifest_path.write_text(
|
||||
json.dumps({
|
||||
"suite": self.suite,
|
||||
"benchmark": self.name,
|
||||
"raw_path": "raw.jsonl",
|
||||
"metrics": metrics,
|
||||
"extra": artifact.extra,
|
||||
}, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return artifact
|
||||
|
||||
def report_section(self, artifacts: list[RunArtifact]) -> ReportSection:
|
||||
if not artifacts:
|
||||
return ReportSection(
|
||||
title="FRAMES — Bare LLM vs SurfSense (multi-hop Wikipedia RAG)",
|
||||
headline=True,
|
||||
body_md="(no run artifacts found)",
|
||||
body_json={},
|
||||
)
|
||||
latest = max(artifacts, key=lambda a: a.run_timestamp)
|
||||
m = latest.metrics
|
||||
bare = m.get("bare", {})
|
||||
surf = m.get("surfsense", {})
|
||||
delta = m.get("delta", {})
|
||||
per_reasoning = m.get("per_reasoning", {})
|
||||
extra = latest.extra
|
||||
|
||||
body_lines: list[str] = []
|
||||
body_lines.append(
|
||||
f"- Sample size: {extra.get('n_questions', '?')} questions "
|
||||
f"(reasoning filter: `{extra.get('reasoning_filter') or 'none'}`, "
|
||||
f"scope-mentions: `{extra.get('scope_mentions', False)}`, "
|
||||
f"judge: `{extra.get('judge_model') or 'deterministic-only'}`)."
|
||||
)
|
||||
body_lines.append(format_scenario_md(extra))
|
||||
body_lines.append(format_ingest_settings_md(extra.get("ingest_settings")))
|
||||
body_lines.append(
|
||||
"- Bare LLM arm (OpenRouter chat, no retrieval, "
|
||||
f"`{extra.get('native_arm_model') or extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(bare, indent=" "))
|
||||
body_lines.append(
|
||||
"- SurfSense arm (`POST /api/v1/new_chat`, multi-step RAG, "
|
||||
f"`{extra.get('provider_model', '?')}`):"
|
||||
)
|
||||
body_lines.append(_arm_summary_lines(surf, indent=" "))
|
||||
body_lines.append("- Delta (paired):")
|
||||
body_lines.append(
|
||||
f" - Accuracy: SurfSense {_pp(delta.get('accuracy_pp'))} pp "
|
||||
f"(McNemar p={_fmt(delta.get('mcnemar_p_value'), 4)}, "
|
||||
f"method={delta.get('mcnemar_method')})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Bootstrap 95% CI on accuracy delta: "
|
||||
f"[{_pp(delta.get('bootstrap_ci_low'))}pp, {_pp(delta.get('bootstrap_ci_high'))}pp]"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Cost / question: bare ${_dollars(bare.get('cost_micros_mean'))}, "
|
||||
f"surfsense ${_dollars(surf.get('cost_micros_mean'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('cost_micros_pct'))})"
|
||||
)
|
||||
body_lines.append(
|
||||
f" - Latency p50: bare {_ms_to_s(bare.get('latency_ms_median'))}, "
|
||||
f"surfsense {_ms_to_s(surf.get('latency_ms_median'))} "
|
||||
f"(SurfSense delta {_pct_change(delta.get('latency_ms_pct'))})"
|
||||
)
|
||||
if per_reasoning:
|
||||
body_lines.append("- Per-reasoning-type split (accuracy delta in pp):")
|
||||
for tag, vals in sorted(per_reasoning.items()):
|
||||
body_lines.append(
|
||||
f" - {tag}: SurfSense {_pp(vals.get('delta_accuracy_pp'))} pp "
|
||||
f"(n={vals.get('n')}, bare acc={vals.get('bare_accuracy', 0)*100:.1f}%, "
|
||||
f"surf acc={vals.get('surfsense_accuracy', 0)*100:.1f}%)"
|
||||
)
|
||||
|
||||
return ReportSection(
|
||||
title="FRAMES — Bare LLM vs SurfSense (multi-hop Wikipedia RAG)",
|
||||
headline=True,
|
||||
body_md="\n".join(body_lines),
|
||||
body_json=m,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-question helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bare_request(q: FramesRunnerQuestion, max_tokens: int) -> ArmRequest:
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=build_bare_prompt(q.question),
|
||||
options={"max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
|
||||
def _make_surfsense_request(q: FramesRunnerQuestion, *, scope_mentions: bool) -> ArmRequest:
|
||||
mentions: list[int] | None = None
|
||||
if scope_mentions and q.document_ids:
|
||||
mentions = list(q.document_ids)
|
||||
return ArmRequest(
|
||||
question_id=q.qid,
|
||||
prompt=build_surfsense_prompt(q.question),
|
||||
mentioned_document_ids=mentions,
|
||||
)
|
||||
|
||||
|
||||
async def _grade_results(
|
||||
questions: list[FramesRunnerQuestion],
|
||||
results: list[ArmResult],
|
||||
*,
|
||||
judge: LlmJudge | None,
|
||||
) -> list[GradeResult]:
|
||||
rows: list[tuple[str, str, str, str]] = []
|
||||
for q, r in zip(questions, results, strict=False):
|
||||
pred = extract_freeform_answer(r.raw_text or "")
|
||||
rows.append((q.qid, q.question, q.gold_answer, pred))
|
||||
return await grade_many(rows=rows, judge=judge)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics aggregation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_metrics(
|
||||
questions: list[FramesRunnerQuestion],
|
||||
bare_results: list[ArmResult],
|
||||
surf_results: list[ArmResult],
|
||||
bare_grades: list[GradeResult],
|
||||
surf_grades: list[GradeResult],
|
||||
) -> dict[str, Any]:
|
||||
bare_correct = [g.correct for g in bare_grades]
|
||||
surf_correct = [g.correct for g in surf_grades]
|
||||
|
||||
bare_costs = [float(r.cost_micros) for r in bare_results]
|
||||
surf_costs = [float(r.cost_micros) for r in surf_results]
|
||||
bare_latencies = [float(r.latency_ms) for r in bare_results]
|
||||
surf_latencies = [float(r.latency_ms) for r in surf_results]
|
||||
bare_in_tokens = [float(r.input_tokens) for r in bare_results]
|
||||
bare_out_tokens = [float(r.output_tokens) for r in bare_results]
|
||||
|
||||
bare_acc = accuracy_with_wilson_ci(sum(bare_correct), len(bare_correct))
|
||||
surf_acc = accuracy_with_wilson_ci(sum(surf_correct), len(surf_correct))
|
||||
mc = mcnemar_test(bare_correct, surf_correct)
|
||||
boot = bootstrap_delta_ci(bare_correct, surf_correct, n_resamples=2000)
|
||||
|
||||
bare_cost_agg = paired_aggregate(bare_costs)
|
||||
surf_cost_agg = paired_aggregate(surf_costs)
|
||||
bare_latency_agg = paired_aggregate(bare_latencies)
|
||||
surf_latency_agg = paired_aggregate(surf_latencies)
|
||||
cost_pct = _safe_pct(surf_cost_agg.mean, bare_cost_agg.mean)
|
||||
latency_pct = _safe_pct(surf_latency_agg.median, bare_latency_agg.median)
|
||||
|
||||
# Per-reasoning-type breakdown. Each question may carry multiple
|
||||
# reasoning tags; we count it under each tag (so totals don't
|
||||
# equal len(questions) — the reader is expected to look at the
|
||||
# per-tag ``n``).
|
||||
per_reasoning_pairs: dict[str, list[tuple[bool, bool]]] = {}
|
||||
for q, b_ok, s_ok in zip(questions, bare_correct, surf_correct, strict=False):
|
||||
tags = q.reasoning_types or ["(untagged)"]
|
||||
for tag in tags:
|
||||
per_reasoning_pairs.setdefault(tag, []).append((b_ok, s_ok))
|
||||
|
||||
per_reasoning: dict[str, dict[str, Any]] = {}
|
||||
for tag, pairs in per_reasoning_pairs.items():
|
||||
b_correct = [a for a, _ in pairs]
|
||||
s_correct = [b for _, b in pairs]
|
||||
per_reasoning[tag] = {
|
||||
"n": len(pairs),
|
||||
"bare_accuracy": (sum(b_correct) / len(pairs)) if pairs else 0.0,
|
||||
"surfsense_accuracy": (sum(s_correct) / len(pairs)) if pairs else 0.0,
|
||||
"delta_accuracy_pp": (
|
||||
100.0 * (sum(s_correct) - sum(b_correct)) / len(pairs)
|
||||
if pairs else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
grader_methods = {
|
||||
"bare": _count_methods(bare_grades),
|
||||
"surfsense": _count_methods(surf_grades),
|
||||
}
|
||||
|
||||
return {
|
||||
"bare": {
|
||||
**bare_acc.to_dict(),
|
||||
"cost_micros_mean": bare_cost_agg.mean,
|
||||
"cost_micros_median": bare_cost_agg.median,
|
||||
"latency_ms_mean": bare_latency_agg.mean,
|
||||
"latency_ms_median": bare_latency_agg.median,
|
||||
"latency_ms_p95": bare_latency_agg.p95,
|
||||
"input_tokens_mean": (sum(bare_in_tokens) / len(bare_in_tokens)) if bare_in_tokens else 0.0,
|
||||
"output_tokens_mean": (sum(bare_out_tokens) / len(bare_out_tokens)) if bare_out_tokens else 0.0,
|
||||
},
|
||||
"surfsense": {
|
||||
**surf_acc.to_dict(),
|
||||
"cost_micros_mean": surf_cost_agg.mean,
|
||||
"cost_micros_median": surf_cost_agg.median,
|
||||
"latency_ms_mean": surf_latency_agg.mean,
|
||||
"latency_ms_median": surf_latency_agg.median,
|
||||
"latency_ms_p95": surf_latency_agg.p95,
|
||||
},
|
||||
"delta": {
|
||||
"accuracy_pp": 100.0 * (surf_acc.accuracy - bare_acc.accuracy),
|
||||
"mcnemar_p_value": mc.p_value,
|
||||
"mcnemar_method": mc.method,
|
||||
"mcnemar_b_bare_only": mc.b,
|
||||
"mcnemar_c_surfsense_only": mc.c,
|
||||
"bootstrap_ci_low": 100.0 * boot.ci_low,
|
||||
"bootstrap_ci_high": 100.0 * boot.ci_high,
|
||||
"cost_micros_pct": cost_pct,
|
||||
"latency_ms_pct": latency_pct,
|
||||
},
|
||||
"per_reasoning": per_reasoning,
|
||||
"grader_methods": grader_methods,
|
||||
}
|
||||
|
||||
|
||||
def _count_methods(grades: list[GradeResult]) -> dict[str, int]:
|
||||
out: dict[str, int] = {}
|
||||
for g in grades:
|
||||
out[g.method] = out.get(g.method, 0) + 1
|
||||
return out
|
||||
|
||||
|
||||
def _safe_pct(numerator: float, denominator: float) -> float | None:
|
||||
if denominator == 0:
|
||||
return None
|
||||
return 100.0 * (numerator - denominator) / denominator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tiny formatting helpers used by report_section
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _arm_summary_lines(d: dict[str, Any], *, indent: str) -> str:
|
||||
if not d:
|
||||
return f"{indent}(no data)"
|
||||
acc = d.get("accuracy", 0.0)
|
||||
low = d.get("ci_low", 0.0)
|
||||
high = d.get("ci_high", 0.0)
|
||||
lines = [
|
||||
f"{indent}- Accuracy: {acc * 100:.1f}% (Wilson 95% CI: {low * 100:.1f}% – {high * 100:.1f}%)",
|
||||
f"{indent}- Cost / question: ${_dollars(d.get('cost_micros_mean'))} (mean), "
|
||||
f"${_dollars(d.get('cost_micros_median'))} (median)",
|
||||
f"{indent}- Latency: p50 {_ms_to_s(d.get('latency_ms_median'))}, "
|
||||
f"p95 {_ms_to_s(d.get('latency_ms_p95'))}",
|
||||
]
|
||||
if "input_tokens_mean" in d:
|
||||
lines.append(
|
||||
f"{indent}- Mean tokens / question: in {d.get('input_tokens_mean', 0):.0f}, "
|
||||
f"out {d.get('output_tokens_mean', 0):.0f}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _dollars(micros: Any) -> str:
|
||||
if micros is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{(float(micros) / 1_000_000):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _ms_to_s(ms: Any) -> str:
|
||||
if ms is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(ms) / 1000:.1f}s"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pp(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.1f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _pct_change(value: Any) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):+.0f}%"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
def _fmt(value: Any, ndigits: int) -> str:
|
||||
if value is None:
|
||||
return "?"
|
||||
try:
|
||||
return f"{float(value):.{ndigits}f}"
|
||||
except (TypeError, ValueError):
|
||||
return "?"
|
||||
|
||||
|
||||
__all__ = ["FramesBenchmark", "FramesRunnerQuestion"]
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
"""Wikipedia article fetcher → plain-text markdown, with disk cache.
|
||||
|
||||
We hit the MediaWiki action API for *plain text* extracts:
|
||||
|
||||
GET https://en.wikipedia.org/w/api.php
|
||||
?action=query&prop=extracts&explaintext=true
|
||||
&redirects=1&titles=<Title>&format=json&formatversion=2
|
||||
|
||||
This avoids HTML→markdown conversion (and its many edge cases). The
|
||||
``explaintext=true`` mode strips infoboxes / templates / wikilinks
|
||||
and returns clean section-headered prose, which is exactly what we
|
||||
want SurfSense to chunk + embed. We prepend ``# <Title>\n\n`` so the
|
||||
markdown has a visible H1 (helps SurfSense's chunker preserve doc
|
||||
identity at the top of the first chunk).
|
||||
|
||||
Caching: every fetched article lands in
|
||||
``<bench_dir>/wiki/<sanitised-title>.md`` and is reused on subsequent
|
||||
runs. The cache key is the URL-decoded title (e.g.
|
||||
``Charlotte_Brontë`` regardless of source URL casing or
|
||||
percent-encoding).
|
||||
|
||||
Politeness: 2 RPS rate limit + a descriptive User-Agent (Wikimedia
|
||||
asks for one). We don't parallelise above 2 RPS — this is a courtesy
|
||||
to Wikipedia and only ~300 articles for the n=100 sample.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
WIKI_API = "https://en.wikipedia.org/w/api.php"
|
||||
USER_AGENT = (
|
||||
"SurfSense-Evals/0.1 (https://github.com/MODSetter/SurfSense; "
|
||||
"research-benchmark fetch; respects 2 RPS rate limit)"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WikiArticle:
|
||||
"""One fetched article + metadata."""
|
||||
|
||||
title: str # canonical title returned by MW (post-redirect)
|
||||
source_url: str # the URL we were asked to fetch
|
||||
markdown_path: Path # where the cached body lives on disk
|
||||
n_chars: int # length of the body (post-prepend H1)
|
||||
redirected_from: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Title <-> URL helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_WIKI_PATH_RE = re.compile(r"^/wiki/(?P<title>[^?#]+)$")
|
||||
|
||||
|
||||
def title_from_url(url: str) -> str:
|
||||
"""Pull the page title out of a wiki URL.
|
||||
|
||||
``https://en.wikipedia.org/wiki/Charlotte_Bront%C3%AB`` → ``Charlotte Brontë``.
|
||||
Spaces are preserved (the API accepts spaces and underscores
|
||||
interchangeably; we use spaces to keep cache filenames human-readable).
|
||||
"""
|
||||
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if parsed.netloc and "wikipedia.org" not in parsed.netloc:
|
||||
raise ValueError(f"Not a Wikipedia URL: {url!r}")
|
||||
match = _WIKI_PATH_RE.match(parsed.path)
|
||||
if not match:
|
||||
raise ValueError(f"Unrecognised wiki path: {parsed.path!r}")
|
||||
raw_title = urllib.parse.unquote(match.group("title"))
|
||||
# MW treats underscores and spaces as equivalent; spaces are friendlier.
|
||||
return raw_title.replace("_", " ").strip()
|
||||
|
||||
|
||||
_FILENAME_SAFE = re.compile(r"[^A-Za-z0-9._\- ]")
|
||||
|
||||
|
||||
def cache_filename_for_title(title: str) -> str:
|
||||
"""Map a title to a filesystem-safe filename.
|
||||
|
||||
Replaces every non-(alnum / ``._- `` / space) character with ``_``.
|
||||
Title collisions are rare (FRAMES only has English Wikipedia titles)
|
||||
and a final ``hash(title)[:8]`` would obscure the otherwise-readable
|
||||
filenames; we accept the (vanishingly small) collision risk.
|
||||
"""
|
||||
|
||||
safe = _FILENAME_SAFE.sub("_", title)
|
||||
safe = safe.strip().replace(" ", "_")
|
||||
return f"{safe}.md"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async fetcher with rate limiting + retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WikiFetcher:
|
||||
"""Polite fetch + disk cache + redirect handling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cache_dir: Path,
|
||||
rate_limit_rps: float = 2.0,
|
||||
timeout_s: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
self._cache_dir = Path(cache_dir)
|
||||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._min_interval = 1.0 / max(rate_limit_rps, 0.1)
|
||||
self._last_request_at = 0.0
|
||||
self._rate_lock = asyncio.Lock()
|
||||
self._timeout = httpx.Timeout(timeout_s, connect=10.0)
|
||||
self._max_retries = max_retries
|
||||
|
||||
async def _throttle(self) -> None:
|
||||
async with self._rate_lock:
|
||||
now = asyncio.get_event_loop().time()
|
||||
wait = self._last_request_at + self._min_interval - now
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
self._last_request_at = asyncio.get_event_loop().time()
|
||||
|
||||
async def fetch(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
http: httpx.AsyncClient | None = None,
|
||||
) -> WikiArticle | None:
|
||||
"""Fetch one article. Returns ``None`` only if MW reports the title is missing.
|
||||
|
||||
Raises on transport errors after retries. Caller decides
|
||||
whether to abort the whole ingest or continue with the
|
||||
successfully-fetched subset.
|
||||
"""
|
||||
|
||||
try:
|
||||
title = title_from_url(url)
|
||||
except ValueError as exc:
|
||||
logger.warning("Skipping non-wiki URL %s: %s", url, exc)
|
||||
return None
|
||||
|
||||
cache_path = self._cache_dir / cache_filename_for_title(title)
|
||||
if cache_path.exists() and cache_path.stat().st_size > 0:
|
||||
return WikiArticle(
|
||||
title=title,
|
||||
source_url=url,
|
||||
markdown_path=cache_path,
|
||||
n_chars=cache_path.stat().st_size,
|
||||
)
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
await self._throttle()
|
||||
payload = await self._fetch_extract(title, http=http)
|
||||
break
|
||||
except (httpx.HTTPError, RuntimeError) as exc:
|
||||
last_exc = exc
|
||||
wait = 1.0 * (2 ** attempt)
|
||||
logger.warning(
|
||||
"wiki fetch %r attempt %d failed: %s; retry in %.1fs",
|
||||
title, attempt + 1, exc, wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
page = payload.get("page") or {}
|
||||
if not page or page.get("missing"):
|
||||
logger.warning("Wikipedia reports missing page for %r (url=%s)", title, url)
|
||||
return None
|
||||
|
||||
canonical_title = str(page.get("title") or title).strip()
|
||||
body = str(page.get("extract") or "").strip()
|
||||
if not body:
|
||||
logger.warning("Wikipedia returned empty extract for %r", title)
|
||||
return None
|
||||
markdown = f"# {canonical_title}\n\n{body}\n"
|
||||
cache_path.write_text(markdown, encoding="utf-8")
|
||||
return WikiArticle(
|
||||
title=canonical_title,
|
||||
source_url=url,
|
||||
markdown_path=cache_path,
|
||||
n_chars=len(markdown),
|
||||
redirected_from=title if canonical_title != title else None,
|
||||
)
|
||||
|
||||
async def _fetch_extract(
|
||||
self,
|
||||
title: str,
|
||||
*,
|
||||
http: httpx.AsyncClient | None,
|
||||
) -> dict:
|
||||
"""One MW API call. Returns ``{'page': {...}}`` (formatversion=2)."""
|
||||
|
||||
params = {
|
||||
"action": "query",
|
||||
"prop": "extracts",
|
||||
"explaintext": "true",
|
||||
"redirects": "1",
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
"titles": title,
|
||||
}
|
||||
headers = {"User-Agent": USER_AGENT, "Accept": "application/json"}
|
||||
if http is not None:
|
||||
response = await http.get(WIKI_API, params=params, headers=headers, timeout=self._timeout)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.get(WIKI_API, params=params, headers=headers, timeout=self._timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if "error" in data:
|
||||
raise RuntimeError(f"MediaWiki API error: {data['error']!r}")
|
||||
pages = (data.get("query") or {}).get("pages") or []
|
||||
if not pages:
|
||||
return {"page": {}}
|
||||
return {"page": pages[0]}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WIKI_API",
|
||||
"USER_AGENT",
|
||||
"WikiArticle",
|
||||
"WikiFetcher",
|
||||
"cache_filename_for_title",
|
||||
"title_from_url",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue