mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 23:11:00 +02:00
feat: direction-aware reranker text in GraphRAG hop-and-filter (#1016)
The reranker document text now reflects the traversal direction,
showing only the new information relative to the frontier entity:
- From S (subject is frontier): text = "{predicate} {object}"
- From O (object is frontier): text = "{subject} {predicate}"
- From P (predicate is frontier): text = "{subject} {object}"
This eliminates duplicate reranker texts when traversing inward
from shared object nodes (e.g. 18 CPUs all producing identical
"hasSubcategory Processors" text when the subject was dropped).
execute_batch_triple_queries now returns (triple, direction)
tuples so hop_and_filter can select the appropriate text format.
Updates tech spec to document the direction-aware approach.
Adds unit tests for direction tracking and reranker text
construction.
This commit is contained in:
parent
9cf7dcb578
commit
db7fdbc652
4 changed files with 502 additions and 19 deletions
|
|
@ -129,6 +129,9 @@ class TestBatchTripleQueries:
|
|||
|
||||
# 3 queries, alternating results
|
||||
assert len(result) == 3
|
||||
# Each result is a (triple, direction) tuple
|
||||
for triple, direction in result:
|
||||
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_in_one_query_does_not_block_others(self):
|
||||
|
|
@ -153,6 +156,8 @@ class TestBatchTripleQueries:
|
|||
|
||||
# 3 queries: 2 succeed, 1 fails → 2 triples
|
||||
assert len(result) == 2
|
||||
for triple, direction in result:
|
||||
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_results_filtered(self):
|
||||
|
|
@ -176,6 +181,8 @@ class TestBatchTripleQueries:
|
|||
|
||||
# 3 queries: 1 returns None, 2 return triples
|
||||
assert len(result) == 2
|
||||
for triple, direction in result:
|
||||
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entities_no_queries(self):
|
||||
|
|
@ -220,6 +227,80 @@ class TestBatchTripleQueries:
|
|||
assert calls[2].kwargs["p"] is None
|
||||
assert calls[2].kwargs["o"] == "ent-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directions_assigned_correctly(self):
|
||||
"""Each query position should produce the correct direction tag."""
|
||||
triple = _make_triple("s", "p", "o")
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def one_triple_each(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return [triple]
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = one_triple_each
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1"], limit_per_entity=10
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
# Order matches query order: s-position, p-position, o-position
|
||||
assert result[0][1] == Query.FROM_S
|
||||
assert result[1][1] == Query.FROM_P
|
||||
assert result[2][1] == Query.FROM_O
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directions_correct_for_multiple_entities(self):
|
||||
"""Direction tags cycle correctly across multiple entities."""
|
||||
triple = _make_triple("s", "p", "o")
|
||||
client = AsyncMock()
|
||||
client.query_stream = AsyncMock(return_value=[triple])
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1", "e2"], limit_per_entity=10
|
||||
)
|
||||
|
||||
assert len(result) == 6
|
||||
expected_directions = [
|
||||
Query.FROM_S, Query.FROM_P, Query.FROM_O,
|
||||
Query.FROM_S, Query.FROM_P, Query.FROM_O,
|
||||
]
|
||||
for (_, direction), expected in zip(result, expected_directions):
|
||||
assert direction == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direction_preserved_with_multiple_triples(self):
|
||||
"""All triples from one query share the same direction."""
|
||||
t1 = _make_triple("a", "p1", "b")
|
||||
t2 = _make_triple("a", "p2", "c")
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def multi_results(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return [t1, t2]
|
||||
return []
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = multi_results
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1"], limit_per_entity=10
|
||||
)
|
||||
|
||||
# First query (FROM_S) returns 2 triples, both should be FROM_S
|
||||
assert len(result) == 2
|
||||
assert result[0] == (t1, Query.FROM_S)
|
||||
assert result[1] == (t2, Query.FROM_S)
|
||||
|
||||
|
||||
class TestLRUCacheWithTTL:
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue