feat: direction-aware reranker text in GraphRAG hop-and-filter (#1016)

The reranker document text now reflects the traversal direction,
showing only the new information relative to the frontier entity:
- From S (subject is frontier): text = "{predicate} {object}"
- From O (object is frontier): text = "{subject} {predicate}"
- From P (predicate is frontier): text = "{subject} {object}"

This eliminates duplicate reranker texts when traversing inward
from shared object nodes (e.g. 18 CPUs all producing identical
"hasSubcategory Processors" text when the subject was dropped).

execute_batch_triple_queries now returns (triple, direction)
tuples so hop_and_filter can select the appropriate text format.

Updates tech spec to document the direction-aware approach.
Adds unit tests for direction tracking and reranker text
construction.
This commit is contained in:
cybermaggedon 2026-07-02 21:14:47 +01:00 committed by GitHub
parent 9cf7dcb578
commit db7fdbc652
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 502 additions and 19 deletions

View file

@ -129,6 +129,9 @@ class TestBatchTripleQueries:
# 3 queries, alternating results
assert len(result) == 3
# Each result is a (triple, direction) tuple
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_exception_in_one_query_does_not_block_others(self):
@ -153,6 +156,8 @@ class TestBatchTripleQueries:
# 3 queries: 2 succeed, 1 fails → 2 triples
assert len(result) == 2
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_none_results_filtered(self):
@ -176,6 +181,8 @@ class TestBatchTripleQueries:
# 3 queries: 1 returns None, 2 return triples
assert len(result) == 2
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_empty_entities_no_queries(self):
@ -220,6 +227,80 @@ class TestBatchTripleQueries:
assert calls[2].kwargs["p"] is None
assert calls[2].kwargs["o"] == "ent-1"
@pytest.mark.asyncio
async def test_directions_assigned_correctly(self):
"""Each query position should produce the correct direction tag."""
triple = _make_triple("s", "p", "o")
call_count = 0
async def one_triple_each(**kwargs):
nonlocal call_count
call_count += 1
return [triple]
client = AsyncMock()
client.query_stream = one_triple_each
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
assert len(result) == 3
# Order matches query order: s-position, p-position, o-position
assert result[0][1] == Query.FROM_S
assert result[1][1] == Query.FROM_P
assert result[2][1] == Query.FROM_O
@pytest.mark.asyncio
async def test_directions_correct_for_multiple_entities(self):
"""Direction tags cycle correctly across multiple entities."""
triple = _make_triple("s", "p", "o")
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[triple])
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1", "e2"], limit_per_entity=10
)
assert len(result) == 6
expected_directions = [
Query.FROM_S, Query.FROM_P, Query.FROM_O,
Query.FROM_S, Query.FROM_P, Query.FROM_O,
]
for (_, direction), expected in zip(result, expected_directions):
assert direction == expected
@pytest.mark.asyncio
async def test_direction_preserved_with_multiple_triples(self):
"""All triples from one query share the same direction."""
t1 = _make_triple("a", "p1", "b")
t2 = _make_triple("a", "p2", "c")
call_count = 0
async def multi_results(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return [t1, t2]
return []
client = AsyncMock()
client.query_stream = multi_results
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# First query (FROM_S) returns 2 triples, both should be FROM_S
assert len(result) == 2
assert result[0] == (t1, Query.FROM_S)
assert result[1] == (t2, Query.FROM_S)
class TestLRUCacheWithTTL:

View file

@ -0,0 +1,353 @@
"""
Tests for direction-aware reranker text in GraphRAG hop-and-filter.
The reranker document text varies by traversal direction:
- From S (subject is the frontier entity): text = "{p} {o}"
- From O (object is the frontier entity): text = "{s} {p}"
- From P (predicate is the frontier entity): text = "{s} {o}"
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
from trustgraph.schema import Term, IRI, LITERAL
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_rag(reranker_results=None):
"""Create a mock GraphRag with all clients stubbed."""
rag = MagicMock()
rag.label_cache = LRUCacheWithTTL()
rag.triples_client = AsyncMock()
rag.reranker_client = AsyncMock()
# Label lookups return empty (fall back to URI)
rag.triples_client.query.return_value = []
if reranker_results is not None:
rag.reranker_client.rerank.return_value = reranker_results
else:
rag.reranker_client.rerank.return_value = []
return rag
def _make_query(rag, max_path_length=1, edge_limit=25):
return Query(
rag=rag,
collection="test",
verbose=False,
entity_limit=50,
triple_limit=30,
max_subgraph_size=1000,
max_path_length=max_path_length,
edge_limit=edge_limit,
)
def _make_schema_triple(s, p, o):
"""Create a mock triple matching the schema interface."""
t = MagicMock()
t.s = s
t.p = p
t.o = o
return t
def _reranker_result(document_id, query_id="0", score=0.9):
r = MagicMock()
r.document_id = str(document_id)
r.query_id = str(query_id)
r.score = score
return r
# ---------------------------------------------------------------------------
# Tests: execute_batch_triple_queries direction tracking
# ---------------------------------------------------------------------------
class TestDirectionTracking:
@pytest.mark.asyncio
async def test_from_s_direction(self):
"""Triples from s=entity queries are tagged FROM_S."""
triple = _make_schema_triple("ent1", "pred", "obj")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_s = [(t, d) for t, d in result if d == Query.FROM_S]
assert len(from_s) == 1
assert from_s[0][0] is triple
@pytest.mark.asyncio
async def test_from_o_direction(self):
"""Triples from o=entity queries are tagged FROM_O."""
triple = _make_schema_triple("subj", "pred", "ent1")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if o is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_o = [(t, d) for t, d in result if d == Query.FROM_O]
assert len(from_o) == 1
assert from_o[0][0] is triple
@pytest.mark.asyncio
async def test_from_p_direction(self):
"""Triples from p=entity queries are tagged FROM_P."""
triple = _make_schema_triple("subj", "ent1", "obj")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if p is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_p = [(t, d) for t, d in result if d == Query.FROM_P]
assert len(from_p) == 1
assert from_p[0][0] is triple
# ---------------------------------------------------------------------------
# Tests: hop_and_filter reranker document text
# ---------------------------------------------------------------------------
class TestDirectionAwareRerankerText:
@pytest.mark.asyncio
async def test_from_s_uses_predicate_object(self):
"""From-S traversal: reranker text should be '{p} {o}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-A"],
concepts=["likes"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
# Text should be "{p} {o}" — the URIs since no labels found
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
@pytest.mark.asyncio
async def test_from_o_uses_subject_predicate(self):
"""From-O traversal: reranker text should be '{s} {p}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if o is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-B"],
concepts=["likes"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
@pytest.mark.asyncio
async def test_from_p_uses_subject_object(self):
"""From-P traversal: reranker text should be '{s} {o}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if p is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/likes"],
concepts=["entity"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
@pytest.mark.asyncio
async def test_mixed_directions_produce_different_text(self):
"""Edges from different directions use different text formats."""
triple_from_s = _make_schema_triple(
"http://ex/seed", "http://ex/rel", "http://ex/target",
)
triple_from_o = _make_schema_triple(
"http://ex/other", "http://ex/ref", "http://ex/seed",
)
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
async def query_stream(s=None, p=None, o=None, **kwargs):
if s == "http://ex/seed":
return [triple_from_s]
if o == "http://ex/seed":
return [triple_from_o]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/seed"],
concepts=["test"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
texts = {d["text"] for d in documents}
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
assert "http://ex/rel http://ex/target" in texts
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
assert "http://ex/other http://ex/ref" in texts
@pytest.mark.asyncio
async def test_labels_applied_to_direction_text(self):
"""Labels should be resolved and used in the direction-aware text."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None and p is None:
return [triple]
return []
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
if p == LABEL:
labels = {
"http://ex/entity-A": "Alice",
"http://ex/likes": "likes",
"http://ex/entity-B": "Bob",
}
if s in labels:
return [MagicMock(o=labels[s])]
return []
rag.triples_client.query_stream.side_effect = query_stream
rag.triples_client.query.side_effect = label_query
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-A"],
concepts=["friendship"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
# From S with labels: "{p_label} {o_label}"
assert documents[0]["text"] == "likes Bob"
@pytest.mark.asyncio
async def test_no_duplicate_text_from_shared_object(self):
"""Multiple edges sharing an object should produce distinct texts."""
triple_a = _make_schema_triple(
"http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors",
)
triple_b = _make_schema_triple(
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
)
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
async def query_stream(s=None, p=None, o=None, **kwargs):
if o == "http://ex/Processors":
return [triple_a, triple_b]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/Processors"],
concepts=["CPUs"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
texts = [d["text"] for d in documents]
assert len(texts) == 2
# From O: "{s} {p}" — subjects differ, so texts differ
assert texts[0] != texts[1]
assert "http://ex/cpu-A" in texts[0]
assert "http://ex/cpu-B" in texts[1]