diff --git a/docs/tech-specs/graph-rag-semantic-filter.md b/docs/tech-specs/graph-rag-semantic-filter.md index 0401947e..58497d10 100644 --- a/docs/tech-specs/graph-rag-semantic-filter.md +++ b/docs/tech-specs/graph-rag-semantic-filter.md @@ -224,12 +224,27 @@ The current embedding pre-filter represents each edge as - **Drop commas.** Commas add tokenisation noise without semantic value. -- **Drop the subject.** The subject identifies which entity the - edge belongs to, but it does not contribute to whether the - edge's content is relevant to the query. The predicate and - object carry the semantic meaning — what relationship exists - and what it connects to. Representing edges as `"{p} {o}"` - produces cleaner cross-encoder matches. +- **Direction-aware text.** The reranker text should highlight + the *new* information relative to the traversal direction. + The frontier entity is already known context — repeating it + adds noise and, when traversing from an object node, causes + many edges to produce identical reranker text (e.g. 18 + products sharing the same `hasSubcategory Processors` triple + all collapse to the same string when the subject is dropped). + + The text is constructed based on which position the frontier + entity occupied in the triple: + + - **From subject** (s=entity): `"{predicate} {object}"` — + the subject is known, predicate and object are new. + - **From object** (o=entity): `"{subject} {predicate}"` — + the object is known, subject and predicate are new. + - **From predicate** (p=entity): `"{subject} {object}"` — + the predicate is known, subject and object are new. + + This eliminates the duplicate-text problem that arises when + traversing inward from a shared object node, and gives the + cross-encoder a more informative signal at every hop. #### Remove the embedding pre-filter (step 3) @@ -389,7 +404,10 @@ no LLM call. These fields are dropped from the Focus entity. a. Retrieve all edges one hop from the current frontier nodes. - b. Represent each edge as `"{predicate} {object}"`. + b. Represent each edge using direction-aware text: from a + subject node use `"{predicate} {object}"`, from an object + node use `"{subject} {predicate}"`, from a predicate node + use `"{subject} {object}"`. c. Score edges against the extracted concepts using the cross-encoder service. diff --git a/tests/unit/test_concurrency/test_graph_rag_concurrency.py b/tests/unit/test_concurrency/test_graph_rag_concurrency.py index 1b35a238..ed567962 100644 --- a/tests/unit/test_concurrency/test_graph_rag_concurrency.py +++ b/tests/unit/test_concurrency/test_graph_rag_concurrency.py @@ -129,6 +129,9 @@ class TestBatchTripleQueries: # 3 queries, alternating results assert len(result) == 3 + # Each result is a (triple, direction) tuple + for triple, direction in result: + assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O) @pytest.mark.asyncio async def test_exception_in_one_query_does_not_block_others(self): @@ -153,6 +156,8 @@ class TestBatchTripleQueries: # 3 queries: 2 succeed, 1 fails → 2 triples assert len(result) == 2 + for triple, direction in result: + assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O) @pytest.mark.asyncio async def test_none_results_filtered(self): @@ -176,6 +181,8 @@ class TestBatchTripleQueries: # 3 queries: 1 returns None, 2 return triples assert len(result) == 2 + for triple, direction in result: + assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O) @pytest.mark.asyncio async def test_empty_entities_no_queries(self): @@ -220,6 +227,80 @@ class TestBatchTripleQueries: assert calls[2].kwargs["p"] is None assert calls[2].kwargs["o"] == "ent-1" + @pytest.mark.asyncio + async def test_directions_assigned_correctly(self): + """Each query position should produce the correct direction tag.""" + triple = _make_triple("s", "p", "o") + + call_count = 0 + + async def one_triple_each(**kwargs): + nonlocal call_count + call_count += 1 + return [triple] + + client = AsyncMock() + client.query_stream = one_triple_each + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + assert len(result) == 3 + # Order matches query order: s-position, p-position, o-position + assert result[0][1] == Query.FROM_S + assert result[1][1] == Query.FROM_P + assert result[2][1] == Query.FROM_O + + @pytest.mark.asyncio + async def test_directions_correct_for_multiple_entities(self): + """Direction tags cycle correctly across multiple entities.""" + triple = _make_triple("s", "p", "o") + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[triple]) + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1", "e2"], limit_per_entity=10 + ) + + assert len(result) == 6 + expected_directions = [ + Query.FROM_S, Query.FROM_P, Query.FROM_O, + Query.FROM_S, Query.FROM_P, Query.FROM_O, + ] + for (_, direction), expected in zip(result, expected_directions): + assert direction == expected + + @pytest.mark.asyncio + async def test_direction_preserved_with_multiple_triples(self): + """All triples from one query share the same direction.""" + t1 = _make_triple("a", "p1", "b") + t2 = _make_triple("a", "p2", "c") + + call_count = 0 + + async def multi_results(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [t1, t2] + return [] + + client = AsyncMock() + client.query_stream = multi_results + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + # First query (FROM_S) returns 2 triples, both should be FROM_S + assert len(result) == 2 + assert result[0] == (t1, Query.FROM_S) + assert result[1] == (t2, Query.FROM_S) + class TestLRUCacheWithTTL: diff --git a/tests/unit/test_retrieval/test_document_rag_diversity_selection.py b/tests/unit/test_retrieval/test_document_rag_diversity_selection.py new file mode 100644 index 00000000..6dcd9458 --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_diversity_selection.py @@ -0,0 +1,92 @@ +from trustgraph.retrieval.document_rag.rerank import ( + RerankCandidate, normalize_candidate_scores, mmr_select, + _pair_diversity_penalty +) + +def candidate(index, chunk_id, text, score): + return RerankCandidate( + index=index, + chunk_id=chunk_id, + text=text, + reranker_score=score, + ) + + +def test_normalize_candidate_scores_min_max_scales_raw_scores(): + candidates = [ + candidate(0, "a", "alpha", -2.0), + candidate(1, "b", "beta", 0.0), + candidate(2, "c", "gamma", 4.0), + ] + + normalized = normalize_candidate_scores(candidates) + + assert normalized[0].normalized_score == 0.0 + assert normalized[1].normalized_score == 1.0 / 3.0 + assert normalized[2].normalized_score == 1.0 + + +def test_normalize_candidate_scores_handles_equal_scores(): + candidates = [ + candidate(0, "a", "alpha", 3.0), + candidate(1, "b", "beta", 3.0), + candidate(2, "c", "gamma", 3.0), + ] + + normalized = normalize_candidate_scores(candidates) + + assert [c.normalized_score for c in normalized] == [0.5, 0.5, 0.5] + + +def test_mmr_select_limits_results(): + candidates = [ + candidate(0, "a", "alpha policy", 0.9), + candidate(1, "b", "beta refund", 0.8), + candidate(2, "c", "gamma shipping", 0.7), + ] + + selected = mmr_select(candidates, limit=2) + + assert len(selected) == 2 + + +def test_mmr_select_prefers_highest_reranker_score_first(): + candidates = [ + candidate(0, "a", "weakly relevant text", 0.1), + candidate(1, "b", "strongly relevant answer", 10.0), + candidate(2, "c", "medium relevant text", 5.0), + ] + + selected = mmr_select(candidates, limit=1) + + assert selected[0].chunk_id == "b" + + +def test_mmr_select_penalizes_near_duplicate_chunks(): + candidates = [ + candidate(0, "a", "apple banana fruit return policy", 1.00), + candidate(1, "b", "apple banana fruit return policy duplicate", 0.95), + candidate(2, "c", "engine motor vehicle warranty", 0.90), + ] + + selected = mmr_select( + candidates, + limit=2, + lambda_mult=0.2, + token_overlap_weight=1.0, + ) + + assert [c.chunk_id for c in selected] == ["a", "c"] + + +def test_pair_diversity_penalty_is_clamped(): + left = candidate(0, "a", "same same same", 1.0) + right = candidate(1, "b", "same same same", 0.9) + + penalty = _pair_diversity_penalty( + left, + right, + token_overlap_weight=10.0, + ) + + assert penalty == 1.0 diff --git a/tests/unit/test_retrieval/test_document_rag_rerank.py b/tests/unit/test_retrieval/test_document_rag_rerank.py index d711d57c..67b3a2b1 100644 --- a/tests/unit/test_retrieval/test_document_rag_rerank.py +++ b/tests/unit/test_retrieval/test_document_rag_rerank.py @@ -476,3 +476,75 @@ class TestRerankActive: await rag.query(query="What is the return policy?") assert reranker.calls == [] + +# --------------------------------------------------------------------------- +# 3. Diversity selection: optional MMR after cross-encoder scoring +# --------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self): + """ + With diversity selection enabled, the cross-encoder should score the full + fetched candidate pool before MMR narrows it down to doc_limit. + """ + clients = build_mock_clients() + reranker = StubReranker([ + RerankerResult(document_id="0", query_id="0", score=1.00), + RerankerResult(document_id="1", query_id="0", score=0.95), + RerankerResult(document_id="2", query_id="0", score=0.90), + ]) + rag = DocumentRag( + *clients, + reranker_client=reranker, + rerank_diversity_mode="mmr", + ) + + await rag.query(query="What is the return policy?", doc_limit=2) + + assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT) + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + assert len(passed_docs) == 2 + + + @pytest.mark.asyncio + async def test_diversity_mode_selects_less_redundant_context_set(self): + """ + MMR should use cross-encoder scores as relevance while penalizing redundant + chunks, so a slightly lower-scored but less redundant chunk can be selected. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + + duplicate_a = "apple banana fruit return policy" + duplicate_b = "apple banana fruit return policy duplicate" + diverse_c = "engine motor vehicle warranty" + + async def mock_fetch(chunk_id): + return { + CHUNK_A: duplicate_a, + CHUNK_B: duplicate_b, + CHUNK_C: diverse_c, + }[chunk_id] + + fetch_chunk.side_effect = mock_fetch + + reranker = StubReranker([ + RerankerResult(document_id="0", query_id="0", score=1.00), + RerankerResult(document_id="1", query_id="0", score=0.95), + RerankerResult(document_id="2", query_id="0", score=0.90), + ]) + rag = DocumentRag( + *clients, + reranker_client=reranker, + rerank_diversity_mode="mmr", + rerank_diversity_lambda=0.2, + ) + + await rag.query(query="What is the return policy?", doc_limit=2) + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + + assert passed_docs == [duplicate_a, diverse_c] \ No newline at end of file diff --git a/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py b/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py new file mode 100644 index 00000000..cc95228a --- /dev/null +++ b/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py @@ -0,0 +1,353 @@ +""" +Tests for direction-aware reranker text in GraphRAG hop-and-filter. + +The reranker document text varies by traversal direction: +- From S (subject is the frontier entity): text = "{p} {o}" +- From O (object is the frontier entity): text = "{s} {p}" +- From P (predicate is the frontier entity): text = "{s} {o}" +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock + +from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL +from trustgraph.schema import Term, IRI, LITERAL + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_rag(reranker_results=None): + """Create a mock GraphRag with all clients stubbed.""" + rag = MagicMock() + rag.label_cache = LRUCacheWithTTL() + rag.triples_client = AsyncMock() + rag.reranker_client = AsyncMock() + + # Label lookups return empty (fall back to URI) + rag.triples_client.query.return_value = [] + + if reranker_results is not None: + rag.reranker_client.rerank.return_value = reranker_results + else: + rag.reranker_client.rerank.return_value = [] + + return rag + + +def _make_query(rag, max_path_length=1, edge_limit=25): + return Query( + rag=rag, + collection="test", + verbose=False, + entity_limit=50, + triple_limit=30, + max_subgraph_size=1000, + max_path_length=max_path_length, + edge_limit=edge_limit, + ) + + +def _make_schema_triple(s, p, o): + """Create a mock triple matching the schema interface.""" + t = MagicMock() + t.s = s + t.p = p + t.o = o + return t + + +def _reranker_result(document_id, query_id="0", score=0.9): + r = MagicMock() + r.document_id = str(document_id) + r.query_id = str(query_id) + r.score = score + return r + + +# --------------------------------------------------------------------------- +# Tests: execute_batch_triple_queries direction tracking +# --------------------------------------------------------------------------- + +class TestDirectionTracking: + + @pytest.mark.asyncio + async def test_from_s_direction(self): + """Triples from s=entity queries are tagged FROM_S.""" + triple = _make_schema_triple("ent1", "pred", "obj") + rag = _make_rag() + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + q = _make_query(rag) + + result = await q.execute_batch_triple_queries(["ent1"], 10) + + from_s = [(t, d) for t, d in result if d == Query.FROM_S] + assert len(from_s) == 1 + assert from_s[0][0] is triple + + @pytest.mark.asyncio + async def test_from_o_direction(self): + """Triples from o=entity queries are tagged FROM_O.""" + triple = _make_schema_triple("subj", "pred", "ent1") + rag = _make_rag() + + async def query_stream(s=None, p=None, o=None, **kwargs): + if o is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + q = _make_query(rag) + + result = await q.execute_batch_triple_queries(["ent1"], 10) + + from_o = [(t, d) for t, d in result if d == Query.FROM_O] + assert len(from_o) == 1 + assert from_o[0][0] is triple + + @pytest.mark.asyncio + async def test_from_p_direction(self): + """Triples from p=entity queries are tagged FROM_P.""" + triple = _make_schema_triple("subj", "ent1", "obj") + rag = _make_rag() + + async def query_stream(s=None, p=None, o=None, **kwargs): + if p is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + q = _make_query(rag) + + result = await q.execute_batch_triple_queries(["ent1"], 10) + + from_p = [(t, d) for t, d in result if d == Query.FROM_P] + assert len(from_p) == 1 + assert from_p[0][0] is triple + + +# --------------------------------------------------------------------------- +# Tests: hop_and_filter reranker document text +# --------------------------------------------------------------------------- + +class TestDirectionAwareRerankerText: + + @pytest.mark.asyncio + async def test_from_s_uses_predicate_object(self): + """From-S traversal: reranker text should be '{p} {o}'.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/entity-A"], + concepts=["likes"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + # Text should be "{p} {o}" — the URIs since no labels found + assert len(documents) == 1 + assert documents[0]["text"] == "http://ex/likes http://ex/entity-B" + + @pytest.mark.asyncio + async def test_from_o_uses_subject_predicate(self): + """From-O traversal: reranker text should be '{s} {p}'.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if o is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/entity-B"], + concepts=["likes"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + assert len(documents) == 1 + assert documents[0]["text"] == "http://ex/entity-A http://ex/likes" + + @pytest.mark.asyncio + async def test_from_p_uses_subject_object(self): + """From-P traversal: reranker text should be '{s} {o}'.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if p is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/likes"], + concepts=["entity"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + assert len(documents) == 1 + assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B" + + @pytest.mark.asyncio + async def test_mixed_directions_produce_different_text(self): + """Edges from different directions use different text formats.""" + triple_from_s = _make_schema_triple( + "http://ex/seed", "http://ex/rel", "http://ex/target", + ) + triple_from_o = _make_schema_triple( + "http://ex/other", "http://ex/ref", "http://ex/seed", + ) + + rag = _make_rag(reranker_results=[ + _reranker_result(0), _reranker_result(1), + ]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s == "http://ex/seed": + return [triple_from_s] + if o == "http://ex/seed": + return [triple_from_o] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/seed"], + concepts=["test"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + texts = {d["text"] for d in documents} + + # From S: "{p} {o}" = "http://ex/rel http://ex/target" + assert "http://ex/rel http://ex/target" in texts + # From O: "{s} {p}" = "http://ex/other http://ex/ref" + assert "http://ex/other http://ex/ref" in texts + + @pytest.mark.asyncio + async def test_labels_applied_to_direction_text(self): + """Labels should be resolved and used in the direction-aware text.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + LABEL = "http://www.w3.org/2000/01/rdf-schema#label" + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s is not None and p is None: + return [triple] + return [] + + async def label_query(s=None, p=None, o=None, limit=1, **kwargs): + if p == LABEL: + labels = { + "http://ex/entity-A": "Alice", + "http://ex/likes": "likes", + "http://ex/entity-B": "Bob", + } + if s in labels: + return [MagicMock(o=labels[s])] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + rag.triples_client.query.side_effect = label_query + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/entity-A"], + concepts=["friendship"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + assert len(documents) == 1 + # From S with labels: "{p_label} {o_label}" + assert documents[0]["text"] == "likes Bob" + + @pytest.mark.asyncio + async def test_no_duplicate_text_from_shared_object(self): + """Multiple edges sharing an object should produce distinct texts.""" + triple_a = _make_schema_triple( + "http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors", + ) + triple_b = _make_schema_triple( + "http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors", + ) + + rag = _make_rag(reranker_results=[ + _reranker_result(0), _reranker_result(1), + ]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if o == "http://ex/Processors": + return [triple_a, triple_b] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/Processors"], + concepts=["CPUs"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + texts = [d["text"] for d in documents] + + assert len(texts) == 2 + # From O: "{s} {p}" — subjects differ, so texts differ + assert texts[0] != texts[1] + assert "http://ex/cpu-A" in texts[0] + assert "http://ex/cpu-B" in texts[1] diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index a3730eb9..f2087912 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -20,6 +20,8 @@ from trustgraph.provenance import ( GRAPH_RETRIEVAL, ) +from .rerank import RerankCandidate, mmr_select + # Module logger logger = logging.getLogger(__name__) @@ -150,6 +152,8 @@ class DocumentRag: fetch_chunk, reranker_client=None, verbose=False, + rerank_diversity_mode="none", + rerank_diversity_lambda=0.7, ): self.verbose = verbose @@ -162,6 +166,8 @@ class DocumentRag: # Optional cross-encoder reranker. When None, the retrieval path is # byte-identical to the pre-reranker behaviour. self.reranker_client = reranker_client + self.rerank_diversity_mode = rerank_diversity_mode + self.rerank_diversity_lambda = rerank_diversity_lambda if self.verbose: logger.debug("DocumentRag initialized") @@ -277,30 +283,74 @@ class DocumentRag: # skipped entirely and behaviour is byte-identical to before. reranked = False if self.reranker_client is not None and docs: + use_diversity = self.rerank_diversity_mode == "mmr" + + # Without diversity selection, preserve the existing #1011 + # behavior: ask the reranker for exactly doc_limit results. + # + # With diversity selection enabled, ask the reranker to score the + # full fetched candidate pool first, then let MMR choose the final + # doc_limit context set. + rerank_limit = len(docs) if use_diversity else doc_limit + results = await self.reranker_client.rerank( queries=[{"id": "0", "text": query}], documents=[ {"id": str(i), "text": d} for i, d in enumerate(docs) ], - # Narrow the over-fetched candidate pool down to the final - # doc_limit requested for synthesis. - limit=doc_limit, + limit=rerank_limit, ) - # results are sorted desc by score and truncated to limit by the - # reranker service, so order gives the surviving top-N directly. - order = [int(r.document_id) for r in results] - docs = [docs[i] for i in order] - chunk_ids = [chunk_ids[i] for i in order] + source_docs = docs + source_chunk_ids = chunk_ids + + if use_diversity: + candidates = [ + RerankCandidate( + index=int(r.document_id), + chunk_id=source_chunk_ids[int(r.document_id)], + text=source_docs[int(r.document_id)], + reranker_score=r.score, + ) + for r in results + ] + + selected_candidates = mmr_select( + candidates, + limit=doc_limit, + lambda_mult=self.rerank_diversity_lambda, + ) + + docs = [candidate.text for candidate in selected_candidates] + chunk_ids = [ + candidate.chunk_id for candidate in selected_candidates + ] + + selected_chunks_with_scores = [ + { + "chunk_id": candidate.chunk_id, + "score": candidate.reranker_score, + } + for candidate in selected_candidates + ] + + else: + # results are sorted desc by score and truncated to limit by the + # reranker service, so order gives the surviving top-N directly. + order = [int(r.document_id) for r in results] + docs = [source_docs[i] for i in order] + chunk_ids = [source_chunk_ids[i] for i in order] + + selected_chunks_with_scores = [ + {"chunk_id": chunk_ids[i], "score": r.score} + for i, r in enumerate(results) + ] + reranked = True # Emit chunk-selection (focus) explainability: surviving chunks # with their cross-encoder scores, derived from exploration. if explain_callback: - selected_chunks_with_scores = [ - {"chunk_id": chunk_ids[i], "score": r.score} - for i, r in enumerate(results) - ] foc_triples = set_graph( docrag_chunk_selection_triples( foc_uri, exp_uri, diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 158cbefc..80dfb6b1 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -33,17 +33,23 @@ class Processor(FlowProcessor): # reranking; the rerank step narrows it back down to doc_limit for the # LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit). fetch_limit = params.get("fetch_limit", 0) + rerank_diversity_mode = params.get("rerank_diversity_mode", "none") + rerank_diversity_lambda = params.get("rerank_diversity_lambda", 0.7) super(Processor, self).__init__( **params | { "id": id, "doc_limit": doc_limit, "fetch_limit": fetch_limit, + "rerank_diversity_mode": rerank_diversity_mode, + "rerank_diversity_lambda": rerank_diversity_lambda, } ) self.doc_limit = doc_limit self.fetch_limit = fetch_limit + self.rerank_diversity_mode = rerank_diversity_mode + self.rerank_diversity_lambda = rerank_diversity_lambda self.register_specification( ConsumerSpec( @@ -122,6 +128,8 @@ class Processor(FlowProcessor): fetch_chunk = fetch_chunk, reranker_client = flow("reranker-request"), verbose=True, + rerank_diversity_mode=self.rerank_diversity_mode, + rerank_diversity_lambda=self.rerank_diversity_lambda, ) if v.doc_limit: @@ -277,6 +285,20 @@ class Processor(FlowProcessor): '(default: derive from doc-limit)' ) + parser.add_argument( + '--rerank-diversity-mode', + choices=['none', 'mmr'], + default='none', + help='Optional diversity-aware selection after reranking (default: none)' + ) + + parser.add_argument( + '--rerank-diversity-lambda', + type=float, + default=0.7, + help='MMR relevance/diversity tradeoff, higher values prefer relevance' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py new file mode 100644 index 00000000..a0a7e8ee --- /dev/null +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py @@ -0,0 +1,142 @@ +import re +from dataclasses import dataclass, replace +from typing import List, Sequence, Set + + +@dataclass(frozen=True) +class RerankCandidate: + """ + Candidate chunk after cross-encoder reranking. + + reranker_score is the raw score returned by the reranker backend. It may + not be normalized, so MMR should use normalized_score instead. + """ + index: int + chunk_id: str + text: str + reranker_score: float + normalized_score: float = 0.0 + + +_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+") + + +def _clamp01(value: float) -> float: + return max(0.0, min(1.0, value)) + + +def _token_set(text: str) -> Set[str]: + return set(token.lower() for token in _TOKEN_RE.findall(text or "")) + + +def _jaccard(a: str, b: str) -> float: + a_tokens = _token_set(a) + b_tokens = _token_set(b) + + if not a_tokens or not b_tokens: + return 0.0 + + return len(a_tokens & b_tokens) / len(a_tokens | b_tokens) + + +def normalize_candidate_scores( + candidates: Sequence[RerankCandidate], +) -> List[RerankCandidate]: + """ + Min-max normalize reranker scores within the current candidate set. + + Reranker backends may return different score scales: probabilities, + logits, or prompt-defined scores. MMR needs a stable [0, 1] relevance + signal, so normalize per candidate set instead of assuming a global range. + """ + if not candidates: + return [] + + scores = [float(candidate.reranker_score) for candidate in candidates] + min_score = min(scores) + max_score = max(scores) + + if max_score == min_score: + return [ + replace(candidate, normalized_score=0.5) + for candidate in candidates + ] + + score_range = max_score - min_score + + return [ + replace( + candidate, + normalized_score=(float(candidate.reranker_score) - min_score) / score_range, + ) + for candidate in candidates + ] + + +def _pair_diversity_penalty( + candidate: RerankCandidate, + selected: RerankCandidate, + token_overlap_weight: float, +) -> float: + """ + Pairwise diversity penalty between two candidate chunks. + + The first revision only uses token overlap because the current Document-RAG + reranker document_id is the candidate index, not a source document id. + """ + penalty = token_overlap_weight * _jaccard(candidate.text, selected.text) + return _clamp01(penalty) + + +def mmr_select( + candidates: Sequence[RerankCandidate], + limit: int, + lambda_mult: float = 0.7, + token_overlap_weight: float = 1.0, +) -> List[RerankCandidate]: + """ + Select a diverse final context set using MMR. + + Relevance comes from normalized cross-encoder reranker scores. + Diversity comes from token overlap against already selected chunks. + """ + if limit <= 0: + return [] + + lambda_mult = _clamp01(lambda_mult) + token_overlap_weight = max(0.0, token_overlap_weight) + + remaining = normalize_candidate_scores(candidates) + selected: List[RerankCandidate] = [] + + while remaining and len(selected) < limit: + best_idx = 0 + best_score = None + + for idx, candidate in enumerate(remaining): + relevance = candidate.normalized_score + + if selected: + diversity_penalty = max( + _pair_diversity_penalty( + candidate, + chosen, + token_overlap_weight=token_overlap_weight, + ) + for chosen in selected + ) + else: + diversity_penalty = 0.0 + + mmr_score = ( + lambda_mult * relevance + - (1.0 - lambda_mult) * diversity_penalty + ) + + if best_score is None or mmr_score > best_score: + best_score = mmr_score + best_idx = idx + + selected.append(remaining.pop(best_idx)) + + return selected diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 06c0b5b4..2054cb0f 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -241,38 +241,56 @@ class Query: self.rag.label_cache.put(cache_key, label) return label + FROM_S = "from_s" + FROM_P = "from_p" + FROM_O = "from_o" + async def execute_batch_triple_queries(self, entities, limit_per_entity): - """Execute triple queries for multiple entities concurrently.""" + """Execute triple queries for multiple entities concurrently. + + Returns a list of (triple, direction) tuples where direction + indicates which position the frontier entity occupied. + """ tasks = [] + directions = [] for entity in entities: - tasks.extend([ + tasks.append( self.rag.triples_client.query_stream( s=entity, p=None, o=None, limit=limit_per_entity, collection=self.collection, batch_size=20, g="", ), + ) + directions.append(self.FROM_S) + + tasks.append( self.rag.triples_client.query_stream( s=None, p=entity, o=None, limit=limit_per_entity, collection=self.collection, batch_size=20, g="", ), + ) + directions.append(self.FROM_P) + + tasks.append( self.rag.triples_client.query_stream( s=None, p=None, o=entity, limit=limit_per_entity, collection=self.collection, batch_size=20, g="", - ) - ]) + ), + ) + directions.append(self.FROM_O) results = await asyncio.gather(*tasks, return_exceptions=True) all_triples = [] - for result in results: + for direction, result in zip(directions, results): if not isinstance(result, Exception) and result is not None: - all_triples.extend(result) + all_triples.extend((triple, direction) for triple in result) return all_triples @@ -325,7 +343,8 @@ class Query: # Deduplicate and filter already-seen edges hop_triples = [] hop_term_map = {} - for triple in triples: + hop_directions = {} + for triple, direction in triples: triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) if triple_tuple[1] == LABEL: continue @@ -336,6 +355,7 @@ class Query: hop_term_map[triple_tuple] = ( to_term(triple.s), to_term(triple.p), to_term(triple.o), ) + hop_directions[triple_tuple] = direction if not hop_triples: visited_entities.update(frontier) @@ -361,7 +381,10 @@ class Query: else: label_map[entity] = entity - # Build labeled edges and documents for cross-encoder + # Build labeled edges and documents for cross-encoder. + # The reranker text highlights the NEW information relative + # to the traversal direction: arriving from S means p,o are + # new; from O means s,p are new; from P means s,o are new. labeled_hop = [] for s, p, o in hop_triples: ls = label_map.get(s, s) @@ -369,10 +392,18 @@ class Query: lo = label_map.get(o, o) labeled_hop.append((ls, lp, lo)) - documents = [ - {"id": str(i), "text": f"{lp} {lo}"} - for i, (ls, lp, lo) in enumerate(labeled_hop) - ] + documents = [] + for i, (triple_tuple, (ls, lp, lo)) in enumerate( + zip(hop_triples, labeled_hop) + ): + direction = hop_directions[triple_tuple] + if direction == self.FROM_S: + text = f"{lp} {lo}" + elif direction == self.FROM_O: + text = f"{ls} {lp}" + else: + text = f"{ls} {lo}" + documents.append({"id": str(i), "text": text}) queries = [ {"id": str(i), "text": c}