feat: add cross-encoder reranking to Document-RAG with two-limit control (#878) (#1011)

Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after
vector retrieval, over-fetch a wider candidate pool, rerank with the
cross-encoder, and keep the top doc_limit chunks for synthesis.

Per maintainer review, the fetch and select sizes are two caller-controlled
limits rather than one internal heuristic:

- doc_limit:   chunks selected into the synthesis prompt (unchanged meaning).
- fetch_limit: candidate pool pulled from the vector store before reranking.
  0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are
  raised to it. Lets the caller control how hard the reranker has to work.

Details:
- schema: DocumentRagQuery.fetch_limit (additive, backward compatible).
- document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors
  doc_limit); the core applies the heuristic default and derives synthesis
  provenance from the chunk-selection focus when reranking ran.
- provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection).
- request translator + client SDKs + CLI: fetch-limit / --fetch-limit,
  threaded exactly like doc_limit and the GraphRAG limits.
- tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic
  default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring.

Reranking is skipped byte-identically when no reranker role is wired.
Requires the companion trustgraph-templates change wiring the reranker
topics into the document-rag flow (mirrors #279 for GraphRAG).
This commit is contained in:
Sunny 2026-07-02 02:50:13 -06:00 committed by GitHub
parent f18d48dc39
commit 6c9a545a06
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 853 additions and 26 deletions

View file

@ -9,10 +9,12 @@ from trustgraph.provenance import (
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_focus_uri,
docrag_synthesis_uri,
docrag_question_triples,
grounding_triples,
docrag_exploration_triples,
docrag_chunk_selection_triples,
docrag_synthesis_triples,
set_graph,
GRAPH_RETRIEVAL,
@ -21,19 +23,25 @@ from trustgraph.provenance import (
# Module logger
logger = logging.getLogger(__name__)
# When the caller does not specify a fetch_limit, reranking over-fetches this
# many times the final doc_limit as the candidate pool, so the cross-encoder can
# recover relevant chunks the bi-encoder ranked just outside the top doc_limit.
# This is only the fallback default: an explicit fetch_limit overrides it.
OVERFETCH_FACTOR = 3
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query:
def __init__(
self, rag, workspace, collection, verbose,
doc_limit=20, track_usage=None,
fetch_limit=20, track_usage=None,
):
self.rag = rag
self.workspace = workspace
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
self.fetch_limit = fetch_limit
self.track_usage = track_usage
async def extract_concepts(self, query):
@ -91,7 +99,7 @@ class Query:
# Query chunk matches for each concept concurrently
per_concept_limit = max(
1, self.doc_limit // len(vectors)
1, self.fetch_limit // len(vectors)
)
async def query_concept(vec):
@ -140,6 +148,7 @@ class DocumentRag:
def __init__(
self, prompt_client, embeddings_client, doc_embeddings_client,
fetch_chunk,
reranker_client=None,
verbose=False,
):
@ -150,12 +159,16 @@ class DocumentRag:
self.doc_embeddings_client = doc_embeddings_client
self.fetch_chunk = fetch_chunk
# Optional cross-encoder reranker. When None, the retrieval path is
# byte-identical to the pre-reranker behaviour.
self.reranker_client = reranker_client
if self.verbose:
logger.debug("DocumentRag initialized")
async def query(
self, query, workspace="default", collection="default",
doc_limit=20, streaming=False, chunk_callback=None,
doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None,
explain_callback=None, save_answer_callback=None,
):
"""
@ -165,7 +178,10 @@ class DocumentRag:
query: The query string
workspace: Workspace for isolation (also scopes chunk lookup)
collection: Collection identifier
doc_limit: Max chunks to retrieve
doc_limit: Chunks selected into the synthesis prompt (after rerank)
fetch_limit: Candidate pool fetched from the vector store before
reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a
reranker is wired, else doc_limit).
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for explainability
@ -197,6 +213,7 @@ class DocumentRag:
q_uri = docrag_question_uri(session_id)
gnd_uri = docrag_grounding_uri(session_id)
exp_uri = docrag_exploration_uri(session_id)
foc_uri = docrag_focus_uri(session_id)
syn_uri = docrag_synthesis_uri(session_id)
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
@ -209,10 +226,21 @@ class DocumentRag:
)
await explain_callback(q_triples, q_uri)
# Resolve the candidate-pool size fetched from the vector store. When a
# reranker is wired, honour an explicit fetch_limit; if unset, fall back
# to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit,
# else the rerank could not fill the prompt. Without a reranker, fetch
# doc_limit as before (byte-identical behaviour).
if self.reranker_client is not None:
fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit)
fetch_count = max(fl, doc_limit)
else:
fetch_count = doc_limit
q = Query(
rag=self, workspace=workspace, collection=collection,
verbose=self.verbose,
doc_limit=doc_limit, track_usage=track_usage,
fetch_limit=fetch_count, track_usage=track_usage,
)
# Extract concepts from query (grounding step)
@ -235,6 +263,7 @@ class DocumentRag:
docs, chunk_ids = await q.get_docs(concepts)
# Emit exploration explainability after chunks retrieved
# (full candidate set, before any reranking)
if explain_callback:
exp_triples = set_graph(
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
@ -242,6 +271,45 @@ class DocumentRag:
)
await explain_callback(exp_triples, exp_uri)
# Optional cross-encoder reranking pass between retrieval and
# synthesis. Mirrors GraphRAG's reranker usage but with a single
# query (the question). When no reranker is wired, this block is
# skipped entirely and behaviour is byte-identical to before.
reranked = False
if self.reranker_client is not None and docs:
results = await self.reranker_client.rerank(
queries=[{"id": "0", "text": query}],
documents=[
{"id": str(i), "text": d} for i, d in enumerate(docs)
],
# Narrow the over-fetched candidate pool down to the final
# doc_limit requested for synthesis.
limit=doc_limit,
)
# results are sorted desc by score and truncated to limit by the
# reranker service, so order gives the surviving top-N directly.
order = [int(r.document_id) for r in results]
docs = [docs[i] for i in order]
chunk_ids = [chunk_ids[i] for i in order]
reranked = True
# Emit chunk-selection (focus) explainability: surviving chunks
# with their cross-encoder scores, derived from exploration.
if explain_callback:
selected_chunks_with_scores = [
{"chunk_id": chunk_ids[i], "score": r.score}
for i, r in enumerate(results)
]
foc_triples = set_graph(
docrag_chunk_selection_triples(
foc_uri, exp_uri,
selected_chunks_with_scores, session_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(foc_triples, foc_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Documents: {docs}")
@ -291,9 +359,15 @@ class DocumentRag:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None
# When reranking ran, synthesis derives from the focus (the
# reranked chunks actually fed to the LLM), as GraphRAG always does.
# When no reranker is wired, there is no focus stage, so synthesis
# derives from exploration (the unchanged no-op lineage) - a
# deliberate divergence from GraphRAG's always-on focus.
syn_parent = foc_uri if reranked else exp_uri
syn_triples = set_graph(
docrag_synthesis_triples(
syn_uri, exp_uri,
syn_uri, syn_parent,
document_id=synthesis_doc_id,
in_token=synthesis_result.in_token if synthesis_result else None,
out_token=synthesis_result.out_token if synthesis_result else None,

View file

@ -13,6 +13,7 @@ from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec
from ... base import RerankerClientSpec
from ... base import LibrarianSpec
# Module logger
@ -28,14 +29,21 @@ class Processor(FlowProcessor):
doc_limit = params.get("doc_limit", 5)
# Instance-default candidate-pool size fetched before cross-encoder
# reranking; the rerank step narrows it back down to doc_limit for the
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
fetch_limit = params.get("fetch_limit", 0)
super(Processor, self).__init__(
**params | {
"id": id,
"doc_limit": doc_limit,
"fetch_limit": fetch_limit,
}
)
self.doc_limit = doc_limit
self.fetch_limit = fetch_limit
self.register_specification(
ConsumerSpec(
@ -66,6 +74,13 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
RerankerClientSpec(
request_name = "reranker-request",
response_name = "reranker-response",
)
)
self.register_specification(
ProducerSpec(
name = "response",
@ -105,6 +120,7 @@ class Processor(FlowProcessor):
doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"),
fetch_chunk = fetch_chunk,
reranker_client = flow("reranker-request"),
verbose=True,
)
@ -113,6 +129,13 @@ class Processor(FlowProcessor):
else:
doc_limit = self.doc_limit
# Candidate-pool size: per-request override, else the instance
# default; 0 lets the core derive it from doc_limit.
if v.fetch_limit:
fetch_limit = v.fetch_limit
else:
fetch_limit = self.fetch_limit
async def send_explainability(triples, explain_id):
await flow("explainability").send(Triples(
metadata=Metadata(
@ -163,6 +186,7 @@ class Processor(FlowProcessor):
workspace=flow.workspace,
collection=v.collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
streaming=True,
chunk_callback=send_chunk,
explain_callback=send_explainability,
@ -188,6 +212,7 @@ class Processor(FlowProcessor):
workspace=flow.workspace,
collection=v.collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
explain_callback=send_explainability,
save_answer_callback=save_answer,
)
@ -243,6 +268,15 @@ class Processor(FlowProcessor):
help=f'Default document fetch limit (default: 10)'
)
parser.add_argument(
'--fetch-limit',
type=int,
default=0,
help='Candidate chunks to fetch from the vector store and rerank '
'before keeping the top doc-limit for the LLM '
'(default: derive from doc-limit)'
)
def run():
Processor.launch(default_ident, __doc__)