mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-21 18:55:16 +02:00
260 lines
9.9 KiB
Python
260 lines
9.9 KiB
Python
|
|
"""Unit tests for CRAG Task 3 streaming dataset loader.
|
|||
|
|
|
|||
|
|
We don't (and shouldn't) hit the real 7 GB upstream archive in
|
|||
|
|
unit tests. Instead we construct tiny tar.bz2 archives split across
|
|||
|
|
N parts and verify:
|
|||
|
|
|
|||
|
|
* ``_MultiPartReader`` correctly stitches N files together.
|
|||
|
|
* The streaming path (multi → bz2 → tar → JSONL) yields parsed
|
|||
|
|
``CragQuestion`` rows with the right shape.
|
|||
|
|
* ``max_questions`` cap is honoured (early break, no greedy read).
|
|||
|
|
* ``parts_present`` correctly detects missing/empty parts.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import bz2
|
|||
|
|
import io
|
|||
|
|
import json
|
|||
|
|
import tarfile
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
from surfsense_evals.suites.research.crag.dataset_task3 import (
|
|||
|
|
_MultiPartReader,
|
|||
|
|
iter_questions_task3,
|
|||
|
|
parts_present,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Fixtures: build a tiny synthetic Task 3 archive
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _make_jsonl_payload(n_rows: int) -> bytes:
|
|||
|
|
rows = []
|
|||
|
|
for i in range(n_rows):
|
|||
|
|
rows.append({
|
|||
|
|
"interaction_id": f"int_{i:04d}",
|
|||
|
|
"query_time": "2024-01-01 00:00:00",
|
|||
|
|
"domain": ["finance", "music", "movie", "sports", "open"][i % 5],
|
|||
|
|
"question_type": ["simple", "comparison", "aggregation", "multi-hop"][i % 4],
|
|||
|
|
"static_or_dynamic": "static",
|
|||
|
|
"popularity": "head",
|
|||
|
|
"split": 0,
|
|||
|
|
"query": f"Synthetic CRAG question {i}?",
|
|||
|
|
"answer": f"answer-{i}",
|
|||
|
|
"alt_ans": [f"alt-{i}-a", f"alt-{i}-b"],
|
|||
|
|
"search_results": [
|
|||
|
|
{
|
|||
|
|
"page_name": f"Page {j} for q{i}",
|
|||
|
|
"page_url": f"https://example.com/q{i}/p{j}",
|
|||
|
|
"page_snippet": "snippet",
|
|||
|
|
"page_result": f"<html><body><p>q{i} p{j} body</p></body></html>",
|
|||
|
|
"page_last_modified": "",
|
|||
|
|
}
|
|||
|
|
for j in range(50)
|
|||
|
|
],
|
|||
|
|
})
|
|||
|
|
return b"\n".join(json.dumps(r).encode("utf-8") for r in rows) + b"\n"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _make_tar_bz2(jsonl_bytes: bytes, *, member_name: str = "data.jsonl") -> bytes:
|
|||
|
|
bio = io.BytesIO()
|
|||
|
|
with bz2.BZ2File(bio, mode="wb") as bz:
|
|||
|
|
with tarfile.open(fileobj=bz, mode="w") as tar:
|
|||
|
|
info = tarfile.TarInfo(name=member_name)
|
|||
|
|
info.size = len(jsonl_bytes)
|
|||
|
|
tar.addfile(info, io.BytesIO(jsonl_bytes))
|
|||
|
|
return bio.getvalue()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _make_tar_bz2_multi(shards: list[tuple[str, bytes]]) -> bytes:
|
|||
|
|
"""Build a tar.bz2 archive containing multiple JSONL shards.
|
|||
|
|
|
|||
|
|
Mirrors the real CRAG Task 3 layout: one tar with N JSONL members
|
|||
|
|
named ``crag_task_3_dev_v4_{i}.jsonl`` (or whatever the caller
|
|||
|
|
passes in).
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
bio = io.BytesIO()
|
|||
|
|
with bz2.BZ2File(bio, mode="wb") as bz:
|
|||
|
|
with tarfile.open(fileobj=bz, mode="w") as tar:
|
|||
|
|
for name, payload in shards:
|
|||
|
|
info = tarfile.TarInfo(name=name)
|
|||
|
|
info.size = len(payload)
|
|||
|
|
tar.addfile(info, io.BytesIO(payload))
|
|||
|
|
return bio.getvalue()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _split_into_parts(blob: bytes, n_parts: int) -> list[bytes]:
|
|||
|
|
"""Split byte string into N roughly-equal chunks (last gets remainder)."""
|
|||
|
|
chunk = max(1, len(blob) // n_parts)
|
|||
|
|
parts = [blob[i * chunk : (i + 1) * chunk] for i in range(n_parts - 1)]
|
|||
|
|
parts.append(blob[(n_parts - 1) * chunk :])
|
|||
|
|
return parts
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def task3_parts_dir(tmp_path: Path) -> Path:
|
|||
|
|
"""A directory containing a 4-part synthetic CRAG Task 3 archive (12 rows)."""
|
|||
|
|
blob = _make_tar_bz2(_make_jsonl_payload(12))
|
|||
|
|
parts = _split_into_parts(blob, 4)
|
|||
|
|
parts_dir = tmp_path / ".raw_cache"
|
|||
|
|
parts_dir.mkdir()
|
|||
|
|
for i, b in enumerate(parts, start=1):
|
|||
|
|
(parts_dir / f"crag_task_3_dev_v4.tar.bz2.part{i}").write_bytes(b)
|
|||
|
|
return parts_dir
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# _MultiPartReader
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestMultiPartReader:
|
|||
|
|
def test_concatenates_parts_in_order(self, tmp_path: Path) -> None:
|
|||
|
|
a = tmp_path / "a"
|
|||
|
|
b = tmp_path / "b"
|
|||
|
|
c = tmp_path / "c"
|
|||
|
|
a.write_bytes(b"hello, ")
|
|||
|
|
b.write_bytes(b"streaming ")
|
|||
|
|
c.write_bytes(b"world!")
|
|||
|
|
with _MultiPartReader([a, b, c]) as r:
|
|||
|
|
assert r.read() == b"hello, streaming world!"
|
|||
|
|
|
|||
|
|
def test_read_n_crosses_part_boundary(self, tmp_path: Path) -> None:
|
|||
|
|
a = tmp_path / "a"
|
|||
|
|
b = tmp_path / "b"
|
|||
|
|
a.write_bytes(b"AAA")
|
|||
|
|
b.write_bytes(b"BBBB")
|
|||
|
|
with _MultiPartReader([a, b]) as r:
|
|||
|
|
# Read 5 bytes — straddles boundary between parts.
|
|||
|
|
assert r.read(5) == b"AAABB"
|
|||
|
|
assert r.read(5) == b"BB"
|
|||
|
|
assert r.read(5) == b""
|
|||
|
|
|
|||
|
|
def test_close_is_idempotent(self, tmp_path: Path) -> None:
|
|||
|
|
a = tmp_path / "a"
|
|||
|
|
a.write_bytes(b"x")
|
|||
|
|
r = _MultiPartReader([a])
|
|||
|
|
r.close()
|
|||
|
|
r.close()
|
|||
|
|
with pytest.raises(ValueError):
|
|||
|
|
r.read(1)
|
|||
|
|
|
|||
|
|
def test_missing_part_raises(self, tmp_path: Path) -> None:
|
|||
|
|
with pytest.raises(FileNotFoundError):
|
|||
|
|
_MultiPartReader([tmp_path / "does-not-exist"])
|
|||
|
|
|
|||
|
|
def test_empty_paths_raises(self) -> None:
|
|||
|
|
with pytest.raises(ValueError):
|
|||
|
|
_MultiPartReader([])
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# iter_questions_task3
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def task3_multi_shard_dir(tmp_path: Path) -> Path:
|
|||
|
|
"""A 4-part archive whose tar contains 3 JSONL shards (4 + 4 + 4 rows)."""
|
|||
|
|
payload_a = _make_jsonl_payload(4)
|
|||
|
|
payload_b = _make_jsonl_payload(4)
|
|||
|
|
payload_c = _make_jsonl_payload(4)
|
|||
|
|
blob = _make_tar_bz2_multi([
|
|||
|
|
("crag_task_3_dev_v4_0.jsonl", payload_a),
|
|||
|
|
("crag_task_3_dev_v4_1.jsonl", payload_b),
|
|||
|
|
("crag_task_3_dev_v4_2.jsonl", payload_c),
|
|||
|
|
])
|
|||
|
|
parts = _split_into_parts(blob, 4)
|
|||
|
|
parts_dir = tmp_path / ".raw_cache"
|
|||
|
|
parts_dir.mkdir()
|
|||
|
|
for i, b in enumerate(parts, start=1):
|
|||
|
|
(parts_dir / f"crag_task_3_dev_v4.tar.bz2.part{i}").write_bytes(b)
|
|||
|
|
return parts_dir
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestIterQuestionsTask3:
|
|||
|
|
def test_streams_full_archive(self, task3_parts_dir: Path) -> None:
|
|||
|
|
questions = iter_questions_task3(task3_parts_dir)
|
|||
|
|
assert len(questions) == 12
|
|||
|
|
# All questions get the T3_ prefix and 50 pages each.
|
|||
|
|
assert all(q.qid.startswith("T3_") for q in questions)
|
|||
|
|
assert all(len(q.pages) == 50 for q in questions)
|
|||
|
|
# Schema fields preserved.
|
|||
|
|
first = questions[0]
|
|||
|
|
assert first.query == "Synthetic CRAG question 0?"
|
|||
|
|
assert first.gold_answer == "answer-0"
|
|||
|
|
assert first.domain == "finance"
|
|||
|
|
assert "alt-0-a" in first.alt_answers
|
|||
|
|
|
|||
|
|
def test_max_questions_caps_early(self, task3_parts_dir: Path) -> None:
|
|||
|
|
questions = iter_questions_task3(task3_parts_dir, max_questions=3)
|
|||
|
|
assert len(questions) == 3
|
|||
|
|
# Sequential indices 0..2 — we don't skip rows.
|
|||
|
|
assert [q.raw_index for q in questions] == [0, 1, 2]
|
|||
|
|
|
|||
|
|
def test_streams_multi_shard_archive(self, task3_multi_shard_dir: Path) -> None:
|
|||
|
|
# Three shards × four rows each = twelve rows total.
|
|||
|
|
questions = iter_questions_task3(task3_multi_shard_dir)
|
|||
|
|
assert len(questions) == 12
|
|||
|
|
# raw_index increments monotonically across shards.
|
|||
|
|
assert [q.raw_index for q in questions] == list(range(12))
|
|||
|
|
# qids are unique and sequential across shards.
|
|||
|
|
assert len({q.qid for q in questions}) == 12
|
|||
|
|
|
|||
|
|
def test_max_questions_short_circuits_first_shard(self, task3_multi_shard_dir: Path) -> None:
|
|||
|
|
# Cap < shard size — shouldn't touch shards 1 or 2 at all.
|
|||
|
|
questions = iter_questions_task3(task3_multi_shard_dir, max_questions=2)
|
|||
|
|
assert len(questions) == 2
|
|||
|
|
# Both come from shard 0 (raw_index 0, 1).
|
|||
|
|
assert [q.raw_index for q in questions] == [0, 1]
|
|||
|
|
|
|||
|
|
def test_max_questions_spans_shards(self, task3_multi_shard_dir: Path) -> None:
|
|||
|
|
# Cap = 6 → all 4 from shard 0 + first 2 from shard 1.
|
|||
|
|
questions = iter_questions_task3(task3_multi_shard_dir, max_questions=6)
|
|||
|
|
assert len(questions) == 6
|
|||
|
|
assert [q.raw_index for q in questions] == [0, 1, 2, 3, 4, 5]
|
|||
|
|
|
|||
|
|
def test_raises_when_no_jsonl_member(self, tmp_path: Path) -> None:
|
|||
|
|
# Archive containing a non-jsonl member.
|
|||
|
|
bio = io.BytesIO()
|
|||
|
|
with bz2.BZ2File(bio, mode="wb") as bz:
|
|||
|
|
with tarfile.open(fileobj=bz, mode="w") as tar:
|
|||
|
|
info = tarfile.TarInfo(name="README.md")
|
|||
|
|
payload = b"not jsonl"
|
|||
|
|
info.size = len(payload)
|
|||
|
|
tar.addfile(info, io.BytesIO(payload))
|
|||
|
|
parts_dir = tmp_path / ".raw_cache"
|
|||
|
|
parts_dir.mkdir()
|
|||
|
|
for i, name in enumerate(
|
|||
|
|
("part1", "part2", "part3", "part4"), start=1,
|
|||
|
|
):
|
|||
|
|
half = len(bio.getvalue()) // 4
|
|||
|
|
chunk = bio.getvalue()[(i - 1) * half : i * half if i < 4 else len(bio.getvalue())]
|
|||
|
|
(parts_dir / f"crag_task_3_dev_v4.tar.bz2.{name}").write_bytes(chunk)
|
|||
|
|
with pytest.raises(RuntimeError, match="No JSONL member"):
|
|||
|
|
iter_questions_task3(parts_dir)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# parts_present
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestPartsPresent:
|
|||
|
|
def test_all_present(self, task3_parts_dir: Path) -> None:
|
|||
|
|
assert parts_present(task3_parts_dir) is True
|
|||
|
|
|
|||
|
|
def test_one_missing(self, task3_parts_dir: Path) -> None:
|
|||
|
|
(task3_parts_dir / "crag_task_3_dev_v4.tar.bz2.part2").unlink()
|
|||
|
|
assert parts_present(task3_parts_dir) is False
|
|||
|
|
|
|||
|
|
def test_one_empty(self, task3_parts_dir: Path) -> None:
|
|||
|
|
(task3_parts_dir / "crag_task_3_dev_v4.tar.bz2.part3").write_bytes(b"")
|
|||
|
|
assert parts_present(task3_parts_dir) is False
|