mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 06:51:00 +02:00
Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after vector retrieval, over-fetch a wider candidate pool, rerank with the cross-encoder, and keep the top doc_limit chunks for synthesis. Per maintainer review, the fetch and select sizes are two caller-controlled limits rather than one internal heuristic: - doc_limit: chunks selected into the synthesis prompt (unchanged meaning). - fetch_limit: candidate pool pulled from the vector store before reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are raised to it. Lets the caller control how hard the reranker has to work. Details: - schema: DocumentRagQuery.fetch_limit (additive, backward compatible). - document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors doc_limit); the core applies the heuristic default and derives synthesis provenance from the chunk-selection focus when reranking ran. - provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection). - request translator + client SDKs + CLI: fetch-limit / --fetch-limit, threaded exactly like doc_limit and the GraphRAG limits. - tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring. Reranking is skipped byte-identically when no reranker role is wired. Requires the companion trustgraph-templates change wiring the reranker topics into the document-rag flow (mirrors #279 for GraphRAG).
This commit is contained in:
parent
f18d48dc39
commit
6c9a545a06
18 changed files with 853 additions and 26 deletions
|
|
@ -9,10 +9,12 @@ from trustgraph.provenance import (
|
|||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_focus_uri,
|
||||
docrag_synthesis_uri,
|
||||
docrag_question_triples,
|
||||
grounding_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_chunk_selection_triples,
|
||||
docrag_synthesis_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
|
|
@ -21,19 +23,25 @@ from trustgraph.provenance import (
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# When the caller does not specify a fetch_limit, reranking over-fetches this
|
||||
# many times the final doc_limit as the candidate pool, so the cross-encoder can
|
||||
# recover relevant chunks the bi-encoder ranked just outside the top doc_limit.
|
||||
# This is only the fallback default: an explicit fetch_limit overrides it.
|
||||
OVERFETCH_FACTOR = 3
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
self, rag, workspace, collection, verbose,
|
||||
doc_limit=20, track_usage=None,
|
||||
fetch_limit=20, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.workspace = workspace
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
self.fetch_limit = fetch_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
|
|
@ -91,7 +99,7 @@ class Query:
|
|||
|
||||
# Query chunk matches for each concept concurrently
|
||||
per_concept_limit = max(
|
||||
1, self.doc_limit // len(vectors)
|
||||
1, self.fetch_limit // len(vectors)
|
||||
)
|
||||
|
||||
async def query_concept(vec):
|
||||
|
|
@ -140,6 +148,7 @@ class DocumentRag:
|
|||
def __init__(
|
||||
self, prompt_client, embeddings_client, doc_embeddings_client,
|
||||
fetch_chunk,
|
||||
reranker_client=None,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
|
|
@ -150,12 +159,16 @@ class DocumentRag:
|
|||
self.doc_embeddings_client = doc_embeddings_client
|
||||
self.fetch_chunk = fetch_chunk
|
||||
|
||||
# Optional cross-encoder reranker. When None, the retrieval path is
|
||||
# byte-identical to the pre-reranker behaviour.
|
||||
self.reranker_client = reranker_client
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("DocumentRag initialized")
|
||||
|
||||
async def query(
|
||||
self, query, workspace="default", collection="default",
|
||||
doc_limit=20, streaming=False, chunk_callback=None,
|
||||
doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None,
|
||||
explain_callback=None, save_answer_callback=None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -165,7 +178,10 @@ class DocumentRag:
|
|||
query: The query string
|
||||
workspace: Workspace for isolation (also scopes chunk lookup)
|
||||
collection: Collection identifier
|
||||
doc_limit: Max chunks to retrieve
|
||||
doc_limit: Chunks selected into the synthesis prompt (after rerank)
|
||||
fetch_limit: Candidate pool fetched from the vector store before
|
||||
reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a
|
||||
reranker is wired, else doc_limit).
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for explainability
|
||||
|
|
@ -197,6 +213,7 @@ class DocumentRag:
|
|||
q_uri = docrag_question_uri(session_id)
|
||||
gnd_uri = docrag_grounding_uri(session_id)
|
||||
exp_uri = docrag_exploration_uri(session_id)
|
||||
foc_uri = docrag_focus_uri(session_id)
|
||||
syn_uri = docrag_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
|
@ -209,10 +226,21 @@ class DocumentRag:
|
|||
)
|
||||
await explain_callback(q_triples, q_uri)
|
||||
|
||||
# Resolve the candidate-pool size fetched from the vector store. When a
|
||||
# reranker is wired, honour an explicit fetch_limit; if unset, fall back
|
||||
# to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit,
|
||||
# else the rerank could not fill the prompt. Without a reranker, fetch
|
||||
# doc_limit as before (byte-identical behaviour).
|
||||
if self.reranker_client is not None:
|
||||
fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit)
|
||||
fetch_count = max(fl, doc_limit)
|
||||
else:
|
||||
fetch_count = doc_limit
|
||||
|
||||
q = Query(
|
||||
rag=self, workspace=workspace, collection=collection,
|
||||
verbose=self.verbose,
|
||||
doc_limit=doc_limit, track_usage=track_usage,
|
||||
fetch_limit=fetch_count, track_usage=track_usage,
|
||||
)
|
||||
|
||||
# Extract concepts from query (grounding step)
|
||||
|
|
@ -235,6 +263,7 @@ class DocumentRag:
|
|||
docs, chunk_ids = await q.get_docs(concepts)
|
||||
|
||||
# Emit exploration explainability after chunks retrieved
|
||||
# (full candidate set, before any reranking)
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
|
||||
|
|
@ -242,6 +271,45 @@ class DocumentRag:
|
|||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
# Optional cross-encoder reranking pass between retrieval and
|
||||
# synthesis. Mirrors GraphRAG's reranker usage but with a single
|
||||
# query (the question). When no reranker is wired, this block is
|
||||
# skipped entirely and behaviour is byte-identical to before.
|
||||
reranked = False
|
||||
if self.reranker_client is not None and docs:
|
||||
results = await self.reranker_client.rerank(
|
||||
queries=[{"id": "0", "text": query}],
|
||||
documents=[
|
||||
{"id": str(i), "text": d} for i, d in enumerate(docs)
|
||||
],
|
||||
# Narrow the over-fetched candidate pool down to the final
|
||||
# doc_limit requested for synthesis.
|
||||
limit=doc_limit,
|
||||
)
|
||||
|
||||
# results are sorted desc by score and truncated to limit by the
|
||||
# reranker service, so order gives the surviving top-N directly.
|
||||
order = [int(r.document_id) for r in results]
|
||||
docs = [docs[i] for i in order]
|
||||
chunk_ids = [chunk_ids[i] for i in order]
|
||||
reranked = True
|
||||
|
||||
# Emit chunk-selection (focus) explainability: surviving chunks
|
||||
# with their cross-encoder scores, derived from exploration.
|
||||
if explain_callback:
|
||||
selected_chunks_with_scores = [
|
||||
{"chunk_id": chunk_ids[i], "score": r.score}
|
||||
for i, r in enumerate(results)
|
||||
]
|
||||
foc_triples = set_graph(
|
||||
docrag_chunk_selection_triples(
|
||||
foc_uri, exp_uri,
|
||||
selected_chunks_with_scores, session_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(foc_triples, foc_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Documents: {docs}")
|
||||
|
|
@ -291,9 +359,15 @@ class DocumentRag:
|
|||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None
|
||||
|
||||
# When reranking ran, synthesis derives from the focus (the
|
||||
# reranked chunks actually fed to the LLM), as GraphRAG always does.
|
||||
# When no reranker is wired, there is no focus stage, so synthesis
|
||||
# derives from exploration (the unchanged no-op lineage) - a
|
||||
# deliberate divergence from GraphRAG's always-on focus.
|
||||
syn_parent = foc_uri if reranked else exp_uri
|
||||
syn_triples = set_graph(
|
||||
docrag_synthesis_triples(
|
||||
syn_uri, exp_uri,
|
||||
syn_uri, syn_parent,
|
||||
document_id=synthesis_doc_id,
|
||||
in_token=synthesis_result.in_token if synthesis_result else None,
|
||||
out_token=synthesis_result.out_token if synthesis_result else None,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from . document_rag import DocumentRag
|
|||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import DocumentEmbeddingsClientSpec
|
||||
from ... base import RerankerClientSpec
|
||||
from ... base import LibrarianSpec
|
||||
|
||||
# Module logger
|
||||
|
|
@ -28,14 +29,21 @@ class Processor(FlowProcessor):
|
|||
|
||||
doc_limit = params.get("doc_limit", 5)
|
||||
|
||||
# Instance-default candidate-pool size fetched before cross-encoder
|
||||
# reranking; the rerank step narrows it back down to doc_limit for the
|
||||
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
|
||||
fetch_limit = params.get("fetch_limit", 0)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"doc_limit": doc_limit,
|
||||
"fetch_limit": fetch_limit,
|
||||
}
|
||||
)
|
||||
|
||||
self.doc_limit = doc_limit
|
||||
self.fetch_limit = fetch_limit
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
|
|
@ -66,6 +74,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RerankerClientSpec(
|
||||
request_name = "reranker-request",
|
||||
response_name = "reranker-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
|
|
@ -105,6 +120,7 @@ class Processor(FlowProcessor):
|
|||
doc_embeddings_client = flow("document-embeddings-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
fetch_chunk = fetch_chunk,
|
||||
reranker_client = flow("reranker-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -113,6 +129,13 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
# Candidate-pool size: per-request override, else the instance
|
||||
# default; 0 lets the core derive it from doc_limit.
|
||||
if v.fetch_limit:
|
||||
fetch_limit = v.fetch_limit
|
||||
else:
|
||||
fetch_limit = self.fetch_limit
|
||||
|
||||
async def send_explainability(triples, explain_id):
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
|
|
@ -163,6 +186,7 @@ class Processor(FlowProcessor):
|
|||
workspace=flow.workspace,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
streaming=True,
|
||||
chunk_callback=send_chunk,
|
||||
explain_callback=send_explainability,
|
||||
|
|
@ -188,6 +212,7 @@ class Processor(FlowProcessor):
|
|||
workspace=flow.workspace,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
explain_callback=send_explainability,
|
||||
save_answer_callback=save_answer,
|
||||
)
|
||||
|
|
@ -243,6 +268,15 @@ class Processor(FlowProcessor):
|
|||
help=f'Default document fetch limit (default: 10)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--fetch-limit',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Candidate chunks to fetch from the vector store and rerank '
|
||||
'before keeping the top doc-limit for the LLM '
|
||||
'(default: derive from doc-limit)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue