mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
chore: evals
This commit is contained in:
parent
2402b730fa
commit
3737118050
122 changed files with 22598 additions and 13 deletions
1
surfsense_evals/tests/suites/__init__.py
Normal file
1
surfsense_evals/tests/suites/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
224
surfsense_evals/tests/suites/test_crag_dataset.py
Normal file
224
surfsense_evals/tests/suites/test_crag_dataset.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""Tests for the CRAG dataset loader (parser + sampling).
|
||||
|
||||
The full bz2 download is excluded — these tests synthesise a tiny
|
||||
JSONL-bz2 in a tmp dir and verify the parser / stratified-sampler
|
||||
produce well-shaped objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bz2
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.suites.research.crag.dataset import (
|
||||
CragPage,
|
||||
CragQuestion,
|
||||
iter_questions,
|
||||
stratified_sample,
|
||||
)
|
||||
|
||||
|
||||
def _make_jsonl_bz2(rows: list[dict], tmp_path: Path) -> Path:
|
||||
"""Write ``rows`` as one JSON object per line, bz2-compressed."""
|
||||
|
||||
dest = tmp_path / "fake.jsonl.bz2"
|
||||
payload = "\n".join(json.dumps(r) for r in rows).encode("utf-8")
|
||||
with bz2.open(dest, "wb") as fh:
|
||||
fh.write(payload)
|
||||
return dest
|
||||
|
||||
|
||||
def _row(
|
||||
*,
|
||||
interaction_id: str,
|
||||
query: str,
|
||||
answer: str,
|
||||
domain: str = "movie",
|
||||
question_type: str = "simple",
|
||||
pages: list[dict] | None = None,
|
||||
alt_ans: list[str] | None = None,
|
||||
popularity: str = "head",
|
||||
static_or_dynamic: str = "static",
|
||||
split: int = 0,
|
||||
query_time: str = "2024-04-01",
|
||||
) -> dict:
|
||||
return {
|
||||
"interaction_id": interaction_id,
|
||||
"query_time": query_time,
|
||||
"domain": domain,
|
||||
"question_type": question_type,
|
||||
"static_or_dynamic": static_or_dynamic,
|
||||
"query": query,
|
||||
"answer": answer,
|
||||
"alt_ans": alt_ans or [],
|
||||
"split": split,
|
||||
"popularity": popularity,
|
||||
"search_results": pages or [],
|
||||
}
|
||||
|
||||
|
||||
class TestParser:
|
||||
def test_basic_parse(self, tmp_path: Path) -> None:
|
||||
rows = [
|
||||
_row(
|
||||
interaction_id="abc",
|
||||
query="Who directed Inception?",
|
||||
answer="Christopher Nolan",
|
||||
pages=[{
|
||||
"page_name": "Inception (film)",
|
||||
"page_url": "https://en.wikipedia.org/wiki/Inception",
|
||||
"page_snippet": "snippet",
|
||||
"page_result": "<html>full html</html>",
|
||||
"page_last_modified": "2024-01-01",
|
||||
}],
|
||||
),
|
||||
]
|
||||
path = _make_jsonl_bz2(rows, tmp_path)
|
||||
parsed = iter_questions(path)
|
||||
assert len(parsed) == 1
|
||||
q = parsed[0]
|
||||
assert q.query == "Who directed Inception?"
|
||||
assert q.gold_answer == "Christopher Nolan"
|
||||
assert q.qid == "C00000"
|
||||
assert q.domain == "movie"
|
||||
assert q.question_type == "simple"
|
||||
assert len(q.pages) == 1
|
||||
page = q.pages[0]
|
||||
assert page.page_name == "Inception (film)"
|
||||
assert page.page_url == "https://en.wikipedia.org/wiki/Inception"
|
||||
|
||||
def test_skips_missing_query_or_answer(self, tmp_path: Path) -> None:
|
||||
rows = [
|
||||
_row(interaction_id="1", query="", answer="x"),
|
||||
_row(interaction_id="2", query="ok?", answer=""),
|
||||
_row(interaction_id="3", query="ok?", answer="x"),
|
||||
]
|
||||
path = _make_jsonl_bz2(rows, tmp_path)
|
||||
parsed = iter_questions(path)
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0].interaction_id == "3"
|
||||
|
||||
def test_skips_empty_pages(self, tmp_path: Path) -> None:
|
||||
rows = [
|
||||
_row(
|
||||
interaction_id="x",
|
||||
query="q?",
|
||||
answer="a",
|
||||
pages=[
|
||||
{"page_url": "", "page_result": "<html/>"}, # no URL
|
||||
{"page_url": "https://x.test/", "page_result": ""}, # empty html
|
||||
{"page_url": "https://y.test/", "page_result": "<html>good</html>"},
|
||||
],
|
||||
),
|
||||
]
|
||||
path = _make_jsonl_bz2(rows, tmp_path)
|
||||
parsed = iter_questions(path)
|
||||
assert len(parsed) == 1
|
||||
assert len(parsed[0].pages) == 1
|
||||
assert parsed[0].pages[0].page_url == "https://y.test/"
|
||||
|
||||
def test_alt_answers_parsed(self, tmp_path: Path) -> None:
|
||||
rows = [
|
||||
_row(interaction_id="z", query="q?", answer="42",
|
||||
alt_ans=["forty-two", "42.0"]),
|
||||
]
|
||||
path = _make_jsonl_bz2(rows, tmp_path)
|
||||
parsed = iter_questions(path)
|
||||
assert parsed[0].alt_answers == ["forty-two", "42.0"]
|
||||
|
||||
def test_handles_malformed_line(self, tmp_path: Path) -> None:
|
||||
# Manually construct a bz2 with one valid line and one garbage line.
|
||||
good = json.dumps(_row(interaction_id="ok", query="q?", answer="a"))
|
||||
path = tmp_path / "mixed.jsonl.bz2"
|
||||
with bz2.open(path, "wb") as fh:
|
||||
fh.write(b"not-json{\n")
|
||||
fh.write((good + "\n").encode("utf-8"))
|
||||
parsed = iter_questions(path)
|
||||
# Malformed line is skipped; one good row survives at index 1.
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0].interaction_id == "ok"
|
||||
|
||||
|
||||
class TestPageHash:
|
||||
def test_url_hash_stable(self) -> None:
|
||||
a = CragPage(
|
||||
page_name="A", page_url="https://x.test/p?q=1",
|
||||
page_snippet="", page_html="<html/>",
|
||||
)
|
||||
b = CragPage(
|
||||
page_name="B", page_url="https://x.test/p?q=1",
|
||||
page_snippet="", page_html="<html/>",
|
||||
)
|
||||
assert a.url_hash == b.url_hash
|
||||
assert len(a.url_hash) == 12
|
||||
|
||||
def test_url_hash_unique(self) -> None:
|
||||
a = CragPage(
|
||||
page_name="A", page_url="https://x.test/a", page_snippet="", page_html="<html/>",
|
||||
)
|
||||
b = CragPage(
|
||||
page_name="B", page_url="https://x.test/b", page_snippet="", page_html="<html/>",
|
||||
)
|
||||
assert a.url_hash != b.url_hash
|
||||
|
||||
|
||||
class TestStratifiedSample:
|
||||
def _make_pool(self) -> list[CragQuestion]:
|
||||
out: list[CragQuestion] = []
|
||||
idx = 0
|
||||
# 30 finance/simple, 20 movie/comparison, 5 sports/multi-hop.
|
||||
for n, domain, qtype in (
|
||||
(30, "finance", "simple"),
|
||||
(20, "movie", "comparison"),
|
||||
(5, "sports", "multi-hop"),
|
||||
):
|
||||
for _ in range(n):
|
||||
out.append(CragQuestion(
|
||||
qid=f"C{idx:05d}",
|
||||
interaction_id=f"i{idx}",
|
||||
query_time="2024-01-01",
|
||||
query=f"q{idx}?",
|
||||
gold_answer="a",
|
||||
alt_answers=[],
|
||||
domain=domain,
|
||||
question_type=qtype,
|
||||
static_or_dynamic="static",
|
||||
popularity="head",
|
||||
split=0,
|
||||
raw_index=idx,
|
||||
pages=[],
|
||||
))
|
||||
idx += 1
|
||||
return out
|
||||
|
||||
def test_sample_smaller_than_pool(self) -> None:
|
||||
pool = self._make_pool()
|
||||
sample = stratified_sample(pool, n=15, seed=7)
|
||||
assert len(sample) == 15
|
||||
# Should pull from all three buckets at least once.
|
||||
domains = {q.domain for q in sample}
|
||||
assert domains == {"finance", "movie", "sports"}
|
||||
|
||||
def test_sample_returns_pool_when_n_geq(self) -> None:
|
||||
pool = self._make_pool()
|
||||
sample = stratified_sample(pool, n=999, seed=1)
|
||||
assert len(sample) == len(pool)
|
||||
|
||||
def test_sample_sorted_by_raw_index(self) -> None:
|
||||
pool = self._make_pool()
|
||||
sample = stratified_sample(pool, n=10, seed=42)
|
||||
assert [q.raw_index for q in sample] == sorted(q.raw_index for q in sample)
|
||||
|
||||
def test_sample_deterministic(self) -> None:
|
||||
pool = self._make_pool()
|
||||
s1 = stratified_sample(pool, n=20, seed=11)
|
||||
s2 = stratified_sample(pool, n=20, seed=11)
|
||||
assert [q.qid for q in s1] == [q.qid for q in s2]
|
||||
|
||||
def test_n_zero_or_negative_returns_pool(self) -> None:
|
||||
pool = self._make_pool()
|
||||
assert len(stratified_sample(pool, n=0)) == len(pool)
|
||||
assert len(stratified_sample(pool, n=-1)) == len(pool)
|
||||
259
surfsense_evals/tests/suites/test_crag_dataset_task3.py
Normal file
259
surfsense_evals/tests/suites/test_crag_dataset_task3.py
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
"""Unit tests for CRAG Task 3 streaming dataset loader.
|
||||
|
||||
We don't (and shouldn't) hit the real 7 GB upstream archive in
|
||||
unit tests. Instead we construct tiny tar.bz2 archives split across
|
||||
N parts and verify:
|
||||
|
||||
* ``_MultiPartReader`` correctly stitches N files together.
|
||||
* The streaming path (multi → bz2 → tar → JSONL) yields parsed
|
||||
``CragQuestion`` rows with the right shape.
|
||||
* ``max_questions`` cap is honoured (early break, no greedy read).
|
||||
* ``parts_present`` correctly detects missing/empty parts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bz2
|
||||
import io
|
||||
import json
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.suites.research.crag.dataset_task3 import (
|
||||
_MultiPartReader,
|
||||
iter_questions_task3,
|
||||
parts_present,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: build a tiny synthetic Task 3 archive
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_jsonl_payload(n_rows: int) -> bytes:
|
||||
rows = []
|
||||
for i in range(n_rows):
|
||||
rows.append({
|
||||
"interaction_id": f"int_{i:04d}",
|
||||
"query_time": "2024-01-01 00:00:00",
|
||||
"domain": ["finance", "music", "movie", "sports", "open"][i % 5],
|
||||
"question_type": ["simple", "comparison", "aggregation", "multi-hop"][i % 4],
|
||||
"static_or_dynamic": "static",
|
||||
"popularity": "head",
|
||||
"split": 0,
|
||||
"query": f"Synthetic CRAG question {i}?",
|
||||
"answer": f"answer-{i}",
|
||||
"alt_ans": [f"alt-{i}-a", f"alt-{i}-b"],
|
||||
"search_results": [
|
||||
{
|
||||
"page_name": f"Page {j} for q{i}",
|
||||
"page_url": f"https://example.com/q{i}/p{j}",
|
||||
"page_snippet": "snippet",
|
||||
"page_result": f"<html><body><p>q{i} p{j} body</p></body></html>",
|
||||
"page_last_modified": "",
|
||||
}
|
||||
for j in range(50)
|
||||
],
|
||||
})
|
||||
return b"\n".join(json.dumps(r).encode("utf-8") for r in rows) + b"\n"
|
||||
|
||||
|
||||
def _make_tar_bz2(jsonl_bytes: bytes, *, member_name: str = "data.jsonl") -> bytes:
|
||||
bio = io.BytesIO()
|
||||
with bz2.BZ2File(bio, mode="wb") as bz:
|
||||
with tarfile.open(fileobj=bz, mode="w") as tar:
|
||||
info = tarfile.TarInfo(name=member_name)
|
||||
info.size = len(jsonl_bytes)
|
||||
tar.addfile(info, io.BytesIO(jsonl_bytes))
|
||||
return bio.getvalue()
|
||||
|
||||
|
||||
def _make_tar_bz2_multi(shards: list[tuple[str, bytes]]) -> bytes:
|
||||
"""Build a tar.bz2 archive containing multiple JSONL shards.
|
||||
|
||||
Mirrors the real CRAG Task 3 layout: one tar with N JSONL members
|
||||
named ``crag_task_3_dev_v4_{i}.jsonl`` (or whatever the caller
|
||||
passes in).
|
||||
"""
|
||||
|
||||
bio = io.BytesIO()
|
||||
with bz2.BZ2File(bio, mode="wb") as bz:
|
||||
with tarfile.open(fileobj=bz, mode="w") as tar:
|
||||
for name, payload in shards:
|
||||
info = tarfile.TarInfo(name=name)
|
||||
info.size = len(payload)
|
||||
tar.addfile(info, io.BytesIO(payload))
|
||||
return bio.getvalue()
|
||||
|
||||
|
||||
def _split_into_parts(blob: bytes, n_parts: int) -> list[bytes]:
|
||||
"""Split byte string into N roughly-equal chunks (last gets remainder)."""
|
||||
chunk = max(1, len(blob) // n_parts)
|
||||
parts = [blob[i * chunk : (i + 1) * chunk] for i in range(n_parts - 1)]
|
||||
parts.append(blob[(n_parts - 1) * chunk :])
|
||||
return parts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task3_parts_dir(tmp_path: Path) -> Path:
|
||||
"""A directory containing a 4-part synthetic CRAG Task 3 archive (12 rows)."""
|
||||
blob = _make_tar_bz2(_make_jsonl_payload(12))
|
||||
parts = _split_into_parts(blob, 4)
|
||||
parts_dir = tmp_path / ".raw_cache"
|
||||
parts_dir.mkdir()
|
||||
for i, b in enumerate(parts, start=1):
|
||||
(parts_dir / f"crag_task_3_dev_v4.tar.bz2.part{i}").write_bytes(b)
|
||||
return parts_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _MultiPartReader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMultiPartReader:
|
||||
def test_concatenates_parts_in_order(self, tmp_path: Path) -> None:
|
||||
a = tmp_path / "a"
|
||||
b = tmp_path / "b"
|
||||
c = tmp_path / "c"
|
||||
a.write_bytes(b"hello, ")
|
||||
b.write_bytes(b"streaming ")
|
||||
c.write_bytes(b"world!")
|
||||
with _MultiPartReader([a, b, c]) as r:
|
||||
assert r.read() == b"hello, streaming world!"
|
||||
|
||||
def test_read_n_crosses_part_boundary(self, tmp_path: Path) -> None:
|
||||
a = tmp_path / "a"
|
||||
b = tmp_path / "b"
|
||||
a.write_bytes(b"AAA")
|
||||
b.write_bytes(b"BBBB")
|
||||
with _MultiPartReader([a, b]) as r:
|
||||
# Read 5 bytes — straddles boundary between parts.
|
||||
assert r.read(5) == b"AAABB"
|
||||
assert r.read(5) == b"BB"
|
||||
assert r.read(5) == b""
|
||||
|
||||
def test_close_is_idempotent(self, tmp_path: Path) -> None:
|
||||
a = tmp_path / "a"
|
||||
a.write_bytes(b"x")
|
||||
r = _MultiPartReader([a])
|
||||
r.close()
|
||||
r.close()
|
||||
with pytest.raises(ValueError):
|
||||
r.read(1)
|
||||
|
||||
def test_missing_part_raises(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_MultiPartReader([tmp_path / "does-not-exist"])
|
||||
|
||||
def test_empty_paths_raises(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_MultiPartReader([])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# iter_questions_task3
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task3_multi_shard_dir(tmp_path: Path) -> Path:
|
||||
"""A 4-part archive whose tar contains 3 JSONL shards (4 + 4 + 4 rows)."""
|
||||
payload_a = _make_jsonl_payload(4)
|
||||
payload_b = _make_jsonl_payload(4)
|
||||
payload_c = _make_jsonl_payload(4)
|
||||
blob = _make_tar_bz2_multi([
|
||||
("crag_task_3_dev_v4_0.jsonl", payload_a),
|
||||
("crag_task_3_dev_v4_1.jsonl", payload_b),
|
||||
("crag_task_3_dev_v4_2.jsonl", payload_c),
|
||||
])
|
||||
parts = _split_into_parts(blob, 4)
|
||||
parts_dir = tmp_path / ".raw_cache"
|
||||
parts_dir.mkdir()
|
||||
for i, b in enumerate(parts, start=1):
|
||||
(parts_dir / f"crag_task_3_dev_v4.tar.bz2.part{i}").write_bytes(b)
|
||||
return parts_dir
|
||||
|
||||
|
||||
class TestIterQuestionsTask3:
|
||||
def test_streams_full_archive(self, task3_parts_dir: Path) -> None:
|
||||
questions = iter_questions_task3(task3_parts_dir)
|
||||
assert len(questions) == 12
|
||||
# All questions get the T3_ prefix and 50 pages each.
|
||||
assert all(q.qid.startswith("T3_") for q in questions)
|
||||
assert all(len(q.pages) == 50 for q in questions)
|
||||
# Schema fields preserved.
|
||||
first = questions[0]
|
||||
assert first.query == "Synthetic CRAG question 0?"
|
||||
assert first.gold_answer == "answer-0"
|
||||
assert first.domain == "finance"
|
||||
assert "alt-0-a" in first.alt_answers
|
||||
|
||||
def test_max_questions_caps_early(self, task3_parts_dir: Path) -> None:
|
||||
questions = iter_questions_task3(task3_parts_dir, max_questions=3)
|
||||
assert len(questions) == 3
|
||||
# Sequential indices 0..2 — we don't skip rows.
|
||||
assert [q.raw_index for q in questions] == [0, 1, 2]
|
||||
|
||||
def test_streams_multi_shard_archive(self, task3_multi_shard_dir: Path) -> None:
|
||||
# Three shards × four rows each = twelve rows total.
|
||||
questions = iter_questions_task3(task3_multi_shard_dir)
|
||||
assert len(questions) == 12
|
||||
# raw_index increments monotonically across shards.
|
||||
assert [q.raw_index for q in questions] == list(range(12))
|
||||
# qids are unique and sequential across shards.
|
||||
assert len({q.qid for q in questions}) == 12
|
||||
|
||||
def test_max_questions_short_circuits_first_shard(self, task3_multi_shard_dir: Path) -> None:
|
||||
# Cap < shard size — shouldn't touch shards 1 or 2 at all.
|
||||
questions = iter_questions_task3(task3_multi_shard_dir, max_questions=2)
|
||||
assert len(questions) == 2
|
||||
# Both come from shard 0 (raw_index 0, 1).
|
||||
assert [q.raw_index for q in questions] == [0, 1]
|
||||
|
||||
def test_max_questions_spans_shards(self, task3_multi_shard_dir: Path) -> None:
|
||||
# Cap = 6 → all 4 from shard 0 + first 2 from shard 1.
|
||||
questions = iter_questions_task3(task3_multi_shard_dir, max_questions=6)
|
||||
assert len(questions) == 6
|
||||
assert [q.raw_index for q in questions] == [0, 1, 2, 3, 4, 5]
|
||||
|
||||
def test_raises_when_no_jsonl_member(self, tmp_path: Path) -> None:
|
||||
# Archive containing a non-jsonl member.
|
||||
bio = io.BytesIO()
|
||||
with bz2.BZ2File(bio, mode="wb") as bz:
|
||||
with tarfile.open(fileobj=bz, mode="w") as tar:
|
||||
info = tarfile.TarInfo(name="README.md")
|
||||
payload = b"not jsonl"
|
||||
info.size = len(payload)
|
||||
tar.addfile(info, io.BytesIO(payload))
|
||||
parts_dir = tmp_path / ".raw_cache"
|
||||
parts_dir.mkdir()
|
||||
for i, name in enumerate(
|
||||
("part1", "part2", "part3", "part4"), start=1,
|
||||
):
|
||||
half = len(bio.getvalue()) // 4
|
||||
chunk = bio.getvalue()[(i - 1) * half : i * half if i < 4 else len(bio.getvalue())]
|
||||
(parts_dir / f"crag_task_3_dev_v4.tar.bz2.{name}").write_bytes(chunk)
|
||||
with pytest.raises(RuntimeError, match="No JSONL member"):
|
||||
iter_questions_task3(parts_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parts_present
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPartsPresent:
|
||||
def test_all_present(self, task3_parts_dir: Path) -> None:
|
||||
assert parts_present(task3_parts_dir) is True
|
||||
|
||||
def test_one_missing(self, task3_parts_dir: Path) -> None:
|
||||
(task3_parts_dir / "crag_task_3_dev_v4.tar.bz2.part2").unlink()
|
||||
assert parts_present(task3_parts_dir) is False
|
||||
|
||||
def test_one_empty(self, task3_parts_dir: Path) -> None:
|
||||
(task3_parts_dir / "crag_task_3_dev_v4.tar.bz2.part3").write_bytes(b"")
|
||||
assert parts_present(task3_parts_dir) is False
|
||||
248
surfsense_evals/tests/suites/test_crag_grader.py
Normal file
248
surfsense_evals/tests/suites/test_crag_grader.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""Tests for the CRAG 3-class deterministic grader.
|
||||
|
||||
The LLM-judge fallback is excluded here (network call); these tests
|
||||
exercise the deterministic shortcut + the special-case routing for
|
||||
``false_premise`` questions and refusal detection (``I don't know``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.suites.research.crag.grader import (
|
||||
CragGradeResult,
|
||||
_flags_false_premise,
|
||||
_is_refusal,
|
||||
_maybe_number,
|
||||
_normalise,
|
||||
_whole_word_substring,
|
||||
grade_deterministic,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalisation:
|
||||
def test_lowercase_and_punct_stripped(self) -> None:
|
||||
assert _normalise("Apple Inc.") == "apple inc"
|
||||
|
||||
def test_articles_removed(self) -> None:
|
||||
assert _normalise("The Apple Watch") == "apple watch"
|
||||
|
||||
def test_empty_returns_empty(self) -> None:
|
||||
assert _normalise("") == ""
|
||||
|
||||
|
||||
class TestNumericExtraction:
|
||||
def test_simple_int(self) -> None:
|
||||
assert _maybe_number("42") == 42.0
|
||||
|
||||
def test_int_with_commas(self) -> None:
|
||||
assert _maybe_number("$1,234") == 1234.0
|
||||
|
||||
def test_year_in_sentence(self) -> None:
|
||||
assert _maybe_number("released in 2008") == 2008.0
|
||||
|
||||
def test_word_number(self) -> None:
|
||||
assert _maybe_number("seven") == 7.0
|
||||
|
||||
|
||||
class TestWholeWordSubstring:
|
||||
def test_phrase_match(self) -> None:
|
||||
assert _whole_word_substring("the new york yankees", "new york")
|
||||
|
||||
def test_word_boundary_required(self) -> None:
|
||||
assert not _whole_word_substring("yorkshire", "york")
|
||||
|
||||
|
||||
class TestRefusalDetection:
|
||||
def test_explicit_idk(self) -> None:
|
||||
assert _is_refusal("Answer: I don't know")
|
||||
|
||||
def test_idk_no_apostrophe(self) -> None:
|
||||
assert _is_refusal("I dont know")
|
||||
|
||||
def test_no_information(self) -> None:
|
||||
assert _is_refusal("There is no information available about this.")
|
||||
|
||||
def test_unable_to_answer(self) -> None:
|
||||
assert _is_refusal("I am unable to answer this question.")
|
||||
|
||||
def test_empty_is_refusal(self) -> None:
|
||||
assert _is_refusal("")
|
||||
assert _is_refusal(" ")
|
||||
|
||||
def test_real_answer_is_not_refusal(self) -> None:
|
||||
assert not _is_refusal("Answer: Apple Inc")
|
||||
assert not _is_refusal("The CEO is Tim Cook.")
|
||||
|
||||
|
||||
class TestFalsePremiseDetection:
|
||||
def test_explicit_false_premise(self) -> None:
|
||||
assert _flags_false_premise(
|
||||
"The question contains a false premise; the company never had that product."
|
||||
)
|
||||
|
||||
def test_no_such(self) -> None:
|
||||
assert _flags_false_premise("There is no such album.")
|
||||
|
||||
def test_did_not_happen(self) -> None:
|
||||
assert _flags_false_premise("That event did not happen.")
|
||||
|
||||
def test_does_not_exist(self) -> None:
|
||||
assert _flags_false_premise("That movie does not exist.")
|
||||
|
||||
def test_normal_answer_is_not_premise_flag(self) -> None:
|
||||
assert not _flags_false_premise("Apple, headquartered in Cupertino.")
|
||||
|
||||
|
||||
class TestGradeDeterministicHappyPath:
|
||||
def test_exact_match_correct(self) -> None:
|
||||
result = grade_deterministic(pred="Tim Cook", gold="Tim Cook", question_type="simple")
|
||||
assert result.grade == "correct"
|
||||
assert result.score == 1
|
||||
assert result.method == "exact"
|
||||
|
||||
def test_substring_match(self) -> None:
|
||||
result = grade_deterministic(
|
||||
pred="The answer is Tim Cook, CEO of Apple.",
|
||||
gold="Tim Cook",
|
||||
question_type="simple",
|
||||
)
|
||||
assert result.grade == "correct"
|
||||
assert result.method == "substring"
|
||||
|
||||
def test_alt_answer_match(self) -> None:
|
||||
result = grade_deterministic(
|
||||
pred="2,008",
|
||||
gold="two thousand eight",
|
||||
alt_answers=["2008"],
|
||||
question_type="simple",
|
||||
)
|
||||
assert result.grade == "correct"
|
||||
assert result.score == 1
|
||||
|
||||
def test_numeric_within_tolerance(self) -> None:
|
||||
result = grade_deterministic(
|
||||
pred="The revenue was $1,234,000 USD",
|
||||
gold="$1,234,123",
|
||||
question_type="aggregation",
|
||||
)
|
||||
assert result.grade == "correct"
|
||||
assert result.method == "numeric"
|
||||
|
||||
def test_numeric_outside_tolerance(self) -> None:
|
||||
result = grade_deterministic(
|
||||
pred="100",
|
||||
gold="500",
|
||||
question_type="aggregation",
|
||||
)
|
||||
assert result.grade == "incorrect"
|
||||
assert result.score == -1
|
||||
|
||||
def test_numeric_strict_small_currency(self) -> None:
|
||||
# CRAG (unlike FRAMES) does not apply a 0.5 absolute floor —
|
||||
# ``$2.05`` should NOT match ``$2.17`` (≈5.5% off, well over 1%).
|
||||
result = grade_deterministic(
|
||||
pred="$2.05",
|
||||
gold="$2.17",
|
||||
question_type="simple",
|
||||
)
|
||||
# Falls through to lexical_miss (no substring overlap either).
|
||||
assert result.grade == "incorrect"
|
||||
assert result.method == "lexical_miss"
|
||||
|
||||
|
||||
class TestGradeDeterministicRefusal:
|
||||
def test_idk_maps_to_missing(self) -> None:
|
||||
result = grade_deterministic(
|
||||
pred="I don't know.", gold="Tim Cook", question_type="simple",
|
||||
)
|
||||
assert result.grade == "missing"
|
||||
assert result.score == 0
|
||||
assert result.method == "refusal"
|
||||
|
||||
def test_empty_pred_maps_to_missing(self) -> None:
|
||||
result = grade_deterministic(pred="", gold="Tim Cook", question_type="simple")
|
||||
assert result.grade == "missing"
|
||||
|
||||
def test_no_information_maps_to_missing(self) -> None:
|
||||
result = grade_deterministic(
|
||||
pred="There is not enough information to answer.",
|
||||
gold="42",
|
||||
question_type="simple",
|
||||
)
|
||||
assert result.grade == "missing"
|
||||
|
||||
|
||||
class TestGradeDeterministicFalsePremise:
|
||||
def test_flagging_premise_is_correct(self) -> None:
|
||||
result = grade_deterministic(
|
||||
pred="The question contains a false premise; that movie does not exist.",
|
||||
gold="invalid question",
|
||||
question_type="false_premise",
|
||||
)
|
||||
assert result.grade == "correct"
|
||||
assert result.method == "false_premise_flagged"
|
||||
|
||||
def test_committing_to_false_answer_is_unclear(self) -> None:
|
||||
# Should land in false_premise_unclear → judge fallback territory.
|
||||
result = grade_deterministic(
|
||||
pred="The album was released in 2010.",
|
||||
gold="invalid question",
|
||||
question_type="false_premise",
|
||||
)
|
||||
assert result.grade == "incorrect"
|
||||
assert result.method == "false_premise_unclear"
|
||||
|
||||
def test_idk_on_false_premise_is_missing(self) -> None:
|
||||
# Refusal precedes false-premise routing.
|
||||
result = grade_deterministic(
|
||||
pred="I don't know.",
|
||||
gold="invalid question",
|
||||
question_type="false_premise",
|
||||
)
|
||||
assert result.grade == "missing"
|
||||
|
||||
|
||||
class TestGradeDeterministicLexicalMiss:
|
||||
def test_unknown_paraphrase_routes_to_judge(self) -> None:
|
||||
result = grade_deterministic(
|
||||
pred="It is the technology giant in Cupertino.",
|
||||
gold="Apple Inc",
|
||||
question_type="simple",
|
||||
)
|
||||
# Without a judge, we fall through to lexical_miss → incorrect.
|
||||
assert result.grade == "incorrect"
|
||||
assert result.method == "lexical_miss"
|
||||
|
||||
def test_short_pred_no_substring_credit(self) -> None:
|
||||
# Reverse-substring path requires len >= 3 to credit.
|
||||
result = grade_deterministic(
|
||||
pred="JK",
|
||||
gold="JK Rowling",
|
||||
question_type="simple",
|
||||
)
|
||||
assert result.grade == "incorrect"
|
||||
|
||||
|
||||
class TestGradeResultShape:
|
||||
def test_to_dict_round_trip(self) -> None:
|
||||
result = CragGradeResult(
|
||||
grade="correct", score=1, method="exact",
|
||||
normalised_pred="x", normalised_gold="x",
|
||||
)
|
||||
d = result.to_dict()
|
||||
assert d["grade"] == "correct"
|
||||
assert d["score"] == 1
|
||||
assert d["method"] == "exact"
|
||||
|
||||
def test_score_matches_grade(self) -> None:
|
||||
# Construct via grader so the score field is populated correctly.
|
||||
for gold, pred, want_grade in (
|
||||
("hi", "hi", "correct"),
|
||||
("hi", "I don't know", "missing"),
|
||||
("hi", "bye", "incorrect"),
|
||||
):
|
||||
result = grade_deterministic(pred=pred, gold=gold, question_type="simple")
|
||||
assert result.grade == want_grade
|
||||
expected_score = {"correct": 1, "missing": 0, "incorrect": -1}[want_grade]
|
||||
assert result.score == expected_score
|
||||
149
surfsense_evals/tests/suites/test_crag_html_extract.py
Normal file
149
surfsense_evals/tests/suites/test_crag_html_extract.py
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
"""Tests for the CRAG HTML extractor.
|
||||
|
||||
We don't network-fetch trafilatura; we just verify the wrapper:
|
||||
|
||||
* Strips obvious boilerplate (nav/footer/scripts) out of the result.
|
||||
* Falls back to the stdlib stripper on degenerate input.
|
||||
* Caps output at the configured ceiling.
|
||||
* Always prepends a metadata header (``# title``) when content is
|
||||
produced.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.suites.research.crag.html_extract import (
|
||||
extract_main_content,
|
||||
)
|
||||
|
||||
|
||||
_RICH_HTML = """\
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Apple Q3 Earnings</title>
|
||||
<script>const a=1;</script>
|
||||
<style>body{font-family:sans;}</style>
|
||||
</head>
|
||||
<body>
|
||||
<nav><a href="/home">Home</a><a href="/about">About</a></nav>
|
||||
<header><h1>Tech News Site</h1><p>Subscribe to our newsletter</p></header>
|
||||
<main>
|
||||
<article>
|
||||
<h1>Apple posts $90B revenue in Q3 2024</h1>
|
||||
<p>Apple Inc. announced its Q3 2024 financial results today, reporting
|
||||
$90 billion in revenue, beating analyst expectations of $87 billion.</p>
|
||||
<p>The company saw growth across iPhone, services, and wearables.
|
||||
CEO Tim Cook attributed the performance to strong demand in emerging
|
||||
markets, particularly India.</p>
|
||||
<h2>Segment breakdown</h2>
|
||||
<ul>
|
||||
<li>iPhone: $45B</li>
|
||||
<li>Services: $24B</li>
|
||||
<li>Mac: $7B</li>
|
||||
</ul>
|
||||
</article>
|
||||
</main>
|
||||
<footer><p>Copyright 2024 Tech News Site. All rights reserved.</p></footer>
|
||||
</body></html>
|
||||
"""
|
||||
|
||||
|
||||
class TestExtractMainContent:
|
||||
def test_extracts_main_article(self) -> None:
|
||||
result = extract_main_content(
|
||||
_RICH_HTML,
|
||||
url="https://example.com/apple",
|
||||
page_name="Apple Q3 Earnings",
|
||||
)
|
||||
assert result.ok
|
||||
assert "Apple" in result.text
|
||||
assert "Q3 2024" in result.text
|
||||
# Header line is prepended.
|
||||
assert result.text.startswith("# Apple Q3 Earnings")
|
||||
assert "Source: https://example.com/apple" in result.text
|
||||
|
||||
def test_strips_boilerplate(self) -> None:
|
||||
result = extract_main_content(
|
||||
_RICH_HTML,
|
||||
url="https://example.com/apple",
|
||||
page_name="Apple Q3 Earnings",
|
||||
)
|
||||
assert result.ok
|
||||
# Boilerplate strings should NOT make it through.
|
||||
assert "Subscribe to our newsletter" not in result.text
|
||||
assert "Copyright 2024 Tech News Site" not in result.text
|
||||
assert "const a=1" not in result.text # script content
|
||||
|
||||
def test_includes_last_modified_when_provided(self) -> None:
|
||||
result = extract_main_content(
|
||||
_RICH_HTML,
|
||||
url="https://example.com/apple",
|
||||
page_name="Apple Q3 Earnings",
|
||||
last_modified="2024-08-01",
|
||||
)
|
||||
assert "Last modified: 2024-08-01" in result.text
|
||||
|
||||
def test_empty_html_returns_empty_result(self) -> None:
|
||||
result = extract_main_content("", url="https://x.test/")
|
||||
assert not result.ok
|
||||
assert result.method == "empty"
|
||||
assert result.n_chars == 0
|
||||
|
||||
def test_whitespace_only_html_is_empty(self) -> None:
|
||||
result = extract_main_content(" \n ", url="https://x.test/")
|
||||
assert not result.ok
|
||||
|
||||
def test_garbage_html_falls_back(self) -> None:
|
||||
# Trafilatura should reject this, fallback strip should still yield text.
|
||||
result = extract_main_content(
|
||||
"<<weird>>not a tag>>>The brown fox<<jumped<<",
|
||||
url="https://x.test/garbage",
|
||||
page_name="Garbage",
|
||||
)
|
||||
# Either trafilatura recovers something or fallback_strip does.
|
||||
if result.ok:
|
||||
assert "brown fox" in result.text or "jumped" in result.text
|
||||
|
||||
|
||||
class TestFallbackStripper:
|
||||
def test_extract_when_no_clear_main(self) -> None:
|
||||
html = """
|
||||
<html><body>
|
||||
<p>This is content one.</p>
|
||||
<p>This is content two.</p>
|
||||
</body></html>
|
||||
"""
|
||||
result = extract_main_content(
|
||||
html, url="https://x.test/", page_name="Title",
|
||||
)
|
||||
assert result.ok
|
||||
assert "content one" in result.text
|
||||
assert "content two" in result.text
|
||||
|
||||
def test_html_entities_decoded(self) -> None:
|
||||
html = """<html><body>
|
||||
<article>
|
||||
<p>Tom & Jerry — classic cartoon © 1940.</p>
|
||||
<p>It's a story about a cat <Tom> and a mouse <Jerry>.</p>
|
||||
</article>
|
||||
</body></html>"""
|
||||
result = extract_main_content(html, url="https://x.test/")
|
||||
assert result.ok
|
||||
# & should be decoded
|
||||
assert "&" not in result.text
|
||||
assert "Tom" in result.text and "Jerry" in result.text
|
||||
|
||||
|
||||
class TestOutputCapping:
|
||||
def test_long_output_is_truncated(self) -> None:
|
||||
# Generate enough content to exceed 200k cap.
|
||||
body = "<p>" + ("hello world " * 50_000) + "</p>"
|
||||
html = f"<html><body><article>{body}</article></body></html>"
|
||||
result = extract_main_content(html, url="https://x.test/", page_name="long")
|
||||
assert result.ok
|
||||
# The body text itself + the metadata header. Truncation marker
|
||||
# appears either at the body limit or before EOF.
|
||||
if "[...truncated...]" in result.text:
|
||||
# The truncation kicked in.
|
||||
assert len(result.text) <= 250_000 # header + 200k cap + slack
|
||||
154
surfsense_evals/tests/suites/test_frames_dataset.py
Normal file
154
surfsense_evals/tests/suites/test_frames_dataset.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""Tests for the FRAMES dataset parser.
|
||||
|
||||
Network-free: we round-trip a tiny fixture TSV through pandas and
|
||||
``load_questions`` to confirm:
|
||||
|
||||
* row indices become zero-padded ``Q###`` ids,
|
||||
* ``wiki_links`` (Python list literal) is materialised correctly,
|
||||
* ``reasoning_types`` is split on the pipe separator,
|
||||
* missing Prompt/Answer rows are dropped, and
|
||||
* the legacy ``wikipedia_link_*`` per-cell fallback works when
|
||||
``wiki_links`` is missing/empty.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.suites.research.frames.dataset import (
|
||||
FramesQuestion,
|
||||
_parse_reasoning_types,
|
||||
_parse_wiki_links,
|
||||
load_questions,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure-function tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseWikiLinks:
|
||||
def test_python_list_literal(self) -> None:
|
||||
s = "['https://en.wikipedia.org/wiki/A', 'https://en.wikipedia.org/wiki/B']"
|
||||
assert _parse_wiki_links(s) == [
|
||||
"https://en.wikipedia.org/wiki/A",
|
||||
"https://en.wikipedia.org/wiki/B",
|
||||
]
|
||||
|
||||
def test_none_or_empty(self) -> None:
|
||||
assert _parse_wiki_links(None) == []
|
||||
assert _parse_wiki_links("") == []
|
||||
assert _parse_wiki_links("[]") == []
|
||||
|
||||
def test_unquoted_csv_fallback(self) -> None:
|
||||
# Defensive: non-Python-list strings still split on commas.
|
||||
s = "https://a, https://b"
|
||||
assert _parse_wiki_links(s) == ["https://a", "https://b"]
|
||||
|
||||
def test_already_a_list(self) -> None:
|
||||
assert _parse_wiki_links(["x", "y"]) == ["x", "y"]
|
||||
|
||||
|
||||
class TestParseReasoningTypes:
|
||||
def test_pipe_separated(self) -> None:
|
||||
assert _parse_reasoning_types("Numerical reasoning | Multiple constraints") == [
|
||||
"Numerical reasoning",
|
||||
"Multiple constraints",
|
||||
]
|
||||
|
||||
def test_single_tag(self) -> None:
|
||||
assert _parse_reasoning_types("Tabular reasoning") == ["Tabular reasoning"]
|
||||
|
||||
def test_empty(self) -> None:
|
||||
assert _parse_reasoning_types(None) == []
|
||||
assert _parse_reasoning_types("") == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip via pandas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _write_tsv(path: Path, body: str) -> None:
|
||||
"""Helper that writes a tab-separated fixture exactly as the user typed it."""
|
||||
|
||||
path.write_text(textwrap.dedent(body), encoding="utf-8")
|
||||
|
||||
|
||||
def test_load_questions_basic(tmp_path: Path) -> None:
|
||||
tsv = tmp_path / "test.tsv"
|
||||
rows = [
|
||||
# Header (first column is unnamed → pandas treats as index)
|
||||
"\tPrompt\tAnswer\twikipedia_link_1\twikipedia_link_2\treasoning_types\twiki_links",
|
||||
# Row 0
|
||||
"0\tWho was the 15th president?\tJames Buchanan\t"
|
||||
"https://en.wikipedia.org/wiki/James_Buchanan\t\t"
|
||||
"Multiple constraints\t"
|
||||
"['https://en.wikipedia.org/wiki/James_Buchanan']",
|
||||
# Row 1
|
||||
"1\tHow many years between A and B?\t87\t"
|
||||
"https://en.wikipedia.org/wiki/A\thttps://en.wikipedia.org/wiki/B\t"
|
||||
"Numerical reasoning | Temporal reasoning\t"
|
||||
"['https://en.wikipedia.org/wiki/A', 'https://en.wikipedia.org/wiki/B']",
|
||||
# Row 2 (intentionally missing Prompt — should be dropped)
|
||||
"2\t\tunused\t\t\t\t",
|
||||
]
|
||||
tsv.write_text("\n".join(rows) + "\n", encoding="utf-8")
|
||||
|
||||
questions = load_questions(tsv)
|
||||
assert len(questions) == 2
|
||||
|
||||
q0, q1 = questions
|
||||
assert isinstance(q0, FramesQuestion)
|
||||
assert q0.qid == "Q000"
|
||||
assert q0.raw_index == 0
|
||||
assert q0.gold_answer == "James Buchanan"
|
||||
assert q0.wiki_urls == ["https://en.wikipedia.org/wiki/James_Buchanan"]
|
||||
assert q0.reasoning_types == ["Multiple constraints"]
|
||||
|
||||
assert q1.qid == "Q001"
|
||||
assert q1.gold_answer == "87"
|
||||
assert q1.wiki_urls == [
|
||||
"https://en.wikipedia.org/wiki/A",
|
||||
"https://en.wikipedia.org/wiki/B",
|
||||
]
|
||||
assert q1.reasoning_types == ["Numerical reasoning", "Temporal reasoning"]
|
||||
|
||||
|
||||
def test_load_questions_falls_back_to_per_cell_links(tmp_path: Path) -> None:
|
||||
"""When ``wiki_links`` is empty, the loader should glue the
|
||||
``wikipedia_link_*`` cells back together."""
|
||||
|
||||
tsv = tmp_path / "test.tsv"
|
||||
rows = [
|
||||
"\tPrompt\tAnswer\twikipedia_link_1\twikipedia_link_2\treasoning_types\twiki_links",
|
||||
"0\tQ?\tA\t"
|
||||
"https://en.wikipedia.org/wiki/Cell1\thttps://en.wikipedia.org/wiki/Cell2\t"
|
||||
"Numerical reasoning\t",
|
||||
]
|
||||
tsv.write_text("\n".join(rows) + "\n", encoding="utf-8")
|
||||
questions = load_questions(tsv)
|
||||
assert len(questions) == 1
|
||||
assert questions[0].wiki_urls == [
|
||||
"https://en.wikipedia.org/wiki/Cell1",
|
||||
"https://en.wikipedia.org/wiki/Cell2",
|
||||
]
|
||||
|
||||
|
||||
def test_load_questions_to_dict_round_trip(tmp_path: Path) -> None:
|
||||
tsv = tmp_path / "test.tsv"
|
||||
rows = [
|
||||
"\tPrompt\tAnswer\treasoning_types\twiki_links",
|
||||
"0\tQ?\tParis\tTemporal reasoning\t['https://en.wikipedia.org/wiki/Paris']",
|
||||
]
|
||||
tsv.write_text("\n".join(rows) + "\n", encoding="utf-8")
|
||||
|
||||
[q] = load_questions(tsv)
|
||||
d = q.to_dict()
|
||||
assert d["qid"] == "Q000"
|
||||
assert d["wiki_urls"] == ["https://en.wikipedia.org/wiki/Paris"]
|
||||
assert d["reasoning_types"] == ["Temporal reasoning"]
|
||||
160
surfsense_evals/tests/suites/test_frames_grader.py
Normal file
160
surfsense_evals/tests/suites/test_frames_grader.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
"""Tests for the FRAMES grader's deterministic shortcut.
|
||||
|
||||
The LLM-judge fallback is excluded here (network call); we just
|
||||
confirm the rule-based path picks up obvious correct/incorrect
|
||||
cases and routes the ambiguous ones to ``lexical_miss`` so the
|
||||
runner knows to consult the judge.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.suites.research.frames.grader import (
|
||||
GradeResult,
|
||||
_maybe_number,
|
||||
_normalise,
|
||||
_whole_word_substring,
|
||||
grade_deterministic,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalisation:
|
||||
def test_lowercase_and_punct_stripped(self) -> None:
|
||||
assert _normalise("Jane Ballou.") == "jane ballou"
|
||||
|
||||
def test_articles_removed(self) -> None:
|
||||
assert _normalise("The Eiffel Tower") == "eiffel tower"
|
||||
|
||||
def test_whitespace_squashed(self) -> None:
|
||||
assert _normalise(" multi space\tinput ") == "multi space input"
|
||||
|
||||
def test_empty_returns_empty(self) -> None:
|
||||
assert _normalise("") == ""
|
||||
assert _normalise(None) == "" # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestNumericExtraction:
|
||||
def test_simple_int(self) -> None:
|
||||
assert _maybe_number("42") == 42.0
|
||||
|
||||
def test_int_with_commas(self) -> None:
|
||||
assert _maybe_number("1,234") == 1234.0
|
||||
|
||||
def test_year_in_sentence(self) -> None:
|
||||
assert _maybe_number("It was published in 1847.") == 1847.0
|
||||
|
||||
def test_word_number(self) -> None:
|
||||
assert _maybe_number("five") == 5.0
|
||||
assert _maybe_number("Twenty") == 20.0
|
||||
|
||||
def test_no_number_returns_none(self) -> None:
|
||||
assert _maybe_number("Jane Ballou") is None
|
||||
assert _maybe_number("") is None
|
||||
|
||||
|
||||
class TestWholeWordSubstring:
|
||||
def test_phrase_match(self) -> None:
|
||||
assert _whole_word_substring("president of the united states", "united states")
|
||||
|
||||
def test_word_boundary_required(self) -> None:
|
||||
# "states" should NOT match inside "statesman"
|
||||
assert not _whole_word_substring("the renowned statesman", "states")
|
||||
|
||||
def test_empty_needle(self) -> None:
|
||||
assert not _whole_word_substring("anything", "")
|
||||
|
||||
|
||||
class TestExactMatch:
|
||||
def test_identical(self) -> None:
|
||||
r = grade_deterministic(pred="Jane Ballou", gold="Jane Ballou")
|
||||
assert r.correct is True
|
||||
assert r.method == "exact"
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
r = grade_deterministic(pred="paris", gold="Paris")
|
||||
assert r.correct is True
|
||||
assert r.method == "exact"
|
||||
|
||||
def test_punctuation_ignored(self) -> None:
|
||||
r = grade_deterministic(pred="Jane Ballou.", gold="Jane Ballou")
|
||||
assert r.correct is True
|
||||
|
||||
|
||||
class TestNumericPath:
|
||||
def test_int_match(self) -> None:
|
||||
r = grade_deterministic(pred="The answer is 87", gold="87")
|
||||
assert r.correct is True
|
||||
assert r.method == "numeric"
|
||||
|
||||
def test_word_number_matches_digit(self) -> None:
|
||||
r = grade_deterministic(pred="five", gold="5")
|
||||
assert r.correct is True
|
||||
assert r.method == "numeric"
|
||||
|
||||
def test_off_by_more_than_tolerance_fails(self) -> None:
|
||||
r = grade_deterministic(pred="86", gold="87")
|
||||
# 86 vs 87, abs diff = 1, tol = max(0.01*87, 0.5) = 0.87 → fails
|
||||
assert r.correct is False
|
||||
assert r.method == "numeric_miss"
|
||||
|
||||
def test_within_one_percent_passes(self) -> None:
|
||||
r = grade_deterministic(pred="100", gold="101")
|
||||
# 1.0 abs diff, tol = max(0.01*101, 0.5) = 1.01 → passes
|
||||
assert r.correct is True
|
||||
|
||||
|
||||
class TestSubstringPath:
|
||||
def test_pred_contains_gold(self) -> None:
|
||||
r = grade_deterministic(
|
||||
pred="The answer is Jane Ballou according to records",
|
||||
gold="Jane Ballou",
|
||||
)
|
||||
assert r.correct is True
|
||||
assert r.method == "substring"
|
||||
|
||||
def test_gold_contains_pred_with_minimum_length(self) -> None:
|
||||
# Gold = "John F Kennedy", pred = "Kennedy" → reverse substring,
|
||||
# ≥3 chars, but the FRAMES style usually accepts this.
|
||||
r = grade_deterministic(pred="Kennedy", gold="John F. Kennedy")
|
||||
assert r.correct is True
|
||||
assert r.method == "substring_reverse"
|
||||
|
||||
def test_too_short_pred_no_reverse_credit(self) -> None:
|
||||
r = grade_deterministic(pred="of", gold="World of Warcraft")
|
||||
# "of" passes length but is a stopword; the article-stripping
|
||||
# normaliser removes it from gold, so substring fails. Either
|
||||
# way, the grader should NOT credit this.
|
||||
assert r.correct is False
|
||||
|
||||
|
||||
class TestLexicalMiss:
|
||||
def test_completely_different_pred_falls_through(self) -> None:
|
||||
r = grade_deterministic(pred="London", gold="Paris")
|
||||
assert r.correct is False
|
||||
assert r.method == "lexical_miss"
|
||||
|
||||
def test_empty_pred(self) -> None:
|
||||
r = grade_deterministic(pred="", gold="Paris")
|
||||
assert r.correct is False
|
||||
assert r.method == "empty_pred"
|
||||
|
||||
def test_empty_gold_defensive(self) -> None:
|
||||
r = grade_deterministic(pred="something", gold="")
|
||||
# Defensive guard — gold should never be empty in practice.
|
||||
assert r.correct is False
|
||||
assert r.method == "empty_gold"
|
||||
|
||||
|
||||
class TestGradeResultShape:
|
||||
def test_dict_has_all_expected_keys(self) -> None:
|
||||
r = grade_deterministic(pred="Paris", gold="Paris")
|
||||
d = r.to_dict()
|
||||
assert set(d) >= {
|
||||
"correct",
|
||||
"f1",
|
||||
"method",
|
||||
"normalised_pred",
|
||||
"normalised_gold",
|
||||
"judge_rationale",
|
||||
}
|
||||
112
surfsense_evals/tests/suites/test_frames_wiki_fetch.py
Normal file
112
surfsense_evals/tests/suites/test_frames_wiki_fetch.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
"""Tests for the FRAMES Wikipedia fetcher.
|
||||
|
||||
We mock the MW API with respx so tests are network-free. Coverage:
|
||||
|
||||
* URL → title parsing (percent-encoded, underscores, redirects)
|
||||
* Filename safety (slashes, special chars)
|
||||
* Cache hit short-circuits the API call
|
||||
* Missing pages return ``None`` (not an exception)
|
||||
* Successful fetches write ``# Title`` markdown to disk
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from surfsense_evals.suites.research.frames.wiki_fetch import (
|
||||
WIKI_API,
|
||||
WikiFetcher,
|
||||
cache_filename_for_title,
|
||||
title_from_url,
|
||||
)
|
||||
|
||||
|
||||
class TestTitleFromUrl:
|
||||
def test_basic(self) -> None:
|
||||
assert title_from_url("https://en.wikipedia.org/wiki/James_Buchanan") == "James Buchanan"
|
||||
|
||||
def test_percent_encoded(self) -> None:
|
||||
assert (
|
||||
title_from_url("https://en.wikipedia.org/wiki/Charlotte_Bront%C3%AB")
|
||||
== "Charlotte Brontë"
|
||||
)
|
||||
|
||||
def test_query_string_dropped(self) -> None:
|
||||
assert title_from_url("https://en.wikipedia.org/wiki/Foo?action=edit") == "Foo"
|
||||
|
||||
def test_non_wiki_raises(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
title_from_url("https://example.com/wiki/Foo")
|
||||
|
||||
|
||||
class TestCacheFilename:
|
||||
def test_simple(self) -> None:
|
||||
assert cache_filename_for_title("James Buchanan") == "James_Buchanan.md"
|
||||
|
||||
def test_unicode_replaced_with_underscore(self) -> None:
|
||||
# Brontë's diaeresis is non-ASCII so the regex replaces it with `_`.
|
||||
# The space → `_` happens after the unicode swap, so the final
|
||||
# name has exactly one underscore for the diaeresis. Acceptable:
|
||||
# filenames stay round-trippable as long as the rule is deterministic.
|
||||
assert cache_filename_for_title("Charlotte Brontë") == "Charlotte_Bront_.md"
|
||||
|
||||
def test_slashes_replaced(self) -> None:
|
||||
# Wikipedia titles can contain ``/`` (e.g. "I/O"), which would
|
||||
# break the filesystem layout if not sanitised.
|
||||
assert cache_filename_for_title("I/O") == "I_O.md"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_fetch_success_writes_markdown(tmp_path: Path) -> None:
|
||||
respx.get(WIKI_API).mock(return_value=httpx.Response(
|
||||
200,
|
||||
json={"query": {"pages": [{
|
||||
"pageid": 1,
|
||||
"title": "James Buchanan",
|
||||
"extract": "James Buchanan was the 15th president of the United States.",
|
||||
}]}},
|
||||
))
|
||||
fetcher = WikiFetcher(cache_dir=tmp_path, rate_limit_rps=100) # disable throttle
|
||||
article = await fetcher.fetch("https://en.wikipedia.org/wiki/James_Buchanan")
|
||||
assert article is not None
|
||||
assert article.title == "James Buchanan"
|
||||
body = article.markdown_path.read_text(encoding="utf-8")
|
||||
assert body.startswith("# James Buchanan")
|
||||
assert "15th president" in body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_fetch_missing_page_returns_none(tmp_path: Path) -> None:
|
||||
respx.get(WIKI_API).mock(return_value=httpx.Response(
|
||||
200,
|
||||
json={"query": {"pages": [{
|
||||
"title": "DoesNotExist",
|
||||
"missing": True,
|
||||
}]}},
|
||||
))
|
||||
fetcher = WikiFetcher(cache_dir=tmp_path, rate_limit_rps=100)
|
||||
article = await fetcher.fetch("https://en.wikipedia.org/wiki/DoesNotExist")
|
||||
assert article is None
|
||||
assert not (tmp_path / "DoesNotExist.md").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_fetch_cache_hit_skips_api(tmp_path: Path) -> None:
|
||||
# Pre-populate the cache.
|
||||
cached = tmp_path / cache_filename_for_title("Cached Page")
|
||||
cached.write_text("# Cached Page\n\nfrom disk\n", encoding="utf-8")
|
||||
fetcher = WikiFetcher(cache_dir=tmp_path, rate_limit_rps=100)
|
||||
|
||||
# No respx mock registered; if the fetcher hits the network, respx
|
||||
# would error out (it intercepts everything inside the decorator).
|
||||
article = await fetcher.fetch("https://en.wikipedia.org/wiki/Cached_Page")
|
||||
assert article is not None
|
||||
assert article.markdown_path == cached
|
||||
assert article.markdown_path.read_text(encoding="utf-8").endswith("from disk\n")
|
||||
129
surfsense_evals/tests/suites/test_mmlongbench_grader.py
Normal file
129
surfsense_evals/tests/suites/test_mmlongbench_grader.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""Tests for the MMLongBench-Doc format-aware grader.
|
||||
|
||||
The grader is the critical correctness piece for the open-ended
|
||||
benchmark (no MCQ shortcut), so we cover all five formats with
|
||||
representative happy-path + edge-case rows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from surfsense_evals.suites.multimodal_doc.mmlongbench.grader import grade
|
||||
|
||||
|
||||
class TestStrFormat:
|
||||
def test_exact_match(self) -> None:
|
||||
r = grade(pred="Apollo 11", gold="Apollo 11", answer_format="Str")
|
||||
assert r.correct is True
|
||||
assert r.f1 == 1.0
|
||||
assert r.method == "str_norm"
|
||||
|
||||
def test_lowercase_normalised(self) -> None:
|
||||
r = grade(pred="paris", gold="Paris", answer_format="Str")
|
||||
assert r.correct is True
|
||||
|
||||
def test_punctuation_difference_drops_to_substring(self) -> None:
|
||||
# "N.A.S.A." normalises to "n a s a" (whitespace tokens) which
|
||||
# doesn't equal "nasa" — but the F1 token overlap is still 0
|
||||
# because none of the single letters appear standalone in "nasa".
|
||||
# We assert the grader fails closed rather than over-claiming.
|
||||
r = grade(pred="N.A.S.A.", gold="NASA", answer_format="Str")
|
||||
assert r.correct is False # explicit: this is a failure mode we accept
|
||||
|
||||
def test_substring_credit(self) -> None:
|
||||
r = grade(pred="The answer is Paris.", gold="Paris", answer_format="Str")
|
||||
assert r.correct is True
|
||||
|
||||
def test_completely_wrong(self) -> None:
|
||||
r = grade(pred="London", gold="Paris", answer_format="Str")
|
||||
assert r.correct is False
|
||||
assert r.f1 < 0.5
|
||||
|
||||
def test_empty_pred(self) -> None:
|
||||
r = grade(pred="", gold="Paris", answer_format="Str")
|
||||
assert r.correct is False
|
||||
assert r.f1 == 0.0
|
||||
|
||||
|
||||
class TestIntFormat:
|
||||
def test_exact_int(self) -> None:
|
||||
assert grade(pred="42", gold="42", answer_format="Int").correct is True
|
||||
|
||||
def test_int_in_sentence(self) -> None:
|
||||
assert grade(pred="The answer is 42 years.", gold="42", answer_format="Int").correct is True
|
||||
|
||||
def test_int_with_commas(self) -> None:
|
||||
assert grade(pred="1,500", gold="1500", answer_format="Int").correct is True
|
||||
|
||||
def test_wrong_int(self) -> None:
|
||||
assert grade(pred="41", gold="42", answer_format="Int").correct is False
|
||||
|
||||
def test_no_int_in_pred(self) -> None:
|
||||
assert grade(pred="not answerable", gold="42", answer_format="Int").correct is False
|
||||
|
||||
|
||||
class TestFloatFormat:
|
||||
def test_exact_float(self) -> None:
|
||||
assert grade(pred="3.14", gold="3.14", answer_format="Float").correct is True
|
||||
|
||||
def test_within_tolerance(self) -> None:
|
||||
# 1% tolerance — 3.14 vs 3.13 is well within.
|
||||
assert grade(pred="3.13", gold="3.14", answer_format="Float").correct is True
|
||||
|
||||
def test_outside_tolerance(self) -> None:
|
||||
assert grade(pred="3.5", gold="3.14", answer_format="Float").correct is False
|
||||
|
||||
def test_european_decimal_comma(self) -> None:
|
||||
# ``3,14`` should parse as 3.14
|
||||
assert grade(pred="3,14", gold="3.14", answer_format="Float").correct is True
|
||||
|
||||
def test_zero_gold_with_small_abs_diff(self) -> None:
|
||||
# Absolute tolerance of 0.01 should kick in for near-zero golds.
|
||||
assert grade(pred="0.005", gold="0", answer_format="Float").correct is True
|
||||
|
||||
|
||||
class TestListFormat:
|
||||
def test_exact_set_match(self) -> None:
|
||||
r = grade(pred="apple, banana, cherry", gold="apple, banana, cherry", answer_format="List")
|
||||
assert r.correct is True
|
||||
assert r.f1 == pytest.approx(1.0)
|
||||
|
||||
def test_set_match_different_order(self) -> None:
|
||||
r = grade(pred="cherry, apple, banana", gold="apple, banana, cherry", answer_format="List")
|
||||
assert r.correct is True
|
||||
|
||||
def test_partial_overlap_gives_f1(self) -> None:
|
||||
r = grade(pred="apple, banana", gold="apple, banana, cherry", answer_format="List")
|
||||
assert r.correct is False
|
||||
assert 0.0 < r.f1 < 1.0
|
||||
|
||||
def test_extra_items_lower_precision(self) -> None:
|
||||
r = grade(pred="apple, banana, cherry, date", gold="apple, banana, cherry", answer_format="List")
|
||||
assert 0.0 < r.f1 < 1.0
|
||||
# Recall=1, precision=3/4 → F1 ~= 0.857
|
||||
assert r.f1 == pytest.approx(2 * (3 / 4) * 1 / (3 / 4 + 1), rel=1e-3)
|
||||
|
||||
|
||||
class TestNoneFormat:
|
||||
def test_unknown_phrase_credited(self) -> None:
|
||||
for phrase in ("Not answerable", "I cannot answer this.", "No answer", "N/A"):
|
||||
r = grade(pred=phrase, gold="Not answerable", answer_format="None")
|
||||
assert r.correct is True, phrase
|
||||
|
||||
def test_actual_answer_marked_wrong(self) -> None:
|
||||
# The arm hallucinated an answer when it should have said "I don't know".
|
||||
r = grade(pred="The answer is 42.", gold="Not answerable", answer_format="None")
|
||||
assert r.correct is False
|
||||
|
||||
|
||||
class TestUnknownFormatFallsBackToStr:
|
||||
def test_blank_format_uses_str_grader(self) -> None:
|
||||
r = grade(pred="Paris", gold="Paris", answer_format="")
|
||||
assert r.correct is True
|
||||
assert r.method == "str_norm"
|
||||
|
||||
def test_garbage_format_uses_str_grader(self) -> None:
|
||||
r = grade(pred="Paris", gold="Paris", answer_format="quux")
|
||||
assert r.correct is True
|
||||
assert r.method == "str_norm"
|
||||
Loading…
Add table
Add a link
Reference in a new issue