mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
205 lines
6.5 KiB
Python
205 lines
6.5 KiB
Python
|
|
"""
|
||
|
|
Unit tests for universal decoder section grouping strategies.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import MagicMock
|
||
|
|
|
||
|
|
|
||
|
|
from trustgraph.decoding.universal.strategies import (
|
||
|
|
group_whole_document,
|
||
|
|
group_by_heading,
|
||
|
|
group_by_element_type,
|
||
|
|
group_by_count,
|
||
|
|
group_by_size,
|
||
|
|
get_strategy,
|
||
|
|
STRATEGIES,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def make_element(category="NarrativeText", text="Some text"):
|
||
|
|
"""Create a mock unstructured element."""
|
||
|
|
el = MagicMock()
|
||
|
|
el.category = category
|
||
|
|
el.text = text
|
||
|
|
return el
|
||
|
|
|
||
|
|
|
||
|
|
class TestGroupWholeDocument:
|
||
|
|
|
||
|
|
def test_empty_input(self):
|
||
|
|
assert group_whole_document([]) == []
|
||
|
|
|
||
|
|
def test_returns_single_group(self):
|
||
|
|
elements = [make_element() for _ in range(5)]
|
||
|
|
result = group_whole_document(elements)
|
||
|
|
assert len(result) == 1
|
||
|
|
assert len(result[0]) == 5
|
||
|
|
|
||
|
|
def test_preserves_all_elements(self):
|
||
|
|
elements = [make_element(text=f"text-{i}") for i in range(3)]
|
||
|
|
result = group_whole_document(elements)
|
||
|
|
assert result[0] == elements
|
||
|
|
|
||
|
|
|
||
|
|
class TestGroupByHeading:
|
||
|
|
|
||
|
|
def test_empty_input(self):
|
||
|
|
assert group_by_heading([]) == []
|
||
|
|
|
||
|
|
def test_no_headings_falls_back(self):
|
||
|
|
elements = [make_element("NarrativeText") for _ in range(3)]
|
||
|
|
result = group_by_heading(elements)
|
||
|
|
assert len(result) == 1
|
||
|
|
assert len(result[0]) == 3
|
||
|
|
|
||
|
|
def test_splits_at_headings(self):
|
||
|
|
elements = [
|
||
|
|
make_element("Title", "Heading 1"),
|
||
|
|
make_element("NarrativeText", "Paragraph 1"),
|
||
|
|
make_element("NarrativeText", "Paragraph 2"),
|
||
|
|
make_element("Title", "Heading 2"),
|
||
|
|
make_element("NarrativeText", "Paragraph 3"),
|
||
|
|
]
|
||
|
|
result = group_by_heading(elements)
|
||
|
|
assert len(result) == 2
|
||
|
|
assert len(result[0]) == 3 # Heading 1 + 2 paragraphs
|
||
|
|
assert len(result[1]) == 2 # Heading 2 + 1 paragraph
|
||
|
|
|
||
|
|
def test_leading_content_before_first_heading(self):
|
||
|
|
elements = [
|
||
|
|
make_element("NarrativeText", "Preamble"),
|
||
|
|
make_element("Title", "Heading 1"),
|
||
|
|
make_element("NarrativeText", "Content"),
|
||
|
|
]
|
||
|
|
result = group_by_heading(elements)
|
||
|
|
assert len(result) == 2
|
||
|
|
assert len(result[0]) == 1 # Preamble
|
||
|
|
assert len(result[1]) == 2 # Heading + content
|
||
|
|
|
||
|
|
def test_consecutive_headings(self):
|
||
|
|
elements = [
|
||
|
|
make_element("Title", "H1"),
|
||
|
|
make_element("Title", "H2"),
|
||
|
|
make_element("NarrativeText", "Content"),
|
||
|
|
]
|
||
|
|
result = group_by_heading(elements)
|
||
|
|
assert len(result) == 2
|
||
|
|
|
||
|
|
|
||
|
|
class TestGroupByElementType:
|
||
|
|
|
||
|
|
def test_empty_input(self):
|
||
|
|
assert group_by_element_type([]) == []
|
||
|
|
|
||
|
|
def test_all_same_type(self):
|
||
|
|
elements = [make_element("NarrativeText") for _ in range(3)]
|
||
|
|
result = group_by_element_type(elements)
|
||
|
|
assert len(result) == 1
|
||
|
|
|
||
|
|
def test_splits_at_table_boundary(self):
|
||
|
|
elements = [
|
||
|
|
make_element("NarrativeText", "Intro"),
|
||
|
|
make_element("NarrativeText", "More text"),
|
||
|
|
make_element("Table", "Table data"),
|
||
|
|
make_element("NarrativeText", "After table"),
|
||
|
|
]
|
||
|
|
result = group_by_element_type(elements)
|
||
|
|
assert len(result) == 3
|
||
|
|
assert len(result[0]) == 2 # Two narrative elements
|
||
|
|
assert len(result[1]) == 1 # One table
|
||
|
|
assert len(result[2]) == 1 # One narrative
|
||
|
|
|
||
|
|
def test_consecutive_tables_stay_grouped(self):
|
||
|
|
elements = [
|
||
|
|
make_element("Table", "Table 1"),
|
||
|
|
make_element("Table", "Table 2"),
|
||
|
|
]
|
||
|
|
result = group_by_element_type(elements)
|
||
|
|
assert len(result) == 1
|
||
|
|
assert len(result[0]) == 2
|
||
|
|
|
||
|
|
|
||
|
|
class TestGroupByCount:
|
||
|
|
|
||
|
|
def test_empty_input(self):
|
||
|
|
assert group_by_count([]) == []
|
||
|
|
|
||
|
|
def test_exact_multiple(self):
|
||
|
|
elements = [make_element() for _ in range(6)]
|
||
|
|
result = group_by_count(elements, element_count=3)
|
||
|
|
assert len(result) == 2
|
||
|
|
assert all(len(g) == 3 for g in result)
|
||
|
|
|
||
|
|
def test_remainder_group(self):
|
||
|
|
elements = [make_element() for _ in range(7)]
|
||
|
|
result = group_by_count(elements, element_count=3)
|
||
|
|
assert len(result) == 3
|
||
|
|
assert len(result[0]) == 3
|
||
|
|
assert len(result[1]) == 3
|
||
|
|
assert len(result[2]) == 1
|
||
|
|
|
||
|
|
def test_fewer_than_count(self):
|
||
|
|
elements = [make_element() for _ in range(2)]
|
||
|
|
result = group_by_count(elements, element_count=10)
|
||
|
|
assert len(result) == 1
|
||
|
|
assert len(result[0]) == 2
|
||
|
|
|
||
|
|
|
||
|
|
class TestGroupBySize:
|
||
|
|
|
||
|
|
def test_empty_input(self):
|
||
|
|
assert group_by_size([]) == []
|
||
|
|
|
||
|
|
def test_small_elements_grouped(self):
|
||
|
|
elements = [make_element(text="Hi") for _ in range(5)]
|
||
|
|
result = group_by_size(elements, max_size=100)
|
||
|
|
assert len(result) == 1
|
||
|
|
|
||
|
|
def test_splits_at_size_limit(self):
|
||
|
|
elements = [make_element(text="x" * 100) for _ in range(5)]
|
||
|
|
result = group_by_size(elements, max_size=250)
|
||
|
|
# 2 elements per group (200 chars), then split
|
||
|
|
assert len(result) == 3
|
||
|
|
assert len(result[0]) == 2
|
||
|
|
assert len(result[1]) == 2
|
||
|
|
assert len(result[2]) == 1
|
||
|
|
|
||
|
|
def test_large_element_own_group(self):
|
||
|
|
elements = [
|
||
|
|
make_element(text="small"),
|
||
|
|
make_element(text="x" * 5000), # Exceeds max
|
||
|
|
make_element(text="small"),
|
||
|
|
]
|
||
|
|
result = group_by_size(elements, max_size=100)
|
||
|
|
assert len(result) == 3
|
||
|
|
|
||
|
|
def test_respects_element_boundaries(self):
|
||
|
|
# Each element is 50 chars, max is 120
|
||
|
|
# Should get 2 per group, not split mid-element
|
||
|
|
elements = [make_element(text="x" * 50) for _ in range(5)]
|
||
|
|
result = group_by_size(elements, max_size=120)
|
||
|
|
assert len(result) == 3
|
||
|
|
assert len(result[0]) == 2
|
||
|
|
assert len(result[1]) == 2
|
||
|
|
assert len(result[2]) == 1
|
||
|
|
|
||
|
|
|
||
|
|
class TestGetStrategy:
|
||
|
|
|
||
|
|
def test_all_strategies_accessible(self):
|
||
|
|
for name in STRATEGIES:
|
||
|
|
fn = get_strategy(name)
|
||
|
|
assert callable(fn)
|
||
|
|
|
||
|
|
def test_unknown_strategy_raises(self):
|
||
|
|
with pytest.raises(ValueError, match="Unknown section strategy"):
|
||
|
|
get_strategy("nonexistent")
|
||
|
|
|
||
|
|
def test_returns_correct_function(self):
|
||
|
|
assert get_strategy("whole-document") is group_whole_document
|
||
|
|
assert get_strategy("heading") is group_by_heading
|
||
|
|
assert get_strategy("element-type") is group_by_element_type
|
||
|
|
assert get_strategy("count") is group_by_count
|
||
|
|
assert get_strategy("size") is group_by_size
|