diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py new file mode 100644 index 00000000..55b9b97f --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -0,0 +1,77 @@ +""" +Unit test for DocumentRAG service parameter passing fix. +Tests that user and collection parameters from the message are correctly +passed to the DocumentRag.query() method. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.retrieval.document_rag.rag import Processor +from trustgraph.schema import DocumentRagQuery, DocumentRagResponse + + +class TestDocumentRagService: + """Test DocumentRAG service parameter passing""" + + @patch('trustgraph.retrieval.document_rag.rag.DocumentRag') + @pytest.mark.asyncio + async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class): + """ + Test that user and collection from message are passed to DocumentRag.query(). + + This is a regression test for the bug where user/collection parameters + were ignored, causing wrong collection names like 'd_trustgraph_default_384' + instead of 'd_my_user_test_coll_1_384'. + """ + # Setup processor + processor = Processor( + taskgroup=MagicMock(), + id="test-processor", + doc_limit=10 + ) + + # Setup mock DocumentRag instance + mock_rag_instance = AsyncMock() + mock_document_rag_class.return_value = mock_rag_instance + mock_rag_instance.query.return_value = "test response" + + # Setup message with custom user/collection + msg = MagicMock() + msg.value.return_value = DocumentRagQuery( + query="test query", + user="my_user", # Custom user (not default "trustgraph") + collection="test_coll_1", # Custom collection (not default "default") + doc_limit=5 + ) + msg.properties.return_value = {"id": "test-id"} + + # Setup flow mock + consumer = MagicMock() + flow = MagicMock() + + # Mock flow to return AsyncMock for clients and response producer + mock_producer = AsyncMock() + def flow_router(service_name): + if service_name == "response": + return mock_producer + return AsyncMock() # embeddings, doc-embeddings, prompt clients + flow.side_effect = flow_router + + # Execute + await processor.on_request(msg, consumer, flow) + + # Verify: DocumentRag.query was called with correct parameters + mock_rag_instance.query.assert_called_once_with( + "test query", + user="my_user", # Must be from message, not hardcoded default + collection="test_coll_1", # Must be from message, not hardcoded default + doc_limit=5 + ) + + # Verify response was sent + mock_producer.send.assert_called_once() + sent_response = mock_producer.send.call_args[0][0] + assert isinstance(sent_response, DocumentRagResponse) + assert sent_response.response == "test response" + assert sent_response.error is None \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 0cca2cff..2e5149c9 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -92,7 +92,12 @@ class Processor(FlowProcessor): else: doc_limit = self.doc_limit - response = await self.rag.query(v.query, doc_limit=doc_limit) + response = await self.rag.query( + v.query, + user=v.user, + collection=v.collection, + doc_limit=doc_limit + ) await flow("response").send( DocumentRagResponse(