mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 06:51:00 +02:00
Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after vector retrieval, over-fetch a wider candidate pool, rerank with the cross-encoder, and keep the top doc_limit chunks for synthesis. Per maintainer review, the fetch and select sizes are two caller-controlled limits rather than one internal heuristic: - doc_limit: chunks selected into the synthesis prompt (unchanged meaning). - fetch_limit: candidate pool pulled from the vector store before reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are raised to it. Lets the caller control how hard the reranker has to work. Details: - schema: DocumentRagQuery.fetch_limit (additive, backward compatible). - document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors doc_limit); the core applies the heuristic default and derives synthesis provenance from the chunk-selection focus when reranking ran. - provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection). - request translator + client SDKs + CLI: fetch-limit / --fetch-limit, threaded exactly like doc_limit and the GraphRAG limits. - tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring. Reranking is skipped byte-identically when no reranker role is wired. Requires the companion trustgraph-templates change wiring the reranker topics into the document-rag flow (mirrors #279 for GraphRAG).
This commit is contained in:
parent
f18d48dc39
commit
6c9a545a06
18 changed files with 853 additions and 26 deletions
|
|
@ -101,27 +101,27 @@ class TestQuery:
|
||||||
assert query.rag == mock_rag
|
assert query.rag == mock_rag
|
||||||
assert query.collection == "test_collection"
|
assert query.collection == "test_collection"
|
||||||
assert query.verbose is False
|
assert query.verbose is False
|
||||||
assert query.doc_limit == 20 # Default value
|
assert query.fetch_limit == 20 # Default value
|
||||||
|
|
||||||
def test_query_initialization_with_custom_doc_limit(self):
|
def test_query_initialization_with_custom_fetch_limit(self):
|
||||||
"""Test Query initialization with custom doc_limit"""
|
"""Test Query initialization with custom fetch_limit"""
|
||||||
# Create mock DocumentRag
|
# Create mock DocumentRag
|
||||||
mock_rag = MagicMock()
|
mock_rag = MagicMock()
|
||||||
|
|
||||||
# Initialize Query with custom doc_limit
|
# Initialize Query with custom fetch_limit
|
||||||
query = Query(
|
query = Query(
|
||||||
rag=mock_rag,
|
rag=mock_rag,
|
||||||
workspace="test_workspace",
|
workspace="test_workspace",
|
||||||
collection="custom_collection",
|
collection="custom_collection",
|
||||||
verbose=True,
|
verbose=True,
|
||||||
doc_limit=50
|
fetch_limit=50
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify initialization
|
# Verify initialization
|
||||||
assert query.rag == mock_rag
|
assert query.rag == mock_rag
|
||||||
assert query.collection == "custom_collection"
|
assert query.collection == "custom_collection"
|
||||||
assert query.verbose is True
|
assert query.verbose is True
|
||||||
assert query.doc_limit == 50
|
assert query.fetch_limit == 50
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_concepts(self):
|
async def test_extract_concepts(self):
|
||||||
|
|
@ -224,7 +224,7 @@ class TestQuery:
|
||||||
workspace="test_workspace",
|
workspace="test_workspace",
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
verbose=False,
|
verbose=False,
|
||||||
doc_limit=15
|
fetch_limit=15
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call get_docs with concepts list
|
# Call get_docs with concepts list
|
||||||
|
|
@ -377,7 +377,7 @@ class TestQuery:
|
||||||
workspace="test_workspace",
|
workspace="test_workspace",
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
verbose=True,
|
verbose=True,
|
||||||
doc_limit=5
|
fetch_limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call get_docs with concepts
|
# Call get_docs with concepts
|
||||||
|
|
@ -615,7 +615,7 @@ class TestQuery:
|
||||||
workspace="test_workspace",
|
workspace="test_workspace",
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
verbose=False,
|
verbose=False,
|
||||||
doc_limit=10
|
fetch_limit=10
|
||||||
)
|
)
|
||||||
|
|
||||||
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
|
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
|
||||||
|
|
|
||||||
478
tests/unit/test_retrieval/test_document_rag_rerank.py
Normal file
478
tests/unit/test_retrieval/test_document_rag_rerank.py
Normal file
|
|
@ -0,0 +1,478 @@
|
||||||
|
"""
|
||||||
|
Tests for the optional cross-encoder reranking pass in DocumentRag.query().
|
||||||
|
|
||||||
|
Two behaviours are covered:
|
||||||
|
|
||||||
|
1. No-op: when no reranker_client is wired (the default), query() must feed
|
||||||
|
the LLM the exact same chunks, in the same order, that retrieval produced
|
||||||
|
- byte-identical to the pre-reranker behaviour - and must NOT emit a
|
||||||
|
chunk-selection provenance event.
|
||||||
|
|
||||||
|
2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered
|
||||||
|
and truncated according to the reranker's results, the LLM receives the
|
||||||
|
reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted
|
||||||
|
carrying the per-surviving-chunk scores and chunk references.
|
||||||
|
|
||||||
|
These are pure orchestration tests - the reranker is a stub, so there is no
|
||||||
|
torch / network dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||||
|
from trustgraph.base import PromptResult
|
||||||
|
from trustgraph.schema import RerankerResult
|
||||||
|
|
||||||
|
from trustgraph.provenance.namespaces import (
|
||||||
|
RDF_TYPE, PROV_WAS_DERIVED_FROM,
|
||||||
|
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||||
|
TG_FOCUS, TG_SYNTHESIS,
|
||||||
|
TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def find_triple(triples, predicate, subject=None):
|
||||||
|
for t in triples:
|
||||||
|
if t.p.iri == predicate:
|
||||||
|
if subject is None or t.s.iri == subject:
|
||||||
|
return t
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def find_triples(triples, predicate, subject=None):
|
||||||
|
return [
|
||||||
|
t for t in triples
|
||||||
|
if t.p.iri == predicate
|
||||||
|
and (subject is None or t.s.iri == subject)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def has_type(triples, subject, rdf_type):
|
||||||
|
return any(
|
||||||
|
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||||
|
for t in triples
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def derived_from(triples, subject):
|
||||||
|
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||||
|
return t.o.iri if t else None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkMatch:
|
||||||
|
"""Mimics the result from doc_embeddings_client.query()."""
|
||||||
|
chunk_id: str
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures: three retrievable chunks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
|
||||||
|
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
|
||||||
|
CHUNK_C = "urn:chunk:policy-doc-1:chunk-2"
|
||||||
|
|
||||||
|
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
|
||||||
|
CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays."
|
||||||
|
CHUNK_C_CONTENT = "Refunds are processed to the original payment method."
|
||||||
|
|
||||||
|
# Retrieval (post-dedupe) order is A, B, C.
|
||||||
|
ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT]
|
||||||
|
ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C]
|
||||||
|
|
||||||
|
|
||||||
|
def build_mock_clients():
|
||||||
|
"""
|
||||||
|
Build mock subsidiary clients for a document-rag query returning three
|
||||||
|
distinct chunks (A, B, C) in that order.
|
||||||
|
"""
|
||||||
|
prompt_client = AsyncMock()
|
||||||
|
embeddings_client = AsyncMock()
|
||||||
|
doc_embeddings_client = AsyncMock()
|
||||||
|
fetch_chunk = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||||
|
if template_id == "extract-concepts":
|
||||||
|
return PromptResult(response_type="text", text="return policy\nrefund")
|
||||||
|
return PromptResult(response_type="text", text="")
|
||||||
|
|
||||||
|
prompt_client.prompt.side_effect = mock_prompt
|
||||||
|
|
||||||
|
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||||
|
|
||||||
|
# Each concept query returns the same three chunks; dedupe keeps A, B, C.
|
||||||
|
doc_embeddings_client.query.return_value = [
|
||||||
|
ChunkMatch(chunk_id=CHUNK_A),
|
||||||
|
ChunkMatch(chunk_id=CHUNK_B),
|
||||||
|
ChunkMatch(chunk_id=CHUNK_C),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def mock_fetch(chunk_id):
|
||||||
|
return {
|
||||||
|
CHUNK_A: CHUNK_A_CONTENT,
|
||||||
|
CHUNK_B: CHUNK_B_CONTENT,
|
||||||
|
CHUNK_C: CHUNK_C_CONTENT,
|
||||||
|
}[chunk_id]
|
||||||
|
|
||||||
|
fetch_chunk.side_effect = mock_fetch
|
||||||
|
|
||||||
|
prompt_client.document_prompt.return_value = PromptResult(
|
||||||
|
response_type="text",
|
||||||
|
text="Items can be returned within 30 days for a full refund.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||||
|
|
||||||
|
|
||||||
|
class StubReranker:
|
||||||
|
"""
|
||||||
|
Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed,
|
||||||
|
pre-sorted, truncated list of RerankerResult - exactly the contract the
|
||||||
|
flashrank service guarantees (sorted desc by score, truncated to limit).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, results):
|
||||||
|
self._results = results
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def rerank(self, queries, documents, limit=10, timeout=300):
|
||||||
|
self.calls.append(
|
||||||
|
{"queries": queries, "documents": documents, "limit": limit}
|
||||||
|
)
|
||||||
|
return self._results
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. No-op: reranker_client=None must not change anything
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRerankNoOp:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_documents_passed_to_llm_are_unchanged(self):
|
||||||
|
"""
|
||||||
|
With no reranker wired, document_prompt must receive the retrieved
|
||||||
|
chunks in the original order and length.
|
||||||
|
"""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
rag = DocumentRag(*clients) # reranker_client defaults to None
|
||||||
|
|
||||||
|
await rag.query(query="What is the return policy?")
|
||||||
|
|
||||||
|
call = rag.prompt_client.document_prompt.call_args
|
||||||
|
passed_docs = call.kwargs["documents"]
|
||||||
|
assert passed_docs == ORDERED_CONTENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_chunk_selection_event_emitted(self):
|
||||||
|
"""
|
||||||
|
Without a reranker, the provenance chain is the original 4 stages:
|
||||||
|
question, grounding, exploration, synthesis - no focus stage.
|
||||||
|
"""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
rag = DocumentRag(*clients)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def explain_callback(triples, explain_id):
|
||||||
|
events.append({"triples": triples, "explain_id": explain_id})
|
||||||
|
|
||||||
|
await rag.query(
|
||||||
|
query="What is the return policy?",
|
||||||
|
explain_callback=explain_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(events) == 4
|
||||||
|
types = [
|
||||||
|
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS,
|
||||||
|
]
|
||||||
|
for i, expected in enumerate(types):
|
||||||
|
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
|
||||||
|
|
||||||
|
# No chunk-selection entity anywhere.
|
||||||
|
for e in events:
|
||||||
|
assert not any(
|
||||||
|
t.o.iri == TG_CHUNK_SELECTION
|
||||||
|
for t in e["triples"]
|
||||||
|
if t.p.iri == RDF_TYPE
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_synthesis_derives_from_exploration_when_no_rerank(self):
|
||||||
|
"""
|
||||||
|
No-op lineage is unchanged: synthesis derives from exploration
|
||||||
|
(there is no focus stage). Guards the conditional synthesis parent.
|
||||||
|
"""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
rag = DocumentRag(*clients)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def explain_callback(triples, explain_id):
|
||||||
|
events.append({"triples": triples, "explain_id": explain_id})
|
||||||
|
|
||||||
|
await rag.query(
|
||||||
|
query="What is the return policy?",
|
||||||
|
explain_callback=explain_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# events: question, grounding, exploration, synthesis
|
||||||
|
exp_uri = events[2]["explain_id"]
|
||||||
|
syn_event = events[3]
|
||||||
|
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 2. Rerank: reorder + truncate + provenance
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRerankActive:
|
||||||
|
|
||||||
|
def _reranker_keeping_C_then_A(self):
|
||||||
|
# Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped.
|
||||||
|
# Pre-sorted desc by score and truncated to limit, per the contract.
|
||||||
|
return StubReranker([
|
||||||
|
RerankerResult(document_id="2", query_id="0", score=0.95),
|
||||||
|
RerankerResult(document_id="0", query_id="0", score=0.42),
|
||||||
|
])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_documents_reordered_and_truncated(self):
|
||||||
|
clients = build_mock_clients()
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
await rag.query(query="What is the return policy?")
|
||||||
|
|
||||||
|
call = rag.prompt_client.document_prompt.call_args
|
||||||
|
passed_docs = call.kwargs["documents"]
|
||||||
|
assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reranker_called_with_single_query_and_all_docs(self):
|
||||||
|
clients = build_mock_clients()
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||||
|
|
||||||
|
assert len(reranker.calls) == 1
|
||||||
|
c = reranker.calls[0]
|
||||||
|
assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}]
|
||||||
|
assert c["documents"] == [
|
||||||
|
{"id": "0", "text": CHUNK_A_CONTENT},
|
||||||
|
{"id": "1", "text": CHUNK_B_CONTENT},
|
||||||
|
{"id": "2", "text": CHUNK_C_CONTENT},
|
||||||
|
]
|
||||||
|
# The rerank narrows down to the final doc_limit, NOT fetch_limit
|
||||||
|
# (fetch_limit is the over-fetched candidate pool size).
|
||||||
|
assert c["limit"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_explicit_fetch_limit_over_fetches_then_narrows(self):
|
||||||
|
"""
|
||||||
|
Semantic guard for the value of reranking AND the maintainer's two-limit
|
||||||
|
contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider
|
||||||
|
candidate pool so the cross-encoder can surface chunks the bi-encoder
|
||||||
|
ranked outside the final doc_limit, then the rerank narrows the pool back
|
||||||
|
down to doc_limit. The fetch_limit is honoured directly (caller controls
|
||||||
|
how hard the reranker works), not overridden by any heuristic.
|
||||||
|
"""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
# Candidate pool (fetch_limit=60) >> final doc_limit (6).
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
await rag.query(
|
||||||
|
query="What is the return policy?", doc_limit=6, fetch_limit=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Over-fetch: the embeddings store is queried with the fetch_limit
|
||||||
|
# budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit
|
||||||
|
# budget (6 // 2 = 3). This is the bug guard.
|
||||||
|
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||||
|
assert q_limit == 30
|
||||||
|
|
||||||
|
# Narrow: the rerank keeps the final doc_limit (6), not fetch_limit.
|
||||||
|
assert reranker.calls[0]["limit"] == 6
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self):
|
||||||
|
"""
|
||||||
|
With no fetch_limit passed to query(), the candidate pool falls back to
|
||||||
|
the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with
|
||||||
|
doc_limit and reranking keeps its recall benefit out of the box.
|
||||||
|
"""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
# No fetch_limit -> heuristic default.
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
await rag.query(query="What is the return policy?", doc_limit=20)
|
||||||
|
|
||||||
|
# fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept.
|
||||||
|
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||||
|
assert q_limit == 30
|
||||||
|
# Rerank narrows to the final doc_limit (20).
|
||||||
|
assert reranker.calls[0]["limit"] == 20
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_limit_floored_at_doc_limit(self):
|
||||||
|
"""
|
||||||
|
A fetch_limit below doc_limit is floored up to doc_limit: retrieval must
|
||||||
|
never fetch fewer candidates than the rerank is asked to keep, else the
|
||||||
|
prompt could not be filled.
|
||||||
|
"""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
await rag.query(
|
||||||
|
query="What is the return policy?", doc_limit=10, fetch_limit=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept.
|
||||||
|
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||||
|
assert q_limit == 5
|
||||||
|
assert reranker.calls[0]["limit"] == 10
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunk_selection_event_emitted(self):
|
||||||
|
clients = build_mock_clients()
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def explain_callback(triples, explain_id):
|
||||||
|
events.append({"triples": triples, "explain_id": explain_id})
|
||||||
|
|
||||||
|
await rag.query(
|
||||||
|
query="What is the return policy?",
|
||||||
|
explain_callback=explain_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now 5 stages: question, grounding, exploration, focus, synthesis.
|
||||||
|
assert len(events) == 5
|
||||||
|
ordered_types = [
|
||||||
|
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||||
|
TG_FOCUS, TG_SYNTHESIS,
|
||||||
|
]
|
||||||
|
for i, expected in enumerate(ordered_types):
|
||||||
|
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunk_selection_carries_scores_and_chunk_refs(self):
|
||||||
|
clients = build_mock_clients()
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def explain_callback(triples, explain_id):
|
||||||
|
events.append({"triples": triples, "explain_id": explain_id})
|
||||||
|
|
||||||
|
await rag.query(
|
||||||
|
query="What is the return policy?",
|
||||||
|
explain_callback=explain_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
focus_event = events[3]
|
||||||
|
foc_uri = focus_event["explain_id"]
|
||||||
|
triples = focus_event["triples"]
|
||||||
|
|
||||||
|
# focus is derived from exploration
|
||||||
|
exp_uri = events[2]["explain_id"]
|
||||||
|
assert derived_from(triples, foc_uri) == exp_uri
|
||||||
|
|
||||||
|
# Two ChunkSelection sub-entities, linked from focus.
|
||||||
|
sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri)
|
||||||
|
assert len(sel_links) == 2
|
||||||
|
|
||||||
|
# Each selection has a ChunkSelection type, a chunk document ref and a score.
|
||||||
|
chunk_refs = set()
|
||||||
|
scores = set()
|
||||||
|
for link in sel_links:
|
||||||
|
sel_uri = link.o.iri
|
||||||
|
assert has_type(triples, sel_uri, TG_CHUNK_SELECTION)
|
||||||
|
doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri)
|
||||||
|
assert doc_ref is not None
|
||||||
|
chunk_refs.add(doc_ref.o.iri)
|
||||||
|
score_t = find_triple(triples, TG_SCORE, sel_uri)
|
||||||
|
assert score_t is not None
|
||||||
|
scores.add(score_t.o.value)
|
||||||
|
|
||||||
|
# Surviving chunks are C and A (B dropped), with the reranker scores.
|
||||||
|
assert chunk_refs == {CHUNK_C, CHUNK_A}
|
||||||
|
assert scores == {"0.95", "0.42"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_focus_triples_in_retrieval_graph(self):
|
||||||
|
clients = build_mock_clients()
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def explain_callback(triples, explain_id):
|
||||||
|
events.append({"triples": triples, "explain_id": explain_id})
|
||||||
|
|
||||||
|
await rag.query(
|
||||||
|
query="What is the return policy?",
|
||||||
|
explain_callback=explain_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
for t in events[3]["triples"]:
|
||||||
|
assert t.g == "urn:graph:retrieval"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_synthesis_derives_from_focus_when_reranking(self):
|
||||||
|
"""
|
||||||
|
When reranking runs, synthesis must derive from the focus node (the
|
||||||
|
reranked chunks actually fed to the LLM), mirroring GraphRAG - not from
|
||||||
|
exploration, which would leave focus as a dangling branch and
|
||||||
|
misrepresent what fed the answer.
|
||||||
|
"""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def explain_callback(triples, explain_id):
|
||||||
|
events.append({"triples": triples, "explain_id": explain_id})
|
||||||
|
|
||||||
|
await rag.query(
|
||||||
|
query="What is the return policy?",
|
||||||
|
doc_limit=2,
|
||||||
|
explain_callback=explain_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# events: question, grounding, exploration, focus, synthesis
|
||||||
|
foc_uri = events[3]["explain_id"]
|
||||||
|
syn_event = events[4]
|
||||||
|
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_docs_skips_reranker(self):
|
||||||
|
"""If retrieval returns no chunks, the reranker is never called."""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||||
|
doc_embeddings_client.query.return_value = [] # no matches
|
||||||
|
|
||||||
|
reranker = self._reranker_keeping_C_then_A()
|
||||||
|
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||||
|
|
||||||
|
await rag.query(query="What is the return policy?")
|
||||||
|
|
||||||
|
assert reranker.calls == []
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""
|
||||||
|
Cross-layer wiring contract for the Document-RAG reranker (issue #878).
|
||||||
|
|
||||||
|
The Document-RAG processor registers a ``RerankerClientSpec`` for the
|
||||||
|
``reranker-request`` / ``reranker-response`` roles (see
|
||||||
|
``retrieval/document_rag/rag.py``). At flow construction every spec runs
|
||||||
|
``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add``
|
||||||
|
resolves its topics via ``definition["topics"][name]`` - which raises
|
||||||
|
``KeyError`` if the flow blueprint does not provide those topics.
|
||||||
|
|
||||||
|
This means the monorepo code change is only safe to deploy together with the
|
||||||
|
companion ``trustgraph-templates`` change that wires ``reranker-request`` /
|
||||||
|
``reranker-response`` into the Document-RAG flow (mirroring what templates
|
||||||
|
PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that
|
||||||
|
contract from the monorepo side:
|
||||||
|
|
||||||
|
* with the reranker topics present (as the updated templates compile them),
|
||||||
|
the spec binds cleanly and registers the client;
|
||||||
|
* without them (the pre-companion blueprint), construction fails fast with a
|
||||||
|
KeyError naming the missing role - documenting exactly why the templates
|
||||||
|
change is required.
|
||||||
|
|
||||||
|
No broker/network: the pub/sub backend is mocked (topics are bound at add()
|
||||||
|
time, connections happen later at start()).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from trustgraph.base import RerankerClientSpec
|
||||||
|
|
||||||
|
|
||||||
|
def _flow():
|
||||||
|
f = MagicMock()
|
||||||
|
f.workspace = "ws"
|
||||||
|
f.name = "document-rag"
|
||||||
|
f.id = "proc1"
|
||||||
|
f.consumer = {}
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def _processor():
|
||||||
|
p = MagicMock()
|
||||||
|
p.pubsub = MagicMock()
|
||||||
|
p.id = "proc1"
|
||||||
|
p.taskgroup = MagicMock()
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def _spec():
|
||||||
|
return RerankerClientSpec(
|
||||||
|
request_name="reranker-request",
|
||||||
|
response_name="reranker-response",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Topics dict as the UPDATED document-store.jsonnet compiles them
|
||||||
|
# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}).
|
||||||
|
DEFINITION_WITH_RERANKER = {
|
||||||
|
"topics": {
|
||||||
|
"request": "request:tg:document-rag:ws:id",
|
||||||
|
"response": "response:tg:document-rag:ws:id",
|
||||||
|
"reranker-request": "request:tg:reranker:ws:id",
|
||||||
|
"reranker-response": "response:tg:reranker:ws:id",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pre-companion blueprint: no reranker topics (document-rag before the templates change).
|
||||||
|
DEFINITION_WITHOUT_RERANKER = {
|
||||||
|
"topics": {
|
||||||
|
"request": "request:tg:document-rag:ws:id",
|
||||||
|
"response": "response:tg:document-rag:ws:id",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_reranker_client_binds_when_flow_provides_topics():
|
||||||
|
flow = _flow()
|
||||||
|
_spec().add(flow, _processor(), DEFINITION_WITH_RERANKER)
|
||||||
|
# The client consumer is registered against the reranker role.
|
||||||
|
assert "reranker-request" in flow.consumer
|
||||||
|
|
||||||
|
|
||||||
|
def test_reranker_client_keyerrors_without_companion_template_topics():
|
||||||
|
with pytest.raises(KeyError) as exc:
|
||||||
|
_spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER)
|
||||||
|
# Fails fast naming the missing role -> the trustgraph-templates companion
|
||||||
|
# change (wire reranker-request/response into the document-rag flow) is required.
|
||||||
|
assert "reranker-request" in str(exc.value)
|
||||||
|
|
@ -66,6 +66,7 @@ class TestDocumentRagService:
|
||||||
workspace=ANY, # Workspace comes from flow.workspace (mock)
|
workspace=ANY, # Workspace comes from flow.workspace (mock)
|
||||||
collection="test_coll_1", # Must be from message, not hardcoded default
|
collection="test_coll_1", # Must be from message, not hardcoded default
|
||||||
doc_limit=5,
|
doc_limit=5,
|
||||||
|
fetch_limit=0, # Unset -> core derives the candidate pool
|
||||||
explain_callback=ANY, # Explainability callback is always passed
|
explain_callback=ANY, # Explainability callback is always passed
|
||||||
save_answer_callback=ANY, # Librarian save callback is always passed
|
save_answer_callback=ANY, # Librarian save callback is always passed
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -527,7 +527,8 @@ class AsyncFlowInstance:
|
||||||
return result.get("response", "")
|
return result.get("response", "")
|
||||||
|
|
||||||
async def document_rag(self, query: str, collection: str,
|
async def document_rag(self, query: str, collection: str,
|
||||||
doc_limit: int = 10, **kwargs: Any) -> str:
|
doc_limit: int = 10, fetch_limit: int = 0,
|
||||||
|
**kwargs: Any) -> str:
|
||||||
"""
|
"""
|
||||||
Execute document-based RAG query (non-streaming).
|
Execute document-based RAG query (non-streaming).
|
||||||
|
|
||||||
|
|
@ -541,7 +542,9 @@ class AsyncFlowInstance:
|
||||||
Args:
|
Args:
|
||||||
query: User query text
|
query: User query text
|
||||||
collection: Collection identifier containing documents
|
collection: Collection identifier containing documents
|
||||||
doc_limit: Maximum number of document chunks to retrieve (default: 10)
|
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||||
|
fetch_limit: Candidate chunks fetched from the vector store before
|
||||||
|
reranking (default: 0 = derive from doc_limit)
|
||||||
**kwargs: Additional service-specific parameters
|
**kwargs: Additional service-specific parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -564,6 +567,7 @@ class AsyncFlowInstance:
|
||||||
"query": query,
|
"query": query,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"doc-limit": doc_limit,
|
"doc-limit": doc_limit,
|
||||||
|
"fetch-limit": fetch_limit,
|
||||||
"streaming": False
|
"streaming": False
|
||||||
}
|
}
|
||||||
request_data.update(kwargs)
|
request_data.update(kwargs)
|
||||||
|
|
|
||||||
|
|
@ -379,12 +379,14 @@ class AsyncSocketFlowInstance:
|
||||||
yield chunk.content
|
yield chunk.content
|
||||||
|
|
||||||
async def document_rag(self, query: str, collection: str,
|
async def document_rag(self, query: str, collection: str,
|
||||||
doc_limit: int = 10, streaming: bool = False, **kwargs):
|
doc_limit: int = 10, fetch_limit: int = 0,
|
||||||
|
streaming: bool = False, **kwargs):
|
||||||
"""Document RAG with optional streaming"""
|
"""Document RAG with optional streaming"""
|
||||||
request = {
|
request = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"doc-limit": doc_limit,
|
"doc-limit": doc_limit,
|
||||||
|
"fetch-limit": fetch_limit,
|
||||||
"streaming": streaming
|
"streaming": streaming
|
||||||
}
|
}
|
||||||
request.update(kwargs)
|
request.update(kwargs)
|
||||||
|
|
|
||||||
|
|
@ -415,7 +415,7 @@ class FlowInstance:
|
||||||
|
|
||||||
def document_rag(
|
def document_rag(
|
||||||
self, query,collection="default",
|
self, query,collection="default",
|
||||||
doc_limit=10,
|
doc_limit=10, fetch_limit=0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute document-based Retrieval-Augmented Generation (RAG) query.
|
Execute document-based Retrieval-Augmented Generation (RAG) query.
|
||||||
|
|
@ -426,7 +426,9 @@ class FlowInstance:
|
||||||
Args:
|
Args:
|
||||||
query: Natural language query
|
query: Natural language query
|
||||||
collection: Collection identifier (default: "default")
|
collection: Collection identifier (default: "default")
|
||||||
doc_limit: Maximum document chunks to retrieve (default: 10)
|
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||||
|
fetch_limit: Candidate chunks fetched from the vector store before
|
||||||
|
reranking (default: 0 = derive from doc_limit)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Generated response incorporating document context
|
str: Generated response incorporating document context
|
||||||
|
|
@ -447,6 +449,7 @@ class FlowInstance:
|
||||||
"query": query,
|
"query": query,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"doc-limit": doc_limit,
|
"doc-limit": doc_limit,
|
||||||
|
"fetch-limit": fetch_limit,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.request(
|
result = self.request(
|
||||||
|
|
|
||||||
|
|
@ -752,6 +752,7 @@ class SocketFlowInstance:
|
||||||
query: str,
|
query: str,
|
||||||
collection: str,
|
collection: str,
|
||||||
doc_limit: int = 10,
|
doc_limit: int = 10,
|
||||||
|
fetch_limit: int = 0,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||||
|
|
@ -764,6 +765,7 @@ class SocketFlowInstance:
|
||||||
"query": query,
|
"query": query,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"doc-limit": doc_limit,
|
"doc-limit": doc_limit,
|
||||||
|
"fetch-limit": fetch_limit,
|
||||||
"streaming": streaming
|
"streaming": streaming
|
||||||
}
|
}
|
||||||
request.update(kwargs)
|
request.update(kwargs)
|
||||||
|
|
@ -785,6 +787,7 @@ class SocketFlowInstance:
|
||||||
query: str,
|
query: str,
|
||||||
collection: str,
|
collection: str,
|
||||||
doc_limit: int = 10,
|
doc_limit: int = 10,
|
||||||
|
fetch_limit: int = 0,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||||
"""Execute document-based RAG query with explainability support."""
|
"""Execute document-based RAG query with explainability support."""
|
||||||
|
|
@ -792,6 +795,7 @@ class SocketFlowInstance:
|
||||||
"query": query,
|
"query": query,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"doc-limit": doc_limit,
|
"doc-limit": doc_limit,
|
||||||
|
"fetch-limit": fetch_limit,
|
||||||
"streaming": True,
|
"streaming": True,
|
||||||
"explainable": True,
|
"explainable": True,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
||||||
query=data["query"],
|
query=data["query"],
|
||||||
collection=data.get("collection", "default"),
|
collection=data.get("collection", "default"),
|
||||||
doc_limit=int(data.get("doc-limit", 20)),
|
doc_limit=int(data.get("doc-limit", 20)),
|
||||||
|
fetch_limit=int(data.get("fetch-limit", 0)),
|
||||||
streaming=data.get("streaming", False)
|
streaming=data.get("streaming", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
||||||
"query": obj.query,
|
"query": obj.query,
|
||||||
"collection": obj.collection,
|
"collection": obj.collection,
|
||||||
"doc-limit": obj.doc_limit,
|
"doc-limit": obj.doc_limit,
|
||||||
|
"fetch-limit": obj.fetch_limit,
|
||||||
"streaming": getattr(obj, "streaming", False)
|
"streaming": getattr(obj, "streaming", False)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,8 @@ from . uris import (
|
||||||
docrag_question_uri,
|
docrag_question_uri,
|
||||||
docrag_grounding_uri,
|
docrag_grounding_uri,
|
||||||
docrag_exploration_uri,
|
docrag_exploration_uri,
|
||||||
|
docrag_focus_uri,
|
||||||
|
chunk_selection_uri,
|
||||||
docrag_synthesis_uri,
|
docrag_synthesis_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -94,6 +96,8 @@ from . namespaces import (
|
||||||
TG_EDGE_SELECTION,
|
TG_EDGE_SELECTION,
|
||||||
# Query-time provenance predicates (DocumentRAG)
|
# Query-time provenance predicates (DocumentRAG)
|
||||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||||
|
# Chunk selection entity type
|
||||||
|
TG_CHUNK_SELECTION,
|
||||||
# Explainability entity types
|
# Explainability entity types
|
||||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||||
TG_ANALYSIS, TG_CONCLUSION,
|
TG_ANALYSIS, TG_CONCLUSION,
|
||||||
|
|
@ -132,6 +136,7 @@ from . triples import (
|
||||||
# Query-time provenance triple builders (DocumentRAG)
|
# Query-time provenance triple builders (DocumentRAG)
|
||||||
docrag_question_triples,
|
docrag_question_triples,
|
||||||
docrag_exploration_triples,
|
docrag_exploration_triples,
|
||||||
|
docrag_chunk_selection_triples,
|
||||||
docrag_synthesis_triples,
|
docrag_synthesis_triples,
|
||||||
# Utility
|
# Utility
|
||||||
set_graph,
|
set_graph,
|
||||||
|
|
@ -196,6 +201,8 @@ __all__ = [
|
||||||
"docrag_question_uri",
|
"docrag_question_uri",
|
||||||
"docrag_grounding_uri",
|
"docrag_grounding_uri",
|
||||||
"docrag_exploration_uri",
|
"docrag_exploration_uri",
|
||||||
|
"docrag_focus_uri",
|
||||||
|
"chunk_selection_uri",
|
||||||
"docrag_synthesis_uri",
|
"docrag_synthesis_uri",
|
||||||
# Namespaces
|
# Namespaces
|
||||||
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
|
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
|
||||||
|
|
@ -219,6 +226,8 @@ __all__ = [
|
||||||
"TG_EDGE_SELECTION",
|
"TG_EDGE_SELECTION",
|
||||||
# Query-time provenance predicates (DocumentRAG)
|
# Query-time provenance predicates (DocumentRAG)
|
||||||
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
|
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
|
||||||
|
# Chunk selection entity type
|
||||||
|
"TG_CHUNK_SELECTION",
|
||||||
# Explainability entity types
|
# Explainability entity types
|
||||||
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
|
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
|
||||||
"TG_ANALYSIS", "TG_CONCLUSION",
|
"TG_ANALYSIS", "TG_CONCLUSION",
|
||||||
|
|
@ -254,6 +263,7 @@ __all__ = [
|
||||||
# Query-time provenance triple builders (DocumentRAG)
|
# Query-time provenance triple builders (DocumentRAG)
|
||||||
"docrag_question_triples",
|
"docrag_question_triples",
|
||||||
"docrag_exploration_triples",
|
"docrag_exploration_triples",
|
||||||
|
"docrag_chunk_selection_triples",
|
||||||
"docrag_synthesis_triples",
|
"docrag_synthesis_triples",
|
||||||
# Agent provenance triple builders
|
# Agent provenance triple builders
|
||||||
"agent_session_triples",
|
"agent_session_triples",
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,9 @@ TG_EDGE_SELECTION = TG + "EdgeSelection"
|
||||||
TG_CHUNK_COUNT = TG + "chunkCount"
|
TG_CHUNK_COUNT = TG + "chunkCount"
|
||||||
TG_SELECTED_CHUNK = TG + "selectedChunk"
|
TG_SELECTED_CHUNK = TG + "selectedChunk"
|
||||||
|
|
||||||
|
# Chunk selection entity type (cross-encoder reranked chunk in Focus)
|
||||||
|
TG_CHUNK_SELECTION = TG + "ChunkSelection"
|
||||||
|
|
||||||
# Extraction provenance entity types
|
# Extraction provenance entity types
|
||||||
TG_DOCUMENT_TYPE = TG + "Document"
|
TG_DOCUMENT_TYPE = TG + "Document"
|
||||||
TG_PAGE_TYPE = TG + "Page"
|
TG_PAGE_TYPE = TG + "Page"
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ from . namespaces import (
|
||||||
TG_EDGE_SELECTION,
|
TG_EDGE_SELECTION,
|
||||||
# Query-time provenance predicates (DocumentRAG)
|
# Query-time provenance predicates (DocumentRAG)
|
||||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||||
|
# Chunk selection entity type
|
||||||
|
TG_CHUNK_SELECTION,
|
||||||
# Explainability entity types
|
# Explainability entity types
|
||||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||||
# Unifying types
|
# Unifying types
|
||||||
|
|
@ -40,7 +42,10 @@ from . namespaces import (
|
||||||
TG_IN_TOKEN, TG_OUT_TOKEN,
|
TG_IN_TOKEN, TG_OUT_TOKEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
|
from . uris import (
|
||||||
|
activity_uri, agent_uri, subgraph_uri, edge_selection_uri,
|
||||||
|
chunk_selection_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
|
def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
|
||||||
|
|
@ -718,6 +723,75 @@ def docrag_exploration_triples(
|
||||||
return triples
|
return triples
|
||||||
|
|
||||||
|
|
||||||
|
def docrag_chunk_selection_triples(
|
||||||
|
focus_uri: str,
|
||||||
|
exploration_uri: str,
|
||||||
|
selected_chunks_with_scores: List[dict],
|
||||||
|
session_id: str,
|
||||||
|
) -> List[Triple]:
|
||||||
|
"""
|
||||||
|
Build triples for a document RAG focus entity (chunks selected by the
|
||||||
|
cross-encoder reranker).
|
||||||
|
|
||||||
|
Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity
|
||||||
|
derived from exploration, with one ChunkSelection sub-entity per surviving
|
||||||
|
chunk carrying the chunk reference and the reranker score.
|
||||||
|
|
||||||
|
Structure:
|
||||||
|
<focus> a tg:Focus ; prov:wasDerivedFrom <exploration> .
|
||||||
|
<focus> tg:selectedChunk <chunk_sel_0> .
|
||||||
|
<chunk_sel_0> a tg:ChunkSelection .
|
||||||
|
<chunk_sel_0> tg:document <chunk_id> .
|
||||||
|
<chunk_sel_0> tg:score "0.97" .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
focus_uri: URI of the focus entity (from docrag_focus_uri)
|
||||||
|
exploration_uri: URI of the parent exploration entity
|
||||||
|
selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score'
|
||||||
|
session_id: Session UUID for generating chunk selection URIs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Triple objects
|
||||||
|
"""
|
||||||
|
triples = [
|
||||||
|
_triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||||
|
_triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)),
|
||||||
|
_triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")),
|
||||||
|
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||||
|
]
|
||||||
|
|
||||||
|
for idx, chunk_info in enumerate(selected_chunks_with_scores):
|
||||||
|
chunk_id = chunk_info.get("chunk_id")
|
||||||
|
if not chunk_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_sel_uri = chunk_selection_uri(session_id, idx)
|
||||||
|
|
||||||
|
# Link focus to chunk selection entity
|
||||||
|
triples.append(
|
||||||
|
_triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Type the chunk selection entity
|
||||||
|
triples.append(
|
||||||
|
_triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reference the actual chunk (in librarian)
|
||||||
|
triples.append(
|
||||||
|
_triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cross-encoder score
|
||||||
|
score = chunk_info.get("score")
|
||||||
|
if score is not None:
|
||||||
|
triples.append(
|
||||||
|
_triple(chunk_sel_uri, TG_SCORE, _literal(str(score)))
|
||||||
|
)
|
||||||
|
|
||||||
|
return triples
|
||||||
|
|
||||||
|
|
||||||
def docrag_synthesis_triples(
|
def docrag_synthesis_triples(
|
||||||
synthesis_uri: str,
|
synthesis_uri: str,
|
||||||
exploration_uri: str,
|
exploration_uri: str,
|
||||||
|
|
|
||||||
|
|
@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str:
|
||||||
return f"urn:trustgraph:docrag:{session_id}/exploration"
|
return f"urn:trustgraph:docrag:{session_id}/exploration"
|
||||||
|
|
||||||
|
|
||||||
|
def docrag_focus_uri(session_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate URI for a document RAG focus entity (chunks selected by the
|
||||||
|
cross-encoder reranker).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URN in format: urn:trustgraph:docrag:{uuid}/focus
|
||||||
|
"""
|
||||||
|
return f"urn:trustgraph:docrag:{session_id}/focus"
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_selection_uri(session_id: str, chunk_index: int) -> str:
|
||||||
|
"""
|
||||||
|
Generate URI for a chunk selection item (links a reranked chunk to its
|
||||||
|
score). Mirrors edge_selection_uri for GraphRAG.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session UUID.
|
||||||
|
chunk_index: Index of this chunk in the selection (0-based).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URN in format: urn:trustgraph:prov:chunk:{uuid}:{index}
|
||||||
|
"""
|
||||||
|
return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}"
|
||||||
|
|
||||||
|
|
||||||
def docrag_synthesis_uri(session_id: str) -> str:
|
def docrag_synthesis_uri(session_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Generate URI for a document RAG synthesis entity (final answer).
|
Generate URI for a document RAG synthesis entity (final answer).
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from . namespaces import (
|
||||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||||
TG_EDGE_SELECTION, TG_SCORE,
|
TG_EDGE_SELECTION, TG_SCORE,
|
||||||
|
TG_CHUNK_SELECTION,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -95,6 +96,7 @@ TG_CLASS_LABELS = [
|
||||||
_label_triple(TG_PLAN_TYPE, "Plan"),
|
_label_triple(TG_PLAN_TYPE, "Plan"),
|
||||||
_label_triple(TG_STEP_RESULT, "Step Result"),
|
_label_triple(TG_STEP_RESULT, "Step Result"),
|
||||||
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
|
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
|
||||||
|
_label_triple(TG_CHUNK_SELECTION, "Chunk Selection"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# TrustGraph predicate labels
|
# TrustGraph predicate labels
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,10 @@ class GraphRagResponse:
|
||||||
class DocumentRagQuery:
|
class DocumentRagQuery:
|
||||||
query: str = ""
|
query: str = ""
|
||||||
collection: str = ""
|
collection: str = ""
|
||||||
doc_limit: int = 0
|
doc_limit: int = 0 # docs selected into the synthesis prompt
|
||||||
|
fetch_limit: int = 0 # candidate pool fetched from the vector store
|
||||||
|
# before reranking (0 = derive from doc_limit;
|
||||||
|
# values below doc_limit are raised to it)
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||||
default_collection = 'default'
|
default_collection = 'default'
|
||||||
default_doc_limit = 10
|
default_doc_limit = 10
|
||||||
|
default_fetch_limit = 0
|
||||||
|
|
||||||
|
|
||||||
def question_explainable(
|
def question_explainable(
|
||||||
url, flow_id, question_text, collection, doc_limit, token=None, debug=False,
|
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
|
||||||
|
token=None, debug=False,
|
||||||
workspace="default",
|
workspace="default",
|
||||||
):
|
):
|
||||||
"""Execute document RAG with explainability - shows provenance events inline."""
|
"""Execute document RAG with explainability - shows provenance events inline."""
|
||||||
|
|
@ -39,6 +41,7 @@ def question_explainable(
|
||||||
query=question_text,
|
query=question_text,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
doc_limit=doc_limit,
|
doc_limit=doc_limit,
|
||||||
|
fetch_limit=fetch_limit,
|
||||||
):
|
):
|
||||||
if isinstance(item, RAGChunk):
|
if isinstance(item, RAGChunk):
|
||||||
# Print response content
|
# Print response content
|
||||||
|
|
@ -97,7 +100,7 @@ def question_explainable(
|
||||||
|
|
||||||
|
|
||||||
def question(
|
def question(
|
||||||
url, flow_id, question_text, collection, doc_limit,
|
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
|
||||||
streaming=True, token=None, explainable=False, debug=False,
|
streaming=True, token=None, explainable=False, debug=False,
|
||||||
show_usage=False, workspace="default",
|
show_usage=False, workspace="default",
|
||||||
):
|
):
|
||||||
|
|
@ -109,6 +112,7 @@ def question(
|
||||||
question_text=question_text,
|
question_text=question_text,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
doc_limit=doc_limit,
|
doc_limit=doc_limit,
|
||||||
|
fetch_limit=fetch_limit,
|
||||||
token=token,
|
token=token,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
|
|
@ -128,6 +132,7 @@ def question(
|
||||||
query=question_text,
|
query=question_text,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
doc_limit=doc_limit,
|
doc_limit=doc_limit,
|
||||||
|
fetch_limit=fetch_limit,
|
||||||
streaming=True
|
streaming=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -155,6 +160,7 @@ def question(
|
||||||
query=question_text,
|
query=question_text,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
doc_limit=doc_limit,
|
doc_limit=doc_limit,
|
||||||
|
fetch_limit=fetch_limit,
|
||||||
)
|
)
|
||||||
print(result.text)
|
print(result.text)
|
||||||
|
|
||||||
|
|
@ -214,7 +220,15 @@ def main():
|
||||||
'-d', '--doc-limit',
|
'-d', '--doc-limit',
|
||||||
type=int,
|
type=int,
|
||||||
default=default_doc_limit,
|
default=default_doc_limit,
|
||||||
help=f'Document limit (default: {default_doc_limit})'
|
help=f'Documents selected into the prompt (default: {default_doc_limit})'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--fetch-limit',
|
||||||
|
type=int,
|
||||||
|
default=default_fetch_limit,
|
||||||
|
help='Candidate documents fetched from the vector store before '
|
||||||
|
'reranking (default: derive from doc-limit)'
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -251,6 +265,7 @@ def main():
|
||||||
question_text=args.question,
|
question_text=args.question,
|
||||||
collection=args.collection,
|
collection=args.collection,
|
||||||
doc_limit=args.doc_limit,
|
doc_limit=args.doc_limit,
|
||||||
|
fetch_limit=args.fetch_limit,
|
||||||
streaming=not args.no_streaming,
|
streaming=not args.no_streaming,
|
||||||
token=args.token,
|
token=args.token,
|
||||||
explainable=args.explainable,
|
explainable=args.explainable,
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,12 @@ from trustgraph.provenance import (
|
||||||
docrag_question_uri,
|
docrag_question_uri,
|
||||||
docrag_grounding_uri,
|
docrag_grounding_uri,
|
||||||
docrag_exploration_uri,
|
docrag_exploration_uri,
|
||||||
|
docrag_focus_uri,
|
||||||
docrag_synthesis_uri,
|
docrag_synthesis_uri,
|
||||||
docrag_question_triples,
|
docrag_question_triples,
|
||||||
grounding_triples,
|
grounding_triples,
|
||||||
docrag_exploration_triples,
|
docrag_exploration_triples,
|
||||||
|
docrag_chunk_selection_triples,
|
||||||
docrag_synthesis_triples,
|
docrag_synthesis_triples,
|
||||||
set_graph,
|
set_graph,
|
||||||
GRAPH_RETRIEVAL,
|
GRAPH_RETRIEVAL,
|
||||||
|
|
@ -21,19 +23,25 @@ from trustgraph.provenance import (
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# When the caller does not specify a fetch_limit, reranking over-fetches this
|
||||||
|
# many times the final doc_limit as the candidate pool, so the cross-encoder can
|
||||||
|
# recover relevant chunks the bi-encoder ranked just outside the top doc_limit.
|
||||||
|
# This is only the fallback default: an explicit fetch_limit overrides it.
|
||||||
|
OVERFETCH_FACTOR = 3
|
||||||
|
|
||||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||||
|
|
||||||
class Query:
|
class Query:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, rag, workspace, collection, verbose,
|
self, rag, workspace, collection, verbose,
|
||||||
doc_limit=20, track_usage=None,
|
fetch_limit=20, track_usage=None,
|
||||||
):
|
):
|
||||||
self.rag = rag
|
self.rag = rag
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.collection = collection
|
self.collection = collection
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.doc_limit = doc_limit
|
self.fetch_limit = fetch_limit
|
||||||
self.track_usage = track_usage
|
self.track_usage = track_usage
|
||||||
|
|
||||||
async def extract_concepts(self, query):
|
async def extract_concepts(self, query):
|
||||||
|
|
@ -91,7 +99,7 @@ class Query:
|
||||||
|
|
||||||
# Query chunk matches for each concept concurrently
|
# Query chunk matches for each concept concurrently
|
||||||
per_concept_limit = max(
|
per_concept_limit = max(
|
||||||
1, self.doc_limit // len(vectors)
|
1, self.fetch_limit // len(vectors)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def query_concept(vec):
|
async def query_concept(vec):
|
||||||
|
|
@ -140,6 +148,7 @@ class DocumentRag:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, prompt_client, embeddings_client, doc_embeddings_client,
|
self, prompt_client, embeddings_client, doc_embeddings_client,
|
||||||
fetch_chunk,
|
fetch_chunk,
|
||||||
|
reranker_client=None,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -150,12 +159,16 @@ class DocumentRag:
|
||||||
self.doc_embeddings_client = doc_embeddings_client
|
self.doc_embeddings_client = doc_embeddings_client
|
||||||
self.fetch_chunk = fetch_chunk
|
self.fetch_chunk = fetch_chunk
|
||||||
|
|
||||||
|
# Optional cross-encoder reranker. When None, the retrieval path is
|
||||||
|
# byte-identical to the pre-reranker behaviour.
|
||||||
|
self.reranker_client = reranker_client
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("DocumentRag initialized")
|
logger.debug("DocumentRag initialized")
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self, query, workspace="default", collection="default",
|
self, query, workspace="default", collection="default",
|
||||||
doc_limit=20, streaming=False, chunk_callback=None,
|
doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None,
|
||||||
explain_callback=None, save_answer_callback=None,
|
explain_callback=None, save_answer_callback=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -165,7 +178,10 @@ class DocumentRag:
|
||||||
query: The query string
|
query: The query string
|
||||||
workspace: Workspace for isolation (also scopes chunk lookup)
|
workspace: Workspace for isolation (also scopes chunk lookup)
|
||||||
collection: Collection identifier
|
collection: Collection identifier
|
||||||
doc_limit: Max chunks to retrieve
|
doc_limit: Chunks selected into the synthesis prompt (after rerank)
|
||||||
|
fetch_limit: Candidate pool fetched from the vector store before
|
||||||
|
reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a
|
||||||
|
reranker is wired, else doc_limit).
|
||||||
streaming: Enable streaming LLM response
|
streaming: Enable streaming LLM response
|
||||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||||
explain_callback: async def callback(triples, explain_id) for explainability
|
explain_callback: async def callback(triples, explain_id) for explainability
|
||||||
|
|
@ -197,6 +213,7 @@ class DocumentRag:
|
||||||
q_uri = docrag_question_uri(session_id)
|
q_uri = docrag_question_uri(session_id)
|
||||||
gnd_uri = docrag_grounding_uri(session_id)
|
gnd_uri = docrag_grounding_uri(session_id)
|
||||||
exp_uri = docrag_exploration_uri(session_id)
|
exp_uri = docrag_exploration_uri(session_id)
|
||||||
|
foc_uri = docrag_focus_uri(session_id)
|
||||||
syn_uri = docrag_synthesis_uri(session_id)
|
syn_uri = docrag_synthesis_uri(session_id)
|
||||||
|
|
||||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||||
|
|
@ -209,10 +226,21 @@ class DocumentRag:
|
||||||
)
|
)
|
||||||
await explain_callback(q_triples, q_uri)
|
await explain_callback(q_triples, q_uri)
|
||||||
|
|
||||||
|
# Resolve the candidate-pool size fetched from the vector store. When a
|
||||||
|
# reranker is wired, honour an explicit fetch_limit; if unset, fall back
|
||||||
|
# to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit,
|
||||||
|
# else the rerank could not fill the prompt. Without a reranker, fetch
|
||||||
|
# doc_limit as before (byte-identical behaviour).
|
||||||
|
if self.reranker_client is not None:
|
||||||
|
fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit)
|
||||||
|
fetch_count = max(fl, doc_limit)
|
||||||
|
else:
|
||||||
|
fetch_count = doc_limit
|
||||||
|
|
||||||
q = Query(
|
q = Query(
|
||||||
rag=self, workspace=workspace, collection=collection,
|
rag=self, workspace=workspace, collection=collection,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
doc_limit=doc_limit, track_usage=track_usage,
|
fetch_limit=fetch_count, track_usage=track_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract concepts from query (grounding step)
|
# Extract concepts from query (grounding step)
|
||||||
|
|
@ -235,6 +263,7 @@ class DocumentRag:
|
||||||
docs, chunk_ids = await q.get_docs(concepts)
|
docs, chunk_ids = await q.get_docs(concepts)
|
||||||
|
|
||||||
# Emit exploration explainability after chunks retrieved
|
# Emit exploration explainability after chunks retrieved
|
||||||
|
# (full candidate set, before any reranking)
|
||||||
if explain_callback:
|
if explain_callback:
|
||||||
exp_triples = set_graph(
|
exp_triples = set_graph(
|
||||||
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
|
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
|
||||||
|
|
@ -242,6 +271,45 @@ class DocumentRag:
|
||||||
)
|
)
|
||||||
await explain_callback(exp_triples, exp_uri)
|
await explain_callback(exp_triples, exp_uri)
|
||||||
|
|
||||||
|
# Optional cross-encoder reranking pass between retrieval and
|
||||||
|
# synthesis. Mirrors GraphRAG's reranker usage but with a single
|
||||||
|
# query (the question). When no reranker is wired, this block is
|
||||||
|
# skipped entirely and behaviour is byte-identical to before.
|
||||||
|
reranked = False
|
||||||
|
if self.reranker_client is not None and docs:
|
||||||
|
results = await self.reranker_client.rerank(
|
||||||
|
queries=[{"id": "0", "text": query}],
|
||||||
|
documents=[
|
||||||
|
{"id": str(i), "text": d} for i, d in enumerate(docs)
|
||||||
|
],
|
||||||
|
# Narrow the over-fetched candidate pool down to the final
|
||||||
|
# doc_limit requested for synthesis.
|
||||||
|
limit=doc_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# results are sorted desc by score and truncated to limit by the
|
||||||
|
# reranker service, so order gives the surviving top-N directly.
|
||||||
|
order = [int(r.document_id) for r in results]
|
||||||
|
docs = [docs[i] for i in order]
|
||||||
|
chunk_ids = [chunk_ids[i] for i in order]
|
||||||
|
reranked = True
|
||||||
|
|
||||||
|
# Emit chunk-selection (focus) explainability: surviving chunks
|
||||||
|
# with their cross-encoder scores, derived from exploration.
|
||||||
|
if explain_callback:
|
||||||
|
selected_chunks_with_scores = [
|
||||||
|
{"chunk_id": chunk_ids[i], "score": r.score}
|
||||||
|
for i, r in enumerate(results)
|
||||||
|
]
|
||||||
|
foc_triples = set_graph(
|
||||||
|
docrag_chunk_selection_triples(
|
||||||
|
foc_uri, exp_uri,
|
||||||
|
selected_chunks_with_scores, session_id,
|
||||||
|
),
|
||||||
|
GRAPH_RETRIEVAL
|
||||||
|
)
|
||||||
|
await explain_callback(foc_triples, foc_uri)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Invoking LLM...")
|
logger.debug("Invoking LLM...")
|
||||||
logger.debug(f"Documents: {docs}")
|
logger.debug(f"Documents: {docs}")
|
||||||
|
|
@ -291,9 +359,15 @@ class DocumentRag:
|
||||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||||
synthesis_doc_id = None
|
synthesis_doc_id = None
|
||||||
|
|
||||||
|
# When reranking ran, synthesis derives from the focus (the
|
||||||
|
# reranked chunks actually fed to the LLM), as GraphRAG always does.
|
||||||
|
# When no reranker is wired, there is no focus stage, so synthesis
|
||||||
|
# derives from exploration (the unchanged no-op lineage) - a
|
||||||
|
# deliberate divergence from GraphRAG's always-on focus.
|
||||||
|
syn_parent = foc_uri if reranked else exp_uri
|
||||||
syn_triples = set_graph(
|
syn_triples = set_graph(
|
||||||
docrag_synthesis_triples(
|
docrag_synthesis_triples(
|
||||||
syn_uri, exp_uri,
|
syn_uri, syn_parent,
|
||||||
document_id=synthesis_doc_id,
|
document_id=synthesis_doc_id,
|
||||||
in_token=synthesis_result.in_token if synthesis_result else None,
|
in_token=synthesis_result.in_token if synthesis_result else None,
|
||||||
out_token=synthesis_result.out_token if synthesis_result else None,
|
out_token=synthesis_result.out_token if synthesis_result else None,
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from . document_rag import DocumentRag
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||||
from ... base import DocumentEmbeddingsClientSpec
|
from ... base import DocumentEmbeddingsClientSpec
|
||||||
|
from ... base import RerankerClientSpec
|
||||||
from ... base import LibrarianSpec
|
from ... base import LibrarianSpec
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
|
|
@ -28,14 +29,21 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
doc_limit = params.get("doc_limit", 5)
|
doc_limit = params.get("doc_limit", 5)
|
||||||
|
|
||||||
|
# Instance-default candidate-pool size fetched before cross-encoder
|
||||||
|
# reranking; the rerank step narrows it back down to doc_limit for the
|
||||||
|
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
|
||||||
|
fetch_limit = params.get("fetch_limit", 0)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
"doc_limit": doc_limit,
|
"doc_limit": doc_limit,
|
||||||
|
"fetch_limit": fetch_limit,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.doc_limit = doc_limit
|
self.doc_limit = doc_limit
|
||||||
|
self.fetch_limit = fetch_limit
|
||||||
|
|
||||||
self.register_specification(
|
self.register_specification(
|
||||||
ConsumerSpec(
|
ConsumerSpec(
|
||||||
|
|
@ -66,6 +74,13 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
RerankerClientSpec(
|
||||||
|
request_name = "reranker-request",
|
||||||
|
response_name = "reranker-response",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.register_specification(
|
self.register_specification(
|
||||||
ProducerSpec(
|
ProducerSpec(
|
||||||
name = "response",
|
name = "response",
|
||||||
|
|
@ -105,6 +120,7 @@ class Processor(FlowProcessor):
|
||||||
doc_embeddings_client = flow("document-embeddings-request"),
|
doc_embeddings_client = flow("document-embeddings-request"),
|
||||||
prompt_client = flow("prompt-request"),
|
prompt_client = flow("prompt-request"),
|
||||||
fetch_chunk = fetch_chunk,
|
fetch_chunk = fetch_chunk,
|
||||||
|
reranker_client = flow("reranker-request"),
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -113,6 +129,13 @@ class Processor(FlowProcessor):
|
||||||
else:
|
else:
|
||||||
doc_limit = self.doc_limit
|
doc_limit = self.doc_limit
|
||||||
|
|
||||||
|
# Candidate-pool size: per-request override, else the instance
|
||||||
|
# default; 0 lets the core derive it from doc_limit.
|
||||||
|
if v.fetch_limit:
|
||||||
|
fetch_limit = v.fetch_limit
|
||||||
|
else:
|
||||||
|
fetch_limit = self.fetch_limit
|
||||||
|
|
||||||
async def send_explainability(triples, explain_id):
|
async def send_explainability(triples, explain_id):
|
||||||
await flow("explainability").send(Triples(
|
await flow("explainability").send(Triples(
|
||||||
metadata=Metadata(
|
metadata=Metadata(
|
||||||
|
|
@ -163,6 +186,7 @@ class Processor(FlowProcessor):
|
||||||
workspace=flow.workspace,
|
workspace=flow.workspace,
|
||||||
collection=v.collection,
|
collection=v.collection,
|
||||||
doc_limit=doc_limit,
|
doc_limit=doc_limit,
|
||||||
|
fetch_limit=fetch_limit,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
chunk_callback=send_chunk,
|
chunk_callback=send_chunk,
|
||||||
explain_callback=send_explainability,
|
explain_callback=send_explainability,
|
||||||
|
|
@ -188,6 +212,7 @@ class Processor(FlowProcessor):
|
||||||
workspace=flow.workspace,
|
workspace=flow.workspace,
|
||||||
collection=v.collection,
|
collection=v.collection,
|
||||||
doc_limit=doc_limit,
|
doc_limit=doc_limit,
|
||||||
|
fetch_limit=fetch_limit,
|
||||||
explain_callback=send_explainability,
|
explain_callback=send_explainability,
|
||||||
save_answer_callback=save_answer,
|
save_answer_callback=save_answer,
|
||||||
)
|
)
|
||||||
|
|
@ -243,6 +268,15 @@ class Processor(FlowProcessor):
|
||||||
help=f'Default document fetch limit (default: 10)'
|
help=f'Default document fetch limit (default: 10)'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--fetch-limit',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Candidate chunks to fetch from the vector store and rerank '
|
||||||
|
'before keeping the top doc-limit for the LLM '
|
||||||
|
'(default: derive from doc-limit)'
|
||||||
|
)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
Processor.launch(default_ident, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue