Add diversity-aware selection after Document-RAG reranking (#1014)

* Add Document-RAG diversity selection helper

* Add optional MMR diversity selection after reranking

* Fix Document-RAG diversity test method signatures
This commit is contained in:
YingzuoLiu 2026-07-03 20:35:42 +08:00 committed by GitHub
parent db7fdbc652
commit f04ae5331d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 412 additions and 12 deletions

View file

@ -476,3 +476,75 @@ class TestRerankActive:
await rag.query(query="What is the return policy?")
assert reranker.calls == []
# ---------------------------------------------------------------------------
# 3. Diversity selection: optional MMR after cross-encoder scoring
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self):
"""
With diversity selection enabled, the cross-encoder should score the full
fetched candidate pool before MMR narrows it down to doc_limit.
"""
clients = build_mock_clients()
reranker = StubReranker([
RerankerResult(document_id="0", query_id="0", score=1.00),
RerankerResult(document_id="1", query_id="0", score=0.95),
RerankerResult(document_id="2", query_id="0", score=0.90),
])
rag = DocumentRag(
*clients,
reranker_client=reranker,
rerank_diversity_mode="mmr",
)
await rag.query(query="What is the return policy?", doc_limit=2)
assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT)
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert len(passed_docs) == 2
@pytest.mark.asyncio
async def test_diversity_mode_selects_less_redundant_context_set(self):
"""
MMR should use cross-encoder scores as relevance while penalizing redundant
chunks, so a slightly lower-scored but less redundant chunk can be selected.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
duplicate_a = "apple banana fruit return policy"
duplicate_b = "apple banana fruit return policy duplicate"
diverse_c = "engine motor vehicle warranty"
async def mock_fetch(chunk_id):
return {
CHUNK_A: duplicate_a,
CHUNK_B: duplicate_b,
CHUNK_C: diverse_c,
}[chunk_id]
fetch_chunk.side_effect = mock_fetch
reranker = StubReranker([
RerankerResult(document_id="0", query_id="0", score=1.00),
RerankerResult(document_id="1", query_id="0", score=0.95),
RerankerResult(document_id="2", query_id="0", score=0.90),
])
rag = DocumentRag(
*clients,
reranker_client=reranker,
rerank_diversity_mode="mmr",
rerank_diversity_lambda=0.2,
)
await rag.query(query="What is the return policy?", doc_limit=2)
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert passed_docs == [duplicate_a, diverse_c]