mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 06:51:00 +02:00
Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after vector retrieval, over-fetch a wider candidate pool, rerank with the cross-encoder, and keep the top doc_limit chunks for synthesis. Per maintainer review, the fetch and select sizes are two caller-controlled limits rather than one internal heuristic: - doc_limit: chunks selected into the synthesis prompt (unchanged meaning). - fetch_limit: candidate pool pulled from the vector store before reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are raised to it. Lets the caller control how hard the reranker has to work. Details: - schema: DocumentRagQuery.fetch_limit (additive, backward compatible). - document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors doc_limit); the core applies the heuristic default and derives synthesis provenance from the chunk-selection focus when reranking ran. - provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection). - request translator + client SDKs + CLI: fetch-limit / --fetch-limit, threaded exactly like doc_limit and the GraphRAG limits. - tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring. Reranking is skipped byte-identically when no reranker role is wired. Requires the companion trustgraph-templates change wiring the reranker topics into the document-rag flow (mirrors #279 for GraphRAG).
This commit is contained in:
parent
f18d48dc39
commit
6c9a545a06
18 changed files with 853 additions and 26 deletions
|
|
@ -527,7 +527,8 @@ class AsyncFlowInstance:
|
|||
return result.get("response", "")
|
||||
|
||||
async def document_rag(self, query: str, collection: str,
|
||||
doc_limit: int = 10, **kwargs: Any) -> str:
|
||||
doc_limit: int = 10, fetch_limit: int = 0,
|
||||
**kwargs: Any) -> str:
|
||||
"""
|
||||
Execute document-based RAG query (non-streaming).
|
||||
|
||||
|
|
@ -541,7 +542,9 @@ class AsyncFlowInstance:
|
|||
Args:
|
||||
query: User query text
|
||||
collection: Collection identifier containing documents
|
||||
doc_limit: Maximum number of document chunks to retrieve (default: 10)
|
||||
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||
fetch_limit: Candidate chunks fetched from the vector store before
|
||||
reranking (default: 0 = derive from doc_limit)
|
||||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
|
|
@ -564,6 +567,7 @@ class AsyncFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": False
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
|
|
|||
|
|
@ -379,12 +379,14 @@ class AsyncSocketFlowInstance:
|
|||
yield chunk.content
|
||||
|
||||
async def document_rag(self, query: str, collection: str,
|
||||
doc_limit: int = 10, streaming: bool = False, **kwargs):
|
||||
doc_limit: int = 10, fetch_limit: int = 0,
|
||||
streaming: bool = False, **kwargs):
|
||||
"""Document RAG with optional streaming"""
|
||||
request = {
|
||||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
|
|||
|
|
@ -415,7 +415,7 @@ class FlowInstance:
|
|||
|
||||
def document_rag(
|
||||
self, query,collection="default",
|
||||
doc_limit=10,
|
||||
doc_limit=10, fetch_limit=0,
|
||||
):
|
||||
"""
|
||||
Execute document-based Retrieval-Augmented Generation (RAG) query.
|
||||
|
|
@ -426,7 +426,9 @@ class FlowInstance:
|
|||
Args:
|
||||
query: Natural language query
|
||||
collection: Collection identifier (default: "default")
|
||||
doc_limit: Maximum document chunks to retrieve (default: 10)
|
||||
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||
fetch_limit: Candidate chunks fetched from the vector store before
|
||||
reranking (default: 0 = derive from doc_limit)
|
||||
|
||||
Returns:
|
||||
str: Generated response incorporating document context
|
||||
|
|
@ -447,6 +449,7 @@ class FlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
}
|
||||
|
||||
result = self.request(
|
||||
|
|
|
|||
|
|
@ -752,6 +752,7 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
fetch_limit: int = 0,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
|
|
@ -764,6 +765,7 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
@ -785,6 +787,7 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
fetch_limit: int = 0,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||
"""Execute document-based RAG query with explainability support."""
|
||||
|
|
@ -792,6 +795,7 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": True,
|
||||
"explainable": True,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
|||
query=data["query"],
|
||||
collection=data.get("collection", "default"),
|
||||
doc_limit=int(data.get("doc-limit", 20)),
|
||||
fetch_limit=int(data.get("fetch-limit", 0)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
|||
"query": obj.query,
|
||||
"collection": obj.collection,
|
||||
"doc-limit": obj.doc_limit,
|
||||
"fetch-limit": obj.fetch_limit,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ from . uris import (
|
|||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_focus_uri,
|
||||
chunk_selection_uri,
|
||||
docrag_synthesis_uri,
|
||||
)
|
||||
|
||||
|
|
@ -94,6 +96,8 @@ from . namespaces import (
|
|||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Chunk selection entity type
|
||||
TG_CHUNK_SELECTION,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
|
|
@ -132,6 +136,7 @@ from . triples import (
|
|||
# Query-time provenance triple builders (DocumentRAG)
|
||||
docrag_question_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_chunk_selection_triples,
|
||||
docrag_synthesis_triples,
|
||||
# Utility
|
||||
set_graph,
|
||||
|
|
@ -196,6 +201,8 @@ __all__ = [
|
|||
"docrag_question_uri",
|
||||
"docrag_grounding_uri",
|
||||
"docrag_exploration_uri",
|
||||
"docrag_focus_uri",
|
||||
"chunk_selection_uri",
|
||||
"docrag_synthesis_uri",
|
||||
# Namespaces
|
||||
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
|
||||
|
|
@ -219,6 +226,8 @@ __all__ = [
|
|||
"TG_EDGE_SELECTION",
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
|
||||
# Chunk selection entity type
|
||||
"TG_CHUNK_SELECTION",
|
||||
# Explainability entity types
|
||||
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
|
||||
"TG_ANALYSIS", "TG_CONCLUSION",
|
||||
|
|
@ -254,6 +263,7 @@ __all__ = [
|
|||
# Query-time provenance triple builders (DocumentRAG)
|
||||
"docrag_question_triples",
|
||||
"docrag_exploration_triples",
|
||||
"docrag_chunk_selection_triples",
|
||||
"docrag_synthesis_triples",
|
||||
# Agent provenance triple builders
|
||||
"agent_session_triples",
|
||||
|
|
|
|||
|
|
@ -76,6 +76,9 @@ TG_EDGE_SELECTION = TG + "EdgeSelection"
|
|||
TG_CHUNK_COUNT = TG + "chunkCount"
|
||||
TG_SELECTED_CHUNK = TG + "selectedChunk"
|
||||
|
||||
# Chunk selection entity type (cross-encoder reranked chunk in Focus)
|
||||
TG_CHUNK_SELECTION = TG + "ChunkSelection"
|
||||
|
||||
# Extraction provenance entity types
|
||||
TG_DOCUMENT_TYPE = TG + "Document"
|
||||
TG_PAGE_TYPE = TG + "Page"
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ from . namespaces import (
|
|||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Chunk selection entity type
|
||||
TG_CHUNK_SELECTION,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
# Unifying types
|
||||
|
|
@ -40,7 +42,10 @@ from . namespaces import (
|
|||
TG_IN_TOKEN, TG_OUT_TOKEN,
|
||||
)
|
||||
|
||||
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
|
||||
from . uris import (
|
||||
activity_uri, agent_uri, subgraph_uri, edge_selection_uri,
|
||||
chunk_selection_uri,
|
||||
)
|
||||
|
||||
|
||||
def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
|
||||
|
|
@ -718,6 +723,75 @@ def docrag_exploration_triples(
|
|||
return triples
|
||||
|
||||
|
||||
def docrag_chunk_selection_triples(
|
||||
focus_uri: str,
|
||||
exploration_uri: str,
|
||||
selected_chunks_with_scores: List[dict],
|
||||
session_id: str,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a document RAG focus entity (chunks selected by the
|
||||
cross-encoder reranker).
|
||||
|
||||
Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity
|
||||
derived from exploration, with one ChunkSelection sub-entity per surviving
|
||||
chunk carrying the chunk reference and the reranker score.
|
||||
|
||||
Structure:
|
||||
<focus> a tg:Focus ; prov:wasDerivedFrom <exploration> .
|
||||
<focus> tg:selectedChunk <chunk_sel_0> .
|
||||
<chunk_sel_0> a tg:ChunkSelection .
|
||||
<chunk_sel_0> tg:document <chunk_id> .
|
||||
<chunk_sel_0> tg:score "0.97" .
|
||||
|
||||
Args:
|
||||
focus_uri: URI of the focus entity (from docrag_focus_uri)
|
||||
exploration_uri: URI of the parent exploration entity
|
||||
selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score'
|
||||
session_id: Session UUID for generating chunk selection URIs
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
triples = [
|
||||
_triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)),
|
||||
_triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")),
|
||||
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||
]
|
||||
|
||||
for idx, chunk_info in enumerate(selected_chunks_with_scores):
|
||||
chunk_id = chunk_info.get("chunk_id")
|
||||
if not chunk_id:
|
||||
continue
|
||||
|
||||
chunk_sel_uri = chunk_selection_uri(session_id, idx)
|
||||
|
||||
# Link focus to chunk selection entity
|
||||
triples.append(
|
||||
_triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri))
|
||||
)
|
||||
|
||||
# Type the chunk selection entity
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION))
|
||||
)
|
||||
|
||||
# Reference the actual chunk (in librarian)
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id))
|
||||
)
|
||||
|
||||
# Cross-encoder score
|
||||
score = chunk_info.get("score")
|
||||
if score is not None:
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, TG_SCORE, _literal(str(score)))
|
||||
)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def docrag_synthesis_triples(
|
||||
synthesis_uri: str,
|
||||
exploration_uri: str,
|
||||
|
|
|
|||
|
|
@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str:
|
|||
return f"urn:trustgraph:docrag:{session_id}/exploration"
|
||||
|
||||
|
||||
def docrag_focus_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG focus entity (chunks selected by the
|
||||
cross-encoder reranker).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:docrag:{uuid}/focus
|
||||
"""
|
||||
return f"urn:trustgraph:docrag:{session_id}/focus"
|
||||
|
||||
|
||||
def chunk_selection_uri(session_id: str, chunk_index: int) -> str:
|
||||
"""
|
||||
Generate URI for a chunk selection item (links a reranked chunk to its
|
||||
score). Mirrors edge_selection_uri for GraphRAG.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
chunk_index: Index of this chunk in the selection (0-based).
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:prov:chunk:{uuid}:{index}
|
||||
"""
|
||||
return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}"
|
||||
|
||||
|
||||
def docrag_synthesis_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG synthesis entity (final answer).
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from . namespaces import (
|
|||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
TG_EDGE_SELECTION, TG_SCORE,
|
||||
TG_CHUNK_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -95,6 +96,7 @@ TG_CLASS_LABELS = [
|
|||
_label_triple(TG_PLAN_TYPE, "Plan"),
|
||||
_label_triple(TG_STEP_RESULT, "Step Result"),
|
||||
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
|
||||
_label_triple(TG_CHUNK_SELECTION, "Chunk Selection"),
|
||||
]
|
||||
|
||||
# TrustGraph predicate labels
|
||||
|
|
|
|||
|
|
@ -40,7 +40,10 @@ class GraphRagResponse:
|
|||
class DocumentRagQuery:
|
||||
query: str = ""
|
||||
collection: str = ""
|
||||
doc_limit: int = 0
|
||||
doc_limit: int = 0 # docs selected into the synthesis prompt
|
||||
fetch_limit: int = 0 # candidate pool fetched from the vector store
|
||||
# before reranking (0 = derive from doc_limit;
|
||||
# values below doc_limit are raised to it)
|
||||
streaming: bool = False
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue