From db7fdbc6526537c422755330fa9132d76cc7c832 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 2 Jul 2026 21:14:47 +0100 Subject: [PATCH] feat: direction-aware reranker text in GraphRAG hop-and-filter (#1016) The reranker document text now reflects the traversal direction, showing only the new information relative to the frontier entity: - From S (subject is frontier): text = "{predicate} {object}" - From O (object is frontier): text = "{subject} {predicate}" - From P (predicate is frontier): text = "{subject} {object}" This eliminates duplicate reranker texts when traversing inward from shared object nodes (e.g. 18 CPUs all producing identical "hasSubcategory Processors" text when the subject was dropped). execute_batch_triple_queries now returns (triple, direction) tuples so hop_and_filter can select the appropriate text format. Updates tech spec to document the direction-aware approach. Adds unit tests for direction tracking and reranker text construction. --- docs/tech-specs/graph-rag-semantic-filter.md | 32 +- .../test_graph_rag_concurrency.py | 81 ++++ .../test_graph_rag_direction_aware_text.py | 353 ++++++++++++++++++ .../retrieval/graph_rag/graph_rag.py | 55 ++- 4 files changed, 502 insertions(+), 19 deletions(-) create mode 100644 tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py diff --git a/docs/tech-specs/graph-rag-semantic-filter.md b/docs/tech-specs/graph-rag-semantic-filter.md index 0401947e..58497d10 100644 --- a/docs/tech-specs/graph-rag-semantic-filter.md +++ b/docs/tech-specs/graph-rag-semantic-filter.md @@ -224,12 +224,27 @@ The current embedding pre-filter represents each edge as - **Drop commas.** Commas add tokenisation noise without semantic value. -- **Drop the subject.** The subject identifies which entity the - edge belongs to, but it does not contribute to whether the - edge's content is relevant to the query. The predicate and - object carry the semantic meaning — what relationship exists - and what it connects to. Representing edges as `"{p} {o}"` - produces cleaner cross-encoder matches. +- **Direction-aware text.** The reranker text should highlight + the *new* information relative to the traversal direction. + The frontier entity is already known context — repeating it + adds noise and, when traversing from an object node, causes + many edges to produce identical reranker text (e.g. 18 + products sharing the same `hasSubcategory Processors` triple + all collapse to the same string when the subject is dropped). + + The text is constructed based on which position the frontier + entity occupied in the triple: + + - **From subject** (s=entity): `"{predicate} {object}"` — + the subject is known, predicate and object are new. + - **From object** (o=entity): `"{subject} {predicate}"` — + the object is known, subject and predicate are new. + - **From predicate** (p=entity): `"{subject} {object}"` — + the predicate is known, subject and object are new. + + This eliminates the duplicate-text problem that arises when + traversing inward from a shared object node, and gives the + cross-encoder a more informative signal at every hop. #### Remove the embedding pre-filter (step 3) @@ -389,7 +404,10 @@ no LLM call. These fields are dropped from the Focus entity. a. Retrieve all edges one hop from the current frontier nodes. - b. Represent each edge as `"{predicate} {object}"`. + b. Represent each edge using direction-aware text: from a + subject node use `"{predicate} {object}"`, from an object + node use `"{subject} {predicate}"`, from a predicate node + use `"{subject} {object}"`. c. Score edges against the extracted concepts using the cross-encoder service. diff --git a/tests/unit/test_concurrency/test_graph_rag_concurrency.py b/tests/unit/test_concurrency/test_graph_rag_concurrency.py index 1b35a238..ed567962 100644 --- a/tests/unit/test_concurrency/test_graph_rag_concurrency.py +++ b/tests/unit/test_concurrency/test_graph_rag_concurrency.py @@ -129,6 +129,9 @@ class TestBatchTripleQueries: # 3 queries, alternating results assert len(result) == 3 + # Each result is a (triple, direction) tuple + for triple, direction in result: + assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O) @pytest.mark.asyncio async def test_exception_in_one_query_does_not_block_others(self): @@ -153,6 +156,8 @@ class TestBatchTripleQueries: # 3 queries: 2 succeed, 1 fails → 2 triples assert len(result) == 2 + for triple, direction in result: + assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O) @pytest.mark.asyncio async def test_none_results_filtered(self): @@ -176,6 +181,8 @@ class TestBatchTripleQueries: # 3 queries: 1 returns None, 2 return triples assert len(result) == 2 + for triple, direction in result: + assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O) @pytest.mark.asyncio async def test_empty_entities_no_queries(self): @@ -220,6 +227,80 @@ class TestBatchTripleQueries: assert calls[2].kwargs["p"] is None assert calls[2].kwargs["o"] == "ent-1" + @pytest.mark.asyncio + async def test_directions_assigned_correctly(self): + """Each query position should produce the correct direction tag.""" + triple = _make_triple("s", "p", "o") + + call_count = 0 + + async def one_triple_each(**kwargs): + nonlocal call_count + call_count += 1 + return [triple] + + client = AsyncMock() + client.query_stream = one_triple_each + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + assert len(result) == 3 + # Order matches query order: s-position, p-position, o-position + assert result[0][1] == Query.FROM_S + assert result[1][1] == Query.FROM_P + assert result[2][1] == Query.FROM_O + + @pytest.mark.asyncio + async def test_directions_correct_for_multiple_entities(self): + """Direction tags cycle correctly across multiple entities.""" + triple = _make_triple("s", "p", "o") + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[triple]) + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1", "e2"], limit_per_entity=10 + ) + + assert len(result) == 6 + expected_directions = [ + Query.FROM_S, Query.FROM_P, Query.FROM_O, + Query.FROM_S, Query.FROM_P, Query.FROM_O, + ] + for (_, direction), expected in zip(result, expected_directions): + assert direction == expected + + @pytest.mark.asyncio + async def test_direction_preserved_with_multiple_triples(self): + """All triples from one query share the same direction.""" + t1 = _make_triple("a", "p1", "b") + t2 = _make_triple("a", "p2", "c") + + call_count = 0 + + async def multi_results(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [t1, t2] + return [] + + client = AsyncMock() + client.query_stream = multi_results + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + # First query (FROM_S) returns 2 triples, both should be FROM_S + assert len(result) == 2 + assert result[0] == (t1, Query.FROM_S) + assert result[1] == (t2, Query.FROM_S) + class TestLRUCacheWithTTL: diff --git a/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py b/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py new file mode 100644 index 00000000..cc95228a --- /dev/null +++ b/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py @@ -0,0 +1,353 @@ +""" +Tests for direction-aware reranker text in GraphRAG hop-and-filter. + +The reranker document text varies by traversal direction: +- From S (subject is the frontier entity): text = "{p} {o}" +- From O (object is the frontier entity): text = "{s} {p}" +- From P (predicate is the frontier entity): text = "{s} {o}" +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock + +from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL +from trustgraph.schema import Term, IRI, LITERAL + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_rag(reranker_results=None): + """Create a mock GraphRag with all clients stubbed.""" + rag = MagicMock() + rag.label_cache = LRUCacheWithTTL() + rag.triples_client = AsyncMock() + rag.reranker_client = AsyncMock() + + # Label lookups return empty (fall back to URI) + rag.triples_client.query.return_value = [] + + if reranker_results is not None: + rag.reranker_client.rerank.return_value = reranker_results + else: + rag.reranker_client.rerank.return_value = [] + + return rag + + +def _make_query(rag, max_path_length=1, edge_limit=25): + return Query( + rag=rag, + collection="test", + verbose=False, + entity_limit=50, + triple_limit=30, + max_subgraph_size=1000, + max_path_length=max_path_length, + edge_limit=edge_limit, + ) + + +def _make_schema_triple(s, p, o): + """Create a mock triple matching the schema interface.""" + t = MagicMock() + t.s = s + t.p = p + t.o = o + return t + + +def _reranker_result(document_id, query_id="0", score=0.9): + r = MagicMock() + r.document_id = str(document_id) + r.query_id = str(query_id) + r.score = score + return r + + +# --------------------------------------------------------------------------- +# Tests: execute_batch_triple_queries direction tracking +# --------------------------------------------------------------------------- + +class TestDirectionTracking: + + @pytest.mark.asyncio + async def test_from_s_direction(self): + """Triples from s=entity queries are tagged FROM_S.""" + triple = _make_schema_triple("ent1", "pred", "obj") + rag = _make_rag() + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + q = _make_query(rag) + + result = await q.execute_batch_triple_queries(["ent1"], 10) + + from_s = [(t, d) for t, d in result if d == Query.FROM_S] + assert len(from_s) == 1 + assert from_s[0][0] is triple + + @pytest.mark.asyncio + async def test_from_o_direction(self): + """Triples from o=entity queries are tagged FROM_O.""" + triple = _make_schema_triple("subj", "pred", "ent1") + rag = _make_rag() + + async def query_stream(s=None, p=None, o=None, **kwargs): + if o is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + q = _make_query(rag) + + result = await q.execute_batch_triple_queries(["ent1"], 10) + + from_o = [(t, d) for t, d in result if d == Query.FROM_O] + assert len(from_o) == 1 + assert from_o[0][0] is triple + + @pytest.mark.asyncio + async def test_from_p_direction(self): + """Triples from p=entity queries are tagged FROM_P.""" + triple = _make_schema_triple("subj", "ent1", "obj") + rag = _make_rag() + + async def query_stream(s=None, p=None, o=None, **kwargs): + if p is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + q = _make_query(rag) + + result = await q.execute_batch_triple_queries(["ent1"], 10) + + from_p = [(t, d) for t, d in result if d == Query.FROM_P] + assert len(from_p) == 1 + assert from_p[0][0] is triple + + +# --------------------------------------------------------------------------- +# Tests: hop_and_filter reranker document text +# --------------------------------------------------------------------------- + +class TestDirectionAwareRerankerText: + + @pytest.mark.asyncio + async def test_from_s_uses_predicate_object(self): + """From-S traversal: reranker text should be '{p} {o}'.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/entity-A"], + concepts=["likes"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + # Text should be "{p} {o}" — the URIs since no labels found + assert len(documents) == 1 + assert documents[0]["text"] == "http://ex/likes http://ex/entity-B" + + @pytest.mark.asyncio + async def test_from_o_uses_subject_predicate(self): + """From-O traversal: reranker text should be '{s} {p}'.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if o is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/entity-B"], + concepts=["likes"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + assert len(documents) == 1 + assert documents[0]["text"] == "http://ex/entity-A http://ex/likes" + + @pytest.mark.asyncio + async def test_from_p_uses_subject_object(self): + """From-P traversal: reranker text should be '{s} {o}'.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if p is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/likes"], + concepts=["entity"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + assert len(documents) == 1 + assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B" + + @pytest.mark.asyncio + async def test_mixed_directions_produce_different_text(self): + """Edges from different directions use different text formats.""" + triple_from_s = _make_schema_triple( + "http://ex/seed", "http://ex/rel", "http://ex/target", + ) + triple_from_o = _make_schema_triple( + "http://ex/other", "http://ex/ref", "http://ex/seed", + ) + + rag = _make_rag(reranker_results=[ + _reranker_result(0), _reranker_result(1), + ]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s == "http://ex/seed": + return [triple_from_s] + if o == "http://ex/seed": + return [triple_from_o] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/seed"], + concepts=["test"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + texts = {d["text"] for d in documents} + + # From S: "{p} {o}" = "http://ex/rel http://ex/target" + assert "http://ex/rel http://ex/target" in texts + # From O: "{s} {p}" = "http://ex/other http://ex/ref" + assert "http://ex/other http://ex/ref" in texts + + @pytest.mark.asyncio + async def test_labels_applied_to_direction_text(self): + """Labels should be resolved and used in the direction-aware text.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + LABEL = "http://www.w3.org/2000/01/rdf-schema#label" + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s is not None and p is None: + return [triple] + return [] + + async def label_query(s=None, p=None, o=None, limit=1, **kwargs): + if p == LABEL: + labels = { + "http://ex/entity-A": "Alice", + "http://ex/likes": "likes", + "http://ex/entity-B": "Bob", + } + if s in labels: + return [MagicMock(o=labels[s])] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + rag.triples_client.query.side_effect = label_query + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/entity-A"], + concepts=["friendship"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + assert len(documents) == 1 + # From S with labels: "{p_label} {o_label}" + assert documents[0]["text"] == "likes Bob" + + @pytest.mark.asyncio + async def test_no_duplicate_text_from_shared_object(self): + """Multiple edges sharing an object should produce distinct texts.""" + triple_a = _make_schema_triple( + "http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors", + ) + triple_b = _make_schema_triple( + "http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors", + ) + + rag = _make_rag(reranker_results=[ + _reranker_result(0), _reranker_result(1), + ]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if o == "http://ex/Processors": + return [triple_a, triple_b] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/Processors"], + concepts=["CPUs"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + texts = [d["text"] for d in documents] + + assert len(texts) == 2 + # From O: "{s} {p}" — subjects differ, so texts differ + assert texts[0] != texts[1] + assert "http://ex/cpu-A" in texts[0] + assert "http://ex/cpu-B" in texts[1] diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 06c0b5b4..2054cb0f 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -241,38 +241,56 @@ class Query: self.rag.label_cache.put(cache_key, label) return label + FROM_S = "from_s" + FROM_P = "from_p" + FROM_O = "from_o" + async def execute_batch_triple_queries(self, entities, limit_per_entity): - """Execute triple queries for multiple entities concurrently.""" + """Execute triple queries for multiple entities concurrently. + + Returns a list of (triple, direction) tuples where direction + indicates which position the frontier entity occupied. + """ tasks = [] + directions = [] for entity in entities: - tasks.extend([ + tasks.append( self.rag.triples_client.query_stream( s=entity, p=None, o=None, limit=limit_per_entity, collection=self.collection, batch_size=20, g="", ), + ) + directions.append(self.FROM_S) + + tasks.append( self.rag.triples_client.query_stream( s=None, p=entity, o=None, limit=limit_per_entity, collection=self.collection, batch_size=20, g="", ), + ) + directions.append(self.FROM_P) + + tasks.append( self.rag.triples_client.query_stream( s=None, p=None, o=entity, limit=limit_per_entity, collection=self.collection, batch_size=20, g="", - ) - ]) + ), + ) + directions.append(self.FROM_O) results = await asyncio.gather(*tasks, return_exceptions=True) all_triples = [] - for result in results: + for direction, result in zip(directions, results): if not isinstance(result, Exception) and result is not None: - all_triples.extend(result) + all_triples.extend((triple, direction) for triple in result) return all_triples @@ -325,7 +343,8 @@ class Query: # Deduplicate and filter already-seen edges hop_triples = [] hop_term_map = {} - for triple in triples: + hop_directions = {} + for triple, direction in triples: triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) if triple_tuple[1] == LABEL: continue @@ -336,6 +355,7 @@ class Query: hop_term_map[triple_tuple] = ( to_term(triple.s), to_term(triple.p), to_term(triple.o), ) + hop_directions[triple_tuple] = direction if not hop_triples: visited_entities.update(frontier) @@ -361,7 +381,10 @@ class Query: else: label_map[entity] = entity - # Build labeled edges and documents for cross-encoder + # Build labeled edges and documents for cross-encoder. + # The reranker text highlights the NEW information relative + # to the traversal direction: arriving from S means p,o are + # new; from O means s,p are new; from P means s,o are new. labeled_hop = [] for s, p, o in hop_triples: ls = label_map.get(s, s) @@ -369,10 +392,18 @@ class Query: lo = label_map.get(o, o) labeled_hop.append((ls, lp, lo)) - documents = [ - {"id": str(i), "text": f"{lp} {lo}"} - for i, (ls, lp, lo) in enumerate(labeled_hop) - ] + documents = [] + for i, (triple_tuple, (ls, lp, lo)) in enumerate( + zip(hop_triples, labeled_hop) + ): + direction = hop_directions[triple_tuple] + if direction == self.FROM_S: + text = f"{lp} {lo}" + elif direction == self.FROM_O: + text = f"{ls} {lp}" + else: + text = f"{ls} {lo}" + documents.append({"id": str(i), "text": text}) queries = [ {"id": str(i), "text": c}