trustgraph/tests/unit/test_decoding/test_universal_strategies.py

205 lines
6.5 KiB
Python
Raw Normal View History

"""
Unit tests for universal decoder section grouping strategies.
"""
import pytest
from unittest.mock import MagicMock
from trustgraph.decoding.universal.strategies import (
group_whole_document,
group_by_heading,
group_by_element_type,
group_by_count,
group_by_size,
get_strategy,
STRATEGIES,
)
def make_element(category="NarrativeText", text="Some text"):
"""Create a mock unstructured element."""
el = MagicMock()
el.category = category
el.text = text
return el
class TestGroupWholeDocument:
def test_empty_input(self):
assert group_whole_document([]) == []
def test_returns_single_group(self):
elements = [make_element() for _ in range(5)]
result = group_whole_document(elements)
assert len(result) == 1
assert len(result[0]) == 5
def test_preserves_all_elements(self):
elements = [make_element(text=f"text-{i}") for i in range(3)]
result = group_whole_document(elements)
assert result[0] == elements
class TestGroupByHeading:
def test_empty_input(self):
assert group_by_heading([]) == []
def test_no_headings_falls_back(self):
elements = [make_element("NarrativeText") for _ in range(3)]
result = group_by_heading(elements)
assert len(result) == 1
assert len(result[0]) == 3
def test_splits_at_headings(self):
elements = [
make_element("Title", "Heading 1"),
make_element("NarrativeText", "Paragraph 1"),
make_element("NarrativeText", "Paragraph 2"),
make_element("Title", "Heading 2"),
make_element("NarrativeText", "Paragraph 3"),
]
result = group_by_heading(elements)
assert len(result) == 2
assert len(result[0]) == 3 # Heading 1 + 2 paragraphs
assert len(result[1]) == 2 # Heading 2 + 1 paragraph
def test_leading_content_before_first_heading(self):
elements = [
make_element("NarrativeText", "Preamble"),
make_element("Title", "Heading 1"),
make_element("NarrativeText", "Content"),
]
result = group_by_heading(elements)
assert len(result) == 2
assert len(result[0]) == 1 # Preamble
assert len(result[1]) == 2 # Heading + content
def test_consecutive_headings(self):
elements = [
make_element("Title", "H1"),
make_element("Title", "H2"),
make_element("NarrativeText", "Content"),
]
result = group_by_heading(elements)
assert len(result) == 2
class TestGroupByElementType:
def test_empty_input(self):
assert group_by_element_type([]) == []
def test_all_same_type(self):
elements = [make_element("NarrativeText") for _ in range(3)]
result = group_by_element_type(elements)
assert len(result) == 1
def test_splits_at_table_boundary(self):
elements = [
make_element("NarrativeText", "Intro"),
make_element("NarrativeText", "More text"),
make_element("Table", "Table data"),
make_element("NarrativeText", "After table"),
]
result = group_by_element_type(elements)
assert len(result) == 3
assert len(result[0]) == 2 # Two narrative elements
assert len(result[1]) == 1 # One table
assert len(result[2]) == 1 # One narrative
def test_consecutive_tables_stay_grouped(self):
elements = [
make_element("Table", "Table 1"),
make_element("Table", "Table 2"),
]
result = group_by_element_type(elements)
assert len(result) == 1
assert len(result[0]) == 2
class TestGroupByCount:
def test_empty_input(self):
assert group_by_count([]) == []
def test_exact_multiple(self):
elements = [make_element() for _ in range(6)]
result = group_by_count(elements, element_count=3)
assert len(result) == 2
assert all(len(g) == 3 for g in result)
def test_remainder_group(self):
elements = [make_element() for _ in range(7)]
result = group_by_count(elements, element_count=3)
assert len(result) == 3
assert len(result[0]) == 3
assert len(result[1]) == 3
assert len(result[2]) == 1
def test_fewer_than_count(self):
elements = [make_element() for _ in range(2)]
result = group_by_count(elements, element_count=10)
assert len(result) == 1
assert len(result[0]) == 2
class TestGroupBySize:
def test_empty_input(self):
assert group_by_size([]) == []
def test_small_elements_grouped(self):
elements = [make_element(text="Hi") for _ in range(5)]
result = group_by_size(elements, max_size=100)
assert len(result) == 1
def test_splits_at_size_limit(self):
elements = [make_element(text="x" * 100) for _ in range(5)]
result = group_by_size(elements, max_size=250)
# 2 elements per group (200 chars), then split
assert len(result) == 3
assert len(result[0]) == 2
assert len(result[1]) == 2
assert len(result[2]) == 1
def test_large_element_own_group(self):
elements = [
make_element(text="small"),
make_element(text="x" * 5000), # Exceeds max
make_element(text="small"),
]
result = group_by_size(elements, max_size=100)
assert len(result) == 3
def test_respects_element_boundaries(self):
# Each element is 50 chars, max is 120
# Should get 2 per group, not split mid-element
elements = [make_element(text="x" * 50) for _ in range(5)]
result = group_by_size(elements, max_size=120)
assert len(result) == 3
assert len(result[0]) == 2
assert len(result[1]) == 2
assert len(result[2]) == 1
class TestGetStrategy:
def test_all_strategies_accessible(self):
for name in STRATEGIES:
fn = get_strategy(name)
assert callable(fn)
def test_unknown_strategy_raises(self):
with pytest.raises(ValueError, match="Unknown section strategy"):
get_strategy("nonexistent")
def test_returns_correct_function(self):
assert get_strategy("whole-document") is group_whole_document
assert get_strategy("heading") is group_by_heading
assert get_strategy("element-type") is group_by_element_type
assert get_strategy("count") is group_by_count
assert get_strategy("size") is group_by_size