feat: add cross-encoder reranking to Document-RAG with two-limit control (#878) (#1011)

Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after
vector retrieval, over-fetch a wider candidate pool, rerank with the
cross-encoder, and keep the top doc_limit chunks for synthesis.

Per maintainer review, the fetch and select sizes are two caller-controlled
limits rather than one internal heuristic:

- doc_limit:   chunks selected into the synthesis prompt (unchanged meaning).
- fetch_limit: candidate pool pulled from the vector store before reranking.
  0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are
  raised to it. Lets the caller control how hard the reranker has to work.

Details:
- schema: DocumentRagQuery.fetch_limit (additive, backward compatible).
- document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors
  doc_limit); the core applies the heuristic default and derives synthesis
  provenance from the chunk-selection focus when reranking ran.
- provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection).
- request translator + client SDKs + CLI: fetch-limit / --fetch-limit,
  threaded exactly like doc_limit and the GraphRAG limits.
- tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic
  default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring.

Reranking is skipped byte-identically when no reranker role is wired.
Requires the companion trustgraph-templates change wiring the reranker
topics into the document-rag flow (mirrors #279 for GraphRAG).
This commit is contained in:
Sunny 2026-07-02 02:50:13 -06:00 committed by GitHub
parent f18d48dc39
commit 6c9a545a06
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 853 additions and 26 deletions

View file

@ -101,27 +101,27 @@ class TestQuery:
assert query.rag == mock_rag assert query.rag == mock_rag
assert query.collection == "test_collection" assert query.collection == "test_collection"
assert query.verbose is False assert query.verbose is False
assert query.doc_limit == 20 # Default value assert query.fetch_limit == 20 # Default value
def test_query_initialization_with_custom_doc_limit(self): def test_query_initialization_with_custom_fetch_limit(self):
"""Test Query initialization with custom doc_limit""" """Test Query initialization with custom fetch_limit"""
# Create mock DocumentRag # Create mock DocumentRag
mock_rag = MagicMock() mock_rag = MagicMock()
# Initialize Query with custom doc_limit # Initialize Query with custom fetch_limit
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
workspace="test_workspace", workspace="test_workspace",
collection="custom_collection", collection="custom_collection",
verbose=True, verbose=True,
doc_limit=50 fetch_limit=50
) )
# Verify initialization # Verify initialization
assert query.rag == mock_rag assert query.rag == mock_rag
assert query.collection == "custom_collection" assert query.collection == "custom_collection"
assert query.verbose is True assert query.verbose is True
assert query.doc_limit == 50 assert query.fetch_limit == 50
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_concepts(self): async def test_extract_concepts(self):
@ -224,7 +224,7 @@ class TestQuery:
workspace="test_workspace", workspace="test_workspace",
collection="test_collection", collection="test_collection",
verbose=False, verbose=False,
doc_limit=15 fetch_limit=15
) )
# Call get_docs with concepts list # Call get_docs with concepts list
@ -377,7 +377,7 @@ class TestQuery:
workspace="test_workspace", workspace="test_workspace",
collection="test_collection", collection="test_collection",
verbose=True, verbose=True,
doc_limit=5 fetch_limit=5
) )
# Call get_docs with concepts # Call get_docs with concepts
@ -615,7 +615,7 @@ class TestQuery:
workspace="test_workspace", workspace="test_workspace",
collection="test_collection", collection="test_collection",
verbose=False, verbose=False,
doc_limit=10 fetch_limit=10
) )
docs, chunk_ids = await query.get_docs(["concept A", "concept B"]) docs, chunk_ids = await query.get_docs(["concept A", "concept B"])

View file

@ -0,0 +1,478 @@
"""
Tests for the optional cross-encoder reranking pass in DocumentRag.query().
Two behaviours are covered:
1. No-op: when no reranker_client is wired (the default), query() must feed
the LLM the exact same chunks, in the same order, that retrieval produced
- byte-identical to the pre-reranker behaviour - and must NOT emit a
chunk-selection provenance event.
2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered
and truncated according to the reranker's results, the LLM receives the
reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted
carrying the per-surviving-chunk scores and chunk references.
These are pure orchestration tests - the reranker is a stub, so there is no
torch / network dependency.
"""
import pytest
from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.base import PromptResult
from trustgraph.schema import RerankerResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_WAS_DERIVED_FROM,
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS,
TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
@dataclass
class ChunkMatch:
"""Mimics the result from doc_embeddings_client.query()."""
chunk_id: str
# ---------------------------------------------------------------------------
# Fixtures: three retrievable chunks
# ---------------------------------------------------------------------------
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
CHUNK_C = "urn:chunk:policy-doc-1:chunk-2"
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays."
CHUNK_C_CONTENT = "Refunds are processed to the original payment method."
# Retrieval (post-dedupe) order is A, B, C.
ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT]
ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C]
def build_mock_clients():
"""
Build mock subsidiary clients for a document-rag query returning three
distinct chunks (A, B, C) in that order.
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock()
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="return policy\nrefund")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# Each concept query returns the same three chunks; dedupe keeps A, B, C.
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id=CHUNK_A),
ChunkMatch(chunk_id=CHUNK_B),
ChunkMatch(chunk_id=CHUNK_C),
]
async def mock_fetch(chunk_id):
return {
CHUNK_A: CHUNK_A_CONTENT,
CHUNK_B: CHUNK_B_CONTENT,
CHUNK_C: CHUNK_C_CONTENT,
}[chunk_id]
fetch_chunk.side_effect = mock_fetch
prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="Items can be returned within 30 days for a full refund.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
class StubReranker:
"""
Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed,
pre-sorted, truncated list of RerankerResult - exactly the contract the
flashrank service guarantees (sorted desc by score, truncated to limit).
"""
def __init__(self, results):
self._results = results
self.calls = []
async def rerank(self, queries, documents, limit=10, timeout=300):
self.calls.append(
{"queries": queries, "documents": documents, "limit": limit}
)
return self._results
# ---------------------------------------------------------------------------
# 1. No-op: reranker_client=None must not change anything
# ---------------------------------------------------------------------------
class TestRerankNoOp:
@pytest.mark.asyncio
async def test_documents_passed_to_llm_are_unchanged(self):
"""
With no reranker wired, document_prompt must receive the retrieved
chunks in the original order and length.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients) # reranker_client defaults to None
await rag.query(query="What is the return policy?")
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert passed_docs == ORDERED_CONTENT
@pytest.mark.asyncio
async def test_no_chunk_selection_event_emitted(self):
"""
Without a reranker, the provenance chain is the original 4 stages:
question, grounding, exploration, synthesis - no focus stage.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
assert len(events) == 4
types = [
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS,
]
for i, expected in enumerate(types):
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
# No chunk-selection entity anywhere.
for e in events:
assert not any(
t.o.iri == TG_CHUNK_SELECTION
for t in e["triples"]
if t.p.iri == RDF_TYPE
)
@pytest.mark.asyncio
async def test_synthesis_derives_from_exploration_when_no_rerank(self):
"""
No-op lineage is unchanged: synthesis derives from exploration
(there is no focus stage). Guards the conditional synthesis parent.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
# events: question, grounding, exploration, synthesis
exp_uri = events[2]["explain_id"]
syn_event = events[3]
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri
# ---------------------------------------------------------------------------
# 2. Rerank: reorder + truncate + provenance
# ---------------------------------------------------------------------------
class TestRerankActive:
def _reranker_keeping_C_then_A(self):
# Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped.
# Pre-sorted desc by score and truncated to limit, per the contract.
return StubReranker([
RerankerResult(document_id="2", query_id="0", score=0.95),
RerankerResult(document_id="0", query_id="0", score=0.42),
])
@pytest.mark.asyncio
async def test_documents_reordered_and_truncated(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?")
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT]
@pytest.mark.asyncio
async def test_reranker_called_with_single_query_and_all_docs(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?", doc_limit=2)
assert len(reranker.calls) == 1
c = reranker.calls[0]
assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}]
assert c["documents"] == [
{"id": "0", "text": CHUNK_A_CONTENT},
{"id": "1", "text": CHUNK_B_CONTENT},
{"id": "2", "text": CHUNK_C_CONTENT},
]
# The rerank narrows down to the final doc_limit, NOT fetch_limit
# (fetch_limit is the over-fetched candidate pool size).
assert c["limit"] == 2
@pytest.mark.asyncio
async def test_explicit_fetch_limit_over_fetches_then_narrows(self):
"""
Semantic guard for the value of reranking AND the maintainer's two-limit
contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider
candidate pool so the cross-encoder can surface chunks the bi-encoder
ranked outside the final doc_limit, then the rerank narrows the pool back
down to doc_limit. The fetch_limit is honoured directly (caller controls
how hard the reranker works), not overridden by any heuristic.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
reranker = self._reranker_keeping_C_then_A()
# Candidate pool (fetch_limit=60) >> final doc_limit (6).
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(
query="What is the return policy?", doc_limit=6, fetch_limit=60,
)
# Over-fetch: the embeddings store is queried with the fetch_limit
# budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit
# budget (6 // 2 = 3). This is the bug guard.
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
assert q_limit == 30
# Narrow: the rerank keeps the final doc_limit (6), not fetch_limit.
assert reranker.calls[0]["limit"] == 6
@pytest.mark.asyncio
async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self):
"""
With no fetch_limit passed to query(), the candidate pool falls back to
the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with
doc_limit and reranking keeps its recall benefit out of the box.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
reranker = self._reranker_keeping_C_then_A()
# No fetch_limit -> heuristic default.
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?", doc_limit=20)
# fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept.
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
assert q_limit == 30
# Rerank narrows to the final doc_limit (20).
assert reranker.calls[0]["limit"] == 20
@pytest.mark.asyncio
async def test_fetch_limit_floored_at_doc_limit(self):
"""
A fetch_limit below doc_limit is floored up to doc_limit: retrieval must
never fetch fewer candidates than the rerank is asked to keep, else the
prompt could not be filled.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(
query="What is the return policy?", doc_limit=10, fetch_limit=4,
)
# fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept.
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
assert q_limit == 5
assert reranker.calls[0]["limit"] == 10
@pytest.mark.asyncio
async def test_chunk_selection_event_emitted(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
# Now 5 stages: question, grounding, exploration, focus, synthesis.
assert len(events) == 5
ordered_types = [
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS,
]
for i, expected in enumerate(ordered_types):
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
@pytest.mark.asyncio
async def test_chunk_selection_carries_scores_and_chunk_refs(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
focus_event = events[3]
foc_uri = focus_event["explain_id"]
triples = focus_event["triples"]
# focus is derived from exploration
exp_uri = events[2]["explain_id"]
assert derived_from(triples, foc_uri) == exp_uri
# Two ChunkSelection sub-entities, linked from focus.
sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri)
assert len(sel_links) == 2
# Each selection has a ChunkSelection type, a chunk document ref and a score.
chunk_refs = set()
scores = set()
for link in sel_links:
sel_uri = link.o.iri
assert has_type(triples, sel_uri, TG_CHUNK_SELECTION)
doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri)
assert doc_ref is not None
chunk_refs.add(doc_ref.o.iri)
score_t = find_triple(triples, TG_SCORE, sel_uri)
assert score_t is not None
scores.add(score_t.o.value)
# Surviving chunks are C and A (B dropped), with the reranker scores.
assert chunk_refs == {CHUNK_C, CHUNK_A}
assert scores == {"0.95", "0.42"}
@pytest.mark.asyncio
async def test_all_focus_triples_in_retrieval_graph(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
for t in events[3]["triples"]:
assert t.g == "urn:graph:retrieval"
@pytest.mark.asyncio
async def test_synthesis_derives_from_focus_when_reranking(self):
"""
When reranking runs, synthesis must derive from the focus node (the
reranked chunks actually fed to the LLM), mirroring GraphRAG - not from
exploration, which would leave focus as a dangling branch and
misrepresent what fed the answer.
"""
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
doc_limit=2,
explain_callback=explain_callback,
)
# events: question, grounding, exploration, focus, synthesis
foc_uri = events[3]["explain_id"]
syn_event = events[4]
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri
@pytest.mark.asyncio
async def test_empty_docs_skips_reranker(self):
"""If retrieval returns no chunks, the reranker is never called."""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
doc_embeddings_client.query.return_value = [] # no matches
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?")
assert reranker.calls == []

View file

@ -0,0 +1,89 @@
"""
Cross-layer wiring contract for the Document-RAG reranker (issue #878).
The Document-RAG processor registers a ``RerankerClientSpec`` for the
``reranker-request`` / ``reranker-response`` roles (see
``retrieval/document_rag/rag.py``). At flow construction every spec runs
``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add``
resolves its topics via ``definition["topics"][name]`` - which raises
``KeyError`` if the flow blueprint does not provide those topics.
This means the monorepo code change is only safe to deploy together with the
companion ``trustgraph-templates`` change that wires ``reranker-request`` /
``reranker-response`` into the Document-RAG flow (mirroring what templates
PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that
contract from the monorepo side:
* with the reranker topics present (as the updated templates compile them),
the spec binds cleanly and registers the client;
* without them (the pre-companion blueprint), construction fails fast with a
KeyError naming the missing role - documenting exactly why the templates
change is required.
No broker/network: the pub/sub backend is mocked (topics are bound at add()
time, connections happen later at start()).
"""
import pytest
from unittest.mock import MagicMock
from trustgraph.base import RerankerClientSpec
def _flow():
f = MagicMock()
f.workspace = "ws"
f.name = "document-rag"
f.id = "proc1"
f.consumer = {}
return f
def _processor():
p = MagicMock()
p.pubsub = MagicMock()
p.id = "proc1"
p.taskgroup = MagicMock()
return p
def _spec():
return RerankerClientSpec(
request_name="reranker-request",
response_name="reranker-response",
)
# Topics dict as the UPDATED document-store.jsonnet compiles them
# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}).
DEFINITION_WITH_RERANKER = {
"topics": {
"request": "request:tg:document-rag:ws:id",
"response": "response:tg:document-rag:ws:id",
"reranker-request": "request:tg:reranker:ws:id",
"reranker-response": "response:tg:reranker:ws:id",
}
}
# Pre-companion blueprint: no reranker topics (document-rag before the templates change).
DEFINITION_WITHOUT_RERANKER = {
"topics": {
"request": "request:tg:document-rag:ws:id",
"response": "response:tg:document-rag:ws:id",
}
}
def test_reranker_client_binds_when_flow_provides_topics():
flow = _flow()
_spec().add(flow, _processor(), DEFINITION_WITH_RERANKER)
# The client consumer is registered against the reranker role.
assert "reranker-request" in flow.consumer
def test_reranker_client_keyerrors_without_companion_template_topics():
with pytest.raises(KeyError) as exc:
_spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER)
# Fails fast naming the missing role -> the trustgraph-templates companion
# change (wire reranker-request/response into the document-rag flow) is required.
assert "reranker-request" in str(exc.value)

View file

@ -66,6 +66,7 @@ class TestDocumentRagService:
workspace=ANY, # Workspace comes from flow.workspace (mock) workspace=ANY, # Workspace comes from flow.workspace (mock)
collection="test_coll_1", # Must be from message, not hardcoded default collection="test_coll_1", # Must be from message, not hardcoded default
doc_limit=5, doc_limit=5,
fetch_limit=0, # Unset -> core derives the candidate pool
explain_callback=ANY, # Explainability callback is always passed explain_callback=ANY, # Explainability callback is always passed
save_answer_callback=ANY, # Librarian save callback is always passed save_answer_callback=ANY, # Librarian save callback is always passed
) )

View file

@ -527,7 +527,8 @@ class AsyncFlowInstance:
return result.get("response", "") return result.get("response", "")
async def document_rag(self, query: str, collection: str, async def document_rag(self, query: str, collection: str,
doc_limit: int = 10, **kwargs: Any) -> str: doc_limit: int = 10, fetch_limit: int = 0,
**kwargs: Any) -> str:
""" """
Execute document-based RAG query (non-streaming). Execute document-based RAG query (non-streaming).
@ -541,7 +542,9 @@ class AsyncFlowInstance:
Args: Args:
query: User query text query: User query text
collection: Collection identifier containing documents collection: Collection identifier containing documents
doc_limit: Maximum number of document chunks to retrieve (default: 10) doc_limit: Document chunks selected into the prompt (default: 10)
fetch_limit: Candidate chunks fetched from the vector store before
reranking (default: 0 = derive from doc_limit)
**kwargs: Additional service-specific parameters **kwargs: Additional service-specific parameters
Returns: Returns:
@ -564,6 +567,7 @@ class AsyncFlowInstance:
"query": query, "query": query,
"collection": collection, "collection": collection,
"doc-limit": doc_limit, "doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": False "streaming": False
} }
request_data.update(kwargs) request_data.update(kwargs)

View file

@ -379,12 +379,14 @@ class AsyncSocketFlowInstance:
yield chunk.content yield chunk.content
async def document_rag(self, query: str, collection: str, async def document_rag(self, query: str, collection: str,
doc_limit: int = 10, streaming: bool = False, **kwargs): doc_limit: int = 10, fetch_limit: int = 0,
streaming: bool = False, **kwargs):
"""Document RAG with optional streaming""" """Document RAG with optional streaming"""
request = { request = {
"query": query, "query": query,
"collection": collection, "collection": collection,
"doc-limit": doc_limit, "doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": streaming "streaming": streaming
} }
request.update(kwargs) request.update(kwargs)

View file

@ -415,7 +415,7 @@ class FlowInstance:
def document_rag( def document_rag(
self, query,collection="default", self, query,collection="default",
doc_limit=10, doc_limit=10, fetch_limit=0,
): ):
""" """
Execute document-based Retrieval-Augmented Generation (RAG) query. Execute document-based Retrieval-Augmented Generation (RAG) query.
@ -426,7 +426,9 @@ class FlowInstance:
Args: Args:
query: Natural language query query: Natural language query
collection: Collection identifier (default: "default") collection: Collection identifier (default: "default")
doc_limit: Maximum document chunks to retrieve (default: 10) doc_limit: Document chunks selected into the prompt (default: 10)
fetch_limit: Candidate chunks fetched from the vector store before
reranking (default: 0 = derive from doc_limit)
Returns: Returns:
str: Generated response incorporating document context str: Generated response incorporating document context
@ -447,6 +449,7 @@ class FlowInstance:
"query": query, "query": query,
"collection": collection, "collection": collection,
"doc-limit": doc_limit, "doc-limit": doc_limit,
"fetch-limit": fetch_limit,
} }
result = self.request( result = self.request(

View file

@ -752,6 +752,7 @@ class SocketFlowInstance:
query: str, query: str,
collection: str, collection: str,
doc_limit: int = 10, doc_limit: int = 10,
fetch_limit: int = 0,
streaming: bool = False, streaming: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[TextCompletionResult, Iterator[RAGChunk]]: ) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
@ -764,6 +765,7 @@ class SocketFlowInstance:
"query": query, "query": query,
"collection": collection, "collection": collection,
"doc-limit": doc_limit, "doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": streaming "streaming": streaming
} }
request.update(kwargs) request.update(kwargs)
@ -785,6 +787,7 @@ class SocketFlowInstance:
query: str, query: str,
collection: str, collection: str,
doc_limit: int = 10, doc_limit: int = 10,
fetch_limit: int = 0,
**kwargs: Any **kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""Execute document-based RAG query with explainability support.""" """Execute document-based RAG query with explainability support."""
@ -792,6 +795,7 @@ class SocketFlowInstance:
"query": query, "query": query,
"collection": collection, "collection": collection,
"doc-limit": doc_limit, "doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": True, "streaming": True,
"explainable": True, "explainable": True,
} }

View file

@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
query=data["query"], query=data["query"],
collection=data.get("collection", "default"), collection=data.get("collection", "default"),
doc_limit=int(data.get("doc-limit", 20)), doc_limit=int(data.get("doc-limit", 20)),
fetch_limit=int(data.get("fetch-limit", 0)),
streaming=data.get("streaming", False) streaming=data.get("streaming", False)
) )
@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
"query": obj.query, "query": obj.query,
"collection": obj.collection, "collection": obj.collection,
"doc-limit": obj.doc_limit, "doc-limit": obj.doc_limit,
"fetch-limit": obj.fetch_limit,
"streaming": getattr(obj, "streaming", False) "streaming": getattr(obj, "streaming", False)
} }

View file

@ -64,6 +64,8 @@ from . uris import (
docrag_question_uri, docrag_question_uri,
docrag_grounding_uri, docrag_grounding_uri,
docrag_exploration_uri, docrag_exploration_uri,
docrag_focus_uri,
chunk_selection_uri,
docrag_synthesis_uri, docrag_synthesis_uri,
) )
@ -94,6 +96,8 @@ from . namespaces import (
TG_EDGE_SELECTION, TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG) # Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK, TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Chunk selection entity type
TG_CHUNK_SELECTION,
# Explainability entity types # Explainability entity types
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION, TG_ANALYSIS, TG_CONCLUSION,
@ -132,6 +136,7 @@ from . triples import (
# Query-time provenance triple builders (DocumentRAG) # Query-time provenance triple builders (DocumentRAG)
docrag_question_triples, docrag_question_triples,
docrag_exploration_triples, docrag_exploration_triples,
docrag_chunk_selection_triples,
docrag_synthesis_triples, docrag_synthesis_triples,
# Utility # Utility
set_graph, set_graph,
@ -196,6 +201,8 @@ __all__ = [
"docrag_question_uri", "docrag_question_uri",
"docrag_grounding_uri", "docrag_grounding_uri",
"docrag_exploration_uri", "docrag_exploration_uri",
"docrag_focus_uri",
"chunk_selection_uri",
"docrag_synthesis_uri", "docrag_synthesis_uri",
# Namespaces # Namespaces
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT", "PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
@ -219,6 +226,8 @@ __all__ = [
"TG_EDGE_SELECTION", "TG_EDGE_SELECTION",
# Query-time provenance predicates (DocumentRAG) # Query-time provenance predicates (DocumentRAG)
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK", "TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
# Chunk selection entity type
"TG_CHUNK_SELECTION",
# Explainability entity types # Explainability entity types
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS", "TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
"TG_ANALYSIS", "TG_CONCLUSION", "TG_ANALYSIS", "TG_CONCLUSION",
@ -254,6 +263,7 @@ __all__ = [
# Query-time provenance triple builders (DocumentRAG) # Query-time provenance triple builders (DocumentRAG)
"docrag_question_triples", "docrag_question_triples",
"docrag_exploration_triples", "docrag_exploration_triples",
"docrag_chunk_selection_triples",
"docrag_synthesis_triples", "docrag_synthesis_triples",
# Agent provenance triple builders # Agent provenance triple builders
"agent_session_triples", "agent_session_triples",

View file

@ -76,6 +76,9 @@ TG_EDGE_SELECTION = TG + "EdgeSelection"
TG_CHUNK_COUNT = TG + "chunkCount" TG_CHUNK_COUNT = TG + "chunkCount"
TG_SELECTED_CHUNK = TG + "selectedChunk" TG_SELECTED_CHUNK = TG + "selectedChunk"
# Chunk selection entity type (cross-encoder reranked chunk in Focus)
TG_CHUNK_SELECTION = TG + "ChunkSelection"
# Extraction provenance entity types # Extraction provenance entity types
TG_DOCUMENT_TYPE = TG + "Document" TG_DOCUMENT_TYPE = TG + "Document"
TG_PAGE_TYPE = TG + "Page" TG_PAGE_TYPE = TG + "Page"

View file

@ -30,6 +30,8 @@ from . namespaces import (
TG_EDGE_SELECTION, TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG) # Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK, TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Chunk selection entity type
TG_CHUNK_SELECTION,
# Explainability entity types # Explainability entity types
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
# Unifying types # Unifying types
@ -40,7 +42,10 @@ from . namespaces import (
TG_IN_TOKEN, TG_OUT_TOKEN, TG_IN_TOKEN, TG_OUT_TOKEN,
) )
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri from . uris import (
activity_uri, agent_uri, subgraph_uri, edge_selection_uri,
chunk_selection_uri,
)
def set_graph(triples: List[Triple], graph: str) -> List[Triple]: def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
@ -718,6 +723,75 @@ def docrag_exploration_triples(
return triples return triples
def docrag_chunk_selection_triples(
focus_uri: str,
exploration_uri: str,
selected_chunks_with_scores: List[dict],
session_id: str,
) -> List[Triple]:
"""
Build triples for a document RAG focus entity (chunks selected by the
cross-encoder reranker).
Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity
derived from exploration, with one ChunkSelection sub-entity per surviving
chunk carrying the chunk reference and the reranker score.
Structure:
<focus> a tg:Focus ; prov:wasDerivedFrom <exploration> .
<focus> tg:selectedChunk <chunk_sel_0> .
<chunk_sel_0> a tg:ChunkSelection .
<chunk_sel_0> tg:document <chunk_id> .
<chunk_sel_0> tg:score "0.97" .
Args:
focus_uri: URI of the focus entity (from docrag_focus_uri)
exploration_uri: URI of the parent exploration entity
selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score'
session_id: Session UUID for generating chunk selection URIs
Returns:
List of Triple objects
"""
triples = [
_triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)),
_triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")),
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
]
for idx, chunk_info in enumerate(selected_chunks_with_scores):
chunk_id = chunk_info.get("chunk_id")
if not chunk_id:
continue
chunk_sel_uri = chunk_selection_uri(session_id, idx)
# Link focus to chunk selection entity
triples.append(
_triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri))
)
# Type the chunk selection entity
triples.append(
_triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION))
)
# Reference the actual chunk (in librarian)
triples.append(
_triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id))
)
# Cross-encoder score
score = chunk_info.get("score")
if score is not None:
triples.append(
_triple(chunk_sel_uri, TG_SCORE, _literal(str(score)))
)
return triples
def docrag_synthesis_triples( def docrag_synthesis_triples(
synthesis_uri: str, synthesis_uri: str,
exploration_uri: str, exploration_uri: str,

View file

@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str:
return f"urn:trustgraph:docrag:{session_id}/exploration" return f"urn:trustgraph:docrag:{session_id}/exploration"
def docrag_focus_uri(session_id: str) -> str:
"""
Generate URI for a document RAG focus entity (chunks selected by the
cross-encoder reranker).
Args:
session_id: The session UUID.
Returns:
URN in format: urn:trustgraph:docrag:{uuid}/focus
"""
return f"urn:trustgraph:docrag:{session_id}/focus"
def chunk_selection_uri(session_id: str, chunk_index: int) -> str:
"""
Generate URI for a chunk selection item (links a reranked chunk to its
score). Mirrors edge_selection_uri for GraphRAG.
Args:
session_id: The session UUID.
chunk_index: Index of this chunk in the selection (0-based).
Returns:
URN in format: urn:trustgraph:prov:chunk:{uuid}:{index}
"""
return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}"
def docrag_synthesis_uri(session_id: str) -> str: def docrag_synthesis_uri(session_id: str) -> str:
""" """
Generate URI for a document RAG synthesis entity (final answer). Generate URI for a document RAG synthesis entity (final answer).

View file

@ -30,6 +30,7 @@ from . namespaces import (
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SUBAGENT_GOAL, TG_PLAN_STEP, TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_EDGE_SELECTION, TG_SCORE, TG_EDGE_SELECTION, TG_SCORE,
TG_CHUNK_SELECTION,
) )
@ -95,6 +96,7 @@ TG_CLASS_LABELS = [
_label_triple(TG_PLAN_TYPE, "Plan"), _label_triple(TG_PLAN_TYPE, "Plan"),
_label_triple(TG_STEP_RESULT, "Step Result"), _label_triple(TG_STEP_RESULT, "Step Result"),
_label_triple(TG_EDGE_SELECTION, "Edge Selection"), _label_triple(TG_EDGE_SELECTION, "Edge Selection"),
_label_triple(TG_CHUNK_SELECTION, "Chunk Selection"),
] ]
# TrustGraph predicate labels # TrustGraph predicate labels

View file

@ -40,7 +40,10 @@ class GraphRagResponse:
class DocumentRagQuery: class DocumentRagQuery:
query: str = "" query: str = ""
collection: str = "" collection: str = ""
doc_limit: int = 0 doc_limit: int = 0 # docs selected into the synthesis prompt
fetch_limit: int = 0 # candidate pool fetched from the vector store
# before reranking (0 = derive from doc_limit;
# values below doc_limit are raised to it)
streaming: bool = False streaming: bool = False
@dataclass @dataclass

View file

@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_collection = 'default' default_collection = 'default'
default_doc_limit = 10 default_doc_limit = 10
default_fetch_limit = 0
def question_explainable( def question_explainable(
url, flow_id, question_text, collection, doc_limit, token=None, debug=False, url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
token=None, debug=False,
workspace="default", workspace="default",
): ):
"""Execute document RAG with explainability - shows provenance events inline.""" """Execute document RAG with explainability - shows provenance events inline."""
@ -39,6 +41,7 @@ def question_explainable(
query=question_text, query=question_text,
collection=collection, collection=collection,
doc_limit=doc_limit, doc_limit=doc_limit,
fetch_limit=fetch_limit,
): ):
if isinstance(item, RAGChunk): if isinstance(item, RAGChunk):
# Print response content # Print response content
@ -97,7 +100,7 @@ def question_explainable(
def question( def question(
url, flow_id, question_text, collection, doc_limit, url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
streaming=True, token=None, explainable=False, debug=False, streaming=True, token=None, explainable=False, debug=False,
show_usage=False, workspace="default", show_usage=False, workspace="default",
): ):
@ -109,6 +112,7 @@ def question(
question_text=question_text, question_text=question_text,
collection=collection, collection=collection,
doc_limit=doc_limit, doc_limit=doc_limit,
fetch_limit=fetch_limit,
token=token, token=token,
debug=debug, debug=debug,
workspace=workspace, workspace=workspace,
@ -128,6 +132,7 @@ def question(
query=question_text, query=question_text,
collection=collection, collection=collection,
doc_limit=doc_limit, doc_limit=doc_limit,
fetch_limit=fetch_limit,
streaming=True streaming=True
) )
@ -155,6 +160,7 @@ def question(
query=question_text, query=question_text,
collection=collection, collection=collection,
doc_limit=doc_limit, doc_limit=doc_limit,
fetch_limit=fetch_limit,
) )
print(result.text) print(result.text)
@ -214,7 +220,15 @@ def main():
'-d', '--doc-limit', '-d', '--doc-limit',
type=int, type=int,
default=default_doc_limit, default=default_doc_limit,
help=f'Document limit (default: {default_doc_limit})' help=f'Documents selected into the prompt (default: {default_doc_limit})'
)
parser.add_argument(
'--fetch-limit',
type=int,
default=default_fetch_limit,
help='Candidate documents fetched from the vector store before '
'reranking (default: derive from doc-limit)'
) )
parser.add_argument( parser.add_argument(
@ -251,6 +265,7 @@ def main():
question_text=args.question, question_text=args.question,
collection=args.collection, collection=args.collection,
doc_limit=args.doc_limit, doc_limit=args.doc_limit,
fetch_limit=args.fetch_limit,
streaming=not args.no_streaming, streaming=not args.no_streaming,
token=args.token, token=args.token,
explainable=args.explainable, explainable=args.explainable,

View file

@ -9,10 +9,12 @@ from trustgraph.provenance import (
docrag_question_uri, docrag_question_uri,
docrag_grounding_uri, docrag_grounding_uri,
docrag_exploration_uri, docrag_exploration_uri,
docrag_focus_uri,
docrag_synthesis_uri, docrag_synthesis_uri,
docrag_question_triples, docrag_question_triples,
grounding_triples, grounding_triples,
docrag_exploration_triples, docrag_exploration_triples,
docrag_chunk_selection_triples,
docrag_synthesis_triples, docrag_synthesis_triples,
set_graph, set_graph,
GRAPH_RETRIEVAL, GRAPH_RETRIEVAL,
@ -21,19 +23,25 @@ from trustgraph.provenance import (
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# When the caller does not specify a fetch_limit, reranking over-fetches this
# many times the final doc_limit as the candidate pool, so the cross-encoder can
# recover relevant chunks the bi-encoder ranked just outside the top doc_limit.
# This is only the fallback default: an explicit fetch_limit overrides it.
OVERFETCH_FACTOR = 3
LABEL="http://www.w3.org/2000/01/rdf-schema#label" LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query: class Query:
def __init__( def __init__(
self, rag, workspace, collection, verbose, self, rag, workspace, collection, verbose,
doc_limit=20, track_usage=None, fetch_limit=20, track_usage=None,
): ):
self.rag = rag self.rag = rag
self.workspace = workspace self.workspace = workspace
self.collection = collection self.collection = collection
self.verbose = verbose self.verbose = verbose
self.doc_limit = doc_limit self.fetch_limit = fetch_limit
self.track_usage = track_usage self.track_usage = track_usage
async def extract_concepts(self, query): async def extract_concepts(self, query):
@ -91,7 +99,7 @@ class Query:
# Query chunk matches for each concept concurrently # Query chunk matches for each concept concurrently
per_concept_limit = max( per_concept_limit = max(
1, self.doc_limit // len(vectors) 1, self.fetch_limit // len(vectors)
) )
async def query_concept(vec): async def query_concept(vec):
@ -140,6 +148,7 @@ class DocumentRag:
def __init__( def __init__(
self, prompt_client, embeddings_client, doc_embeddings_client, self, prompt_client, embeddings_client, doc_embeddings_client,
fetch_chunk, fetch_chunk,
reranker_client=None,
verbose=False, verbose=False,
): ):
@ -150,12 +159,16 @@ class DocumentRag:
self.doc_embeddings_client = doc_embeddings_client self.doc_embeddings_client = doc_embeddings_client
self.fetch_chunk = fetch_chunk self.fetch_chunk = fetch_chunk
# Optional cross-encoder reranker. When None, the retrieval path is
# byte-identical to the pre-reranker behaviour.
self.reranker_client = reranker_client
if self.verbose: if self.verbose:
logger.debug("DocumentRag initialized") logger.debug("DocumentRag initialized")
async def query( async def query(
self, query, workspace="default", collection="default", self, query, workspace="default", collection="default",
doc_limit=20, streaming=False, chunk_callback=None, doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None,
explain_callback=None, save_answer_callback=None, explain_callback=None, save_answer_callback=None,
): ):
""" """
@ -165,7 +178,10 @@ class DocumentRag:
query: The query string query: The query string
workspace: Workspace for isolation (also scopes chunk lookup) workspace: Workspace for isolation (also scopes chunk lookup)
collection: Collection identifier collection: Collection identifier
doc_limit: Max chunks to retrieve doc_limit: Chunks selected into the synthesis prompt (after rerank)
fetch_limit: Candidate pool fetched from the vector store before
reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a
reranker is wired, else doc_limit).
streaming: Enable streaming LLM response streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for explainability explain_callback: async def callback(triples, explain_id) for explainability
@ -197,6 +213,7 @@ class DocumentRag:
q_uri = docrag_question_uri(session_id) q_uri = docrag_question_uri(session_id)
gnd_uri = docrag_grounding_uri(session_id) gnd_uri = docrag_grounding_uri(session_id)
exp_uri = docrag_exploration_uri(session_id) exp_uri = docrag_exploration_uri(session_id)
foc_uri = docrag_focus_uri(session_id)
syn_uri = docrag_synthesis_uri(session_id) syn_uri = docrag_synthesis_uri(session_id)
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
@ -209,10 +226,21 @@ class DocumentRag:
) )
await explain_callback(q_triples, q_uri) await explain_callback(q_triples, q_uri)
# Resolve the candidate-pool size fetched from the vector store. When a
# reranker is wired, honour an explicit fetch_limit; if unset, fall back
# to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit,
# else the rerank could not fill the prompt. Without a reranker, fetch
# doc_limit as before (byte-identical behaviour).
if self.reranker_client is not None:
fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit)
fetch_count = max(fl, doc_limit)
else:
fetch_count = doc_limit
q = Query( q = Query(
rag=self, workspace=workspace, collection=collection, rag=self, workspace=workspace, collection=collection,
verbose=self.verbose, verbose=self.verbose,
doc_limit=doc_limit, track_usage=track_usage, fetch_limit=fetch_count, track_usage=track_usage,
) )
# Extract concepts from query (grounding step) # Extract concepts from query (grounding step)
@ -235,6 +263,7 @@ class DocumentRag:
docs, chunk_ids = await q.get_docs(concepts) docs, chunk_ids = await q.get_docs(concepts)
# Emit exploration explainability after chunks retrieved # Emit exploration explainability after chunks retrieved
# (full candidate set, before any reranking)
if explain_callback: if explain_callback:
exp_triples = set_graph( exp_triples = set_graph(
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids), docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
@ -242,6 +271,45 @@ class DocumentRag:
) )
await explain_callback(exp_triples, exp_uri) await explain_callback(exp_triples, exp_uri)
# Optional cross-encoder reranking pass between retrieval and
# synthesis. Mirrors GraphRAG's reranker usage but with a single
# query (the question). When no reranker is wired, this block is
# skipped entirely and behaviour is byte-identical to before.
reranked = False
if self.reranker_client is not None and docs:
results = await self.reranker_client.rerank(
queries=[{"id": "0", "text": query}],
documents=[
{"id": str(i), "text": d} for i, d in enumerate(docs)
],
# Narrow the over-fetched candidate pool down to the final
# doc_limit requested for synthesis.
limit=doc_limit,
)
# results are sorted desc by score and truncated to limit by the
# reranker service, so order gives the surviving top-N directly.
order = [int(r.document_id) for r in results]
docs = [docs[i] for i in order]
chunk_ids = [chunk_ids[i] for i in order]
reranked = True
# Emit chunk-selection (focus) explainability: surviving chunks
# with their cross-encoder scores, derived from exploration.
if explain_callback:
selected_chunks_with_scores = [
{"chunk_id": chunk_ids[i], "score": r.score}
for i, r in enumerate(results)
]
foc_triples = set_graph(
docrag_chunk_selection_triples(
foc_uri, exp_uri,
selected_chunks_with_scores, session_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(foc_triples, foc_uri)
if self.verbose: if self.verbose:
logger.debug("Invoking LLM...") logger.debug("Invoking LLM...")
logger.debug(f"Documents: {docs}") logger.debug(f"Documents: {docs}")
@ -291,9 +359,15 @@ class DocumentRag:
logger.warning(f"Failed to save answer to librarian: {e}") logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None synthesis_doc_id = None
# When reranking ran, synthesis derives from the focus (the
# reranked chunks actually fed to the LLM), as GraphRAG always does.
# When no reranker is wired, there is no focus stage, so synthesis
# derives from exploration (the unchanged no-op lineage) - a
# deliberate divergence from GraphRAG's always-on focus.
syn_parent = foc_uri if reranked else exp_uri
syn_triples = set_graph( syn_triples = set_graph(
docrag_synthesis_triples( docrag_synthesis_triples(
syn_uri, exp_uri, syn_uri, syn_parent,
document_id=synthesis_doc_id, document_id=synthesis_doc_id,
in_token=synthesis_result.in_token if synthesis_result else None, in_token=synthesis_result.in_token if synthesis_result else None,
out_token=synthesis_result.out_token if synthesis_result else None, out_token=synthesis_result.out_token if synthesis_result else None,

View file

@ -13,6 +13,7 @@ from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec
from ... base import RerankerClientSpec
from ... base import LibrarianSpec from ... base import LibrarianSpec
# Module logger # Module logger
@ -28,14 +29,21 @@ class Processor(FlowProcessor):
doc_limit = params.get("doc_limit", 5) doc_limit = params.get("doc_limit", 5)
# Instance-default candidate-pool size fetched before cross-encoder
# reranking; the rerank step narrows it back down to doc_limit for the
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
fetch_limit = params.get("fetch_limit", 0)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"id": id, "id": id,
"doc_limit": doc_limit, "doc_limit": doc_limit,
"fetch_limit": fetch_limit,
} }
) )
self.doc_limit = doc_limit self.doc_limit = doc_limit
self.fetch_limit = fetch_limit
self.register_specification( self.register_specification(
ConsumerSpec( ConsumerSpec(
@ -66,6 +74,13 @@ class Processor(FlowProcessor):
) )
) )
self.register_specification(
RerankerClientSpec(
request_name = "reranker-request",
response_name = "reranker-response",
)
)
self.register_specification( self.register_specification(
ProducerSpec( ProducerSpec(
name = "response", name = "response",
@ -105,6 +120,7 @@ class Processor(FlowProcessor):
doc_embeddings_client = flow("document-embeddings-request"), doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"), prompt_client = flow("prompt-request"),
fetch_chunk = fetch_chunk, fetch_chunk = fetch_chunk,
reranker_client = flow("reranker-request"),
verbose=True, verbose=True,
) )
@ -113,6 +129,13 @@ class Processor(FlowProcessor):
else: else:
doc_limit = self.doc_limit doc_limit = self.doc_limit
# Candidate-pool size: per-request override, else the instance
# default; 0 lets the core derive it from doc_limit.
if v.fetch_limit:
fetch_limit = v.fetch_limit
else:
fetch_limit = self.fetch_limit
async def send_explainability(triples, explain_id): async def send_explainability(triples, explain_id):
await flow("explainability").send(Triples( await flow("explainability").send(Triples(
metadata=Metadata( metadata=Metadata(
@ -163,6 +186,7 @@ class Processor(FlowProcessor):
workspace=flow.workspace, workspace=flow.workspace,
collection=v.collection, collection=v.collection,
doc_limit=doc_limit, doc_limit=doc_limit,
fetch_limit=fetch_limit,
streaming=True, streaming=True,
chunk_callback=send_chunk, chunk_callback=send_chunk,
explain_callback=send_explainability, explain_callback=send_explainability,
@ -188,6 +212,7 @@ class Processor(FlowProcessor):
workspace=flow.workspace, workspace=flow.workspace,
collection=v.collection, collection=v.collection,
doc_limit=doc_limit, doc_limit=doc_limit,
fetch_limit=fetch_limit,
explain_callback=send_explainability, explain_callback=send_explainability,
save_answer_callback=save_answer, save_answer_callback=save_answer,
) )
@ -243,6 +268,15 @@ class Processor(FlowProcessor):
help=f'Default document fetch limit (default: 10)' help=f'Default document fetch limit (default: 10)'
) )
parser.add_argument(
'--fetch-limit',
type=int,
default=0,
help='Candidate chunks to fetch from the vector store and rerank '
'before keeping the top doc-limit for the LLM '
'(default: derive from doc-limit)'
)
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)