mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-05 03:12:10 +02:00
Add diversity-aware selection after Document-RAG reranking (#1014)
* Add Document-RAG diversity selection helper * Add optional MMR diversity selection after reranking * Fix Document-RAG diversity test method signatures
This commit is contained in:
parent
db7fdbc652
commit
f04ae5331d
5 changed files with 412 additions and 12 deletions
|
|
@ -476,3 +476,75 @@ class TestRerankActive:
|
|||
await rag.query(query="What is the return policy?")
|
||||
|
||||
assert reranker.calls == []
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Diversity selection: optional MMR after cross-encoder scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self):
|
||||
"""
|
||||
With diversity selection enabled, the cross-encoder should score the full
|
||||
fetched candidate pool before MMR narrows it down to doc_limit.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
reranker = StubReranker([
|
||||
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||
])
|
||||
rag = DocumentRag(
|
||||
*clients,
|
||||
reranker_client=reranker,
|
||||
rerank_diversity_mode="mmr",
|
||||
)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT)
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert len(passed_docs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diversity_mode_selects_less_redundant_context_set(self):
|
||||
"""
|
||||
MMR should use cross-encoder scores as relevance while penalizing redundant
|
||||
chunks, so a slightly lower-scored but less redundant chunk can be selected.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
|
||||
duplicate_a = "apple banana fruit return policy"
|
||||
duplicate_b = "apple banana fruit return policy duplicate"
|
||||
diverse_c = "engine motor vehicle warranty"
|
||||
|
||||
async def mock_fetch(chunk_id):
|
||||
return {
|
||||
CHUNK_A: duplicate_a,
|
||||
CHUNK_B: duplicate_b,
|
||||
CHUNK_C: diverse_c,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
reranker = StubReranker([
|
||||
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||
])
|
||||
rag = DocumentRag(
|
||||
*clients,
|
||||
reranker_client=reranker,
|
||||
rerank_diversity_mode="mmr",
|
||||
rerank_diversity_lambda=0.2,
|
||||
)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
|
||||
assert passed_docs == [duplicate_a, diverse_c]
|
||||
Loading…
Add table
Add a link
Reference in a new issue