Merge branch 'release/v2.6'

This commit is contained in:
Cyber MacGeddon 2026-07-03 13:45:02 +01:00
commit 508d0bb5c1
9 changed files with 892 additions and 31 deletions

View file

@ -224,12 +224,27 @@ The current embedding pre-filter represents each edge as
- **Drop commas.** Commas add tokenisation noise without semantic
value.
- **Drop the subject.** The subject identifies which entity the
edge belongs to, but it does not contribute to whether the
edge's content is relevant to the query. The predicate and
object carry the semantic meaning — what relationship exists
and what it connects to. Representing edges as `"{p} {o}"`
produces cleaner cross-encoder matches.
- **Direction-aware text.** The reranker text should highlight
the *new* information relative to the traversal direction.
The frontier entity is already known context — repeating it
adds noise and, when traversing from an object node, causes
many edges to produce identical reranker text (e.g. 18
products sharing the same `hasSubcategory Processors` triple
all collapse to the same string when the subject is dropped).
The text is constructed based on which position the frontier
entity occupied in the triple:
- **From subject** (s=entity): `"{predicate} {object}"`
the subject is known, predicate and object are new.
- **From object** (o=entity): `"{subject} {predicate}"`
the object is known, subject and predicate are new.
- **From predicate** (p=entity): `"{subject} {object}"`
the predicate is known, subject and object are new.
This eliminates the duplicate-text problem that arises when
traversing inward from a shared object node, and gives the
cross-encoder a more informative signal at every hop.
#### Remove the embedding pre-filter (step 3)
@ -389,7 +404,10 @@ no LLM call. These fields are dropped from the Focus entity.
a. Retrieve all edges one hop from the current frontier nodes.
b. Represent each edge as `"{predicate} {object}"`.
b. Represent each edge using direction-aware text: from a
subject node use `"{predicate} {object}"`, from an object
node use `"{subject} {predicate}"`, from a predicate node
use `"{subject} {object}"`.
c. Score edges against the extracted concepts using the
cross-encoder service.

View file

@ -129,6 +129,9 @@ class TestBatchTripleQueries:
# 3 queries, alternating results
assert len(result) == 3
# Each result is a (triple, direction) tuple
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_exception_in_one_query_does_not_block_others(self):
@ -153,6 +156,8 @@ class TestBatchTripleQueries:
# 3 queries: 2 succeed, 1 fails → 2 triples
assert len(result) == 2
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_none_results_filtered(self):
@ -176,6 +181,8 @@ class TestBatchTripleQueries:
# 3 queries: 1 returns None, 2 return triples
assert len(result) == 2
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_empty_entities_no_queries(self):
@ -220,6 +227,80 @@ class TestBatchTripleQueries:
assert calls[2].kwargs["p"] is None
assert calls[2].kwargs["o"] == "ent-1"
@pytest.mark.asyncio
async def test_directions_assigned_correctly(self):
"""Each query position should produce the correct direction tag."""
triple = _make_triple("s", "p", "o")
call_count = 0
async def one_triple_each(**kwargs):
nonlocal call_count
call_count += 1
return [triple]
client = AsyncMock()
client.query_stream = one_triple_each
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
assert len(result) == 3
# Order matches query order: s-position, p-position, o-position
assert result[0][1] == Query.FROM_S
assert result[1][1] == Query.FROM_P
assert result[2][1] == Query.FROM_O
@pytest.mark.asyncio
async def test_directions_correct_for_multiple_entities(self):
"""Direction tags cycle correctly across multiple entities."""
triple = _make_triple("s", "p", "o")
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[triple])
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1", "e2"], limit_per_entity=10
)
assert len(result) == 6
expected_directions = [
Query.FROM_S, Query.FROM_P, Query.FROM_O,
Query.FROM_S, Query.FROM_P, Query.FROM_O,
]
for (_, direction), expected in zip(result, expected_directions):
assert direction == expected
@pytest.mark.asyncio
async def test_direction_preserved_with_multiple_triples(self):
"""All triples from one query share the same direction."""
t1 = _make_triple("a", "p1", "b")
t2 = _make_triple("a", "p2", "c")
call_count = 0
async def multi_results(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return [t1, t2]
return []
client = AsyncMock()
client.query_stream = multi_results
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# First query (FROM_S) returns 2 triples, both should be FROM_S
assert len(result) == 2
assert result[0] == (t1, Query.FROM_S)
assert result[1] == (t2, Query.FROM_S)
class TestLRUCacheWithTTL:

View file

@ -0,0 +1,92 @@
from trustgraph.retrieval.document_rag.rerank import (
RerankCandidate, normalize_candidate_scores, mmr_select,
_pair_diversity_penalty
)
def candidate(index, chunk_id, text, score):
return RerankCandidate(
index=index,
chunk_id=chunk_id,
text=text,
reranker_score=score,
)
def test_normalize_candidate_scores_min_max_scales_raw_scores():
candidates = [
candidate(0, "a", "alpha", -2.0),
candidate(1, "b", "beta", 0.0),
candidate(2, "c", "gamma", 4.0),
]
normalized = normalize_candidate_scores(candidates)
assert normalized[0].normalized_score == 0.0
assert normalized[1].normalized_score == 1.0 / 3.0
assert normalized[2].normalized_score == 1.0
def test_normalize_candidate_scores_handles_equal_scores():
candidates = [
candidate(0, "a", "alpha", 3.0),
candidate(1, "b", "beta", 3.0),
candidate(2, "c", "gamma", 3.0),
]
normalized = normalize_candidate_scores(candidates)
assert [c.normalized_score for c in normalized] == [0.5, 0.5, 0.5]
def test_mmr_select_limits_results():
candidates = [
candidate(0, "a", "alpha policy", 0.9),
candidate(1, "b", "beta refund", 0.8),
candidate(2, "c", "gamma shipping", 0.7),
]
selected = mmr_select(candidates, limit=2)
assert len(selected) == 2
def test_mmr_select_prefers_highest_reranker_score_first():
candidates = [
candidate(0, "a", "weakly relevant text", 0.1),
candidate(1, "b", "strongly relevant answer", 10.0),
candidate(2, "c", "medium relevant text", 5.0),
]
selected = mmr_select(candidates, limit=1)
assert selected[0].chunk_id == "b"
def test_mmr_select_penalizes_near_duplicate_chunks():
candidates = [
candidate(0, "a", "apple banana fruit return policy", 1.00),
candidate(1, "b", "apple banana fruit return policy duplicate", 0.95),
candidate(2, "c", "engine motor vehicle warranty", 0.90),
]
selected = mmr_select(
candidates,
limit=2,
lambda_mult=0.2,
token_overlap_weight=1.0,
)
assert [c.chunk_id for c in selected] == ["a", "c"]
def test_pair_diversity_penalty_is_clamped():
left = candidate(0, "a", "same same same", 1.0)
right = candidate(1, "b", "same same same", 0.9)
penalty = _pair_diversity_penalty(
left,
right,
token_overlap_weight=10.0,
)
assert penalty == 1.0

View file

@ -476,3 +476,75 @@ class TestRerankActive:
await rag.query(query="What is the return policy?")
assert reranker.calls == []
# ---------------------------------------------------------------------------
# 3. Diversity selection: optional MMR after cross-encoder scoring
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self):
"""
With diversity selection enabled, the cross-encoder should score the full
fetched candidate pool before MMR narrows it down to doc_limit.
"""
clients = build_mock_clients()
reranker = StubReranker([
RerankerResult(document_id="0", query_id="0", score=1.00),
RerankerResult(document_id="1", query_id="0", score=0.95),
RerankerResult(document_id="2", query_id="0", score=0.90),
])
rag = DocumentRag(
*clients,
reranker_client=reranker,
rerank_diversity_mode="mmr",
)
await rag.query(query="What is the return policy?", doc_limit=2)
assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT)
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert len(passed_docs) == 2
@pytest.mark.asyncio
async def test_diversity_mode_selects_less_redundant_context_set(self):
"""
MMR should use cross-encoder scores as relevance while penalizing redundant
chunks, so a slightly lower-scored but less redundant chunk can be selected.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
duplicate_a = "apple banana fruit return policy"
duplicate_b = "apple banana fruit return policy duplicate"
diverse_c = "engine motor vehicle warranty"
async def mock_fetch(chunk_id):
return {
CHUNK_A: duplicate_a,
CHUNK_B: duplicate_b,
CHUNK_C: diverse_c,
}[chunk_id]
fetch_chunk.side_effect = mock_fetch
reranker = StubReranker([
RerankerResult(document_id="0", query_id="0", score=1.00),
RerankerResult(document_id="1", query_id="0", score=0.95),
RerankerResult(document_id="2", query_id="0", score=0.90),
])
rag = DocumentRag(
*clients,
reranker_client=reranker,
rerank_diversity_mode="mmr",
rerank_diversity_lambda=0.2,
)
await rag.query(query="What is the return policy?", doc_limit=2)
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert passed_docs == [duplicate_a, diverse_c]

View file

@ -0,0 +1,353 @@
"""
Tests for direction-aware reranker text in GraphRAG hop-and-filter.
The reranker document text varies by traversal direction:
- From S (subject is the frontier entity): text = "{p} {o}"
- From O (object is the frontier entity): text = "{s} {p}"
- From P (predicate is the frontier entity): text = "{s} {o}"
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
from trustgraph.schema import Term, IRI, LITERAL
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_rag(reranker_results=None):
"""Create a mock GraphRag with all clients stubbed."""
rag = MagicMock()
rag.label_cache = LRUCacheWithTTL()
rag.triples_client = AsyncMock()
rag.reranker_client = AsyncMock()
# Label lookups return empty (fall back to URI)
rag.triples_client.query.return_value = []
if reranker_results is not None:
rag.reranker_client.rerank.return_value = reranker_results
else:
rag.reranker_client.rerank.return_value = []
return rag
def _make_query(rag, max_path_length=1, edge_limit=25):
return Query(
rag=rag,
collection="test",
verbose=False,
entity_limit=50,
triple_limit=30,
max_subgraph_size=1000,
max_path_length=max_path_length,
edge_limit=edge_limit,
)
def _make_schema_triple(s, p, o):
"""Create a mock triple matching the schema interface."""
t = MagicMock()
t.s = s
t.p = p
t.o = o
return t
def _reranker_result(document_id, query_id="0", score=0.9):
r = MagicMock()
r.document_id = str(document_id)
r.query_id = str(query_id)
r.score = score
return r
# ---------------------------------------------------------------------------
# Tests: execute_batch_triple_queries direction tracking
# ---------------------------------------------------------------------------
class TestDirectionTracking:
@pytest.mark.asyncio
async def test_from_s_direction(self):
"""Triples from s=entity queries are tagged FROM_S."""
triple = _make_schema_triple("ent1", "pred", "obj")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_s = [(t, d) for t, d in result if d == Query.FROM_S]
assert len(from_s) == 1
assert from_s[0][0] is triple
@pytest.mark.asyncio
async def test_from_o_direction(self):
"""Triples from o=entity queries are tagged FROM_O."""
triple = _make_schema_triple("subj", "pred", "ent1")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if o is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_o = [(t, d) for t, d in result if d == Query.FROM_O]
assert len(from_o) == 1
assert from_o[0][0] is triple
@pytest.mark.asyncio
async def test_from_p_direction(self):
"""Triples from p=entity queries are tagged FROM_P."""
triple = _make_schema_triple("subj", "ent1", "obj")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if p is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_p = [(t, d) for t, d in result if d == Query.FROM_P]
assert len(from_p) == 1
assert from_p[0][0] is triple
# ---------------------------------------------------------------------------
# Tests: hop_and_filter reranker document text
# ---------------------------------------------------------------------------
class TestDirectionAwareRerankerText:
@pytest.mark.asyncio
async def test_from_s_uses_predicate_object(self):
"""From-S traversal: reranker text should be '{p} {o}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-A"],
concepts=["likes"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
# Text should be "{p} {o}" — the URIs since no labels found
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
@pytest.mark.asyncio
async def test_from_o_uses_subject_predicate(self):
"""From-O traversal: reranker text should be '{s} {p}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if o is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-B"],
concepts=["likes"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
@pytest.mark.asyncio
async def test_from_p_uses_subject_object(self):
"""From-P traversal: reranker text should be '{s} {o}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if p is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/likes"],
concepts=["entity"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
@pytest.mark.asyncio
async def test_mixed_directions_produce_different_text(self):
"""Edges from different directions use different text formats."""
triple_from_s = _make_schema_triple(
"http://ex/seed", "http://ex/rel", "http://ex/target",
)
triple_from_o = _make_schema_triple(
"http://ex/other", "http://ex/ref", "http://ex/seed",
)
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
async def query_stream(s=None, p=None, o=None, **kwargs):
if s == "http://ex/seed":
return [triple_from_s]
if o == "http://ex/seed":
return [triple_from_o]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/seed"],
concepts=["test"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
texts = {d["text"] for d in documents}
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
assert "http://ex/rel http://ex/target" in texts
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
assert "http://ex/other http://ex/ref" in texts
@pytest.mark.asyncio
async def test_labels_applied_to_direction_text(self):
"""Labels should be resolved and used in the direction-aware text."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None and p is None:
return [triple]
return []
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
if p == LABEL:
labels = {
"http://ex/entity-A": "Alice",
"http://ex/likes": "likes",
"http://ex/entity-B": "Bob",
}
if s in labels:
return [MagicMock(o=labels[s])]
return []
rag.triples_client.query_stream.side_effect = query_stream
rag.triples_client.query.side_effect = label_query
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-A"],
concepts=["friendship"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
# From S with labels: "{p_label} {o_label}"
assert documents[0]["text"] == "likes Bob"
@pytest.mark.asyncio
async def test_no_duplicate_text_from_shared_object(self):
"""Multiple edges sharing an object should produce distinct texts."""
triple_a = _make_schema_triple(
"http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors",
)
triple_b = _make_schema_triple(
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
)
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
async def query_stream(s=None, p=None, o=None, **kwargs):
if o == "http://ex/Processors":
return [triple_a, triple_b]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/Processors"],
concepts=["CPUs"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
texts = [d["text"] for d in documents]
assert len(texts) == 2
# From O: "{s} {p}" — subjects differ, so texts differ
assert texts[0] != texts[1]
assert "http://ex/cpu-A" in texts[0]
assert "http://ex/cpu-B" in texts[1]

View file

@ -20,6 +20,8 @@ from trustgraph.provenance import (
GRAPH_RETRIEVAL,
)
from .rerank import RerankCandidate, mmr_select
# Module logger
logger = logging.getLogger(__name__)
@ -150,6 +152,8 @@ class DocumentRag:
fetch_chunk,
reranker_client=None,
verbose=False,
rerank_diversity_mode="none",
rerank_diversity_lambda=0.7,
):
self.verbose = verbose
@ -162,6 +166,8 @@ class DocumentRag:
# Optional cross-encoder reranker. When None, the retrieval path is
# byte-identical to the pre-reranker behaviour.
self.reranker_client = reranker_client
self.rerank_diversity_mode = rerank_diversity_mode
self.rerank_diversity_lambda = rerank_diversity_lambda
if self.verbose:
logger.debug("DocumentRag initialized")
@ -277,30 +283,74 @@ class DocumentRag:
# skipped entirely and behaviour is byte-identical to before.
reranked = False
if self.reranker_client is not None and docs:
use_diversity = self.rerank_diversity_mode == "mmr"
# Without diversity selection, preserve the existing #1011
# behavior: ask the reranker for exactly doc_limit results.
#
# With diversity selection enabled, ask the reranker to score the
# full fetched candidate pool first, then let MMR choose the final
# doc_limit context set.
rerank_limit = len(docs) if use_diversity else doc_limit
results = await self.reranker_client.rerank(
queries=[{"id": "0", "text": query}],
documents=[
{"id": str(i), "text": d} for i, d in enumerate(docs)
],
# Narrow the over-fetched candidate pool down to the final
# doc_limit requested for synthesis.
limit=doc_limit,
limit=rerank_limit,
)
# results are sorted desc by score and truncated to limit by the
# reranker service, so order gives the surviving top-N directly.
order = [int(r.document_id) for r in results]
docs = [docs[i] for i in order]
chunk_ids = [chunk_ids[i] for i in order]
source_docs = docs
source_chunk_ids = chunk_ids
if use_diversity:
candidates = [
RerankCandidate(
index=int(r.document_id),
chunk_id=source_chunk_ids[int(r.document_id)],
text=source_docs[int(r.document_id)],
reranker_score=r.score,
)
for r in results
]
selected_candidates = mmr_select(
candidates,
limit=doc_limit,
lambda_mult=self.rerank_diversity_lambda,
)
docs = [candidate.text for candidate in selected_candidates]
chunk_ids = [
candidate.chunk_id for candidate in selected_candidates
]
selected_chunks_with_scores = [
{
"chunk_id": candidate.chunk_id,
"score": candidate.reranker_score,
}
for candidate in selected_candidates
]
else:
# results are sorted desc by score and truncated to limit by the
# reranker service, so order gives the surviving top-N directly.
order = [int(r.document_id) for r in results]
docs = [source_docs[i] for i in order]
chunk_ids = [source_chunk_ids[i] for i in order]
selected_chunks_with_scores = [
{"chunk_id": chunk_ids[i], "score": r.score}
for i, r in enumerate(results)
]
reranked = True
# Emit chunk-selection (focus) explainability: surviving chunks
# with their cross-encoder scores, derived from exploration.
if explain_callback:
selected_chunks_with_scores = [
{"chunk_id": chunk_ids[i], "score": r.score}
for i, r in enumerate(results)
]
foc_triples = set_graph(
docrag_chunk_selection_triples(
foc_uri, exp_uri,

View file

@ -33,17 +33,23 @@ class Processor(FlowProcessor):
# reranking; the rerank step narrows it back down to doc_limit for the
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
fetch_limit = params.get("fetch_limit", 0)
rerank_diversity_mode = params.get("rerank_diversity_mode", "none")
rerank_diversity_lambda = params.get("rerank_diversity_lambda", 0.7)
super(Processor, self).__init__(
**params | {
"id": id,
"doc_limit": doc_limit,
"fetch_limit": fetch_limit,
"rerank_diversity_mode": rerank_diversity_mode,
"rerank_diversity_lambda": rerank_diversity_lambda,
}
)
self.doc_limit = doc_limit
self.fetch_limit = fetch_limit
self.rerank_diversity_mode = rerank_diversity_mode
self.rerank_diversity_lambda = rerank_diversity_lambda
self.register_specification(
ConsumerSpec(
@ -122,6 +128,8 @@ class Processor(FlowProcessor):
fetch_chunk = fetch_chunk,
reranker_client = flow("reranker-request"),
verbose=True,
rerank_diversity_mode=self.rerank_diversity_mode,
rerank_diversity_lambda=self.rerank_diversity_lambda,
)
if v.doc_limit:
@ -277,6 +285,20 @@ class Processor(FlowProcessor):
'(default: derive from doc-limit)'
)
parser.add_argument(
'--rerank-diversity-mode',
choices=['none', 'mmr'],
default='none',
help='Optional diversity-aware selection after reranking (default: none)'
)
parser.add_argument(
'--rerank-diversity-lambda',
type=float,
default=0.7,
help='MMR relevance/diversity tradeoff, higher values prefer relevance'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,142 @@
import re
from dataclasses import dataclass, replace
from typing import List, Sequence, Set
@dataclass(frozen=True)
class RerankCandidate:
"""
Candidate chunk after cross-encoder reranking.
reranker_score is the raw score returned by the reranker backend. It may
not be normalized, so MMR should use normalized_score instead.
"""
index: int
chunk_id: str
text: str
reranker_score: float
normalized_score: float = 0.0
_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")
def _clamp01(value: float) -> float:
return max(0.0, min(1.0, value))
def _token_set(text: str) -> Set[str]:
return set(token.lower() for token in _TOKEN_RE.findall(text or ""))
def _jaccard(a: str, b: str) -> float:
a_tokens = _token_set(a)
b_tokens = _token_set(b)
if not a_tokens or not b_tokens:
return 0.0
return len(a_tokens & b_tokens) / len(a_tokens | b_tokens)
def normalize_candidate_scores(
candidates: Sequence[RerankCandidate],
) -> List[RerankCandidate]:
"""
Min-max normalize reranker scores within the current candidate set.
Reranker backends may return different score scales: probabilities,
logits, or prompt-defined scores. MMR needs a stable [0, 1] relevance
signal, so normalize per candidate set instead of assuming a global range.
"""
if not candidates:
return []
scores = [float(candidate.reranker_score) for candidate in candidates]
min_score = min(scores)
max_score = max(scores)
if max_score == min_score:
return [
replace(candidate, normalized_score=0.5)
for candidate in candidates
]
score_range = max_score - min_score
return [
replace(
candidate,
normalized_score=(float(candidate.reranker_score) - min_score) / score_range,
)
for candidate in candidates
]
def _pair_diversity_penalty(
candidate: RerankCandidate,
selected: RerankCandidate,
token_overlap_weight: float,
) -> float:
"""
Pairwise diversity penalty between two candidate chunks.
The first revision only uses token overlap because the current Document-RAG
reranker document_id is the candidate index, not a source document id.
"""
penalty = token_overlap_weight * _jaccard(candidate.text, selected.text)
return _clamp01(penalty)
def mmr_select(
candidates: Sequence[RerankCandidate],
limit: int,
lambda_mult: float = 0.7,
token_overlap_weight: float = 1.0,
) -> List[RerankCandidate]:
"""
Select a diverse final context set using MMR.
Relevance comes from normalized cross-encoder reranker scores.
Diversity comes from token overlap against already selected chunks.
"""
if limit <= 0:
return []
lambda_mult = _clamp01(lambda_mult)
token_overlap_weight = max(0.0, token_overlap_weight)
remaining = normalize_candidate_scores(candidates)
selected: List[RerankCandidate] = []
while remaining and len(selected) < limit:
best_idx = 0
best_score = None
for idx, candidate in enumerate(remaining):
relevance = candidate.normalized_score
if selected:
diversity_penalty = max(
_pair_diversity_penalty(
candidate,
chosen,
token_overlap_weight=token_overlap_weight,
)
for chosen in selected
)
else:
diversity_penalty = 0.0
mmr_score = (
lambda_mult * relevance
- (1.0 - lambda_mult) * diversity_penalty
)
if best_score is None or mmr_score > best_score:
best_score = mmr_score
best_idx = idx
selected.append(remaining.pop(best_idx))
return selected

View file

@ -241,38 +241,56 @@ class Query:
self.rag.label_cache.put(cache_key, label)
return label
FROM_S = "from_s"
FROM_P = "from_p"
FROM_O = "from_o"
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently."""
"""Execute triple queries for multiple entities concurrently.
Returns a list of (triple, direction) tuples where direction
indicates which position the frontier entity occupied.
"""
tasks = []
directions = []
for entity in entities:
tasks.extend([
tasks.append(
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
limit=limit_per_entity,
collection=self.collection,
batch_size=20, g="",
),
)
directions.append(self.FROM_S)
tasks.append(
self.rag.triples_client.query_stream(
s=None, p=entity, o=None,
limit=limit_per_entity,
collection=self.collection,
batch_size=20, g="",
),
)
directions.append(self.FROM_P)
tasks.append(
self.rag.triples_client.query_stream(
s=None, p=None, o=entity,
limit=limit_per_entity,
collection=self.collection,
batch_size=20, g="",
)
])
),
)
directions.append(self.FROM_O)
results = await asyncio.gather(*tasks, return_exceptions=True)
all_triples = []
for result in results:
for direction, result in zip(directions, results):
if not isinstance(result, Exception) and result is not None:
all_triples.extend(result)
all_triples.extend((triple, direction) for triple in result)
return all_triples
@ -325,7 +343,8 @@ class Query:
# Deduplicate and filter already-seen edges
hop_triples = []
hop_term_map = {}
for triple in triples:
hop_directions = {}
for triple, direction in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
if triple_tuple[1] == LABEL:
continue
@ -336,6 +355,7 @@ class Query:
hop_term_map[triple_tuple] = (
to_term(triple.s), to_term(triple.p), to_term(triple.o),
)
hop_directions[triple_tuple] = direction
if not hop_triples:
visited_entities.update(frontier)
@ -361,7 +381,10 @@ class Query:
else:
label_map[entity] = entity
# Build labeled edges and documents for cross-encoder
# Build labeled edges and documents for cross-encoder.
# The reranker text highlights the NEW information relative
# to the traversal direction: arriving from S means p,o are
# new; from O means s,p are new; from P means s,o are new.
labeled_hop = []
for s, p, o in hop_triples:
ls = label_map.get(s, s)
@ -369,10 +392,18 @@ class Query:
lo = label_map.get(o, o)
labeled_hop.append((ls, lp, lo))
documents = [
{"id": str(i), "text": f"{lp} {lo}"}
for i, (ls, lp, lo) in enumerate(labeled_hop)
]
documents = []
for i, (triple_tuple, (ls, lp, lo)) in enumerate(
zip(hop_triples, labeled_hop)
):
direction = hop_directions[triple_tuple]
if direction == self.FROM_S:
text = f"{lp} {lo}"
elif direction == self.FROM_O:
text = f"{ls} {lp}"
else:
text = f"{ls} {lo}"
documents.append({"id": str(i), "text": text})
queries = [
{"id": str(i), "text": c}