mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 23:11:00 +02:00
Add diversity-aware selection after Document-RAG reranking (#1014)
* Add Document-RAG diversity selection helper * Add optional MMR diversity selection after reranking * Fix Document-RAG diversity test method signatures
This commit is contained in:
parent
db7fdbc652
commit
f04ae5331d
5 changed files with 412 additions and 12 deletions
|
|
@ -0,0 +1,114 @@
|
||||||
|
import importlib.util
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||||
|
RERANK_PATH = (
|
||||||
|
REPO_ROOT
|
||||||
|
/ "trustgraph-flow"
|
||||||
|
/ "trustgraph"
|
||||||
|
/ "retrieval"
|
||||||
|
/ "document_rag"
|
||||||
|
/ "rerank.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"document_rag_diversity_rerank",
|
||||||
|
RERANK_PATH,
|
||||||
|
)
|
||||||
|
rerank = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(rerank)
|
||||||
|
|
||||||
|
RerankCandidate = rerank.RerankCandidate
|
||||||
|
normalize_candidate_scores = rerank.normalize_candidate_scores
|
||||||
|
mmr_select = rerank.mmr_select
|
||||||
|
_pair_diversity_penalty = rerank._pair_diversity_penalty
|
||||||
|
|
||||||
|
|
||||||
|
def candidate(index, chunk_id, text, score):
|
||||||
|
return RerankCandidate(
|
||||||
|
index=index,
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
text=text,
|
||||||
|
reranker_score=score,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_candidate_scores_min_max_scales_raw_scores():
|
||||||
|
candidates = [
|
||||||
|
candidate(0, "a", "alpha", -2.0),
|
||||||
|
candidate(1, "b", "beta", 0.0),
|
||||||
|
candidate(2, "c", "gamma", 4.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
normalized = normalize_candidate_scores(candidates)
|
||||||
|
|
||||||
|
assert normalized[0].normalized_score == 0.0
|
||||||
|
assert normalized[1].normalized_score == 1.0 / 3.0
|
||||||
|
assert normalized[2].normalized_score == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_candidate_scores_handles_equal_scores():
|
||||||
|
candidates = [
|
||||||
|
candidate(0, "a", "alpha", 3.0),
|
||||||
|
candidate(1, "b", "beta", 3.0),
|
||||||
|
candidate(2, "c", "gamma", 3.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
normalized = normalize_candidate_scores(candidates)
|
||||||
|
|
||||||
|
assert [c.normalized_score for c in normalized] == [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmr_select_limits_results():
|
||||||
|
candidates = [
|
||||||
|
candidate(0, "a", "alpha policy", 0.9),
|
||||||
|
candidate(1, "b", "beta refund", 0.8),
|
||||||
|
candidate(2, "c", "gamma shipping", 0.7),
|
||||||
|
]
|
||||||
|
|
||||||
|
selected = mmr_select(candidates, limit=2)
|
||||||
|
|
||||||
|
assert len(selected) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmr_select_prefers_highest_reranker_score_first():
|
||||||
|
candidates = [
|
||||||
|
candidate(0, "a", "weakly relevant text", 0.1),
|
||||||
|
candidate(1, "b", "strongly relevant answer", 10.0),
|
||||||
|
candidate(2, "c", "medium relevant text", 5.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
selected = mmr_select(candidates, limit=1)
|
||||||
|
|
||||||
|
assert selected[0].chunk_id == "b"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmr_select_penalizes_near_duplicate_chunks():
|
||||||
|
candidates = [
|
||||||
|
candidate(0, "a", "apple banana fruit return policy", 1.00),
|
||||||
|
candidate(1, "b", "apple banana fruit return policy duplicate", 0.95),
|
||||||
|
candidate(2, "c", "engine motor vehicle warranty", 0.90),
|
||||||
|
]
|
||||||
|
|
||||||
|
selected = mmr_select(
|
||||||
|
candidates,
|
||||||
|
limit=2,
|
||||||
|
lambda_mult=0.2,
|
||||||
|
token_overlap_weight=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [c.chunk_id for c in selected] == ["a", "c"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pair_diversity_penalty_is_clamped():
|
||||||
|
left = candidate(0, "a", "same same same", 1.0)
|
||||||
|
right = candidate(1, "b", "same same same", 0.9)
|
||||||
|
|
||||||
|
penalty = _pair_diversity_penalty(
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
token_overlap_weight=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert penalty == 1.0
|
||||||
|
|
@ -476,3 +476,75 @@ class TestRerankActive:
|
||||||
await rag.query(query="What is the return policy?")
|
await rag.query(query="What is the return policy?")
|
||||||
|
|
||||||
assert reranker.calls == []
|
assert reranker.calls == []
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 3. Diversity selection: optional MMR after cross-encoder scoring
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self):
|
||||||
|
"""
|
||||||
|
With diversity selection enabled, the cross-encoder should score the full
|
||||||
|
fetched candidate pool before MMR narrows it down to doc_limit.
|
||||||
|
"""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
reranker = StubReranker([
|
||||||
|
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||||
|
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||||
|
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||||
|
])
|
||||||
|
rag = DocumentRag(
|
||||||
|
*clients,
|
||||||
|
reranker_client=reranker,
|
||||||
|
rerank_diversity_mode="mmr",
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||||
|
|
||||||
|
assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT)
|
||||||
|
|
||||||
|
call = rag.prompt_client.document_prompt.call_args
|
||||||
|
passed_docs = call.kwargs["documents"]
|
||||||
|
assert len(passed_docs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diversity_mode_selects_less_redundant_context_set(self):
|
||||||
|
"""
|
||||||
|
MMR should use cross-encoder scores as relevance while penalizing redundant
|
||||||
|
chunks, so a slightly lower-scored but less redundant chunk can be selected.
|
||||||
|
"""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||||
|
|
||||||
|
duplicate_a = "apple banana fruit return policy"
|
||||||
|
duplicate_b = "apple banana fruit return policy duplicate"
|
||||||
|
diverse_c = "engine motor vehicle warranty"
|
||||||
|
|
||||||
|
async def mock_fetch(chunk_id):
|
||||||
|
return {
|
||||||
|
CHUNK_A: duplicate_a,
|
||||||
|
CHUNK_B: duplicate_b,
|
||||||
|
CHUNK_C: diverse_c,
|
||||||
|
}[chunk_id]
|
||||||
|
|
||||||
|
fetch_chunk.side_effect = mock_fetch
|
||||||
|
|
||||||
|
reranker = StubReranker([
|
||||||
|
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||||
|
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||||
|
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||||
|
])
|
||||||
|
rag = DocumentRag(
|
||||||
|
*clients,
|
||||||
|
reranker_client=reranker,
|
||||||
|
rerank_diversity_mode="mmr",
|
||||||
|
rerank_diversity_lambda=0.2,
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||||
|
|
||||||
|
call = rag.prompt_client.document_prompt.call_args
|
||||||
|
passed_docs = call.kwargs["documents"]
|
||||||
|
|
||||||
|
assert passed_docs == [duplicate_a, diverse_c]
|
||||||
|
|
@ -20,6 +20,8 @@ from trustgraph.provenance import (
|
||||||
GRAPH_RETRIEVAL,
|
GRAPH_RETRIEVAL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .rerank import RerankCandidate, mmr_select
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -150,6 +152,8 @@ class DocumentRag:
|
||||||
fetch_chunk,
|
fetch_chunk,
|
||||||
reranker_client=None,
|
reranker_client=None,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
|
rerank_diversity_mode="none",
|
||||||
|
rerank_diversity_lambda=0.7,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
@ -162,6 +166,8 @@ class DocumentRag:
|
||||||
# Optional cross-encoder reranker. When None, the retrieval path is
|
# Optional cross-encoder reranker. When None, the retrieval path is
|
||||||
# byte-identical to the pre-reranker behaviour.
|
# byte-identical to the pre-reranker behaviour.
|
||||||
self.reranker_client = reranker_client
|
self.reranker_client = reranker_client
|
||||||
|
self.rerank_diversity_mode = rerank_diversity_mode
|
||||||
|
self.rerank_diversity_lambda = rerank_diversity_lambda
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("DocumentRag initialized")
|
logger.debug("DocumentRag initialized")
|
||||||
|
|
@ -277,30 +283,74 @@ class DocumentRag:
|
||||||
# skipped entirely and behaviour is byte-identical to before.
|
# skipped entirely and behaviour is byte-identical to before.
|
||||||
reranked = False
|
reranked = False
|
||||||
if self.reranker_client is not None and docs:
|
if self.reranker_client is not None and docs:
|
||||||
|
use_diversity = self.rerank_diversity_mode == "mmr"
|
||||||
|
|
||||||
|
# Without diversity selection, preserve the existing #1011
|
||||||
|
# behavior: ask the reranker for exactly doc_limit results.
|
||||||
|
#
|
||||||
|
# With diversity selection enabled, ask the reranker to score the
|
||||||
|
# full fetched candidate pool first, then let MMR choose the final
|
||||||
|
# doc_limit context set.
|
||||||
|
rerank_limit = len(docs) if use_diversity else doc_limit
|
||||||
|
|
||||||
results = await self.reranker_client.rerank(
|
results = await self.reranker_client.rerank(
|
||||||
queries=[{"id": "0", "text": query}],
|
queries=[{"id": "0", "text": query}],
|
||||||
documents=[
|
documents=[
|
||||||
{"id": str(i), "text": d} for i, d in enumerate(docs)
|
{"id": str(i), "text": d} for i, d in enumerate(docs)
|
||||||
],
|
],
|
||||||
# Narrow the over-fetched candidate pool down to the final
|
limit=rerank_limit,
|
||||||
# doc_limit requested for synthesis.
|
|
||||||
limit=doc_limit,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# results are sorted desc by score and truncated to limit by the
|
source_docs = docs
|
||||||
# reranker service, so order gives the surviving top-N directly.
|
source_chunk_ids = chunk_ids
|
||||||
order = [int(r.document_id) for r in results]
|
|
||||||
docs = [docs[i] for i in order]
|
if use_diversity:
|
||||||
chunk_ids = [chunk_ids[i] for i in order]
|
candidates = [
|
||||||
|
RerankCandidate(
|
||||||
|
index=int(r.document_id),
|
||||||
|
chunk_id=source_chunk_ids[int(r.document_id)],
|
||||||
|
text=source_docs[int(r.document_id)],
|
||||||
|
reranker_score=r.score,
|
||||||
|
)
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
selected_candidates = mmr_select(
|
||||||
|
candidates,
|
||||||
|
limit=doc_limit,
|
||||||
|
lambda_mult=self.rerank_diversity_lambda,
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = [candidate.text for candidate in selected_candidates]
|
||||||
|
chunk_ids = [
|
||||||
|
candidate.chunk_id for candidate in selected_candidates
|
||||||
|
]
|
||||||
|
|
||||||
|
selected_chunks_with_scores = [
|
||||||
|
{
|
||||||
|
"chunk_id": candidate.chunk_id,
|
||||||
|
"score": candidate.reranker_score,
|
||||||
|
}
|
||||||
|
for candidate in selected_candidates
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# results are sorted desc by score and truncated to limit by the
|
||||||
|
# reranker service, so order gives the surviving top-N directly.
|
||||||
|
order = [int(r.document_id) for r in results]
|
||||||
|
docs = [source_docs[i] for i in order]
|
||||||
|
chunk_ids = [source_chunk_ids[i] for i in order]
|
||||||
|
|
||||||
|
selected_chunks_with_scores = [
|
||||||
|
{"chunk_id": chunk_ids[i], "score": r.score}
|
||||||
|
for i, r in enumerate(results)
|
||||||
|
]
|
||||||
|
|
||||||
reranked = True
|
reranked = True
|
||||||
|
|
||||||
# Emit chunk-selection (focus) explainability: surviving chunks
|
# Emit chunk-selection (focus) explainability: surviving chunks
|
||||||
# with their cross-encoder scores, derived from exploration.
|
# with their cross-encoder scores, derived from exploration.
|
||||||
if explain_callback:
|
if explain_callback:
|
||||||
selected_chunks_with_scores = [
|
|
||||||
{"chunk_id": chunk_ids[i], "score": r.score}
|
|
||||||
for i, r in enumerate(results)
|
|
||||||
]
|
|
||||||
foc_triples = set_graph(
|
foc_triples = set_graph(
|
||||||
docrag_chunk_selection_triples(
|
docrag_chunk_selection_triples(
|
||||||
foc_uri, exp_uri,
|
foc_uri, exp_uri,
|
||||||
|
|
|
||||||
|
|
@ -33,17 +33,23 @@ class Processor(FlowProcessor):
|
||||||
# reranking; the rerank step narrows it back down to doc_limit for the
|
# reranking; the rerank step narrows it back down to doc_limit for the
|
||||||
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
|
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
|
||||||
fetch_limit = params.get("fetch_limit", 0)
|
fetch_limit = params.get("fetch_limit", 0)
|
||||||
|
rerank_diversity_mode = params.get("rerank_diversity_mode", "none")
|
||||||
|
rerank_diversity_lambda = params.get("rerank_diversity_lambda", 0.7)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
"doc_limit": doc_limit,
|
"doc_limit": doc_limit,
|
||||||
"fetch_limit": fetch_limit,
|
"fetch_limit": fetch_limit,
|
||||||
|
"rerank_diversity_mode": rerank_diversity_mode,
|
||||||
|
"rerank_diversity_lambda": rerank_diversity_lambda,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.doc_limit = doc_limit
|
self.doc_limit = doc_limit
|
||||||
self.fetch_limit = fetch_limit
|
self.fetch_limit = fetch_limit
|
||||||
|
self.rerank_diversity_mode = rerank_diversity_mode
|
||||||
|
self.rerank_diversity_lambda = rerank_diversity_lambda
|
||||||
|
|
||||||
self.register_specification(
|
self.register_specification(
|
||||||
ConsumerSpec(
|
ConsumerSpec(
|
||||||
|
|
@ -122,6 +128,8 @@ class Processor(FlowProcessor):
|
||||||
fetch_chunk = fetch_chunk,
|
fetch_chunk = fetch_chunk,
|
||||||
reranker_client = flow("reranker-request"),
|
reranker_client = flow("reranker-request"),
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
rerank_diversity_mode=self.rerank_diversity_mode,
|
||||||
|
rerank_diversity_lambda=self.rerank_diversity_lambda,
|
||||||
)
|
)
|
||||||
|
|
||||||
if v.doc_limit:
|
if v.doc_limit:
|
||||||
|
|
@ -277,6 +285,20 @@ class Processor(FlowProcessor):
|
||||||
'(default: derive from doc-limit)'
|
'(default: derive from doc-limit)'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--rerank-diversity-mode',
|
||||||
|
choices=['none', 'mmr'],
|
||||||
|
default='none',
|
||||||
|
help='Optional diversity-aware selection after reranking (default: none)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--rerank-diversity-lambda',
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
help='MMR relevance/diversity tradeoff, higher values prefer relevance'
|
||||||
|
)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
Processor.launch(default_ident, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
|
||||||
142
trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py
Normal file
142
trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py
Normal file
|
|
@ -0,0 +1,142 @@
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import List, Sequence, Set
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RerankCandidate:
|
||||||
|
"""
|
||||||
|
Candidate chunk after cross-encoder reranking.
|
||||||
|
|
||||||
|
reranker_score is the raw score returned by the reranker backend. It may
|
||||||
|
not be normalized, so MMR should use normalized_score instead.
|
||||||
|
"""
|
||||||
|
index: int
|
||||||
|
chunk_id: str
|
||||||
|
text: str
|
||||||
|
reranker_score: float
|
||||||
|
normalized_score: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||||
|
|
||||||
|
|
||||||
|
def _clamp01(value: float) -> float:
|
||||||
|
return max(0.0, min(1.0, value))
|
||||||
|
|
||||||
|
|
||||||
|
def _token_set(text: str) -> Set[str]:
|
||||||
|
return set(token.lower() for token in _TOKEN_RE.findall(text or ""))
|
||||||
|
|
||||||
|
|
||||||
|
def _jaccard(a: str, b: str) -> float:
|
||||||
|
a_tokens = _token_set(a)
|
||||||
|
b_tokens = _token_set(b)
|
||||||
|
|
||||||
|
if not a_tokens or not b_tokens:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return len(a_tokens & b_tokens) / len(a_tokens | b_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_candidate_scores(
|
||||||
|
candidates: Sequence[RerankCandidate],
|
||||||
|
) -> List[RerankCandidate]:
|
||||||
|
"""
|
||||||
|
Min-max normalize reranker scores within the current candidate set.
|
||||||
|
|
||||||
|
Reranker backends may return different score scales: probabilities,
|
||||||
|
logits, or prompt-defined scores. MMR needs a stable [0, 1] relevance
|
||||||
|
signal, so normalize per candidate set instead of assuming a global range.
|
||||||
|
"""
|
||||||
|
if not candidates:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scores = [float(candidate.reranker_score) for candidate in candidates]
|
||||||
|
min_score = min(scores)
|
||||||
|
max_score = max(scores)
|
||||||
|
|
||||||
|
if max_score == min_score:
|
||||||
|
return [
|
||||||
|
replace(candidate, normalized_score=0.5)
|
||||||
|
for candidate in candidates
|
||||||
|
]
|
||||||
|
|
||||||
|
score_range = max_score - min_score
|
||||||
|
|
||||||
|
return [
|
||||||
|
replace(
|
||||||
|
candidate,
|
||||||
|
normalized_score=(float(candidate.reranker_score) - min_score) / score_range,
|
||||||
|
)
|
||||||
|
for candidate in candidates
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _pair_diversity_penalty(
|
||||||
|
candidate: RerankCandidate,
|
||||||
|
selected: RerankCandidate,
|
||||||
|
token_overlap_weight: float,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Pairwise diversity penalty between two candidate chunks.
|
||||||
|
|
||||||
|
The first revision only uses token overlap because the current Document-RAG
|
||||||
|
reranker document_id is the candidate index, not a source document id.
|
||||||
|
"""
|
||||||
|
penalty = token_overlap_weight * _jaccard(candidate.text, selected.text)
|
||||||
|
return _clamp01(penalty)
|
||||||
|
|
||||||
|
|
||||||
|
def mmr_select(
|
||||||
|
candidates: Sequence[RerankCandidate],
|
||||||
|
limit: int,
|
||||||
|
lambda_mult: float = 0.7,
|
||||||
|
token_overlap_weight: float = 1.0,
|
||||||
|
) -> List[RerankCandidate]:
|
||||||
|
"""
|
||||||
|
Select a diverse final context set using MMR.
|
||||||
|
|
||||||
|
Relevance comes from normalized cross-encoder reranker scores.
|
||||||
|
Diversity comes from token overlap against already selected chunks.
|
||||||
|
"""
|
||||||
|
if limit <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
lambda_mult = _clamp01(lambda_mult)
|
||||||
|
token_overlap_weight = max(0.0, token_overlap_weight)
|
||||||
|
|
||||||
|
remaining = normalize_candidate_scores(candidates)
|
||||||
|
selected: List[RerankCandidate] = []
|
||||||
|
|
||||||
|
while remaining and len(selected) < limit:
|
||||||
|
best_idx = 0
|
||||||
|
best_score = None
|
||||||
|
|
||||||
|
for idx, candidate in enumerate(remaining):
|
||||||
|
relevance = candidate.normalized_score
|
||||||
|
|
||||||
|
if selected:
|
||||||
|
diversity_penalty = max(
|
||||||
|
_pair_diversity_penalty(
|
||||||
|
candidate,
|
||||||
|
chosen,
|
||||||
|
token_overlap_weight=token_overlap_weight,
|
||||||
|
)
|
||||||
|
for chosen in selected
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
diversity_penalty = 0.0
|
||||||
|
|
||||||
|
mmr_score = (
|
||||||
|
lambda_mult * relevance
|
||||||
|
- (1.0 - lambda_mult) * diversity_penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
if best_score is None or mmr_score > best_score:
|
||||||
|
best_score = mmr_score
|
||||||
|
best_idx = idx
|
||||||
|
|
||||||
|
selected.append(remaining.pop(best_idx))
|
||||||
|
|
||||||
|
return selected
|
||||||
Loading…
Add table
Add a link
Reference in a new issue