diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 7762b543..a08bc718 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -101,27 +101,27 @@ class TestQuery: assert query.rag == mock_rag assert query.collection == "test_collection" assert query.verbose is False - assert query.doc_limit == 20 # Default value + assert query.fetch_limit == 20 # Default value - def test_query_initialization_with_custom_doc_limit(self): - """Test Query initialization with custom doc_limit""" + def test_query_initialization_with_custom_fetch_limit(self): + """Test Query initialization with custom fetch_limit""" # Create mock DocumentRag mock_rag = MagicMock() - # Initialize Query with custom doc_limit + # Initialize Query with custom fetch_limit query = Query( rag=mock_rag, workspace="test_workspace", collection="custom_collection", verbose=True, - doc_limit=50 + fetch_limit=50 ) # Verify initialization assert query.rag == mock_rag assert query.collection == "custom_collection" assert query.verbose is True - assert query.doc_limit == 50 + assert query.fetch_limit == 50 @pytest.mark.asyncio async def test_extract_concepts(self): @@ -224,7 +224,7 @@ class TestQuery: workspace="test_workspace", collection="test_collection", verbose=False, - doc_limit=15 + fetch_limit=15 ) # Call get_docs with concepts list @@ -377,7 +377,7 @@ class TestQuery: workspace="test_workspace", collection="test_collection", verbose=True, - doc_limit=5 + fetch_limit=5 ) # Call get_docs with concepts @@ -615,7 +615,7 @@ class TestQuery: workspace="test_workspace", collection="test_collection", verbose=False, - doc_limit=10 + fetch_limit=10 ) docs, chunk_ids = await query.get_docs(["concept A", "concept B"]) diff --git a/tests/unit/test_retrieval/test_document_rag_rerank.py b/tests/unit/test_retrieval/test_document_rag_rerank.py new file mode 100644 index 00000000..d711d57c --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_rerank.py @@ -0,0 +1,478 @@ +""" +Tests for the optional cross-encoder reranking pass in DocumentRag.query(). + +Two behaviours are covered: + + 1. No-op: when no reranker_client is wired (the default), query() must feed + the LLM the exact same chunks, in the same order, that retrieval produced + - byte-identical to the pre-reranker behaviour - and must NOT emit a + chunk-selection provenance event. + + 2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered + and truncated according to the reranker's results, the LLM receives the + reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted + carrying the per-surviving-chunk scores and chunk references. + +These are pure orchestration tests - the reranker is a stub, so there is no +torch / network dependency. +""" + +import pytest +from unittest.mock import AsyncMock +from dataclasses import dataclass + +from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.base import PromptResult +from trustgraph.schema import RerankerResult + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_WAS_DERIVED_FROM, + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_FOCUS, TG_SYNTHESIS, + TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + return [ + t for t in triples + if t.p.iri == predicate + and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +@dataclass +class ChunkMatch: + """Mimics the result from doc_embeddings_client.query().""" + chunk_id: str + + +# --------------------------------------------------------------------------- +# Fixtures: three retrievable chunks +# --------------------------------------------------------------------------- + +CHUNK_A = "urn:chunk:policy-doc-1:chunk-0" +CHUNK_B = "urn:chunk:policy-doc-1:chunk-1" +CHUNK_C = "urn:chunk:policy-doc-1:chunk-2" + +CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase." +CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays." +CHUNK_C_CONTENT = "Refunds are processed to the original payment method." + +# Retrieval (post-dedupe) order is A, B, C. +ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT] +ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C] + + +def build_mock_clients(): + """ + Build mock subsidiary clients for a document-rag query returning three + distinct chunks (A, B, C) in that order. + """ + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + doc_embeddings_client = AsyncMock() + fetch_chunk = AsyncMock() + + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return PromptResult(response_type="text", text="return policy\nrefund") + return PromptResult(response_type="text", text="") + + prompt_client.prompt.side_effect = mock_prompt + + embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + + # Each concept query returns the same three chunks; dedupe keeps A, B, C. + doc_embeddings_client.query.return_value = [ + ChunkMatch(chunk_id=CHUNK_A), + ChunkMatch(chunk_id=CHUNK_B), + ChunkMatch(chunk_id=CHUNK_C), + ] + + async def mock_fetch(chunk_id): + return { + CHUNK_A: CHUNK_A_CONTENT, + CHUNK_B: CHUNK_B_CONTENT, + CHUNK_C: CHUNK_C_CONTENT, + }[chunk_id] + + fetch_chunk.side_effect = mock_fetch + + prompt_client.document_prompt.return_value = PromptResult( + response_type="text", + text="Items can be returned within 30 days for a full refund.", + ) + + return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk + + +class StubReranker: + """ + Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed, + pre-sorted, truncated list of RerankerResult - exactly the contract the + flashrank service guarantees (sorted desc by score, truncated to limit). + """ + + def __init__(self, results): + self._results = results + self.calls = [] + + async def rerank(self, queries, documents, limit=10, timeout=300): + self.calls.append( + {"queries": queries, "documents": documents, "limit": limit} + ) + return self._results + + +# --------------------------------------------------------------------------- +# 1. No-op: reranker_client=None must not change anything +# --------------------------------------------------------------------------- + +class TestRerankNoOp: + + @pytest.mark.asyncio + async def test_documents_passed_to_llm_are_unchanged(self): + """ + With no reranker wired, document_prompt must receive the retrieved + chunks in the original order and length. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) # reranker_client defaults to None + + await rag.query(query="What is the return policy?") + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + assert passed_docs == ORDERED_CONTENT + + @pytest.mark.asyncio + async def test_no_chunk_selection_event_emitted(self): + """ + Without a reranker, the provenance chain is the original 4 stages: + question, grounding, exploration, synthesis - no focus stage. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + assert len(events) == 4 + types = [ + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS, + ] + for i, expected in enumerate(types): + assert has_type(events[i]["triples"], events[i]["explain_id"], expected) + + # No chunk-selection entity anywhere. + for e in events: + assert not any( + t.o.iri == TG_CHUNK_SELECTION + for t in e["triples"] + if t.p.iri == RDF_TYPE + ) + + @pytest.mark.asyncio + async def test_synthesis_derives_from_exploration_when_no_rerank(self): + """ + No-op lineage is unchanged: synthesis derives from exploration + (there is no focus stage). Guards the conditional synthesis parent. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + # events: question, grounding, exploration, synthesis + exp_uri = events[2]["explain_id"] + syn_event = events[3] + assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri + + +# --------------------------------------------------------------------------- +# 2. Rerank: reorder + truncate + provenance +# --------------------------------------------------------------------------- + +class TestRerankActive: + + def _reranker_keeping_C_then_A(self): + # Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped. + # Pre-sorted desc by score and truncated to limit, per the contract. + return StubReranker([ + RerankerResult(document_id="2", query_id="0", score=0.95), + RerankerResult(document_id="0", query_id="0", score=0.42), + ]) + + @pytest.mark.asyncio + async def test_documents_reordered_and_truncated(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?") + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT] + + @pytest.mark.asyncio + async def test_reranker_called_with_single_query_and_all_docs(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?", doc_limit=2) + + assert len(reranker.calls) == 1 + c = reranker.calls[0] + assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}] + assert c["documents"] == [ + {"id": "0", "text": CHUNK_A_CONTENT}, + {"id": "1", "text": CHUNK_B_CONTENT}, + {"id": "2", "text": CHUNK_C_CONTENT}, + ] + # The rerank narrows down to the final doc_limit, NOT fetch_limit + # (fetch_limit is the over-fetched candidate pool size). + assert c["limit"] == 2 + + @pytest.mark.asyncio + async def test_explicit_fetch_limit_over_fetches_then_narrows(self): + """ + Semantic guard for the value of reranking AND the maintainer's two-limit + contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider + candidate pool so the cross-encoder can surface chunks the bi-encoder + ranked outside the final doc_limit, then the rerank narrows the pool back + down to doc_limit. The fetch_limit is honoured directly (caller controls + how hard the reranker works), not overridden by any heuristic. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + reranker = self._reranker_keeping_C_then_A() + # Candidate pool (fetch_limit=60) >> final doc_limit (6). + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query( + query="What is the return policy?", doc_limit=6, fetch_limit=60, + ) + + # Over-fetch: the embeddings store is queried with the fetch_limit + # budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit + # budget (6 // 2 = 3). This is the bug guard. + q_limit = doc_embeddings_client.query.call_args.kwargs["limit"] + assert q_limit == 30 + + # Narrow: the rerank keeps the final doc_limit (6), not fetch_limit. + assert reranker.calls[0]["limit"] == 6 + + @pytest.mark.asyncio + async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self): + """ + With no fetch_limit passed to query(), the candidate pool falls back to + the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with + doc_limit and reranking keeps its recall benefit out of the box. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + reranker = self._reranker_keeping_C_then_A() + # No fetch_limit -> heuristic default. + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?", doc_limit=20) + + # fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept. + q_limit = doc_embeddings_client.query.call_args.kwargs["limit"] + assert q_limit == 30 + # Rerank narrows to the final doc_limit (20). + assert reranker.calls[0]["limit"] == 20 + + @pytest.mark.asyncio + async def test_fetch_limit_floored_at_doc_limit(self): + """ + A fetch_limit below doc_limit is floored up to doc_limit: retrieval must + never fetch fewer candidates than the rerank is asked to keep, else the + prompt could not be filled. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query( + query="What is the return policy?", doc_limit=10, fetch_limit=4, + ) + + # fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept. + q_limit = doc_embeddings_client.query.call_args.kwargs["limit"] + assert q_limit == 5 + assert reranker.calls[0]["limit"] == 10 + + @pytest.mark.asyncio + async def test_chunk_selection_event_emitted(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + # Now 5 stages: question, grounding, exploration, focus, synthesis. + assert len(events) == 5 + ordered_types = [ + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_FOCUS, TG_SYNTHESIS, + ] + for i, expected in enumerate(ordered_types): + assert has_type(events[i]["triples"], events[i]["explain_id"], expected) + + @pytest.mark.asyncio + async def test_chunk_selection_carries_scores_and_chunk_refs(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + focus_event = events[3] + foc_uri = focus_event["explain_id"] + triples = focus_event["triples"] + + # focus is derived from exploration + exp_uri = events[2]["explain_id"] + assert derived_from(triples, foc_uri) == exp_uri + + # Two ChunkSelection sub-entities, linked from focus. + sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri) + assert len(sel_links) == 2 + + # Each selection has a ChunkSelection type, a chunk document ref and a score. + chunk_refs = set() + scores = set() + for link in sel_links: + sel_uri = link.o.iri + assert has_type(triples, sel_uri, TG_CHUNK_SELECTION) + doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri) + assert doc_ref is not None + chunk_refs.add(doc_ref.o.iri) + score_t = find_triple(triples, TG_SCORE, sel_uri) + assert score_t is not None + scores.add(score_t.o.value) + + # Surviving chunks are C and A (B dropped), with the reranker scores. + assert chunk_refs == {CHUNK_C, CHUNK_A} + assert scores == {"0.95", "0.42"} + + @pytest.mark.asyncio + async def test_all_focus_triples_in_retrieval_graph(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + for t in events[3]["triples"]: + assert t.g == "urn:graph:retrieval" + + @pytest.mark.asyncio + async def test_synthesis_derives_from_focus_when_reranking(self): + """ + When reranking runs, synthesis must derive from the focus node (the + reranked chunks actually fed to the LLM), mirroring GraphRAG - not from + exploration, which would leave focus as a dangling branch and + misrepresent what fed the answer. + """ + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + doc_limit=2, + explain_callback=explain_callback, + ) + + # events: question, grounding, exploration, focus, synthesis + foc_uri = events[3]["explain_id"] + syn_event = events[4] + assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri + + @pytest.mark.asyncio + async def test_empty_docs_skips_reranker(self): + """If retrieval returns no chunks, the reranker is never called.""" + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + doc_embeddings_client.query.return_value = [] # no matches + + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?") + + assert reranker.calls == [] diff --git a/tests/unit/test_retrieval/test_document_rag_reranker_wiring.py b/tests/unit/test_retrieval/test_document_rag_reranker_wiring.py new file mode 100644 index 00000000..bf4337b4 --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_reranker_wiring.py @@ -0,0 +1,89 @@ +""" +Cross-layer wiring contract for the Document-RAG reranker (issue #878). + +The Document-RAG processor registers a ``RerankerClientSpec`` for the +``reranker-request`` / ``reranker-response`` roles (see +``retrieval/document_rag/rag.py``). At flow construction every spec runs +``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add`` +resolves its topics via ``definition["topics"][name]`` - which raises +``KeyError`` if the flow blueprint does not provide those topics. + +This means the monorepo code change is only safe to deploy together with the +companion ``trustgraph-templates`` change that wires ``reranker-request`` / +``reranker-response`` into the Document-RAG flow (mirroring what templates +PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that +contract from the monorepo side: + + * with the reranker topics present (as the updated templates compile them), + the spec binds cleanly and registers the client; + * without them (the pre-companion blueprint), construction fails fast with a + KeyError naming the missing role - documenting exactly why the templates + change is required. + +No broker/network: the pub/sub backend is mocked (topics are bound at add() +time, connections happen later at start()). +""" + +import pytest +from unittest.mock import MagicMock + +from trustgraph.base import RerankerClientSpec + + +def _flow(): + f = MagicMock() + f.workspace = "ws" + f.name = "document-rag" + f.id = "proc1" + f.consumer = {} + return f + + +def _processor(): + p = MagicMock() + p.pubsub = MagicMock() + p.id = "proc1" + p.taskgroup = MagicMock() + return p + + +def _spec(): + return RerankerClientSpec( + request_name="reranker-request", + response_name="reranker-response", + ) + + +# Topics dict as the UPDATED document-store.jsonnet compiles them +# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}). +DEFINITION_WITH_RERANKER = { + "topics": { + "request": "request:tg:document-rag:ws:id", + "response": "response:tg:document-rag:ws:id", + "reranker-request": "request:tg:reranker:ws:id", + "reranker-response": "response:tg:reranker:ws:id", + } +} + +# Pre-companion blueprint: no reranker topics (document-rag before the templates change). +DEFINITION_WITHOUT_RERANKER = { + "topics": { + "request": "request:tg:document-rag:ws:id", + "response": "response:tg:document-rag:ws:id", + } +} + + +def test_reranker_client_binds_when_flow_provides_topics(): + flow = _flow() + _spec().add(flow, _processor(), DEFINITION_WITH_RERANKER) + # The client consumer is registered against the reranker role. + assert "reranker-request" in flow.consumer + + +def test_reranker_client_keyerrors_without_companion_template_topics(): + with pytest.raises(KeyError) as exc: + _spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER) + # Fails fast naming the missing role -> the trustgraph-templates companion + # change (wire reranker-request/response into the document-rag flow) is required. + assert "reranker-request" in str(exc.value) diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index dde3acc1..2bdf3959 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -66,6 +66,7 @@ class TestDocumentRagService: workspace=ANY, # Workspace comes from flow.workspace (mock) collection="test_coll_1", # Must be from message, not hardcoded default doc_limit=5, + fetch_limit=0, # Unset -> core derives the candidate pool explain_callback=ANY, # Explainability callback is always passed save_answer_callback=ANY, # Librarian save callback is always passed ) diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index de592b59..afd48f1b 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -527,7 +527,8 @@ class AsyncFlowInstance: return result.get("response", "") async def document_rag(self, query: str, collection: str, - doc_limit: int = 10, **kwargs: Any) -> str: + doc_limit: int = 10, fetch_limit: int = 0, + **kwargs: Any) -> str: """ Execute document-based RAG query (non-streaming). @@ -541,7 +542,9 @@ class AsyncFlowInstance: Args: query: User query text collection: Collection identifier containing documents - doc_limit: Maximum number of document chunks to retrieve (default: 10) + doc_limit: Document chunks selected into the prompt (default: 10) + fetch_limit: Candidate chunks fetched from the vector store before + reranking (default: 0 = derive from doc_limit) **kwargs: Additional service-specific parameters Returns: @@ -564,6 +567,7 @@ class AsyncFlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": False } request_data.update(kwargs) diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 78b608a7..9eff3d60 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -379,12 +379,14 @@ class AsyncSocketFlowInstance: yield chunk.content async def document_rag(self, query: str, collection: str, - doc_limit: int = 10, streaming: bool = False, **kwargs): + doc_limit: int = 10, fetch_limit: int = 0, + streaming: bool = False, **kwargs): """Document RAG with optional streaming""" request = { "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": streaming } request.update(kwargs) diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 886306b3..b9e9487b 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -415,7 +415,7 @@ class FlowInstance: def document_rag( self, query,collection="default", - doc_limit=10, + doc_limit=10, fetch_limit=0, ): """ Execute document-based Retrieval-Augmented Generation (RAG) query. @@ -426,7 +426,9 @@ class FlowInstance: Args: query: Natural language query collection: Collection identifier (default: "default") - doc_limit: Maximum document chunks to retrieve (default: 10) + doc_limit: Document chunks selected into the prompt (default: 10) + fetch_limit: Candidate chunks fetched from the vector store before + reranking (default: 0 = derive from doc_limit) Returns: str: Generated response incorporating document context @@ -447,6 +449,7 @@ class FlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, } result = self.request( diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 3a06e0d8..efa887a1 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -752,6 +752,7 @@ class SocketFlowInstance: query: str, collection: str, doc_limit: int = 10, + fetch_limit: int = 0, streaming: bool = False, **kwargs: Any ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: @@ -764,6 +765,7 @@ class SocketFlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": streaming } request.update(kwargs) @@ -785,6 +787,7 @@ class SocketFlowInstance: query: str, collection: str, doc_limit: int = 10, + fetch_limit: int = 0, **kwargs: Any ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: """Execute document-based RAG query with explainability support.""" @@ -792,6 +795,7 @@ class SocketFlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": True, "explainable": True, } diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index fe766522..f2a0b29a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator): query=data["query"], collection=data.get("collection", "default"), doc_limit=int(data.get("doc-limit", 20)), + fetch_limit=int(data.get("fetch-limit", 0)), streaming=data.get("streaming", False) ) @@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator): "query": obj.query, "collection": obj.collection, "doc-limit": obj.doc_limit, + "fetch-limit": obj.fetch_limit, "streaming": getattr(obj, "streaming", False) } diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index ce91a3cb..d96bad1e 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -64,6 +64,8 @@ from . uris import ( docrag_question_uri, docrag_grounding_uri, docrag_exploration_uri, + docrag_focus_uri, + chunk_selection_uri, docrag_synthesis_uri, ) @@ -94,6 +96,8 @@ from . namespaces import ( TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + # Chunk selection entity type + TG_CHUNK_SELECTION, # Explainability entity types TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_ANALYSIS, TG_CONCLUSION, @@ -132,6 +136,7 @@ from . triples import ( # Query-time provenance triple builders (DocumentRAG) docrag_question_triples, docrag_exploration_triples, + docrag_chunk_selection_triples, docrag_synthesis_triples, # Utility set_graph, @@ -196,6 +201,8 @@ __all__ = [ "docrag_question_uri", "docrag_grounding_uri", "docrag_exploration_uri", + "docrag_focus_uri", + "chunk_selection_uri", "docrag_synthesis_uri", # Namespaces "PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT", @@ -219,6 +226,8 @@ __all__ = [ "TG_EDGE_SELECTION", # Query-time provenance predicates (DocumentRAG) "TG_CHUNK_COUNT", "TG_SELECTED_CHUNK", + # Chunk selection entity type + "TG_CHUNK_SELECTION", # Explainability entity types "TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS", "TG_ANALYSIS", "TG_CONCLUSION", @@ -254,6 +263,7 @@ __all__ = [ # Query-time provenance triple builders (DocumentRAG) "docrag_question_triples", "docrag_exploration_triples", + "docrag_chunk_selection_triples", "docrag_synthesis_triples", # Agent provenance triple builders "agent_session_triples", diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 6f81f122..da6e30b2 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -76,6 +76,9 @@ TG_EDGE_SELECTION = TG + "EdgeSelection" TG_CHUNK_COUNT = TG + "chunkCount" TG_SELECTED_CHUNK = TG + "selectedChunk" +# Chunk selection entity type (cross-encoder reranked chunk in Focus) +TG_CHUNK_SELECTION = TG + "ChunkSelection" + # Extraction provenance entity types TG_DOCUMENT_TYPE = TG + "Document" TG_PAGE_TYPE = TG + "Page" diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index 8e4871c3..d2374d54 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -30,6 +30,8 @@ from . namespaces import ( TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + # Chunk selection entity type + TG_CHUNK_SELECTION, # Explainability entity types TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, # Unifying types @@ -40,7 +42,10 @@ from . namespaces import ( TG_IN_TOKEN, TG_OUT_TOKEN, ) -from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri +from . uris import ( + activity_uri, agent_uri, subgraph_uri, edge_selection_uri, + chunk_selection_uri, +) def set_graph(triples: List[Triple], graph: str) -> List[Triple]: @@ -718,6 +723,75 @@ def docrag_exploration_triples( return triples +def docrag_chunk_selection_triples( + focus_uri: str, + exploration_uri: str, + selected_chunks_with_scores: List[dict], + session_id: str, +) -> List[Triple]: + """ + Build triples for a document RAG focus entity (chunks selected by the + cross-encoder reranker). + + Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity + derived from exploration, with one ChunkSelection sub-entity per surviving + chunk carrying the chunk reference and the reranker score. + + Structure: + a tg:Focus ; prov:wasDerivedFrom . + tg:selectedChunk . + a tg:ChunkSelection . + tg:document . + tg:score "0.97" . + + Args: + focus_uri: URI of the focus entity (from docrag_focus_uri) + exploration_uri: URI of the parent exploration entity + selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score' + session_id: Session UUID for generating chunk selection URIs + + Returns: + List of Triple objects + """ + triples = [ + _triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)), + _triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")), + _triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), + ] + + for idx, chunk_info in enumerate(selected_chunks_with_scores): + chunk_id = chunk_info.get("chunk_id") + if not chunk_id: + continue + + chunk_sel_uri = chunk_selection_uri(session_id, idx) + + # Link focus to chunk selection entity + triples.append( + _triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri)) + ) + + # Type the chunk selection entity + triples.append( + _triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION)) + ) + + # Reference the actual chunk (in librarian) + triples.append( + _triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id)) + ) + + # Cross-encoder score + score = chunk_info.get("score") + if score is not None: + triples.append( + _triple(chunk_sel_uri, TG_SCORE, _literal(str(score))) + ) + + return triples + + def docrag_synthesis_triples( synthesis_uri: str, exploration_uri: str, diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py index a26ac867..00beacbe 100644 --- a/trustgraph-base/trustgraph/provenance/uris.py +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str: return f"urn:trustgraph:docrag:{session_id}/exploration" +def docrag_focus_uri(session_id: str) -> str: + """ + Generate URI for a document RAG focus entity (chunks selected by the + cross-encoder reranker). + + Args: + session_id: The session UUID. + + Returns: + URN in format: urn:trustgraph:docrag:{uuid}/focus + """ + return f"urn:trustgraph:docrag:{session_id}/focus" + + +def chunk_selection_uri(session_id: str, chunk_index: int) -> str: + """ + Generate URI for a chunk selection item (links a reranked chunk to its + score). Mirrors edge_selection_uri for GraphRAG. + + Args: + session_id: The session UUID. + chunk_index: Index of this chunk in the selection (0-based). + + Returns: + URN in format: urn:trustgraph:prov:chunk:{uuid}:{index} + """ + return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}" + + def docrag_synthesis_uri(session_id: str) -> str: """ Generate URI for a document RAG synthesis entity (final answer). diff --git a/trustgraph-base/trustgraph/provenance/vocabulary.py b/trustgraph-base/trustgraph/provenance/vocabulary.py index 1434d45d..f5139992 100644 --- a/trustgraph-base/trustgraph/provenance/vocabulary.py +++ b/trustgraph-base/trustgraph/provenance/vocabulary.py @@ -30,6 +30,7 @@ from . namespaces import ( TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_SUBAGENT_GOAL, TG_PLAN_STEP, TG_EDGE_SELECTION, TG_SCORE, + TG_CHUNK_SELECTION, ) @@ -95,6 +96,7 @@ TG_CLASS_LABELS = [ _label_triple(TG_PLAN_TYPE, "Plan"), _label_triple(TG_STEP_RESULT, "Step Result"), _label_triple(TG_EDGE_SELECTION, "Edge Selection"), + _label_triple(TG_CHUNK_SELECTION, "Chunk Selection"), ] # TrustGraph predicate labels diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index e937e720..2d4e01e1 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -40,7 +40,10 @@ class GraphRagResponse: class DocumentRagQuery: query: str = "" collection: str = "" - doc_limit: int = 0 + doc_limit: int = 0 # docs selected into the synthesis prompt + fetch_limit: int = 0 # candidate pool fetched from the vector store + # before reranking (0 = derive from doc_limit; + # values below doc_limit are raised to it) streaming: bool = False @dataclass diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index 01512ac8..04f4deda 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' default_doc_limit = 10 +default_fetch_limit = 0 def question_explainable( - url, flow_id, question_text, collection, doc_limit, token=None, debug=False, + url, flow_id, question_text, collection, doc_limit, fetch_limit=0, + token=None, debug=False, workspace="default", ): """Execute document RAG with explainability - shows provenance events inline.""" @@ -39,6 +41,7 @@ def question_explainable( query=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, ): if isinstance(item, RAGChunk): # Print response content @@ -97,7 +100,7 @@ def question_explainable( def question( - url, flow_id, question_text, collection, doc_limit, + url, flow_id, question_text, collection, doc_limit, fetch_limit=0, streaming=True, token=None, explainable=False, debug=False, show_usage=False, workspace="default", ): @@ -109,6 +112,7 @@ def question( question_text=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, token=token, debug=debug, workspace=workspace, @@ -128,6 +132,7 @@ def question( query=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, streaming=True ) @@ -155,6 +160,7 @@ def question( query=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, ) print(result.text) @@ -214,7 +220,15 @@ def main(): '-d', '--doc-limit', type=int, default=default_doc_limit, - help=f'Document limit (default: {default_doc_limit})' + help=f'Documents selected into the prompt (default: {default_doc_limit})' + ) + + parser.add_argument( + '--fetch-limit', + type=int, + default=default_fetch_limit, + help='Candidate documents fetched from the vector store before ' + 'reranking (default: derive from doc-limit)' ) parser.add_argument( @@ -251,6 +265,7 @@ def main(): question_text=args.question, collection=args.collection, doc_limit=args.doc_limit, + fetch_limit=args.fetch_limit, streaming=not args.no_streaming, token=args.token, explainable=args.explainable, diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index ecfa7936..a3730eb9 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -9,10 +9,12 @@ from trustgraph.provenance import ( docrag_question_uri, docrag_grounding_uri, docrag_exploration_uri, + docrag_focus_uri, docrag_synthesis_uri, docrag_question_triples, grounding_triples, docrag_exploration_triples, + docrag_chunk_selection_triples, docrag_synthesis_triples, set_graph, GRAPH_RETRIEVAL, @@ -21,19 +23,25 @@ from trustgraph.provenance import ( # Module logger logger = logging.getLogger(__name__) +# When the caller does not specify a fetch_limit, reranking over-fetches this +# many times the final doc_limit as the candidate pool, so the cross-encoder can +# recover relevant chunks the bi-encoder ranked just outside the top doc_limit. +# This is only the fallback default: an explicit fetch_limit overrides it. +OVERFETCH_FACTOR = 3 + LABEL="http://www.w3.org/2000/01/rdf-schema#label" class Query: def __init__( self, rag, workspace, collection, verbose, - doc_limit=20, track_usage=None, + fetch_limit=20, track_usage=None, ): self.rag = rag self.workspace = workspace self.collection = collection self.verbose = verbose - self.doc_limit = doc_limit + self.fetch_limit = fetch_limit self.track_usage = track_usage async def extract_concepts(self, query): @@ -91,7 +99,7 @@ class Query: # Query chunk matches for each concept concurrently per_concept_limit = max( - 1, self.doc_limit // len(vectors) + 1, self.fetch_limit // len(vectors) ) async def query_concept(vec): @@ -140,6 +148,7 @@ class DocumentRag: def __init__( self, prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk, + reranker_client=None, verbose=False, ): @@ -150,12 +159,16 @@ class DocumentRag: self.doc_embeddings_client = doc_embeddings_client self.fetch_chunk = fetch_chunk + # Optional cross-encoder reranker. When None, the retrieval path is + # byte-identical to the pre-reranker behaviour. + self.reranker_client = reranker_client + if self.verbose: logger.debug("DocumentRag initialized") async def query( self, query, workspace="default", collection="default", - doc_limit=20, streaming=False, chunk_callback=None, + doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None, explain_callback=None, save_answer_callback=None, ): """ @@ -165,7 +178,10 @@ class DocumentRag: query: The query string workspace: Workspace for isolation (also scopes chunk lookup) collection: Collection identifier - doc_limit: Max chunks to retrieve + doc_limit: Chunks selected into the synthesis prompt (after rerank) + fetch_limit: Candidate pool fetched from the vector store before + reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a + reranker is wired, else doc_limit). streaming: Enable streaming LLM response chunk_callback: async def callback(chunk, end_of_stream) for streaming explain_callback: async def callback(triples, explain_id) for explainability @@ -197,6 +213,7 @@ class DocumentRag: q_uri = docrag_question_uri(session_id) gnd_uri = docrag_grounding_uri(session_id) exp_uri = docrag_exploration_uri(session_id) + foc_uri = docrag_focus_uri(session_id) syn_uri = docrag_synthesis_uri(session_id) timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") @@ -209,10 +226,21 @@ class DocumentRag: ) await explain_callback(q_triples, q_uri) + # Resolve the candidate-pool size fetched from the vector store. When a + # reranker is wired, honour an explicit fetch_limit; if unset, fall back + # to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit, + # else the rerank could not fill the prompt. Without a reranker, fetch + # doc_limit as before (byte-identical behaviour). + if self.reranker_client is not None: + fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit) + fetch_count = max(fl, doc_limit) + else: + fetch_count = doc_limit + q = Query( rag=self, workspace=workspace, collection=collection, verbose=self.verbose, - doc_limit=doc_limit, track_usage=track_usage, + fetch_limit=fetch_count, track_usage=track_usage, ) # Extract concepts from query (grounding step) @@ -235,6 +263,7 @@ class DocumentRag: docs, chunk_ids = await q.get_docs(concepts) # Emit exploration explainability after chunks retrieved + # (full candidate set, before any reranking) if explain_callback: exp_triples = set_graph( docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids), @@ -242,6 +271,45 @@ class DocumentRag: ) await explain_callback(exp_triples, exp_uri) + # Optional cross-encoder reranking pass between retrieval and + # synthesis. Mirrors GraphRAG's reranker usage but with a single + # query (the question). When no reranker is wired, this block is + # skipped entirely and behaviour is byte-identical to before. + reranked = False + if self.reranker_client is not None and docs: + results = await self.reranker_client.rerank( + queries=[{"id": "0", "text": query}], + documents=[ + {"id": str(i), "text": d} for i, d in enumerate(docs) + ], + # Narrow the over-fetched candidate pool down to the final + # doc_limit requested for synthesis. + limit=doc_limit, + ) + + # results are sorted desc by score and truncated to limit by the + # reranker service, so order gives the surviving top-N directly. + order = [int(r.document_id) for r in results] + docs = [docs[i] for i in order] + chunk_ids = [chunk_ids[i] for i in order] + reranked = True + + # Emit chunk-selection (focus) explainability: surviving chunks + # with their cross-encoder scores, derived from exploration. + if explain_callback: + selected_chunks_with_scores = [ + {"chunk_id": chunk_ids[i], "score": r.score} + for i, r in enumerate(results) + ] + foc_triples = set_graph( + docrag_chunk_selection_triples( + foc_uri, exp_uri, + selected_chunks_with_scores, session_id, + ), + GRAPH_RETRIEVAL + ) + await explain_callback(foc_triples, foc_uri) + if self.verbose: logger.debug("Invoking LLM...") logger.debug(f"Documents: {docs}") @@ -291,9 +359,15 @@ class DocumentRag: logger.warning(f"Failed to save answer to librarian: {e}") synthesis_doc_id = None + # When reranking ran, synthesis derives from the focus (the + # reranked chunks actually fed to the LLM), as GraphRAG always does. + # When no reranker is wired, there is no focus stage, so synthesis + # derives from exploration (the unchanged no-op lineage) - a + # deliberate divergence from GraphRAG's always-on focus. + syn_parent = foc_uri if reranked else exp_uri syn_triples = set_graph( docrag_synthesis_triples( - syn_uri, exp_uri, + syn_uri, syn_parent, document_id=synthesis_doc_id, in_token=synthesis_result.in_token if synthesis_result else None, out_token=synthesis_result.out_token if synthesis_result else None, diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index c80f4172..158cbefc 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -13,6 +13,7 @@ from . document_rag import DocumentRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec +from ... base import RerankerClientSpec from ... base import LibrarianSpec # Module logger @@ -28,14 +29,21 @@ class Processor(FlowProcessor): doc_limit = params.get("doc_limit", 5) + # Instance-default candidate-pool size fetched before cross-encoder + # reranking; the rerank step narrows it back down to doc_limit for the + # LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit). + fetch_limit = params.get("fetch_limit", 0) + super(Processor, self).__init__( **params | { "id": id, "doc_limit": doc_limit, + "fetch_limit": fetch_limit, } ) self.doc_limit = doc_limit + self.fetch_limit = fetch_limit self.register_specification( ConsumerSpec( @@ -66,6 +74,13 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + RerankerClientSpec( + request_name = "reranker-request", + response_name = "reranker-response", + ) + ) + self.register_specification( ProducerSpec( name = "response", @@ -105,6 +120,7 @@ class Processor(FlowProcessor): doc_embeddings_client = flow("document-embeddings-request"), prompt_client = flow("prompt-request"), fetch_chunk = fetch_chunk, + reranker_client = flow("reranker-request"), verbose=True, ) @@ -113,6 +129,13 @@ class Processor(FlowProcessor): else: doc_limit = self.doc_limit + # Candidate-pool size: per-request override, else the instance + # default; 0 lets the core derive it from doc_limit. + if v.fetch_limit: + fetch_limit = v.fetch_limit + else: + fetch_limit = self.fetch_limit + async def send_explainability(triples, explain_id): await flow("explainability").send(Triples( metadata=Metadata( @@ -163,6 +186,7 @@ class Processor(FlowProcessor): workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, streaming=True, chunk_callback=send_chunk, explain_callback=send_explainability, @@ -188,6 +212,7 @@ class Processor(FlowProcessor): workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, explain_callback=send_explainability, save_answer_callback=save_answer, ) @@ -243,6 +268,15 @@ class Processor(FlowProcessor): help=f'Default document fetch limit (default: 10)' ) + parser.add_argument( + '--fetch-limit', + type=int, + default=0, + help='Candidate chunks to fetch from the vector store and rerank ' + 'before keeping the top doc-limit for the LLM ' + '(default: derive from doc-limit)' + ) + def run(): Processor.launch(default_ident, __doc__)