trustgraph/tests/unit/test_retrieval/test_document_rag_service.py
Sunny 6c9a545a06
feat: add cross-encoder reranking to Document-RAG with two-limit control (#878) (#1011)
Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after
vector retrieval, over-fetch a wider candidate pool, rerank with the
cross-encoder, and keep the top doc_limit chunks for synthesis.

Per maintainer review, the fetch and select sizes are two caller-controlled
limits rather than one internal heuristic:

- doc_limit:   chunks selected into the synthesis prompt (unchanged meaning).
- fetch_limit: candidate pool pulled from the vector store before reranking.
  0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are
  raised to it. Lets the caller control how hard the reranker has to work.

Details:
- schema: DocumentRagQuery.fetch_limit (additive, backward compatible).
- document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors
  doc_limit); the core applies the heuristic default and derives synthesis
  provenance from the chunk-selection focus when reranking ran.
- provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection).
- request translator + client SDKs + CLI: fetch-limit / --fetch-limit,
  threaded exactly like doc_limit and the GraphRAG limits.
- tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic
  default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring.

Reranking is skipped byte-identically when no reranker role is wired.
Requires the companion trustgraph-templates change wiring the reranker
topics into the document-rag flow (mirrors #279 for GraphRAG).
2026-07-02 09:50:13 +01:00

133 lines
No EOL
5.1 KiB
Python

"""
Unit test for DocumentRAG service parameter passing fix.
Tests that the collection parameter from the message is correctly
passed to the DocumentRag.query() method.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch, ANY
from trustgraph.retrieval.document_rag.rag import Processor
from trustgraph.schema import DocumentRagQuery, DocumentRagResponse
class TestDocumentRagService:
"""Test DocumentRAG service parameter passing"""
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
@pytest.mark.asyncio
async def test_collection_parameter_passed_to_query(self, mock_document_rag_class):
"""
Test that collection from message is passed to DocumentRag.query().
This is a regression test for the bug where the collection parameter
was ignored, causing wrong collection names like 'd_trustgraph_default_384'
instead of one that reflects the requested collection.
"""
# Setup processor
processor = Processor(
taskgroup=MagicMock(),
id="test-processor",
doc_limit=10
)
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
# Setup message with custom collection
msg = MagicMock()
msg.value.return_value = DocumentRagQuery(
query="test query",
collection="test_coll_1", # Custom collection (not default "default")
doc_limit=5
)
msg.properties.return_value = {"id": "test-id"}
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
# Mock flow to return AsyncMock for clients and response producer
mock_producer = AsyncMock()
def flow_router(service_name):
if service_name == "response":
return mock_producer
return AsyncMock() # embeddings, doc-embeddings, prompt clients
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: DocumentRag.query was called with correct parameters
mock_rag_instance.query.assert_called_once_with(
"test query",
workspace=ANY, # Workspace comes from flow.workspace (mock)
collection="test_coll_1", # Must be from message, not hardcoded default
doc_limit=5,
fetch_limit=0, # Unset -> core derives the candidate pool
explain_callback=ANY, # Explainability callback is always passed
save_answer_callback=ANY, # Librarian save callback is always passed
)
# Verify response was sent
mock_producer.send.assert_called_once()
sent_response = mock_producer.send.call_args[0][0]
assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "test response"
assert sent_response.error is None
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
@pytest.mark.asyncio
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_document_rag_class):
"""
Test that non-streaming mode sets end_of_stream=True in response.
This is a regression test for the bug where non-streaming responses
didn't set end_of_stream, causing clients to hang waiting for more data.
"""
# Setup processor
processor = Processor(
taskgroup=MagicMock(),
id="test-processor",
doc_limit=10
)
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
# Setup message with non-streaming request
msg = MagicMock()
msg.value.return_value = DocumentRagQuery(
query="What is a cat?",
collection="default",
doc_limit=10,
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
mock_producer = AsyncMock()
def flow_router(service_name):
if service_name == "response":
return mock_producer
return AsyncMock()
flow.side_effect = flow_router
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: response was sent with end_of_stream=True
mock_producer.send.assert_called_once()
sent_response = mock_producer.send.call_args[0][0]
assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "A document about cats."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
assert sent_response.end_of_session is True
assert sent_response.error is None