mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
224 lines
7.6 KiB
Python
224 lines
7.6 KiB
Python
"""Tests for the CRAG dataset loader (parser + sampling).
|
|
|
|
The full bz2 download is excluded — these tests synthesise a tiny
|
|
JSONL-bz2 in a tmp dir and verify the parser / stratified-sampler
|
|
produce well-shaped objects.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import bz2
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from surfsense_evals.suites.research.crag.dataset import (
|
|
CragPage,
|
|
CragQuestion,
|
|
iter_questions,
|
|
stratified_sample,
|
|
)
|
|
|
|
|
|
def _make_jsonl_bz2(rows: list[dict], tmp_path: Path) -> Path:
|
|
"""Write ``rows`` as one JSON object per line, bz2-compressed."""
|
|
|
|
dest = tmp_path / "fake.jsonl.bz2"
|
|
payload = "\n".join(json.dumps(r) for r in rows).encode("utf-8")
|
|
with bz2.open(dest, "wb") as fh:
|
|
fh.write(payload)
|
|
return dest
|
|
|
|
|
|
def _row(
|
|
*,
|
|
interaction_id: str,
|
|
query: str,
|
|
answer: str,
|
|
domain: str = "movie",
|
|
question_type: str = "simple",
|
|
pages: list[dict] | None = None,
|
|
alt_ans: list[str] | None = None,
|
|
popularity: str = "head",
|
|
static_or_dynamic: str = "static",
|
|
split: int = 0,
|
|
query_time: str = "2024-04-01",
|
|
) -> dict:
|
|
return {
|
|
"interaction_id": interaction_id,
|
|
"query_time": query_time,
|
|
"domain": domain,
|
|
"question_type": question_type,
|
|
"static_or_dynamic": static_or_dynamic,
|
|
"query": query,
|
|
"answer": answer,
|
|
"alt_ans": alt_ans or [],
|
|
"split": split,
|
|
"popularity": popularity,
|
|
"search_results": pages or [],
|
|
}
|
|
|
|
|
|
class TestParser:
|
|
def test_basic_parse(self, tmp_path: Path) -> None:
|
|
rows = [
|
|
_row(
|
|
interaction_id="abc",
|
|
query="Who directed Inception?",
|
|
answer="Christopher Nolan",
|
|
pages=[{
|
|
"page_name": "Inception (film)",
|
|
"page_url": "https://en.wikipedia.org/wiki/Inception",
|
|
"page_snippet": "snippet",
|
|
"page_result": "<html>full html</html>",
|
|
"page_last_modified": "2024-01-01",
|
|
}],
|
|
),
|
|
]
|
|
path = _make_jsonl_bz2(rows, tmp_path)
|
|
parsed = iter_questions(path)
|
|
assert len(parsed) == 1
|
|
q = parsed[0]
|
|
assert q.query == "Who directed Inception?"
|
|
assert q.gold_answer == "Christopher Nolan"
|
|
assert q.qid == "C00000"
|
|
assert q.domain == "movie"
|
|
assert q.question_type == "simple"
|
|
assert len(q.pages) == 1
|
|
page = q.pages[0]
|
|
assert page.page_name == "Inception (film)"
|
|
assert page.page_url == "https://en.wikipedia.org/wiki/Inception"
|
|
|
|
def test_skips_missing_query_or_answer(self, tmp_path: Path) -> None:
|
|
rows = [
|
|
_row(interaction_id="1", query="", answer="x"),
|
|
_row(interaction_id="2", query="ok?", answer=""),
|
|
_row(interaction_id="3", query="ok?", answer="x"),
|
|
]
|
|
path = _make_jsonl_bz2(rows, tmp_path)
|
|
parsed = iter_questions(path)
|
|
assert len(parsed) == 1
|
|
assert parsed[0].interaction_id == "3"
|
|
|
|
def test_skips_empty_pages(self, tmp_path: Path) -> None:
|
|
rows = [
|
|
_row(
|
|
interaction_id="x",
|
|
query="q?",
|
|
answer="a",
|
|
pages=[
|
|
{"page_url": "", "page_result": "<html/>"}, # no URL
|
|
{"page_url": "https://x.test/", "page_result": ""}, # empty html
|
|
{"page_url": "https://y.test/", "page_result": "<html>good</html>"},
|
|
],
|
|
),
|
|
]
|
|
path = _make_jsonl_bz2(rows, tmp_path)
|
|
parsed = iter_questions(path)
|
|
assert len(parsed) == 1
|
|
assert len(parsed[0].pages) == 1
|
|
assert parsed[0].pages[0].page_url == "https://y.test/"
|
|
|
|
def test_alt_answers_parsed(self, tmp_path: Path) -> None:
|
|
rows = [
|
|
_row(interaction_id="z", query="q?", answer="42",
|
|
alt_ans=["forty-two", "42.0"]),
|
|
]
|
|
path = _make_jsonl_bz2(rows, tmp_path)
|
|
parsed = iter_questions(path)
|
|
assert parsed[0].alt_answers == ["forty-two", "42.0"]
|
|
|
|
def test_handles_malformed_line(self, tmp_path: Path) -> None:
|
|
# Manually construct a bz2 with one valid line and one garbage line.
|
|
good = json.dumps(_row(interaction_id="ok", query="q?", answer="a"))
|
|
path = tmp_path / "mixed.jsonl.bz2"
|
|
with bz2.open(path, "wb") as fh:
|
|
fh.write(b"not-json{\n")
|
|
fh.write((good + "\n").encode("utf-8"))
|
|
parsed = iter_questions(path)
|
|
# Malformed line is skipped; one good row survives at index 1.
|
|
assert len(parsed) == 1
|
|
assert parsed[0].interaction_id == "ok"
|
|
|
|
|
|
class TestPageHash:
|
|
def test_url_hash_stable(self) -> None:
|
|
a = CragPage(
|
|
page_name="A", page_url="https://x.test/p?q=1",
|
|
page_snippet="", page_html="<html/>",
|
|
)
|
|
b = CragPage(
|
|
page_name="B", page_url="https://x.test/p?q=1",
|
|
page_snippet="", page_html="<html/>",
|
|
)
|
|
assert a.url_hash == b.url_hash
|
|
assert len(a.url_hash) == 12
|
|
|
|
def test_url_hash_unique(self) -> None:
|
|
a = CragPage(
|
|
page_name="A", page_url="https://x.test/a", page_snippet="", page_html="<html/>",
|
|
)
|
|
b = CragPage(
|
|
page_name="B", page_url="https://x.test/b", page_snippet="", page_html="<html/>",
|
|
)
|
|
assert a.url_hash != b.url_hash
|
|
|
|
|
|
class TestStratifiedSample:
|
|
def _make_pool(self) -> list[CragQuestion]:
|
|
out: list[CragQuestion] = []
|
|
idx = 0
|
|
# 30 finance/simple, 20 movie/comparison, 5 sports/multi-hop.
|
|
for n, domain, qtype in (
|
|
(30, "finance", "simple"),
|
|
(20, "movie", "comparison"),
|
|
(5, "sports", "multi-hop"),
|
|
):
|
|
for _ in range(n):
|
|
out.append(CragQuestion(
|
|
qid=f"C{idx:05d}",
|
|
interaction_id=f"i{idx}",
|
|
query_time="2024-01-01",
|
|
query=f"q{idx}?",
|
|
gold_answer="a",
|
|
alt_answers=[],
|
|
domain=domain,
|
|
question_type=qtype,
|
|
static_or_dynamic="static",
|
|
popularity="head",
|
|
split=0,
|
|
raw_index=idx,
|
|
pages=[],
|
|
))
|
|
idx += 1
|
|
return out
|
|
|
|
def test_sample_smaller_than_pool(self) -> None:
|
|
pool = self._make_pool()
|
|
sample = stratified_sample(pool, n=15, seed=7)
|
|
assert len(sample) == 15
|
|
# Should pull from all three buckets at least once.
|
|
domains = {q.domain for q in sample}
|
|
assert domains == {"finance", "movie", "sports"}
|
|
|
|
def test_sample_returns_pool_when_n_geq(self) -> None:
|
|
pool = self._make_pool()
|
|
sample = stratified_sample(pool, n=999, seed=1)
|
|
assert len(sample) == len(pool)
|
|
|
|
def test_sample_sorted_by_raw_index(self) -> None:
|
|
pool = self._make_pool()
|
|
sample = stratified_sample(pool, n=10, seed=42)
|
|
assert [q.raw_index for q in sample] == sorted(q.raw_index for q in sample)
|
|
|
|
def test_sample_deterministic(self) -> None:
|
|
pool = self._make_pool()
|
|
s1 = stratified_sample(pool, n=20, seed=11)
|
|
s2 = stratified_sample(pool, n=20, seed=11)
|
|
assert [q.qid for q in s1] == [q.qid for q in s2]
|
|
|
|
def test_n_zero_or_negative_returns_pool(self) -> None:
|
|
pool = self._make_pool()
|
|
assert len(stratified_sample(pool, n=0)) == len(pool)
|
|
assert len(stratified_sample(pool, n=-1)) == len(pool)
|