feat: add cross-encoder reranking to Document-RAG with two-limit control (#878) (#1011)

Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after
vector retrieval, over-fetch a wider candidate pool, rerank with the
cross-encoder, and keep the top doc_limit chunks for synthesis.

Per maintainer review, the fetch and select sizes are two caller-controlled
limits rather than one internal heuristic:

- doc_limit:   chunks selected into the synthesis prompt (unchanged meaning).
- fetch_limit: candidate pool pulled from the vector store before reranking.
  0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are
  raised to it. Lets the caller control how hard the reranker has to work.

Details:
- schema: DocumentRagQuery.fetch_limit (additive, backward compatible).
- document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors
  doc_limit); the core applies the heuristic default and derives synthesis
  provenance from the chunk-selection focus when reranking ran.
- provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection).
- request translator + client SDKs + CLI: fetch-limit / --fetch-limit,
  threaded exactly like doc_limit and the GraphRAG limits.
- tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic
  default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring.

Reranking is skipped byte-identically when no reranker role is wired.
Requires the companion trustgraph-templates change wiring the reranker
topics into the document-rag flow (mirrors #279 for GraphRAG).
This commit is contained in:
Sunny 2026-07-02 02:50:13 -06:00 committed by GitHub
parent f18d48dc39
commit 6c9a545a06
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 853 additions and 26 deletions

View file

@ -527,7 +527,8 @@ class AsyncFlowInstance:
return result.get("response", "")
async def document_rag(self, query: str, collection: str,
doc_limit: int = 10, **kwargs: Any) -> str:
doc_limit: int = 10, fetch_limit: int = 0,
**kwargs: Any) -> str:
"""
Execute document-based RAG query (non-streaming).
@ -541,7 +542,9 @@ class AsyncFlowInstance:
Args:
query: User query text
collection: Collection identifier containing documents
doc_limit: Maximum number of document chunks to retrieve (default: 10)
doc_limit: Document chunks selected into the prompt (default: 10)
fetch_limit: Candidate chunks fetched from the vector store before
reranking (default: 0 = derive from doc_limit)
**kwargs: Additional service-specific parameters
Returns:
@ -564,6 +567,7 @@ class AsyncFlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": False
}
request_data.update(kwargs)

View file

@ -379,12 +379,14 @@ class AsyncSocketFlowInstance:
yield chunk.content
async def document_rag(self, query: str, collection: str,
doc_limit: int = 10, streaming: bool = False, **kwargs):
doc_limit: int = 10, fetch_limit: int = 0,
streaming: bool = False, **kwargs):
"""Document RAG with optional streaming"""
request = {
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": streaming
}
request.update(kwargs)

View file

@ -415,7 +415,7 @@ class FlowInstance:
def document_rag(
self, query,collection="default",
doc_limit=10,
doc_limit=10, fetch_limit=0,
):
"""
Execute document-based Retrieval-Augmented Generation (RAG) query.
@ -426,7 +426,9 @@ class FlowInstance:
Args:
query: Natural language query
collection: Collection identifier (default: "default")
doc_limit: Maximum document chunks to retrieve (default: 10)
doc_limit: Document chunks selected into the prompt (default: 10)
fetch_limit: Candidate chunks fetched from the vector store before
reranking (default: 0 = derive from doc_limit)
Returns:
str: Generated response incorporating document context
@ -447,6 +449,7 @@ class FlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
}
result = self.request(

View file

@ -752,6 +752,7 @@ class SocketFlowInstance:
query: str,
collection: str,
doc_limit: int = 10,
fetch_limit: int = 0,
streaming: bool = False,
**kwargs: Any
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
@ -764,6 +765,7 @@ class SocketFlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": streaming
}
request.update(kwargs)
@ -785,6 +787,7 @@ class SocketFlowInstance:
query: str,
collection: str,
doc_limit: int = 10,
fetch_limit: int = 0,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""Execute document-based RAG query with explainability support."""
@ -792,6 +795,7 @@ class SocketFlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": True,
"explainable": True,
}

View file

@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
query=data["query"],
collection=data.get("collection", "default"),
doc_limit=int(data.get("doc-limit", 20)),
fetch_limit=int(data.get("fetch-limit", 0)),
streaming=data.get("streaming", False)
)
@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
"query": obj.query,
"collection": obj.collection,
"doc-limit": obj.doc_limit,
"fetch-limit": obj.fetch_limit,
"streaming": getattr(obj, "streaming", False)
}

View file

@ -64,6 +64,8 @@ from . uris import (
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_focus_uri,
chunk_selection_uri,
docrag_synthesis_uri,
)
@ -94,6 +96,8 @@ from . namespaces import (
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Chunk selection entity type
TG_CHUNK_SELECTION,
# Explainability entity types
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
@ -132,6 +136,7 @@ from . triples import (
# Query-time provenance triple builders (DocumentRAG)
docrag_question_triples,
docrag_exploration_triples,
docrag_chunk_selection_triples,
docrag_synthesis_triples,
# Utility
set_graph,
@ -196,6 +201,8 @@ __all__ = [
"docrag_question_uri",
"docrag_grounding_uri",
"docrag_exploration_uri",
"docrag_focus_uri",
"chunk_selection_uri",
"docrag_synthesis_uri",
# Namespaces
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
@ -219,6 +226,8 @@ __all__ = [
"TG_EDGE_SELECTION",
# Query-time provenance predicates (DocumentRAG)
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
# Chunk selection entity type
"TG_CHUNK_SELECTION",
# Explainability entity types
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
"TG_ANALYSIS", "TG_CONCLUSION",
@ -254,6 +263,7 @@ __all__ = [
# Query-time provenance triple builders (DocumentRAG)
"docrag_question_triples",
"docrag_exploration_triples",
"docrag_chunk_selection_triples",
"docrag_synthesis_triples",
# Agent provenance triple builders
"agent_session_triples",

View file

@ -76,6 +76,9 @@ TG_EDGE_SELECTION = TG + "EdgeSelection"
TG_CHUNK_COUNT = TG + "chunkCount"
TG_SELECTED_CHUNK = TG + "selectedChunk"
# Chunk selection entity type (cross-encoder reranked chunk in Focus)
TG_CHUNK_SELECTION = TG + "ChunkSelection"
# Extraction provenance entity types
TG_DOCUMENT_TYPE = TG + "Document"
TG_PAGE_TYPE = TG + "Page"

View file

@ -30,6 +30,8 @@ from . namespaces import (
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Chunk selection entity type
TG_CHUNK_SELECTION,
# Explainability entity types
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
# Unifying types
@ -40,7 +42,10 @@ from . namespaces import (
TG_IN_TOKEN, TG_OUT_TOKEN,
)
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
from . uris import (
activity_uri, agent_uri, subgraph_uri, edge_selection_uri,
chunk_selection_uri,
)
def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
@ -718,6 +723,75 @@ def docrag_exploration_triples(
return triples
def docrag_chunk_selection_triples(
focus_uri: str,
exploration_uri: str,
selected_chunks_with_scores: List[dict],
session_id: str,
) -> List[Triple]:
"""
Build triples for a document RAG focus entity (chunks selected by the
cross-encoder reranker).
Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity
derived from exploration, with one ChunkSelection sub-entity per surviving
chunk carrying the chunk reference and the reranker score.
Structure:
<focus> a tg:Focus ; prov:wasDerivedFrom <exploration> .
<focus> tg:selectedChunk <chunk_sel_0> .
<chunk_sel_0> a tg:ChunkSelection .
<chunk_sel_0> tg:document <chunk_id> .
<chunk_sel_0> tg:score "0.97" .
Args:
focus_uri: URI of the focus entity (from docrag_focus_uri)
exploration_uri: URI of the parent exploration entity
selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score'
session_id: Session UUID for generating chunk selection URIs
Returns:
List of Triple objects
"""
triples = [
_triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)),
_triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")),
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
]
for idx, chunk_info in enumerate(selected_chunks_with_scores):
chunk_id = chunk_info.get("chunk_id")
if not chunk_id:
continue
chunk_sel_uri = chunk_selection_uri(session_id, idx)
# Link focus to chunk selection entity
triples.append(
_triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri))
)
# Type the chunk selection entity
triples.append(
_triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION))
)
# Reference the actual chunk (in librarian)
triples.append(
_triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id))
)
# Cross-encoder score
score = chunk_info.get("score")
if score is not None:
triples.append(
_triple(chunk_sel_uri, TG_SCORE, _literal(str(score)))
)
return triples
def docrag_synthesis_triples(
synthesis_uri: str,
exploration_uri: str,

View file

@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str:
return f"urn:trustgraph:docrag:{session_id}/exploration"
def docrag_focus_uri(session_id: str) -> str:
"""
Generate URI for a document RAG focus entity (chunks selected by the
cross-encoder reranker).
Args:
session_id: The session UUID.
Returns:
URN in format: urn:trustgraph:docrag:{uuid}/focus
"""
return f"urn:trustgraph:docrag:{session_id}/focus"
def chunk_selection_uri(session_id: str, chunk_index: int) -> str:
"""
Generate URI for a chunk selection item (links a reranked chunk to its
score). Mirrors edge_selection_uri for GraphRAG.
Args:
session_id: The session UUID.
chunk_index: Index of this chunk in the selection (0-based).
Returns:
URN in format: urn:trustgraph:prov:chunk:{uuid}:{index}
"""
return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}"
def docrag_synthesis_uri(session_id: str) -> str:
"""
Generate URI for a document RAG synthesis entity (final answer).

View file

@ -30,6 +30,7 @@ from . namespaces import (
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_EDGE_SELECTION, TG_SCORE,
TG_CHUNK_SELECTION,
)
@ -95,6 +96,7 @@ TG_CLASS_LABELS = [
_label_triple(TG_PLAN_TYPE, "Plan"),
_label_triple(TG_STEP_RESULT, "Step Result"),
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
_label_triple(TG_CHUNK_SELECTION, "Chunk Selection"),
]
# TrustGraph predicate labels

View file

@ -40,7 +40,10 @@ class GraphRagResponse:
class DocumentRagQuery:
query: str = ""
collection: str = ""
doc_limit: int = 0
doc_limit: int = 0 # docs selected into the synthesis prompt
fetch_limit: int = 0 # candidate pool fetched from the vector store
# before reranking (0 = derive from doc_limit;
# values below doc_limit are raised to it)
streaming: bool = False
@dataclass