mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 06:51:00 +02:00
Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after vector retrieval, over-fetch a wider candidate pool, rerank with the cross-encoder, and keep the top doc_limit chunks for synthesis. Per maintainer review, the fetch and select sizes are two caller-controlled limits rather than one internal heuristic: - doc_limit: chunks selected into the synthesis prompt (unchanged meaning). - fetch_limit: candidate pool pulled from the vector store before reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are raised to it. Lets the caller control how hard the reranker has to work. Details: - schema: DocumentRagQuery.fetch_limit (additive, backward compatible). - document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors doc_limit); the core applies the heuristic default and derives synthesis provenance from the chunk-selection focus when reranking ran. - provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection). - request translator + client SDKs + CLI: fetch-limit / --fetch-limit, threaded exactly like doc_limit and the GraphRAG limits. - tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring. Reranking is skipped byte-identically when no reranker role is wired. Requires the companion trustgraph-templates change wiring the reranker topics into the document-rag flow (mirrors #279 for GraphRAG).
This commit is contained in:
parent
f18d48dc39
commit
6c9a545a06
18 changed files with 853 additions and 26 deletions
|
|
@ -101,27 +101,27 @@ class TestQuery:
|
|||
assert query.rag == mock_rag
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.doc_limit == 20 # Default value
|
||||
assert query.fetch_limit == 20 # Default value
|
||||
|
||||
def test_query_initialization_with_custom_doc_limit(self):
|
||||
"""Test Query initialization with custom doc_limit"""
|
||||
def test_query_initialization_with_custom_fetch_limit(self):
|
||||
"""Test Query initialization with custom fetch_limit"""
|
||||
# Create mock DocumentRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with custom doc_limit
|
||||
# Initialize Query with custom fetch_limit
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
workspace="test_workspace",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
doc_limit=50
|
||||
fetch_limit=50
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.doc_limit == 50
|
||||
assert query.fetch_limit == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_concepts(self):
|
||||
|
|
@ -224,7 +224,7 @@ class TestQuery:
|
|||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=15
|
||||
fetch_limit=15
|
||||
)
|
||||
|
||||
# Call get_docs with concepts list
|
||||
|
|
@ -377,7 +377,7 @@ class TestQuery:
|
|||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=True,
|
||||
doc_limit=5
|
||||
fetch_limit=5
|
||||
)
|
||||
|
||||
# Call get_docs with concepts
|
||||
|
|
@ -615,7 +615,7 @@ class TestQuery:
|
|||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=10
|
||||
fetch_limit=10
|
||||
)
|
||||
|
||||
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
|
||||
|
|
|
|||
478
tests/unit/test_retrieval/test_document_rag_rerank.py
Normal file
478
tests/unit/test_retrieval/test_document_rag_rerank.py
Normal file
|
|
@ -0,0 +1,478 @@
|
|||
"""
|
||||
Tests for the optional cross-encoder reranking pass in DocumentRag.query().
|
||||
|
||||
Two behaviours are covered:
|
||||
|
||||
1. No-op: when no reranker_client is wired (the default), query() must feed
|
||||
the LLM the exact same chunks, in the same order, that retrieval produced
|
||||
- byte-identical to the pre-reranker behaviour - and must NOT emit a
|
||||
chunk-selection provenance event.
|
||||
|
||||
2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered
|
||||
and truncated according to the reranker's results, the LLM receives the
|
||||
reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted
|
||||
carrying the per-surviving-chunk scores and chunk references.
|
||||
|
||||
These are pure orchestration tests - the reranker is a stub, so there is no
|
||||
torch / network dependency.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.base import PromptResult
|
||||
from trustgraph.schema import RerankerResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_WAS_DERIVED_FROM,
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMatch:
|
||||
"""Mimics the result from doc_embeddings_client.query()."""
|
||||
chunk_id: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: three retrievable chunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
|
||||
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
|
||||
CHUNK_C = "urn:chunk:policy-doc-1:chunk-2"
|
||||
|
||||
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
|
||||
CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays."
|
||||
CHUNK_C_CONTENT = "Refunds are processed to the original payment method."
|
||||
|
||||
# Retrieval (post-dedupe) order is A, B, C.
|
||||
ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT]
|
||||
ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C]
|
||||
|
||||
|
||||
def build_mock_clients():
|
||||
"""
|
||||
Build mock subsidiary clients for a document-rag query returning three
|
||||
distinct chunks (A, B, C) in that order.
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
doc_embeddings_client = AsyncMock()
|
||||
fetch_chunk = AsyncMock()
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="return policy\nrefund")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# Each concept query returns the same three chunks; dedupe keeps A, B, C.
|
||||
doc_embeddings_client.query.return_value = [
|
||||
ChunkMatch(chunk_id=CHUNK_A),
|
||||
ChunkMatch(chunk_id=CHUNK_B),
|
||||
ChunkMatch(chunk_id=CHUNK_C),
|
||||
]
|
||||
|
||||
async def mock_fetch(chunk_id):
|
||||
return {
|
||||
CHUNK_A: CHUNK_A_CONTENT,
|
||||
CHUNK_B: CHUNK_B_CONTENT,
|
||||
CHUNK_C: CHUNK_C_CONTENT,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Items can be returned within 30 days for a full refund.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
||||
|
||||
class StubReranker:
|
||||
"""
|
||||
Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed,
|
||||
pre-sorted, truncated list of RerankerResult - exactly the contract the
|
||||
flashrank service guarantees (sorted desc by score, truncated to limit).
|
||||
"""
|
||||
|
||||
def __init__(self, results):
|
||||
self._results = results
|
||||
self.calls = []
|
||||
|
||||
async def rerank(self, queries, documents, limit=10, timeout=300):
|
||||
self.calls.append(
|
||||
{"queries": queries, "documents": documents, "limit": limit}
|
||||
)
|
||||
return self._results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. No-op: reranker_client=None must not change anything
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRerankNoOp:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_documents_passed_to_llm_are_unchanged(self):
|
||||
"""
|
||||
With no reranker wired, document_prompt must receive the retrieved
|
||||
chunks in the original order and length.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients) # reranker_client defaults to None
|
||||
|
||||
await rag.query(query="What is the return policy?")
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert passed_docs == ORDERED_CONTENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_chunk_selection_event_emitted(self):
|
||||
"""
|
||||
Without a reranker, the provenance chain is the original 4 stages:
|
||||
question, grounding, exploration, synthesis - no focus stage.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
assert len(events) == 4
|
||||
types = [
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS,
|
||||
]
|
||||
for i, expected in enumerate(types):
|
||||
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
|
||||
|
||||
# No chunk-selection entity anywhere.
|
||||
for e in events:
|
||||
assert not any(
|
||||
t.o.iri == TG_CHUNK_SELECTION
|
||||
for t in e["triples"]
|
||||
if t.p.iri == RDF_TYPE
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_derives_from_exploration_when_no_rerank(self):
|
||||
"""
|
||||
No-op lineage is unchanged: synthesis derives from exploration
|
||||
(there is no focus stage). Guards the conditional synthesis parent.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
# events: question, grounding, exploration, synthesis
|
||||
exp_uri = events[2]["explain_id"]
|
||||
syn_event = events[3]
|
||||
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Rerank: reorder + truncate + provenance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRerankActive:
|
||||
|
||||
def _reranker_keeping_C_then_A(self):
|
||||
# Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped.
|
||||
# Pre-sorted desc by score and truncated to limit, per the contract.
|
||||
return StubReranker([
|
||||
RerankerResult(document_id="2", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="0", query_id="0", score=0.42),
|
||||
])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_documents_reordered_and_truncated(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?")
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reranker_called_with_single_query_and_all_docs(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
assert len(reranker.calls) == 1
|
||||
c = reranker.calls[0]
|
||||
assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}]
|
||||
assert c["documents"] == [
|
||||
{"id": "0", "text": CHUNK_A_CONTENT},
|
||||
{"id": "1", "text": CHUNK_B_CONTENT},
|
||||
{"id": "2", "text": CHUNK_C_CONTENT},
|
||||
]
|
||||
# The rerank narrows down to the final doc_limit, NOT fetch_limit
|
||||
# (fetch_limit is the over-fetched candidate pool size).
|
||||
assert c["limit"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_fetch_limit_over_fetches_then_narrows(self):
|
||||
"""
|
||||
Semantic guard for the value of reranking AND the maintainer's two-limit
|
||||
contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider
|
||||
candidate pool so the cross-encoder can surface chunks the bi-encoder
|
||||
ranked outside the final doc_limit, then the rerank narrows the pool back
|
||||
down to doc_limit. The fetch_limit is honoured directly (caller controls
|
||||
how hard the reranker works), not overridden by any heuristic.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
# Candidate pool (fetch_limit=60) >> final doc_limit (6).
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?", doc_limit=6, fetch_limit=60,
|
||||
)
|
||||
|
||||
# Over-fetch: the embeddings store is queried with the fetch_limit
|
||||
# budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit
|
||||
# budget (6 // 2 = 3). This is the bug guard.
|
||||
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||
assert q_limit == 30
|
||||
|
||||
# Narrow: the rerank keeps the final doc_limit (6), not fetch_limit.
|
||||
assert reranker.calls[0]["limit"] == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self):
|
||||
"""
|
||||
With no fetch_limit passed to query(), the candidate pool falls back to
|
||||
the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with
|
||||
doc_limit and reranking keeps its recall benefit out of the box.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
# No fetch_limit -> heuristic default.
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=20)
|
||||
|
||||
# fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept.
|
||||
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||
assert q_limit == 30
|
||||
# Rerank narrows to the final doc_limit (20).
|
||||
assert reranker.calls[0]["limit"] == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_limit_floored_at_doc_limit(self):
|
||||
"""
|
||||
A fetch_limit below doc_limit is floored up to doc_limit: retrieval must
|
||||
never fetch fewer candidates than the rerank is asked to keep, else the
|
||||
prompt could not be filled.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?", doc_limit=10, fetch_limit=4,
|
||||
)
|
||||
|
||||
# fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept.
|
||||
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||
assert q_limit == 5
|
||||
assert reranker.calls[0]["limit"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_selection_event_emitted(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
# Now 5 stages: question, grounding, exploration, focus, synthesis.
|
||||
assert len(events) == 5
|
||||
ordered_types = [
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS,
|
||||
]
|
||||
for i, expected in enumerate(ordered_types):
|
||||
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_selection_carries_scores_and_chunk_refs(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
focus_event = events[3]
|
||||
foc_uri = focus_event["explain_id"]
|
||||
triples = focus_event["triples"]
|
||||
|
||||
# focus is derived from exploration
|
||||
exp_uri = events[2]["explain_id"]
|
||||
assert derived_from(triples, foc_uri) == exp_uri
|
||||
|
||||
# Two ChunkSelection sub-entities, linked from focus.
|
||||
sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri)
|
||||
assert len(sel_links) == 2
|
||||
|
||||
# Each selection has a ChunkSelection type, a chunk document ref and a score.
|
||||
chunk_refs = set()
|
||||
scores = set()
|
||||
for link in sel_links:
|
||||
sel_uri = link.o.iri
|
||||
assert has_type(triples, sel_uri, TG_CHUNK_SELECTION)
|
||||
doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri)
|
||||
assert doc_ref is not None
|
||||
chunk_refs.add(doc_ref.o.iri)
|
||||
score_t = find_triple(triples, TG_SCORE, sel_uri)
|
||||
assert score_t is not None
|
||||
scores.add(score_t.o.value)
|
||||
|
||||
# Surviving chunks are C and A (B dropped), with the reranker scores.
|
||||
assert chunk_refs == {CHUNK_C, CHUNK_A}
|
||||
assert scores == {"0.95", "0.42"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_focus_triples_in_retrieval_graph(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
for t in events[3]["triples"]:
|
||||
assert t.g == "urn:graph:retrieval"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_derives_from_focus_when_reranking(self):
|
||||
"""
|
||||
When reranking runs, synthesis must derive from the focus node (the
|
||||
reranked chunks actually fed to the LLM), mirroring GraphRAG - not from
|
||||
exploration, which would leave focus as a dangling branch and
|
||||
misrepresent what fed the answer.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
doc_limit=2,
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
# events: question, grounding, exploration, focus, synthesis
|
||||
foc_uri = events[3]["explain_id"]
|
||||
syn_event = events[4]
|
||||
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_docs_skips_reranker(self):
|
||||
"""If retrieval returns no chunks, the reranker is never called."""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
doc_embeddings_client.query.return_value = [] # no matches
|
||||
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?")
|
||||
|
||||
assert reranker.calls == []
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Cross-layer wiring contract for the Document-RAG reranker (issue #878).
|
||||
|
||||
The Document-RAG processor registers a ``RerankerClientSpec`` for the
|
||||
``reranker-request`` / ``reranker-response`` roles (see
|
||||
``retrieval/document_rag/rag.py``). At flow construction every spec runs
|
||||
``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add``
|
||||
resolves its topics via ``definition["topics"][name]`` - which raises
|
||||
``KeyError`` if the flow blueprint does not provide those topics.
|
||||
|
||||
This means the monorepo code change is only safe to deploy together with the
|
||||
companion ``trustgraph-templates`` change that wires ``reranker-request`` /
|
||||
``reranker-response`` into the Document-RAG flow (mirroring what templates
|
||||
PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that
|
||||
contract from the monorepo side:
|
||||
|
||||
* with the reranker topics present (as the updated templates compile them),
|
||||
the spec binds cleanly and registers the client;
|
||||
* without them (the pre-companion blueprint), construction fails fast with a
|
||||
KeyError naming the missing role - documenting exactly why the templates
|
||||
change is required.
|
||||
|
||||
No broker/network: the pub/sub backend is mocked (topics are bound at add()
|
||||
time, connections happen later at start()).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.base import RerankerClientSpec
|
||||
|
||||
|
||||
def _flow():
|
||||
f = MagicMock()
|
||||
f.workspace = "ws"
|
||||
f.name = "document-rag"
|
||||
f.id = "proc1"
|
||||
f.consumer = {}
|
||||
return f
|
||||
|
||||
|
||||
def _processor():
|
||||
p = MagicMock()
|
||||
p.pubsub = MagicMock()
|
||||
p.id = "proc1"
|
||||
p.taskgroup = MagicMock()
|
||||
return p
|
||||
|
||||
|
||||
def _spec():
|
||||
return RerankerClientSpec(
|
||||
request_name="reranker-request",
|
||||
response_name="reranker-response",
|
||||
)
|
||||
|
||||
|
||||
# Topics dict as the UPDATED document-store.jsonnet compiles them
|
||||
# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}).
|
||||
DEFINITION_WITH_RERANKER = {
|
||||
"topics": {
|
||||
"request": "request:tg:document-rag:ws:id",
|
||||
"response": "response:tg:document-rag:ws:id",
|
||||
"reranker-request": "request:tg:reranker:ws:id",
|
||||
"reranker-response": "response:tg:reranker:ws:id",
|
||||
}
|
||||
}
|
||||
|
||||
# Pre-companion blueprint: no reranker topics (document-rag before the templates change).
|
||||
DEFINITION_WITHOUT_RERANKER = {
|
||||
"topics": {
|
||||
"request": "request:tg:document-rag:ws:id",
|
||||
"response": "response:tg:document-rag:ws:id",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_reranker_client_binds_when_flow_provides_topics():
|
||||
flow = _flow()
|
||||
_spec().add(flow, _processor(), DEFINITION_WITH_RERANKER)
|
||||
# The client consumer is registered against the reranker role.
|
||||
assert "reranker-request" in flow.consumer
|
||||
|
||||
|
||||
def test_reranker_client_keyerrors_without_companion_template_topics():
|
||||
with pytest.raises(KeyError) as exc:
|
||||
_spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER)
|
||||
# Fails fast naming the missing role -> the trustgraph-templates companion
|
||||
# change (wire reranker-request/response into the document-rag flow) is required.
|
||||
assert "reranker-request" in str(exc.value)
|
||||
|
|
@ -66,6 +66,7 @@ class TestDocumentRagService:
|
|||
workspace=ANY, # Workspace comes from flow.workspace (mock)
|
||||
collection="test_coll_1", # Must be from message, not hardcoded default
|
||||
doc_limit=5,
|
||||
fetch_limit=0, # Unset -> core derives the candidate pool
|
||||
explain_callback=ANY, # Explainability callback is always passed
|
||||
save_answer_callback=ANY, # Librarian save callback is always passed
|
||||
)
|
||||
|
|
|
|||
|
|
@ -527,7 +527,8 @@ class AsyncFlowInstance:
|
|||
return result.get("response", "")
|
||||
|
||||
async def document_rag(self, query: str, collection: str,
|
||||
doc_limit: int = 10, **kwargs: Any) -> str:
|
||||
doc_limit: int = 10, fetch_limit: int = 0,
|
||||
**kwargs: Any) -> str:
|
||||
"""
|
||||
Execute document-based RAG query (non-streaming).
|
||||
|
||||
|
|
@ -541,7 +542,9 @@ class AsyncFlowInstance:
|
|||
Args:
|
||||
query: User query text
|
||||
collection: Collection identifier containing documents
|
||||
doc_limit: Maximum number of document chunks to retrieve (default: 10)
|
||||
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||
fetch_limit: Candidate chunks fetched from the vector store before
|
||||
reranking (default: 0 = derive from doc_limit)
|
||||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
|
|
@ -564,6 +567,7 @@ class AsyncFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": False
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
|
|
|||
|
|
@ -379,12 +379,14 @@ class AsyncSocketFlowInstance:
|
|||
yield chunk.content
|
||||
|
||||
async def document_rag(self, query: str, collection: str,
|
||||
doc_limit: int = 10, streaming: bool = False, **kwargs):
|
||||
doc_limit: int = 10, fetch_limit: int = 0,
|
||||
streaming: bool = False, **kwargs):
|
||||
"""Document RAG with optional streaming"""
|
||||
request = {
|
||||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
|
|||
|
|
@ -415,7 +415,7 @@ class FlowInstance:
|
|||
|
||||
def document_rag(
|
||||
self, query,collection="default",
|
||||
doc_limit=10,
|
||||
doc_limit=10, fetch_limit=0,
|
||||
):
|
||||
"""
|
||||
Execute document-based Retrieval-Augmented Generation (RAG) query.
|
||||
|
|
@ -426,7 +426,9 @@ class FlowInstance:
|
|||
Args:
|
||||
query: Natural language query
|
||||
collection: Collection identifier (default: "default")
|
||||
doc_limit: Maximum document chunks to retrieve (default: 10)
|
||||
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||
fetch_limit: Candidate chunks fetched from the vector store before
|
||||
reranking (default: 0 = derive from doc_limit)
|
||||
|
||||
Returns:
|
||||
str: Generated response incorporating document context
|
||||
|
|
@ -447,6 +449,7 @@ class FlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
}
|
||||
|
||||
result = self.request(
|
||||
|
|
|
|||
|
|
@ -752,6 +752,7 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
fetch_limit: int = 0,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
|
|
@ -764,6 +765,7 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
@ -785,6 +787,7 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
fetch_limit: int = 0,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||
"""Execute document-based RAG query with explainability support."""
|
||||
|
|
@ -792,6 +795,7 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": True,
|
||||
"explainable": True,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
|||
query=data["query"],
|
||||
collection=data.get("collection", "default"),
|
||||
doc_limit=int(data.get("doc-limit", 20)),
|
||||
fetch_limit=int(data.get("fetch-limit", 0)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
|||
"query": obj.query,
|
||||
"collection": obj.collection,
|
||||
"doc-limit": obj.doc_limit,
|
||||
"fetch-limit": obj.fetch_limit,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ from . uris import (
|
|||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_focus_uri,
|
||||
chunk_selection_uri,
|
||||
docrag_synthesis_uri,
|
||||
)
|
||||
|
||||
|
|
@ -94,6 +96,8 @@ from . namespaces import (
|
|||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Chunk selection entity type
|
||||
TG_CHUNK_SELECTION,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
|
|
@ -132,6 +136,7 @@ from . triples import (
|
|||
# Query-time provenance triple builders (DocumentRAG)
|
||||
docrag_question_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_chunk_selection_triples,
|
||||
docrag_synthesis_triples,
|
||||
# Utility
|
||||
set_graph,
|
||||
|
|
@ -196,6 +201,8 @@ __all__ = [
|
|||
"docrag_question_uri",
|
||||
"docrag_grounding_uri",
|
||||
"docrag_exploration_uri",
|
||||
"docrag_focus_uri",
|
||||
"chunk_selection_uri",
|
||||
"docrag_synthesis_uri",
|
||||
# Namespaces
|
||||
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
|
||||
|
|
@ -219,6 +226,8 @@ __all__ = [
|
|||
"TG_EDGE_SELECTION",
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
|
||||
# Chunk selection entity type
|
||||
"TG_CHUNK_SELECTION",
|
||||
# Explainability entity types
|
||||
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
|
||||
"TG_ANALYSIS", "TG_CONCLUSION",
|
||||
|
|
@ -254,6 +263,7 @@ __all__ = [
|
|||
# Query-time provenance triple builders (DocumentRAG)
|
||||
"docrag_question_triples",
|
||||
"docrag_exploration_triples",
|
||||
"docrag_chunk_selection_triples",
|
||||
"docrag_synthesis_triples",
|
||||
# Agent provenance triple builders
|
||||
"agent_session_triples",
|
||||
|
|
|
|||
|
|
@ -76,6 +76,9 @@ TG_EDGE_SELECTION = TG + "EdgeSelection"
|
|||
TG_CHUNK_COUNT = TG + "chunkCount"
|
||||
TG_SELECTED_CHUNK = TG + "selectedChunk"
|
||||
|
||||
# Chunk selection entity type (cross-encoder reranked chunk in Focus)
|
||||
TG_CHUNK_SELECTION = TG + "ChunkSelection"
|
||||
|
||||
# Extraction provenance entity types
|
||||
TG_DOCUMENT_TYPE = TG + "Document"
|
||||
TG_PAGE_TYPE = TG + "Page"
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ from . namespaces import (
|
|||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Chunk selection entity type
|
||||
TG_CHUNK_SELECTION,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
# Unifying types
|
||||
|
|
@ -40,7 +42,10 @@ from . namespaces import (
|
|||
TG_IN_TOKEN, TG_OUT_TOKEN,
|
||||
)
|
||||
|
||||
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
|
||||
from . uris import (
|
||||
activity_uri, agent_uri, subgraph_uri, edge_selection_uri,
|
||||
chunk_selection_uri,
|
||||
)
|
||||
|
||||
|
||||
def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
|
||||
|
|
@ -718,6 +723,75 @@ def docrag_exploration_triples(
|
|||
return triples
|
||||
|
||||
|
||||
def docrag_chunk_selection_triples(
|
||||
focus_uri: str,
|
||||
exploration_uri: str,
|
||||
selected_chunks_with_scores: List[dict],
|
||||
session_id: str,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a document RAG focus entity (chunks selected by the
|
||||
cross-encoder reranker).
|
||||
|
||||
Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity
|
||||
derived from exploration, with one ChunkSelection sub-entity per surviving
|
||||
chunk carrying the chunk reference and the reranker score.
|
||||
|
||||
Structure:
|
||||
<focus> a tg:Focus ; prov:wasDerivedFrom <exploration> .
|
||||
<focus> tg:selectedChunk <chunk_sel_0> .
|
||||
<chunk_sel_0> a tg:ChunkSelection .
|
||||
<chunk_sel_0> tg:document <chunk_id> .
|
||||
<chunk_sel_0> tg:score "0.97" .
|
||||
|
||||
Args:
|
||||
focus_uri: URI of the focus entity (from docrag_focus_uri)
|
||||
exploration_uri: URI of the parent exploration entity
|
||||
selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score'
|
||||
session_id: Session UUID for generating chunk selection URIs
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
triples = [
|
||||
_triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)),
|
||||
_triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")),
|
||||
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||
]
|
||||
|
||||
for idx, chunk_info in enumerate(selected_chunks_with_scores):
|
||||
chunk_id = chunk_info.get("chunk_id")
|
||||
if not chunk_id:
|
||||
continue
|
||||
|
||||
chunk_sel_uri = chunk_selection_uri(session_id, idx)
|
||||
|
||||
# Link focus to chunk selection entity
|
||||
triples.append(
|
||||
_triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri))
|
||||
)
|
||||
|
||||
# Type the chunk selection entity
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION))
|
||||
)
|
||||
|
||||
# Reference the actual chunk (in librarian)
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id))
|
||||
)
|
||||
|
||||
# Cross-encoder score
|
||||
score = chunk_info.get("score")
|
||||
if score is not None:
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, TG_SCORE, _literal(str(score)))
|
||||
)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def docrag_synthesis_triples(
|
||||
synthesis_uri: str,
|
||||
exploration_uri: str,
|
||||
|
|
|
|||
|
|
@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str:
|
|||
return f"urn:trustgraph:docrag:{session_id}/exploration"
|
||||
|
||||
|
||||
def docrag_focus_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG focus entity (chunks selected by the
|
||||
cross-encoder reranker).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:docrag:{uuid}/focus
|
||||
"""
|
||||
return f"urn:trustgraph:docrag:{session_id}/focus"
|
||||
|
||||
|
||||
def chunk_selection_uri(session_id: str, chunk_index: int) -> str:
|
||||
"""
|
||||
Generate URI for a chunk selection item (links a reranked chunk to its
|
||||
score). Mirrors edge_selection_uri for GraphRAG.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
chunk_index: Index of this chunk in the selection (0-based).
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:prov:chunk:{uuid}:{index}
|
||||
"""
|
||||
return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}"
|
||||
|
||||
|
||||
def docrag_synthesis_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG synthesis entity (final answer).
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from . namespaces import (
|
|||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
TG_EDGE_SELECTION, TG_SCORE,
|
||||
TG_CHUNK_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -95,6 +96,7 @@ TG_CLASS_LABELS = [
|
|||
_label_triple(TG_PLAN_TYPE, "Plan"),
|
||||
_label_triple(TG_STEP_RESULT, "Step Result"),
|
||||
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
|
||||
_label_triple(TG_CHUNK_SELECTION, "Chunk Selection"),
|
||||
]
|
||||
|
||||
# TrustGraph predicate labels
|
||||
|
|
|
|||
|
|
@ -40,7 +40,10 @@ class GraphRagResponse:
|
|||
class DocumentRagQuery:
|
||||
query: str = ""
|
||||
collection: str = ""
|
||||
doc_limit: int = 0
|
||||
doc_limit: int = 0 # docs selected into the synthesis prompt
|
||||
fetch_limit: int = 0 # candidate pool fetched from the vector store
|
||||
# before reranking (0 = derive from doc_limit;
|
||||
# values below doc_limit are raised to it)
|
||||
streaming: bool = False
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
|||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
default_collection = 'default'
|
||||
default_doc_limit = 10
|
||||
default_fetch_limit = 0
|
||||
|
||||
|
||||
def question_explainable(
|
||||
url, flow_id, question_text, collection, doc_limit, token=None, debug=False,
|
||||
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
|
||||
token=None, debug=False,
|
||||
workspace="default",
|
||||
):
|
||||
"""Execute document RAG with explainability - shows provenance events inline."""
|
||||
|
|
@ -39,6 +41,7 @@ def question_explainable(
|
|||
query=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
|
|
@ -97,7 +100,7 @@ def question_explainable(
|
|||
|
||||
|
||||
def question(
|
||||
url, flow_id, question_text, collection, doc_limit,
|
||||
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
|
||||
streaming=True, token=None, explainable=False, debug=False,
|
||||
show_usage=False, workspace="default",
|
||||
):
|
||||
|
|
@ -109,6 +112,7 @@ def question(
|
|||
question_text=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
token=token,
|
||||
debug=debug,
|
||||
workspace=workspace,
|
||||
|
|
@ -128,6 +132,7 @@ def question(
|
|||
query=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
|
|
@ -155,6 +160,7 @@ def question(
|
|||
query=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
)
|
||||
print(result.text)
|
||||
|
||||
|
|
@ -214,7 +220,15 @@ def main():
|
|||
'-d', '--doc-limit',
|
||||
type=int,
|
||||
default=default_doc_limit,
|
||||
help=f'Document limit (default: {default_doc_limit})'
|
||||
help=f'Documents selected into the prompt (default: {default_doc_limit})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--fetch-limit',
|
||||
type=int,
|
||||
default=default_fetch_limit,
|
||||
help='Candidate documents fetched from the vector store before '
|
||||
'reranking (default: derive from doc-limit)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -251,6 +265,7 @@ def main():
|
|||
question_text=args.question,
|
||||
collection=args.collection,
|
||||
doc_limit=args.doc_limit,
|
||||
fetch_limit=args.fetch_limit,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
|
|
|
|||
|
|
@ -9,10 +9,12 @@ from trustgraph.provenance import (
|
|||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_focus_uri,
|
||||
docrag_synthesis_uri,
|
||||
docrag_question_triples,
|
||||
grounding_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_chunk_selection_triples,
|
||||
docrag_synthesis_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
|
|
@ -21,19 +23,25 @@ from trustgraph.provenance import (
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# When the caller does not specify a fetch_limit, reranking over-fetches this
|
||||
# many times the final doc_limit as the candidate pool, so the cross-encoder can
|
||||
# recover relevant chunks the bi-encoder ranked just outside the top doc_limit.
|
||||
# This is only the fallback default: an explicit fetch_limit overrides it.
|
||||
OVERFETCH_FACTOR = 3
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
self, rag, workspace, collection, verbose,
|
||||
doc_limit=20, track_usage=None,
|
||||
fetch_limit=20, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.workspace = workspace
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
self.fetch_limit = fetch_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
|
|
@ -91,7 +99,7 @@ class Query:
|
|||
|
||||
# Query chunk matches for each concept concurrently
|
||||
per_concept_limit = max(
|
||||
1, self.doc_limit // len(vectors)
|
||||
1, self.fetch_limit // len(vectors)
|
||||
)
|
||||
|
||||
async def query_concept(vec):
|
||||
|
|
@ -140,6 +148,7 @@ class DocumentRag:
|
|||
def __init__(
|
||||
self, prompt_client, embeddings_client, doc_embeddings_client,
|
||||
fetch_chunk,
|
||||
reranker_client=None,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
|
|
@ -150,12 +159,16 @@ class DocumentRag:
|
|||
self.doc_embeddings_client = doc_embeddings_client
|
||||
self.fetch_chunk = fetch_chunk
|
||||
|
||||
# Optional cross-encoder reranker. When None, the retrieval path is
|
||||
# byte-identical to the pre-reranker behaviour.
|
||||
self.reranker_client = reranker_client
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("DocumentRag initialized")
|
||||
|
||||
async def query(
|
||||
self, query, workspace="default", collection="default",
|
||||
doc_limit=20, streaming=False, chunk_callback=None,
|
||||
doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None,
|
||||
explain_callback=None, save_answer_callback=None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -165,7 +178,10 @@ class DocumentRag:
|
|||
query: The query string
|
||||
workspace: Workspace for isolation (also scopes chunk lookup)
|
||||
collection: Collection identifier
|
||||
doc_limit: Max chunks to retrieve
|
||||
doc_limit: Chunks selected into the synthesis prompt (after rerank)
|
||||
fetch_limit: Candidate pool fetched from the vector store before
|
||||
reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a
|
||||
reranker is wired, else doc_limit).
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for explainability
|
||||
|
|
@ -197,6 +213,7 @@ class DocumentRag:
|
|||
q_uri = docrag_question_uri(session_id)
|
||||
gnd_uri = docrag_grounding_uri(session_id)
|
||||
exp_uri = docrag_exploration_uri(session_id)
|
||||
foc_uri = docrag_focus_uri(session_id)
|
||||
syn_uri = docrag_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
|
@ -209,10 +226,21 @@ class DocumentRag:
|
|||
)
|
||||
await explain_callback(q_triples, q_uri)
|
||||
|
||||
# Resolve the candidate-pool size fetched from the vector store. When a
|
||||
# reranker is wired, honour an explicit fetch_limit; if unset, fall back
|
||||
# to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit,
|
||||
# else the rerank could not fill the prompt. Without a reranker, fetch
|
||||
# doc_limit as before (byte-identical behaviour).
|
||||
if self.reranker_client is not None:
|
||||
fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit)
|
||||
fetch_count = max(fl, doc_limit)
|
||||
else:
|
||||
fetch_count = doc_limit
|
||||
|
||||
q = Query(
|
||||
rag=self, workspace=workspace, collection=collection,
|
||||
verbose=self.verbose,
|
||||
doc_limit=doc_limit, track_usage=track_usage,
|
||||
fetch_limit=fetch_count, track_usage=track_usage,
|
||||
)
|
||||
|
||||
# Extract concepts from query (grounding step)
|
||||
|
|
@ -235,6 +263,7 @@ class DocumentRag:
|
|||
docs, chunk_ids = await q.get_docs(concepts)
|
||||
|
||||
# Emit exploration explainability after chunks retrieved
|
||||
# (full candidate set, before any reranking)
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
|
||||
|
|
@ -242,6 +271,45 @@ class DocumentRag:
|
|||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
# Optional cross-encoder reranking pass between retrieval and
|
||||
# synthesis. Mirrors GraphRAG's reranker usage but with a single
|
||||
# query (the question). When no reranker is wired, this block is
|
||||
# skipped entirely and behaviour is byte-identical to before.
|
||||
reranked = False
|
||||
if self.reranker_client is not None and docs:
|
||||
results = await self.reranker_client.rerank(
|
||||
queries=[{"id": "0", "text": query}],
|
||||
documents=[
|
||||
{"id": str(i), "text": d} for i, d in enumerate(docs)
|
||||
],
|
||||
# Narrow the over-fetched candidate pool down to the final
|
||||
# doc_limit requested for synthesis.
|
||||
limit=doc_limit,
|
||||
)
|
||||
|
||||
# results are sorted desc by score and truncated to limit by the
|
||||
# reranker service, so order gives the surviving top-N directly.
|
||||
order = [int(r.document_id) for r in results]
|
||||
docs = [docs[i] for i in order]
|
||||
chunk_ids = [chunk_ids[i] for i in order]
|
||||
reranked = True
|
||||
|
||||
# Emit chunk-selection (focus) explainability: surviving chunks
|
||||
# with their cross-encoder scores, derived from exploration.
|
||||
if explain_callback:
|
||||
selected_chunks_with_scores = [
|
||||
{"chunk_id": chunk_ids[i], "score": r.score}
|
||||
for i, r in enumerate(results)
|
||||
]
|
||||
foc_triples = set_graph(
|
||||
docrag_chunk_selection_triples(
|
||||
foc_uri, exp_uri,
|
||||
selected_chunks_with_scores, session_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(foc_triples, foc_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Documents: {docs}")
|
||||
|
|
@ -291,9 +359,15 @@ class DocumentRag:
|
|||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None
|
||||
|
||||
# When reranking ran, synthesis derives from the focus (the
|
||||
# reranked chunks actually fed to the LLM), as GraphRAG always does.
|
||||
# When no reranker is wired, there is no focus stage, so synthesis
|
||||
# derives from exploration (the unchanged no-op lineage) - a
|
||||
# deliberate divergence from GraphRAG's always-on focus.
|
||||
syn_parent = foc_uri if reranked else exp_uri
|
||||
syn_triples = set_graph(
|
||||
docrag_synthesis_triples(
|
||||
syn_uri, exp_uri,
|
||||
syn_uri, syn_parent,
|
||||
document_id=synthesis_doc_id,
|
||||
in_token=synthesis_result.in_token if synthesis_result else None,
|
||||
out_token=synthesis_result.out_token if synthesis_result else None,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from . document_rag import DocumentRag
|
|||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import DocumentEmbeddingsClientSpec
|
||||
from ... base import RerankerClientSpec
|
||||
from ... base import LibrarianSpec
|
||||
|
||||
# Module logger
|
||||
|
|
@ -28,14 +29,21 @@ class Processor(FlowProcessor):
|
|||
|
||||
doc_limit = params.get("doc_limit", 5)
|
||||
|
||||
# Instance-default candidate-pool size fetched before cross-encoder
|
||||
# reranking; the rerank step narrows it back down to doc_limit for the
|
||||
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
|
||||
fetch_limit = params.get("fetch_limit", 0)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"doc_limit": doc_limit,
|
||||
"fetch_limit": fetch_limit,
|
||||
}
|
||||
)
|
||||
|
||||
self.doc_limit = doc_limit
|
||||
self.fetch_limit = fetch_limit
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
|
|
@ -66,6 +74,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RerankerClientSpec(
|
||||
request_name = "reranker-request",
|
||||
response_name = "reranker-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
|
|
@ -105,6 +120,7 @@ class Processor(FlowProcessor):
|
|||
doc_embeddings_client = flow("document-embeddings-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
fetch_chunk = fetch_chunk,
|
||||
reranker_client = flow("reranker-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -113,6 +129,13 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
# Candidate-pool size: per-request override, else the instance
|
||||
# default; 0 lets the core derive it from doc_limit.
|
||||
if v.fetch_limit:
|
||||
fetch_limit = v.fetch_limit
|
||||
else:
|
||||
fetch_limit = self.fetch_limit
|
||||
|
||||
async def send_explainability(triples, explain_id):
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
|
|
@ -163,6 +186,7 @@ class Processor(FlowProcessor):
|
|||
workspace=flow.workspace,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
streaming=True,
|
||||
chunk_callback=send_chunk,
|
||||
explain_callback=send_explainability,
|
||||
|
|
@ -188,6 +212,7 @@ class Processor(FlowProcessor):
|
|||
workspace=flow.workspace,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
explain_callback=send_explainability,
|
||||
save_answer_callback=save_answer,
|
||||
)
|
||||
|
|
@ -243,6 +268,15 @@ class Processor(FlowProcessor):
|
|||
help=f'Default document fetch limit (default: 10)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--fetch-limit',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Candidate chunks to fetch from the vector store and rerank '
|
||||
'before keeping the top doc-limit for the LLM '
|
||||
'(default: derive from doc-limit)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue