trustgraph/trustgraph-base/trustgraph/messaging/translators/retrieval.py
Sunny 6c9a545a06
feat: add cross-encoder reranking to Document-RAG with two-limit control (#878) (#1011)
Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after
vector retrieval, over-fetch a wider candidate pool, rerank with the
cross-encoder, and keep the top doc_limit chunks for synthesis.

Per maintainer review, the fetch and select sizes are two caller-controlled
limits rather than one internal heuristic:

- doc_limit:   chunks selected into the synthesis prompt (unchanged meaning).
- fetch_limit: candidate pool pulled from the vector store before reranking.
  0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are
  raised to it. Lets the caller control how hard the reranker has to work.

Details:
- schema: DocumentRagQuery.fetch_limit (additive, backward compatible).
- document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors
  doc_limit); the core applies the heuristic default and derives synthesis
  provenance from the chunk-selection focus when reranking ran.
- provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection).
- request translator + client SDKs + CLI: fetch-limit / --fetch-limit,
  threaded exactly like doc_limit and the GraphRAG limits.
- tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic
  default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring.

Reranking is skipped byte-identically when no reranker role is wired.
Requires the companion trustgraph-templates change wiring the reranker
topics into the document-rag flow (mirrors #279 for GraphRAG).
2026-07-02 09:50:13 +01:00

184 lines
No EOL
7 KiB
Python

from typing import Dict, Any, Tuple
from ...schema import DocumentRagQuery, DocumentRagResponse, GraphRagQuery, GraphRagResponse
from .base import MessageTranslator
from .primitives import TripleTranslator
class DocumentRagRequestTranslator(MessageTranslator):
"""Translator for DocumentRagQuery schema objects"""
def decode(self, data: Dict[str, Any]) -> DocumentRagQuery:
return DocumentRagQuery(
query=data["query"],
collection=data.get("collection", "default"),
doc_limit=int(data.get("doc-limit", 20)),
fetch_limit=int(data.get("fetch-limit", 0)),
streaming=data.get("streaming", False)
)
def encode(self, obj: DocumentRagQuery) -> Dict[str, Any]:
return {
"query": obj.query,
"collection": obj.collection,
"doc-limit": obj.doc_limit,
"fetch-limit": obj.fetch_limit,
"streaming": getattr(obj, "streaming", False)
}
class DocumentRagResponseTranslator(MessageTranslator):
"""Translator for DocumentRagResponse schema objects"""
def __init__(self):
self.triple_translator = TripleTranslator()
def decode(self, data: Dict[str, Any]) -> DocumentRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def encode(self, obj: DocumentRagResponse) -> Dict[str, Any]:
result = {}
# Include message_type for distinguishing chunk vs explain messages
message_type = getattr(obj, "message_type", "")
if message_type:
result["message_type"] = message_type
# Include response content for chunk messages
if obj.response is not None:
result["response"] = obj.response
# Include explain_id for explain messages
explain_id = getattr(obj, "explain_id", None)
if explain_id:
result["explain_id"] = explain_id
# Include explain_graph for explain messages (named graph filter)
explain_graph = getattr(obj, "explain_graph", None)
if explain_graph is not None:
result["explain_graph"] = explain_graph
# Include explain_triples for explain messages
explain_triples = getattr(obj, "explain_triples", [])
if explain_triples:
result["explain_triples"] = [
self.triple_translator.encode(t) for t in explain_triples
]
# Include end_of_stream flag (LLM stream complete)
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
# Include end_of_session flag (entire session complete)
result["end_of_session"] = getattr(obj, "end_of_session", False)
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
# Session is complete when end_of_session is True
is_final = getattr(obj, 'end_of_session', False)
return self.encode(obj), is_final
class GraphRagRequestTranslator(MessageTranslator):
"""Translator for GraphRagQuery schema objects"""
def decode(self, data: Dict[str, Any]) -> GraphRagQuery:
return GraphRagQuery(
query=data["query"],
collection=data.get("collection", "default"),
entity_limit=int(data.get("entity-limit", 50)),
triple_limit=int(data.get("triple-limit", 30)),
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
max_path_length=int(data.get("max-path-length", 2)),
edge_score_limit=int(data.get("edge-score-limit", 30)),
edge_limit=int(data.get("edge-limit", 25)),
streaming=data.get("streaming", False)
)
def encode(self, obj: GraphRagQuery) -> Dict[str, Any]:
return {
"query": obj.query,
"collection": obj.collection,
"entity-limit": obj.entity_limit,
"triple-limit": obj.triple_limit,
"max-subgraph-size": obj.max_subgraph_size,
"max-path-length": obj.max_path_length,
"edge-score-limit": obj.edge_score_limit,
"edge-limit": obj.edge_limit,
"streaming": getattr(obj, "streaming", False)
}
class GraphRagResponseTranslator(MessageTranslator):
"""Translator for GraphRagResponse schema objects"""
def __init__(self):
self.triple_translator = TripleTranslator()
def decode(self, data: Dict[str, Any]) -> GraphRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def encode(self, obj: GraphRagResponse) -> Dict[str, Any]:
result = {}
# Include message_type
message_type = getattr(obj, "message_type", "")
if message_type:
result["message_type"] = message_type
# Include response content for chunk messages
if obj.response is not None:
result["response"] = obj.response
# Include explain_id for explain messages
explain_id = getattr(obj, "explain_id", None)
if explain_id:
result["explain_id"] = explain_id
# Include explain_graph for explain messages (named graph filter)
explain_graph = getattr(obj, "explain_graph", None)
if explain_graph is not None:
result["explain_graph"] = explain_graph
# Include explain_triples for explain messages
explain_triples = getattr(obj, "explain_triples", [])
if explain_triples:
result["explain_triples"] = [
self.triple_translator.encode(t) for t in explain_triples
]
# Include end_of_stream flag (LLM stream complete)
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
# Include end_of_session flag (entire session complete)
result["end_of_session"] = getattr(obj, "end_of_session", False)
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
# Session is complete when end_of_session is True
is_final = getattr(obj, 'end_of_session', False)
return self.encode(obj), is_final