feat: direction-aware reranker text in GraphRAG hop-and-filter (#1016)

The reranker document text now reflects the traversal direction,
showing only the new information relative to the frontier entity:
- From S (subject is frontier): text = "{predicate} {object}"
- From O (object is frontier): text = "{subject} {predicate}"
- From P (predicate is frontier): text = "{subject} {object}"

This eliminates duplicate reranker texts when traversing inward
from shared object nodes (e.g. 18 CPUs all producing identical
"hasSubcategory Processors" text when the subject was dropped).

execute_batch_triple_queries now returns (triple, direction)
tuples so hop_and_filter can select the appropriate text format.

Updates tech spec to document the direction-aware approach.
Adds unit tests for direction tracking and reranker text
construction.
This commit is contained in:
cybermaggedon 2026-07-02 21:14:47 +01:00 committed by GitHub
parent 9cf7dcb578
commit db7fdbc652
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 502 additions and 19 deletions

View file

@ -224,12 +224,27 @@ The current embedding pre-filter represents each edge as
- **Drop commas.** Commas add tokenisation noise without semantic
value.
- **Drop the subject.** The subject identifies which entity the
edge belongs to, but it does not contribute to whether the
edge's content is relevant to the query. The predicate and
object carry the semantic meaning — what relationship exists
and what it connects to. Representing edges as `"{p} {o}"`
produces cleaner cross-encoder matches.
- **Direction-aware text.** The reranker text should highlight
the *new* information relative to the traversal direction.
The frontier entity is already known context — repeating it
adds noise and, when traversing from an object node, causes
many edges to produce identical reranker text (e.g. 18
products sharing the same `hasSubcategory Processors` triple
all collapse to the same string when the subject is dropped).
The text is constructed based on which position the frontier
entity occupied in the triple:
- **From subject** (s=entity): `"{predicate} {object}"`
the subject is known, predicate and object are new.
- **From object** (o=entity): `"{subject} {predicate}"`
the object is known, subject and predicate are new.
- **From predicate** (p=entity): `"{subject} {object}"`
the predicate is known, subject and object are new.
This eliminates the duplicate-text problem that arises when
traversing inward from a shared object node, and gives the
cross-encoder a more informative signal at every hop.
#### Remove the embedding pre-filter (step 3)
@ -389,7 +404,10 @@ no LLM call. These fields are dropped from the Focus entity.
a. Retrieve all edges one hop from the current frontier nodes.
b. Represent each edge as `"{predicate} {object}"`.
b. Represent each edge using direction-aware text: from a
subject node use `"{predicate} {object}"`, from an object
node use `"{subject} {predicate}"`, from a predicate node
use `"{subject} {object}"`.
c. Score edges against the extracted concepts using the
cross-encoder service.

View file

@ -129,6 +129,9 @@ class TestBatchTripleQueries:
# 3 queries, alternating results
assert len(result) == 3
# Each result is a (triple, direction) tuple
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_exception_in_one_query_does_not_block_others(self):
@ -153,6 +156,8 @@ class TestBatchTripleQueries:
# 3 queries: 2 succeed, 1 fails → 2 triples
assert len(result) == 2
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_none_results_filtered(self):
@ -176,6 +181,8 @@ class TestBatchTripleQueries:
# 3 queries: 1 returns None, 2 return triples
assert len(result) == 2
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_empty_entities_no_queries(self):
@ -220,6 +227,80 @@ class TestBatchTripleQueries:
assert calls[2].kwargs["p"] is None
assert calls[2].kwargs["o"] == "ent-1"
@pytest.mark.asyncio
async def test_directions_assigned_correctly(self):
"""Each query position should produce the correct direction tag."""
triple = _make_triple("s", "p", "o")
call_count = 0
async def one_triple_each(**kwargs):
nonlocal call_count
call_count += 1
return [triple]
client = AsyncMock()
client.query_stream = one_triple_each
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
assert len(result) == 3
# Order matches query order: s-position, p-position, o-position
assert result[0][1] == Query.FROM_S
assert result[1][1] == Query.FROM_P
assert result[2][1] == Query.FROM_O
@pytest.mark.asyncio
async def test_directions_correct_for_multiple_entities(self):
"""Direction tags cycle correctly across multiple entities."""
triple = _make_triple("s", "p", "o")
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[triple])
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1", "e2"], limit_per_entity=10
)
assert len(result) == 6
expected_directions = [
Query.FROM_S, Query.FROM_P, Query.FROM_O,
Query.FROM_S, Query.FROM_P, Query.FROM_O,
]
for (_, direction), expected in zip(result, expected_directions):
assert direction == expected
@pytest.mark.asyncio
async def test_direction_preserved_with_multiple_triples(self):
"""All triples from one query share the same direction."""
t1 = _make_triple("a", "p1", "b")
t2 = _make_triple("a", "p2", "c")
call_count = 0
async def multi_results(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return [t1, t2]
return []
client = AsyncMock()
client.query_stream = multi_results
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# First query (FROM_S) returns 2 triples, both should be FROM_S
assert len(result) == 2
assert result[0] == (t1, Query.FROM_S)
assert result[1] == (t2, Query.FROM_S)
class TestLRUCacheWithTTL:

View file

@ -0,0 +1,353 @@
"""
Tests for direction-aware reranker text in GraphRAG hop-and-filter.
The reranker document text varies by traversal direction:
- From S (subject is the frontier entity): text = "{p} {o}"
- From O (object is the frontier entity): text = "{s} {p}"
- From P (predicate is the frontier entity): text = "{s} {o}"
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
from trustgraph.schema import Term, IRI, LITERAL
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_rag(reranker_results=None):
"""Create a mock GraphRag with all clients stubbed."""
rag = MagicMock()
rag.label_cache = LRUCacheWithTTL()
rag.triples_client = AsyncMock()
rag.reranker_client = AsyncMock()
# Label lookups return empty (fall back to URI)
rag.triples_client.query.return_value = []
if reranker_results is not None:
rag.reranker_client.rerank.return_value = reranker_results
else:
rag.reranker_client.rerank.return_value = []
return rag
def _make_query(rag, max_path_length=1, edge_limit=25):
return Query(
rag=rag,
collection="test",
verbose=False,
entity_limit=50,
triple_limit=30,
max_subgraph_size=1000,
max_path_length=max_path_length,
edge_limit=edge_limit,
)
def _make_schema_triple(s, p, o):
"""Create a mock triple matching the schema interface."""
t = MagicMock()
t.s = s
t.p = p
t.o = o
return t
def _reranker_result(document_id, query_id="0", score=0.9):
r = MagicMock()
r.document_id = str(document_id)
r.query_id = str(query_id)
r.score = score
return r
# ---------------------------------------------------------------------------
# Tests: execute_batch_triple_queries direction tracking
# ---------------------------------------------------------------------------
class TestDirectionTracking:
@pytest.mark.asyncio
async def test_from_s_direction(self):
"""Triples from s=entity queries are tagged FROM_S."""
triple = _make_schema_triple("ent1", "pred", "obj")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_s = [(t, d) for t, d in result if d == Query.FROM_S]
assert len(from_s) == 1
assert from_s[0][0] is triple
@pytest.mark.asyncio
async def test_from_o_direction(self):
"""Triples from o=entity queries are tagged FROM_O."""
triple = _make_schema_triple("subj", "pred", "ent1")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if o is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_o = [(t, d) for t, d in result if d == Query.FROM_O]
assert len(from_o) == 1
assert from_o[0][0] is triple
@pytest.mark.asyncio
async def test_from_p_direction(self):
"""Triples from p=entity queries are tagged FROM_P."""
triple = _make_schema_triple("subj", "ent1", "obj")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if p is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_p = [(t, d) for t, d in result if d == Query.FROM_P]
assert len(from_p) == 1
assert from_p[0][0] is triple
# ---------------------------------------------------------------------------
# Tests: hop_and_filter reranker document text
# ---------------------------------------------------------------------------
class TestDirectionAwareRerankerText:
@pytest.mark.asyncio
async def test_from_s_uses_predicate_object(self):
"""From-S traversal: reranker text should be '{p} {o}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-A"],
concepts=["likes"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
# Text should be "{p} {o}" — the URIs since no labels found
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
@pytest.mark.asyncio
async def test_from_o_uses_subject_predicate(self):
"""From-O traversal: reranker text should be '{s} {p}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if o is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-B"],
concepts=["likes"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
@pytest.mark.asyncio
async def test_from_p_uses_subject_object(self):
"""From-P traversal: reranker text should be '{s} {o}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if p is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/likes"],
concepts=["entity"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
@pytest.mark.asyncio
async def test_mixed_directions_produce_different_text(self):
"""Edges from different directions use different text formats."""
triple_from_s = _make_schema_triple(
"http://ex/seed", "http://ex/rel", "http://ex/target",
)
triple_from_o = _make_schema_triple(
"http://ex/other", "http://ex/ref", "http://ex/seed",
)
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
async def query_stream(s=None, p=None, o=None, **kwargs):
if s == "http://ex/seed":
return [triple_from_s]
if o == "http://ex/seed":
return [triple_from_o]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/seed"],
concepts=["test"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
texts = {d["text"] for d in documents}
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
assert "http://ex/rel http://ex/target" in texts
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
assert "http://ex/other http://ex/ref" in texts
@pytest.mark.asyncio
async def test_labels_applied_to_direction_text(self):
"""Labels should be resolved and used in the direction-aware text."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None and p is None:
return [triple]
return []
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
if p == LABEL:
labels = {
"http://ex/entity-A": "Alice",
"http://ex/likes": "likes",
"http://ex/entity-B": "Bob",
}
if s in labels:
return [MagicMock(o=labels[s])]
return []
rag.triples_client.query_stream.side_effect = query_stream
rag.triples_client.query.side_effect = label_query
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-A"],
concepts=["friendship"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
# From S with labels: "{p_label} {o_label}"
assert documents[0]["text"] == "likes Bob"
@pytest.mark.asyncio
async def test_no_duplicate_text_from_shared_object(self):
"""Multiple edges sharing an object should produce distinct texts."""
triple_a = _make_schema_triple(
"http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors",
)
triple_b = _make_schema_triple(
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
)
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
async def query_stream(s=None, p=None, o=None, **kwargs):
if o == "http://ex/Processors":
return [triple_a, triple_b]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/Processors"],
concepts=["CPUs"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
texts = [d["text"] for d in documents]
assert len(texts) == 2
# From O: "{s} {p}" — subjects differ, so texts differ
assert texts[0] != texts[1]
assert "http://ex/cpu-A" in texts[0]
assert "http://ex/cpu-B" in texts[1]

View file

@ -241,38 +241,56 @@ class Query:
self.rag.label_cache.put(cache_key, label)
return label
FROM_S = "from_s"
FROM_P = "from_p"
FROM_O = "from_o"
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently."""
"""Execute triple queries for multiple entities concurrently.
Returns a list of (triple, direction) tuples where direction
indicates which position the frontier entity occupied.
"""
tasks = []
directions = []
for entity in entities:
tasks.extend([
tasks.append(
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
limit=limit_per_entity,
collection=self.collection,
batch_size=20, g="",
),
)
directions.append(self.FROM_S)
tasks.append(
self.rag.triples_client.query_stream(
s=None, p=entity, o=None,
limit=limit_per_entity,
collection=self.collection,
batch_size=20, g="",
),
)
directions.append(self.FROM_P)
tasks.append(
self.rag.triples_client.query_stream(
s=None, p=None, o=entity,
limit=limit_per_entity,
collection=self.collection,
batch_size=20, g="",
)
])
),
)
directions.append(self.FROM_O)
results = await asyncio.gather(*tasks, return_exceptions=True)
all_triples = []
for result in results:
for direction, result in zip(directions, results):
if not isinstance(result, Exception) and result is not None:
all_triples.extend(result)
all_triples.extend((triple, direction) for triple in result)
return all_triples
@ -325,7 +343,8 @@ class Query:
# Deduplicate and filter already-seen edges
hop_triples = []
hop_term_map = {}
for triple in triples:
hop_directions = {}
for triple, direction in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
if triple_tuple[1] == LABEL:
continue
@ -336,6 +355,7 @@ class Query:
hop_term_map[triple_tuple] = (
to_term(triple.s), to_term(triple.p), to_term(triple.o),
)
hop_directions[triple_tuple] = direction
if not hop_triples:
visited_entities.update(frontier)
@ -361,7 +381,10 @@ class Query:
else:
label_map[entity] = entity
# Build labeled edges and documents for cross-encoder
# Build labeled edges and documents for cross-encoder.
# The reranker text highlights the NEW information relative
# to the traversal direction: arriving from S means p,o are
# new; from O means s,p are new; from P means s,o are new.
labeled_hop = []
for s, p, o in hop_triples:
ls = label_map.get(s, s)
@ -369,10 +392,18 @@ class Query:
lo = label_map.get(o, o)
labeled_hop.append((ls, lp, lo))
documents = [
{"id": str(i), "text": f"{lp} {lo}"}
for i, (ls, lp, lo) in enumerate(labeled_hop)
]
documents = []
for i, (triple_tuple, (ls, lp, lo)) in enumerate(
zip(hop_triples, labeled_hop)
):
direction = hop_directions[triple_tuple]
if direction == self.FROM_S:
text = f"{lp} {lo}"
elif direction == self.FROM_O:
text = f"{ls} {lp}"
else:
text = f"{ls} {lo}"
documents.append({"id": str(i), "text": text})
queries = [
{"id": str(i), "text": c}