mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-05 19:32:11 +02:00
Add diversity-aware selection after Document-RAG reranking (#1014)
* Add Document-RAG diversity selection helper * Add optional MMR diversity selection after reranking * Fix Document-RAG diversity test method signatures
This commit is contained in:
parent
db7fdbc652
commit
f04ae5331d
5 changed files with 412 additions and 12 deletions
|
|
@ -0,0 +1,114 @@
|
|||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
RERANK_PATH = (
|
||||
REPO_ROOT
|
||||
/ "trustgraph-flow"
|
||||
/ "trustgraph"
|
||||
/ "retrieval"
|
||||
/ "document_rag"
|
||||
/ "rerank.py"
|
||||
)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"document_rag_diversity_rerank",
|
||||
RERANK_PATH,
|
||||
)
|
||||
rerank = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(rerank)
|
||||
|
||||
RerankCandidate = rerank.RerankCandidate
|
||||
normalize_candidate_scores = rerank.normalize_candidate_scores
|
||||
mmr_select = rerank.mmr_select
|
||||
_pair_diversity_penalty = rerank._pair_diversity_penalty
|
||||
|
||||
|
||||
def candidate(index, chunk_id, text, score):
|
||||
return RerankCandidate(
|
||||
index=index,
|
||||
chunk_id=chunk_id,
|
||||
text=text,
|
||||
reranker_score=score,
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_candidate_scores_min_max_scales_raw_scores():
|
||||
candidates = [
|
||||
candidate(0, "a", "alpha", -2.0),
|
||||
candidate(1, "b", "beta", 0.0),
|
||||
candidate(2, "c", "gamma", 4.0),
|
||||
]
|
||||
|
||||
normalized = normalize_candidate_scores(candidates)
|
||||
|
||||
assert normalized[0].normalized_score == 0.0
|
||||
assert normalized[1].normalized_score == 1.0 / 3.0
|
||||
assert normalized[2].normalized_score == 1.0
|
||||
|
||||
|
||||
def test_normalize_candidate_scores_handles_equal_scores():
|
||||
candidates = [
|
||||
candidate(0, "a", "alpha", 3.0),
|
||||
candidate(1, "b", "beta", 3.0),
|
||||
candidate(2, "c", "gamma", 3.0),
|
||||
]
|
||||
|
||||
normalized = normalize_candidate_scores(candidates)
|
||||
|
||||
assert [c.normalized_score for c in normalized] == [0.5, 0.5, 0.5]
|
||||
|
||||
|
||||
def test_mmr_select_limits_results():
|
||||
candidates = [
|
||||
candidate(0, "a", "alpha policy", 0.9),
|
||||
candidate(1, "b", "beta refund", 0.8),
|
||||
candidate(2, "c", "gamma shipping", 0.7),
|
||||
]
|
||||
|
||||
selected = mmr_select(candidates, limit=2)
|
||||
|
||||
assert len(selected) == 2
|
||||
|
||||
|
||||
def test_mmr_select_prefers_highest_reranker_score_first():
|
||||
candidates = [
|
||||
candidate(0, "a", "weakly relevant text", 0.1),
|
||||
candidate(1, "b", "strongly relevant answer", 10.0),
|
||||
candidate(2, "c", "medium relevant text", 5.0),
|
||||
]
|
||||
|
||||
selected = mmr_select(candidates, limit=1)
|
||||
|
||||
assert selected[0].chunk_id == "b"
|
||||
|
||||
|
||||
def test_mmr_select_penalizes_near_duplicate_chunks():
|
||||
candidates = [
|
||||
candidate(0, "a", "apple banana fruit return policy", 1.00),
|
||||
candidate(1, "b", "apple banana fruit return policy duplicate", 0.95),
|
||||
candidate(2, "c", "engine motor vehicle warranty", 0.90),
|
||||
]
|
||||
|
||||
selected = mmr_select(
|
||||
candidates,
|
||||
limit=2,
|
||||
lambda_mult=0.2,
|
||||
token_overlap_weight=1.0,
|
||||
)
|
||||
|
||||
assert [c.chunk_id for c in selected] == ["a", "c"]
|
||||
|
||||
|
||||
def test_pair_diversity_penalty_is_clamped():
|
||||
left = candidate(0, "a", "same same same", 1.0)
|
||||
right = candidate(1, "b", "same same same", 0.9)
|
||||
|
||||
penalty = _pair_diversity_penalty(
|
||||
left,
|
||||
right,
|
||||
token_overlap_weight=10.0,
|
||||
)
|
||||
|
||||
assert penalty == 1.0
|
||||
|
|
@ -476,3 +476,75 @@ class TestRerankActive:
|
|||
await rag.query(query="What is the return policy?")
|
||||
|
||||
assert reranker.calls == []
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Diversity selection: optional MMR after cross-encoder scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self):
|
||||
"""
|
||||
With diversity selection enabled, the cross-encoder should score the full
|
||||
fetched candidate pool before MMR narrows it down to doc_limit.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
reranker = StubReranker([
|
||||
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||
])
|
||||
rag = DocumentRag(
|
||||
*clients,
|
||||
reranker_client=reranker,
|
||||
rerank_diversity_mode="mmr",
|
||||
)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT)
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert len(passed_docs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diversity_mode_selects_less_redundant_context_set(self):
|
||||
"""
|
||||
MMR should use cross-encoder scores as relevance while penalizing redundant
|
||||
chunks, so a slightly lower-scored but less redundant chunk can be selected.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
|
||||
duplicate_a = "apple banana fruit return policy"
|
||||
duplicate_b = "apple banana fruit return policy duplicate"
|
||||
diverse_c = "engine motor vehicle warranty"
|
||||
|
||||
async def mock_fetch(chunk_id):
|
||||
return {
|
||||
CHUNK_A: duplicate_a,
|
||||
CHUNK_B: duplicate_b,
|
||||
CHUNK_C: diverse_c,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
reranker = StubReranker([
|
||||
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||
])
|
||||
rag = DocumentRag(
|
||||
*clients,
|
||||
reranker_client=reranker,
|
||||
rerank_diversity_mode="mmr",
|
||||
rerank_diversity_lambda=0.2,
|
||||
)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
|
||||
assert passed_docs == [duplicate_a, diverse_c]
|
||||
Loading…
Add table
Add a link
Reference in a new issue