diff --git a/tests/contract/test_document_embeddings_contract.py b/tests/contract/test_document_embeddings_contract.py index e0939aaa..c35dde4b 100644 --- a/tests/contract/test_document_embeddings_contract.py +++ b/tests/contract/test_document_embeddings_contract.py @@ -25,13 +25,13 @@ class TestDocumentEmbeddingsRequestContract: user="test_user", collection="test_collection" ) - + # Verify all expected fields exist assert hasattr(request, 'vectors') assert hasattr(request, 'limit') assert hasattr(request, 'user') assert hasattr(request, 'collection') - + # Verify field values assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] assert request.limit == 10 @@ -41,16 +41,16 @@ class TestDocumentEmbeddingsRequestContract: def test_request_translator_to_pulsar(self): """Test request translator converts dict to Pulsar schema""" translator = DocumentEmbeddingsRequestTranslator() - + data = { "vectors": [[0.1, 0.2], [0.3, 0.4]], "limit": 5, "user": "custom_user", "collection": "custom_collection" } - + result = translator.to_pulsar(data) - + assert isinstance(result, DocumentEmbeddingsRequest) assert result.vectors == [[0.1, 0.2], [0.3, 0.4]] assert result.limit == 5 @@ -60,14 +60,14 @@ class TestDocumentEmbeddingsRequestContract: def test_request_translator_to_pulsar_with_defaults(self): """Test request translator uses correct defaults""" translator = DocumentEmbeddingsRequestTranslator() - + data = { "vectors": [[0.1, 0.2]] # No limit, user, or collection provided } - + result = translator.to_pulsar(data) - + assert isinstance(result, DocumentEmbeddingsRequest) assert result.vectors == [[0.1, 0.2]] assert result.limit == 10 # Default @@ -77,16 +77,16 @@ class TestDocumentEmbeddingsRequestContract: def test_request_translator_from_pulsar(self): """Test request translator converts Pulsar schema to dict""" translator = DocumentEmbeddingsRequestTranslator() - + request = DocumentEmbeddingsRequest( vectors=[[0.5, 0.6]], limit=20, user="test_user", collection="test_collection" ) - + result = translator.from_pulsar(request) - + assert isinstance(result, dict) assert result["vectors"] == [[0.5, 0.6]] assert result["limit"] == 20 @@ -99,19 +99,19 @@ class TestDocumentEmbeddingsResponseContract: def test_response_schema_fields(self): """Test that DocumentEmbeddingsResponse has expected fields""" - # Create a response with chunks + # Create a response with chunk_ids response = DocumentEmbeddingsResponse( error=None, - chunks=["chunk1", "chunk2", "chunk3"] + chunk_ids=["chunk1", "chunk2", "chunk3"] ) - + # Verify all expected fields exist assert hasattr(response, 'error') - assert hasattr(response, 'chunks') - + assert hasattr(response, 'chunk_ids') + # Verify field values assert response.error is None - assert response.chunks == ["chunk1", "chunk2", "chunk3"] + assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"] def test_response_schema_with_error(self): """Test response schema with error""" @@ -119,90 +119,79 @@ class TestDocumentEmbeddingsResponseContract: type="query_error", message="Database connection failed" ) - + response = DocumentEmbeddingsResponse( error=error, - chunks=None + chunk_ids=[] ) - - assert response.error == error - assert response.chunks is None - def test_response_translator_from_pulsar_with_chunks(self): - """Test response translator converts Pulsar schema with chunks to dict""" + assert response.error == error + assert response.chunk_ids == [] + + def test_response_translator_from_pulsar_with_chunk_ids(self): + """Test response translator converts Pulsar schema with chunk_ids to dict""" translator = DocumentEmbeddingsResponseTranslator() - + response = DocumentEmbeddingsResponse( error=None, - chunks=["doc1", "doc2", "doc3"] + chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"] ) - - result = translator.from_pulsar(response) - - assert isinstance(result, dict) - assert "chunks" in result - assert result["chunks"] == ["doc1", "doc2", "doc3"] - def test_response_translator_from_pulsar_with_bytes(self): - """Test response translator handles byte chunks correctly""" - translator = DocumentEmbeddingsResponseTranslator() - - response = MagicMock() - response.chunks = [b"byte_chunk1", b"byte_chunk2"] - result = translator.from_pulsar(response) - - assert isinstance(result, dict) - assert "chunks" in result - assert result["chunks"] == ["byte_chunk1", "byte_chunk2"] - def test_response_translator_from_pulsar_with_empty_chunks(self): - """Test response translator handles empty chunks list""" - translator = DocumentEmbeddingsResponseTranslator() - - response = MagicMock() - response.chunks = [] - - result = translator.from_pulsar(response) - assert isinstance(result, dict) - assert "chunks" in result - assert result["chunks"] == [] + assert "chunk_ids" in result + assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"] - def test_response_translator_from_pulsar_with_none_chunks(self): - """Test response translator handles None chunks""" + def test_response_translator_from_pulsar_with_empty_chunk_ids(self): + """Test response translator handles empty chunk_ids list""" translator = DocumentEmbeddingsResponseTranslator() - - response = MagicMock() - response.chunks = None - + + response = DocumentEmbeddingsResponse( + error=None, + chunk_ids=[] + ) + result = translator.from_pulsar(response) - + assert isinstance(result, dict) - assert "chunks" not in result or result.get("chunks") is None + assert "chunk_ids" in result + assert result["chunk_ids"] == [] + + def test_response_translator_from_pulsar_with_none_chunk_ids(self): + """Test response translator handles None chunk_ids""" + translator = DocumentEmbeddingsResponseTranslator() + + response = MagicMock() + response.chunk_ids = None + + result = translator.from_pulsar(response) + + assert isinstance(result, dict) + assert "chunk_ids" not in result or result.get("chunk_ids") is None def test_response_translator_from_response_with_completion(self): """Test response translator with completion flag""" translator = DocumentEmbeddingsResponseTranslator() - + response = DocumentEmbeddingsResponse( error=None, - chunks=["chunk1", "chunk2"] + chunk_ids=["chunk1", "chunk2"] ) - + result, is_final = translator.from_response_with_completion(response) - + assert isinstance(result, dict) - assert "chunks" in result - assert result["chunks"] == ["chunk1", "chunk2"] + assert "chunk_ids" in result + assert result["chunk_ids"] == ["chunk1", "chunk2"] assert is_final is True # Document embeddings responses are always final def test_response_translator_to_pulsar_not_implemented(self): """Test that to_pulsar raises NotImplementedError for responses""" translator = DocumentEmbeddingsResponseTranslator() - + with pytest.raises(NotImplementedError): - translator.to_pulsar({"chunks": ["test"]}) + translator.to_pulsar({"chunk_ids": ["test"]}) class TestDocumentEmbeddingsMessageCompatibility: @@ -217,26 +206,26 @@ class TestDocumentEmbeddingsMessageCompatibility: "user": "test_user", "collection": "test_collection" } - + # Convert to Pulsar request req_translator = DocumentEmbeddingsRequestTranslator() pulsar_request = req_translator.to_pulsar(request_data) - + # Simulate service processing and creating response response = DocumentEmbeddingsResponse( error=None, - chunks=["relevant chunk 1", "relevant chunk 2"] + chunk_ids=["doc1/c1", "doc2/c2"] ) - + # Convert response back to dict resp_translator = DocumentEmbeddingsResponseTranslator() response_data = resp_translator.from_pulsar(response) - + # Verify data integrity assert isinstance(pulsar_request, DocumentEmbeddingsRequest) assert isinstance(response_data, dict) - assert "chunks" in response_data - assert len(response_data["chunks"]) == 2 + assert "chunk_ids" in response_data + assert len(response_data["chunk_ids"]) == 2 def test_error_response_flow(self): """Test error response flow""" @@ -245,17 +234,18 @@ class TestDocumentEmbeddingsMessageCompatibility: type="vector_db_error", message="Collection not found" ) - + response = DocumentEmbeddingsResponse( error=error, - chunks=None + chunk_ids=[] ) - + # Convert response to dict translator = DocumentEmbeddingsResponseTranslator() response_data = translator.from_pulsar(response) - + # Verify error handling assert isinstance(response_data, dict) - # The translator doesn't include error in the dict, only chunks - assert "chunks" not in response_data or response_data.get("chunks") is None \ No newline at end of file + # The translator doesn't include error in the dict, only chunk_ids + assert "chunk_ids" in response_data + assert response_data["chunk_ids"] == [] diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index 3db22c4d..75f148f3 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -11,6 +11,14 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag +# Sample chunk content for testing - maps chunk_id to content +CHUNK_CONTENT = { + "doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.", + "doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.", + "doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.", +} + + @pytest.mark.integration class TestDocumentRagIntegration: """Integration tests for DocumentRAG system coordination""" @@ -27,15 +35,19 @@ class TestDocumentRagIntegration: @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client that returns realistic document chunks""" + """Mock document embeddings client that returns chunk IDs""" client = AsyncMock() - client.query.return_value = [ - "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.", - "Deep learning uses neural networks with multiple layers to model complex patterns in data.", - "Supervised learning algorithms learn from labeled training data to make predictions on new data." - ] + # Now returns chunk_ids instead of actual content + client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"] return client + @pytest.fixture + def mock_fetch_chunk(self): + """Mock fetch_chunk function that retrieves chunk content from librarian""" + async def fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + return fetch + @pytest.fixture def mock_prompt_client(self): """Mock prompt client that generates realistic responses""" @@ -48,17 +60,19 @@ class TestDocumentRagIntegration: return client @pytest.fixture - def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client): + def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, + mock_prompt_client, mock_fetch_chunk): """Create DocumentRag instance with mocked dependencies""" return DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) @pytest.mark.asyncio - async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client, + async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client): """Test complete DocumentRAG pipeline from query to response""" # Arrange @@ -77,14 +91,15 @@ class TestDocumentRagIntegration: # Assert - Verify service coordination mock_embeddings_client.embed.assert_called_once_with(query) - + mock_doc_embeddings_client.query.assert_called_once_with( [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], limit=doc_limit, user=user, collection=collection ) - + + # Documents are fetched from librarian using chunk_ids mock_prompt_client.document_prompt.assert_called_once_with( query=query, documents=[ @@ -101,17 +116,19 @@ class TestDocumentRagIntegration: assert "artificial intelligence" in result.lower() @pytest.mark.asyncio - async def test_document_rag_with_no_documents_found(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client): + async def test_document_rag_with_no_documents_found(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk): """Test DocumentRAG behavior when no documents are retrieved""" # Arrange - mock_doc_embeddings_client.query.return_value = [] # No documents found + mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query." - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) @@ -125,92 +142,98 @@ class TestDocumentRagIntegration: query="very obscure query", documents=[] ) - + assert result == "I couldn't find any relevant documents for your query." @pytest.mark.asyncio - async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client): + async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk): """Test DocumentRAG error handling when embeddings service fails""" # Arrange mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable") - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) # Act & Assert with pytest.raises(Exception) as exc_info: await document_rag.query("test query") - + assert "Embeddings service unavailable" in str(exc_info.value) mock_embeddings_client.embed.assert_called_once() mock_doc_embeddings_client.query.assert_not_called() mock_prompt_client.document_prompt.assert_not_called() @pytest.mark.asyncio - async def test_document_rag_document_service_failure(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client): + async def test_document_rag_document_service_failure(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk): """Test DocumentRAG error handling when document service fails""" # Arrange mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed") - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) # Act & Assert with pytest.raises(Exception) as exc_info: await document_rag.query("test query") - + assert "Document service connection failed" in str(exc_info.value) mock_embeddings_client.embed.assert_called_once() mock_doc_embeddings_client.query.assert_called_once() mock_prompt_client.document_prompt.assert_not_called() @pytest.mark.asyncio - async def test_document_rag_prompt_service_failure(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client): + async def test_document_rag_prompt_service_failure(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk): """Test DocumentRAG error handling when prompt service fails""" # Arrange mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited") - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) # Act & Assert with pytest.raises(Exception) as exc_info: await document_rag.query("test query") - + assert "LLM service rate limited" in str(exc_info.value) mock_embeddings_client.embed.assert_called_once() mock_doc_embeddings_client.query.assert_called_once() mock_prompt_client.document_prompt.assert_called_once() @pytest.mark.asyncio - async def test_document_rag_with_different_document_limits(self, document_rag, + async def test_document_rag_with_different_document_limits(self, document_rag, mock_doc_embeddings_client): """Test DocumentRAG with various document limit configurations""" # Test different document limits test_cases = [1, 5, 10, 25, 50] - + for limit in test_cases: # Reset mock call history mock_doc_embeddings_client.reset_mock() - + # Act await document_rag.query(f"query with limit {limit}", doc_limit=limit) - + # Assert mock_doc_embeddings_client.query.assert_called_once() call_args = mock_doc_embeddings_client.query.call_args @@ -230,14 +253,14 @@ class TestDocumentRagIntegration: for user, collection in test_scenarios: # Reset mock call history mock_doc_embeddings_client.reset_mock() - + # Act await document_rag.query( f"query from {user} in {collection}", user=user, collection=collection ) - + # Assert mock_doc_embeddings_client.query.assert_called_once() call_args = mock_doc_embeddings_client.query.call_args @@ -245,19 +268,21 @@ class TestDocumentRagIntegration: assert call_args.kwargs['collection'] == collection @pytest.mark.asyncio - async def test_document_rag_verbose_logging(self, mock_embeddings_client, - mock_doc_embeddings_client, mock_prompt_client, + async def test_document_rag_verbose_logging(self, mock_embeddings_client, + mock_doc_embeddings_client, mock_prompt_client, + mock_fetch_chunk, caplog): """Test DocumentRAG verbose logging functionality""" import logging - + # Arrange - Configure logging to capture debug messages caplog.set_level(logging.DEBUG) - + document_rag = DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) @@ -269,25 +294,25 @@ class TestDocumentRagIntegration: assert "DocumentRag initialized" in log_messages assert "Constructing prompt..." in log_messages assert "Computing embeddings..." in log_messages - assert "Getting documents..." in log_messages + assert "chunk_ids" in log_messages.lower() assert "Invoking LLM..." in log_messages assert "Query processing complete" in log_messages @pytest.mark.asyncio @pytest.mark.slow - async def test_document_rag_performance_with_large_document_set(self, document_rag, + async def test_document_rag_performance_with_large_document_set(self, document_rag, mock_doc_embeddings_client): """Test DocumentRAG performance with large document retrieval""" - # Arrange - Mock large document set (100 documents) - large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)] - mock_doc_embeddings_client.query.return_value = large_doc_set + # Arrange - Mock large chunk_id set (100 chunks) + large_chunk_ids = [f"doc/c{i}" for i in range(100)] + mock_doc_embeddings_client.query.return_value = large_chunk_ids # Act import time start_time = time.time() - + result = await document_rag.query("performance test query", doc_limit=100) - + end_time = time.time() execution_time = end_time - start_time @@ -309,4 +334,4 @@ class TestDocumentRagIntegration: call_args = mock_doc_embeddings_client.query.call_args assert call_args.kwargs['user'] == "trustgraph" assert call_args.kwargs['collection'] == "default" - assert call_args.kwargs['limit'] == 20 \ No newline at end of file + assert call_args.kwargs['limit'] == 20 diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index 84061add..52b86caf 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -14,6 +14,14 @@ from tests.utils.streaming_assertions import ( ) +# Sample chunk content for testing - maps chunk_id to content +CHUNK_CONTENT = { + "doc/c1": "Machine learning is a subset of AI.", + "doc/c2": "Deep learning uses neural networks.", + "doc/c3": "Supervised learning needs labeled data.", +} + + @pytest.mark.integration class TestDocumentRagStreaming: """Integration tests for DocumentRAG streaming""" @@ -27,15 +35,19 @@ class TestDocumentRagStreaming: @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client""" + """Mock document embeddings client that returns chunk IDs""" client = AsyncMock() - client.query.return_value = [ - "Machine learning is a subset of AI.", - "Deep learning uses neural networks.", - "Supervised learning needs labeled data." - ] + # Now returns chunk_ids instead of actual content + client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"] return client + @pytest.fixture + def mock_fetch_chunk(self): + """Mock fetch_chunk function that retrieves chunk content from librarian""" + async def fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + return fetch + @pytest.fixture def mock_streaming_prompt_client(self, mock_streaming_llm_response): """Mock prompt client with streaming support""" @@ -66,12 +78,13 @@ class TestDocumentRagStreaming: @pytest.fixture def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client, - mock_streaming_prompt_client): + mock_streaming_prompt_client, mock_fetch_chunk): """Create DocumentRag instance with streaming support""" return DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_streaming_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) @@ -190,7 +203,7 @@ class TestDocumentRagStreaming: mock_doc_embeddings_client): """Test streaming with no documents found""" # Arrange - mock_doc_embeddings_client.query.return_value = [] # No documents + mock_doc_embeddings_client.query.return_value = [] # No chunk_ids callback = AsyncMock() # Act diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index d2ceea95..ba492687 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -202,11 +202,18 @@ class TestDocumentRagStreamingProtocol: @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client""" + """Mock document embeddings client that returns chunk IDs""" client = AsyncMock() - client.query.return_value = ["doc1", "doc2"] + client.query.return_value = ["doc/c1", "doc/c2"] return client + @pytest.fixture + def mock_fetch_chunk(self): + """Mock fetch_chunk function that retrieves chunk content from librarian""" + async def fetch(chunk_id, user): + return f"Content for {chunk_id}" + return fetch + @pytest.fixture def mock_streaming_prompt_client(self): """Mock prompt client with streaming support""" @@ -227,12 +234,13 @@ class TestDocumentRagStreamingProtocol: @pytest.fixture def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, - mock_streaming_prompt_client): + mock_streaming_prompt_client, mock_fetch_chunk): """Create DocumentRag instance with mocked dependencies""" return DocumentRag( embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, prompt_client=mock_streaming_prompt_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) diff --git a/tests/unit/test_base/test_document_embeddings_client.py b/tests/unit/test_base/test_document_embeddings_client.py index 1c91408d..81d4a98e 100644 --- a/tests/unit/test_base/test_document_embeddings_client.py +++ b/tests/unit/test_base/test_document_embeddings_client.py @@ -22,7 +22,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunks = ["chunk1", "chunk2", "chunk3"] + mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"] # Mock the request method client.request = AsyncMock(return_value=mock_response) @@ -75,7 +75,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunks = [] + mock_response.chunk_ids = [] client.request = AsyncMock(return_value=mock_response) @@ -93,7 +93,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunks = ["test_chunk"] + mock_response.chunk_ids = ["test_chunk"] client.request = AsyncMock(return_value=mock_response) @@ -115,7 +115,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunks = ["chunk1"] + mock_response.chunk_ids = ["chunk1"] client.request = AsyncMock(return_value=mock_response) @@ -136,7 +136,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunks = ["test_chunk"] + mock_response.chunk_ids = ["test_chunk"] client.request = AsyncMock(return_value=mock_response) diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py index 622529e5..b4c954d8 100644 --- a/tests/unit/test_query/test_doc_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -77,9 +77,9 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock search results mock_results = [ - {"entity": {"doc": "First document chunk"}}, - {"entity": {"doc": "Second document chunk"}}, - {"entity": {"doc": "Third document chunk"}}, + {"entity": {"chunk_id": "First document chunk"}}, + {"entity": {"chunk_id": "Second document chunk"}}, + {"entity": {"chunk_id": "Third document chunk"}}, ] processor.vecstore.search.return_value = mock_results @@ -108,11 +108,11 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock search results - different results for each vector mock_results_1 = [ - {"entity": {"doc": "Document from first vector"}}, - {"entity": {"doc": "Another doc from first vector"}}, + {"entity": {"chunk_id": "Document from first vector"}}, + {"entity": {"chunk_id": "Another doc from first vector"}}, ] mock_results_2 = [ - {"entity": {"doc": "Document from second vector"}}, + {"entity": {"chunk_id": "Document from second vector"}}, ] processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] @@ -147,10 +147,10 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock search results - more results than limit mock_results = [ - {"entity": {"doc": "Document 1"}}, - {"entity": {"doc": "Document 2"}}, - {"entity": {"doc": "Document 3"}}, - {"entity": {"doc": "Document 4"}}, + {"entity": {"chunk_id": "Document 1"}}, + {"entity": {"chunk_id": "Document 2"}}, + {"entity": {"chunk_id": "Document 3"}}, + {"entity": {"chunk_id": "Document 4"}}, ] processor.vecstore.search.return_value = mock_results @@ -217,9 +217,9 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock search results with Unicode content mock_results = [ - {"entity": {"doc": "Document with Unicode: éñ中文🚀"}}, - {"entity": {"doc": "Regular ASCII document"}}, - {"entity": {"doc": "Document with émojis: 😀🎉"}}, + {"entity": {"chunk_id": "Document with Unicode: éñ中文🚀"}}, + {"entity": {"chunk_id": "Regular ASCII document"}}, + {"entity": {"chunk_id": "Document with émojis: 😀🎉"}}, ] processor.vecstore.search.return_value = mock_results @@ -244,8 +244,8 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock search results with large content large_doc = "A" * 10000 # 10KB of content mock_results = [ - {"entity": {"doc": large_doc}}, - {"entity": {"doc": "Small document"}}, + {"entity": {"chunk_id": large_doc}}, + {"entity": {"chunk_id": "Small document"}}, ] processor.vecstore.search.return_value = mock_results @@ -268,9 +268,9 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock search results with special characters mock_results = [ - {"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}}, - {"entity": {"doc": "Document with\nnewlines\tand\ttabs"}}, - {"entity": {"doc": "Document with special chars: @#$%^&*()"}}, + {"entity": {"chunk_id": "Document with \"quotes\" and 'apostrophes'"}}, + {"entity": {"chunk_id": "Document with\nnewlines\tand\ttabs"}}, + {"entity": {"chunk_id": "Document with special chars: @#$%^&*()"}}, ] processor.vecstore.search.return_value = mock_results @@ -350,9 +350,9 @@ class TestMilvusDocEmbeddingsQueryProcessor: ) # Mock search results for each vector - mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}] - mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}] - mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}] + mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}] + mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}] + mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}] processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] result = await processor.query_document_embeddings(query) @@ -378,12 +378,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock search results with duplicates across vectors mock_results_1 = [ - {"entity": {"doc": "Document A"}}, - {"entity": {"doc": "Document B"}}, + {"entity": {"chunk_id": "Document A"}}, + {"entity": {"chunk_id": "Document B"}}, ] mock_results_2 = [ - {"entity": {"doc": "Document B"}}, # Duplicate - {"entity": {"doc": "Document C"}}, + {"entity": {"chunk_id": "Document B"}}, # Duplicate + {"entity": {"chunk_id": "Document C"}}, ] processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] @@ -458,5 +458,5 @@ class TestMilvusDocEmbeddingsQueryProcessor: mock_launch.assert_called_once_with( default_ident, - "\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n" + "\nDocument embeddings query service. Input is vector, output is an array\nof chunk_ids\n" ) \ No newline at end of file diff --git a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py index ad337f73..204c4cc3 100644 --- a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py @@ -77,9 +77,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query response mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'first document chunk'} + mock_point1.payload = {'chunk_id': 'first document chunk'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'second document chunk'} + mock_point2.payload = {'chunk_id': 'second document chunk'} mock_response = MagicMock() mock_response.points = [mock_point1, mock_point2] @@ -132,11 +132,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query responses for different vectors mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'document from vector 1'} + mock_point1.payload = {'chunk_id': 'document from vector 1'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'document from vector 2'} + mock_point2.payload = {'chunk_id': 'document from vector 2'} mock_point3 = MagicMock() - mock_point3.payload = {'doc': 'another document from vector 2'} + mock_point3.payload = {'chunk_id': 'another document from vector 2'} mock_response1 = MagicMock() mock_response1.points = [mock_point1] @@ -192,7 +192,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_points = [] for i in range(10): mock_point = MagicMock() - mock_point.payload = {'doc': f'document chunk {i}'} + mock_point.payload = {'chunk_id': f'document chunk {i}'} mock_points.append(mock_point) mock_response = MagicMock() @@ -270,9 +270,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query responses mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'document from 2D vector'} + mock_point1.payload = {'chunk_id': 'document from 2D vector'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'document from 3D vector'} + mock_point2.payload = {'chunk_id': 'document from 3D vector'} mock_response1 = MagicMock() mock_response1.points = [mock_point1] @@ -326,9 +326,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query response with UTF-8 content mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'} + mock_point1.payload = {'chunk_id': 'Document with UTF-8: café, naïve, résumé'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'Chinese text: 你好世界'} + mock_point2.payload = {'chunk_id': 'Chinese text: 你好世界'} mock_response = MagicMock() mock_response.points = [mock_point1, mock_point2] @@ -399,7 +399,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query response mock_point = MagicMock() - mock_point.payload = {'doc': 'document chunk'} + mock_point.payload = {'chunk_id': 'document chunk'} mock_response = MagicMock() mock_response.points = [mock_point] mock_qdrant_instance.query_points.return_value = mock_response @@ -442,9 +442,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Mock query response with fewer results than limit mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'document 1'} + mock_point1.payload = {'chunk_id': 'document 1'} mock_point2 = MagicMock() - mock_point2.payload = {'doc': 'document 2'} + mock_point2.payload = {'chunk_id': 'document 2'} mock_response = MagicMock() mock_response.points = [mock_point1, mock_point2] @@ -487,11 +487,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - # Mock query response with missing 'doc' key + # Mock query response with missing 'chunk_id' key mock_point1 = MagicMock() - mock_point1.payload = {'doc': 'valid document'} + mock_point1.payload = {'chunk_id': 'valid document'} mock_point2 = MagicMock() - mock_point2.payload = {} # Missing 'doc' key + mock_point2.payload = {} # Missing 'chunk_id' key mock_point3 = MagicMock() mock_point3.payload = {'other_key': 'invalid'} # Wrong key @@ -514,7 +514,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'payload_collection' # Act & Assert - # This should raise a KeyError when trying to access payload['doc'] + # This should raise a KeyError when trying to access payload['chunk_id'] with pytest.raises(KeyError): await processor.query_document_embeddings(mock_message) diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 590572bc..c9f5e8e1 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -8,48 +8,75 @@ from unittest.mock import MagicMock, AsyncMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query +# Sample chunk content mapping for tests +CHUNK_CONTENT = { + "doc/c1": "Document 1 content", + "doc/c2": "Document 2 content", + "doc/c3": "Relevant document content", + "doc/c4": "Another document", + "doc/c5": "Default doc", + "doc/c6": "Verbose test doc", + "doc/c7": "Verbose doc content", + "doc/ml1": "Machine learning is a subset of artificial intelligence...", + "doc/ml2": "ML algorithms learn patterns from data to make predictions...", + "doc/ml3": "Common ML techniques include supervised and unsupervised learning...", +} + + +@pytest.fixture +def mock_fetch_chunk(): + """Create a mock fetch_chunk function""" + async def fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + return fetch + + class TestDocumentRag: """Test cases for DocumentRag class""" - def test_document_rag_initialization_with_defaults(self): + def test_document_rag_initialization_with_defaults(self, mock_fetch_chunk): """Test DocumentRag initialization with default verbose setting""" # Create mock clients mock_prompt_client = MagicMock() mock_embeddings_client = MagicMock() mock_doc_embeddings_client = MagicMock() - + # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, - doc_embeddings_client=mock_doc_embeddings_client + doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk ) - + # Verify initialization assert document_rag.prompt_client == mock_prompt_client assert document_rag.embeddings_client == mock_embeddings_client assert document_rag.doc_embeddings_client == mock_doc_embeddings_client + assert document_rag.fetch_chunk == mock_fetch_chunk assert document_rag.verbose is False # Default value - def test_document_rag_initialization_with_verbose(self): + def test_document_rag_initialization_with_verbose(self, mock_fetch_chunk): """Test DocumentRag initialization with verbose enabled""" # Create mock clients mock_prompt_client = MagicMock() mock_embeddings_client = MagicMock() mock_doc_embeddings_client = MagicMock() - + # Initialize DocumentRag with verbose=True document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) - + # Verify initialization assert document_rag.prompt_client == mock_prompt_client assert document_rag.embeddings_client == mock_embeddings_client assert document_rag.doc_embeddings_client == mock_doc_embeddings_client + assert document_rag.fetch_chunk == mock_fetch_chunk assert document_rag.verbose is True @@ -60,7 +87,7 @@ class TestQuery: """Test Query initialization with default parameters""" # Create mock DocumentRag mock_rag = MagicMock() - + # Initialize Query with defaults query = Query( rag=mock_rag, @@ -68,7 +95,7 @@ class TestQuery: collection="test_collection", verbose=False ) - + # Verify initialization assert query.rag == mock_rag assert query.user == "test_user" @@ -80,7 +107,7 @@ class TestQuery: """Test Query initialization with custom doc_limit""" # Create mock DocumentRag mock_rag = MagicMock() - + # Initialize Query with custom doc_limit query = Query( rag=mock_rag, @@ -89,7 +116,7 @@ class TestQuery: verbose=True, doc_limit=50 ) - + # Verify initialization assert query.rag == mock_rag assert query.user == "custom_user" @@ -104,11 +131,11 @@ class TestQuery: mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - + # Mock the embed method to return test vectors expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] mock_embeddings_client.embed.return_value = expected_vectors - + # Initialize Query query = Query( rag=mock_rag, @@ -116,14 +143,14 @@ class TestQuery: collection="test_collection", verbose=False ) - + # Call get_vector test_query = "What documents are relevant?" result = await query.get_vector(test_query) - + # Verify embeddings client was called correctly mock_embeddings_client.embed.assert_called_once_with(test_query) - + # Verify result matches expected vectors assert result == expected_vectors @@ -136,15 +163,20 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client - + + # Mock fetch_chunk function + async def mock_fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + mock_rag.fetch_chunk = mock_fetch + # Mock the embedding and document query responses test_vectors = [[0.1, 0.2, 0.3]] mock_embeddings_client.embed.return_value = test_vectors - - # Mock document results - test_docs = ["Document 1 content", "Document 2 content"] - mock_doc_embeddings_client.query.return_value = test_docs - + + # Mock document embeddings returns chunk_ids + test_chunk_ids = ["doc/c1", "doc/c2"] + mock_doc_embeddings_client.query.return_value = test_chunk_ids + # Initialize Query query = Query( rag=mock_rag, @@ -153,14 +185,14 @@ class TestQuery: verbose=False, doc_limit=15 ) - + # Call get_docs test_query = "Find relevant documents" result = await query.get_docs(test_query) - + # Verify embeddings client was called mock_embeddings_client.embed.assert_called_once_with(test_query) - + # Verify doc embeddings client was called correctly mock_doc_embeddings_client.query.assert_called_once_with( test_vectors, @@ -168,35 +200,37 @@ class TestQuery: user="test_user", collection="test_collection" ) - - # Verify result is list of documents - assert result == test_docs + + # Verify result is list of fetched document content + assert "Document 1 content" in result + assert "Document 2 content" in result @pytest.mark.asyncio - async def test_document_rag_query_method(self): + async def test_document_rag_query_method(self, mock_fetch_chunk): """Test DocumentRag.query method orchestrates full document RAG pipeline""" # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - - # Mock embeddings and document responses + + # Mock embeddings and document embeddings responses test_vectors = [[0.1, 0.2, 0.3]] - test_docs = ["Relevant document content", "Another document"] + test_chunk_ids = ["doc/c3", "doc/c4"] expected_response = "This is the document RAG response" - + mock_embeddings_client.embed.return_value = test_vectors - mock_doc_embeddings_client.query.return_value = test_docs + mock_doc_embeddings_client.query.return_value = test_chunk_ids mock_prompt_client.document_prompt.return_value = expected_response - + # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) - + # Call DocumentRag.query result = await document_rag.query( query="test query", @@ -204,10 +238,10 @@ class TestQuery: collection="test_collection", doc_limit=10 ) - + # Verify embeddings client was called mock_embeddings_client.embed.assert_called_once_with("test query") - + # Verify doc embeddings client was called mock_doc_embeddings_client.query.assert_called_once_with( test_vectors, @@ -215,39 +249,43 @@ class TestQuery: user="test_user", collection="test_collection" ) - - # Verify prompt client was called with documents and query - mock_prompt_client.document_prompt.assert_called_once_with( - query="test query", - documents=test_docs - ) - + + # Verify prompt client was called with fetched documents and query + mock_prompt_client.document_prompt.assert_called_once() + call_args = mock_prompt_client.document_prompt.call_args + assert call_args.kwargs["query"] == "test query" + # Documents should be fetched content, not chunk_ids + docs = call_args.kwargs["documents"] + assert "Relevant document content" in docs + assert "Another document" in docs + # Verify result assert result == expected_response @pytest.mark.asyncio - async def test_document_rag_query_with_defaults(self): + async def test_document_rag_query_with_defaults(self, mock_fetch_chunk): """Test DocumentRag.query method with default parameters""" # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - + # Mock responses mock_embeddings_client.embed.return_value = [[0.1, 0.2]] - mock_doc_embeddings_client.query.return_value = ["Default doc"] + mock_doc_embeddings_client.query.return_value = ["doc/c5"] mock_prompt_client.document_prompt.return_value = "Default response" - + # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, - doc_embeddings_client=mock_doc_embeddings_client + doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk ) - + # Call DocumentRag.query with minimal parameters result = await document_rag.query("simple query") - + # Verify default parameters were used mock_doc_embeddings_client.query.assert_called_once_with( [[0.1, 0.2]], @@ -255,7 +293,7 @@ class TestQuery: user="trustgraph", # Default user collection="default" # Default collection ) - + assert result == "Default response" @pytest.mark.asyncio @@ -267,11 +305,16 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client - + + # Mock fetch_chunk + async def mock_fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + mock_rag.fetch_chunk = mock_fetch + # Mock responses mock_embeddings_client.embed.return_value = [[0.7, 0.8]] - mock_doc_embeddings_client.query.return_value = ["Verbose test doc"] - + mock_doc_embeddings_client.query.return_value = ["doc/c6"] + # Initialize Query with verbose=True query = Query( rag=mock_rag, @@ -280,49 +323,51 @@ class TestQuery: verbose=True, doc_limit=5 ) - + # Call get_docs result = await query.get_docs("verbose test") - + # Verify calls were made mock_embeddings_client.embed.assert_called_once_with("verbose test") mock_doc_embeddings_client.query.assert_called_once() - - # Verify result - assert result == ["Verbose test doc"] + + # Verify result contains fetched content + assert "Verbose test doc" in result @pytest.mark.asyncio - async def test_document_rag_query_with_verbose(self): + async def test_document_rag_query_with_verbose(self, mock_fetch_chunk): """Test DocumentRag.query method with verbose logging enabled""" # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - + # Mock responses mock_embeddings_client.embed.return_value = [[0.3, 0.4]] - mock_doc_embeddings_client.query.return_value = ["Verbose doc content"] + mock_doc_embeddings_client.query.return_value = ["doc/c7"] mock_prompt_client.document_prompt.return_value = "Verbose RAG response" - + # Initialize DocumentRag with verbose=True document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=True ) - + # Call DocumentRag.query result = await document_rag.query("verbose query test") - + # Verify all clients were called mock_embeddings_client.embed.assert_called_once_with("verbose query test") mock_doc_embeddings_client.query.assert_called_once() - mock_prompt_client.document_prompt.assert_called_once_with( - query="verbose query test", - documents=["Verbose doc content"] - ) - + + # Verify prompt client was called with fetched content + call_args = mock_prompt_client.document_prompt.call_args + assert call_args.kwargs["query"] == "verbose query test" + assert "Verbose doc content" in call_args.kwargs["documents"] + assert result == "Verbose RAG response" @pytest.mark.asyncio @@ -334,11 +379,16 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client - - # Mock responses - empty document list + + # Mock fetch_chunk (won't be called if no chunk_ids) + async def mock_fetch(chunk_id, user): + return f"Content for {chunk_id}" + mock_rag.fetch_chunk = mock_fetch + + # Mock responses - empty chunk_id list mock_embeddings_client.embed.return_value = [[0.1, 0.2]] - mock_doc_embeddings_client.query.return_value = [] # No documents found - + mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found + # Initialize Query query = Query( rag=mock_rag, @@ -346,47 +396,48 @@ class TestQuery: collection="test_collection", verbose=False ) - + # Call get_docs result = await query.get_docs("query with no results") - + # Verify calls were made mock_embeddings_client.embed.assert_called_once_with("query with no results") mock_doc_embeddings_client.query.assert_called_once() - + # Verify empty result is returned assert result == [] @pytest.mark.asyncio - async def test_document_rag_query_with_empty_documents(self): + async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk): """Test DocumentRag.query method when no documents are retrieved""" # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - - # Mock responses - no documents found + + # Mock responses - no chunk_ids found mock_embeddings_client.embed.return_value = [[0.5, 0.6]] - mock_doc_embeddings_client.query.return_value = [] # Empty document list + mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list mock_prompt_client.document_prompt.return_value = "No documents found response" - + # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) - + # Call DocumentRag.query result = await document_rag.query("query with no matching docs") - + # Verify prompt client was called with empty document list mock_prompt_client.document_prompt.assert_called_once_with( query="query with no matching docs", documents=[] ) - + assert result == "No documents found response" @pytest.mark.asyncio @@ -396,11 +447,11 @@ class TestQuery: mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - + # Mock the embed method expected_vectors = [[0.9, 1.0, 1.1]] mock_embeddings_client.embed.return_value = expected_vectors - + # Initialize Query with verbose=True query = Query( rag=mock_rag, @@ -408,68 +459,71 @@ class TestQuery: collection="test_collection", verbose=True ) - + # Call get_vector result = await query.get_vector("verbose vector test") - + # Verify embeddings client was called mock_embeddings_client.embed.assert_called_once_with("verbose vector test") - + # Verify result assert result == expected_vectors @pytest.mark.asyncio - async def test_document_rag_integration_flow(self): + async def test_document_rag_integration_flow(self, mock_fetch_chunk): """Test complete DocumentRag integration with realistic data flow""" # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - + # Mock realistic responses query_text = "What is machine learning?" query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] - retrieved_docs = [ - "Machine learning is a subset of artificial intelligence...", - "ML algorithms learn patterns from data to make predictions...", - "Common ML techniques include supervised and unsupervised learning..." - ] + retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"] final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." - + mock_embeddings_client.embed.return_value = query_vectors - mock_doc_embeddings_client.query.return_value = retrieved_docs + mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids mock_prompt_client.document_prompt.return_value = final_response - + # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client, + fetch_chunk=mock_fetch_chunk, verbose=False ) - + # Execute full pipeline result = await document_rag.query( query=query_text, - user="research_user", + user="research_user", collection="ml_knowledge", doc_limit=25 ) - + # Verify complete pipeline execution mock_embeddings_client.embed.assert_called_once_with(query_text) - + mock_doc_embeddings_client.query.assert_called_once_with( query_vectors, limit=25, user="research_user", collection="ml_knowledge" ) - - mock_prompt_client.document_prompt.assert_called_once_with( - query=query_text, - documents=retrieved_docs - ) - + + # Verify prompt client was called with fetched document content + mock_prompt_client.document_prompt.assert_called_once() + call_args = mock_prompt_client.document_prompt.call_args + assert call_args.kwargs["query"] == query_text + + # Verify documents were fetched from chunk_ids + docs = call_args.kwargs["documents"] + assert "Machine learning is a subset of artificial intelligence..." in docs + assert "ML algorithms learn patterns from data to make predictions..." in docs + assert "Common ML techniques include supervised and unsupervised learning..." in docs + # Verify final result - assert result == final_response \ No newline at end of file + assert result == final_response diff --git a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py index d957d711..ef66b741 100644 --- a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py @@ -22,11 +22,11 @@ class TestMilvusDocEmbeddingsStorageProcessor: # Create test document embeddings chunk1 = ChunkEmbeddings( - chunk=b"This is the first document chunk", + chunk_id="This is the first document chunk", vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] ) chunk2 = ChunkEmbeddings( - chunk=b"This is the second document chunk", + chunk_id="This is the second document chunk", vectors=[[0.7, 0.8, 0.9]] ) message.chunks = [chunk1, chunk2] @@ -84,7 +84,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( - chunk=b"Test document content", + chunk_id="Test document content", vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] ) message.chunks = [chunk] @@ -136,7 +136,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( - chunk=b"", + chunk_id="", vectors=[[0.1, 0.2, 0.3]] ) message.chunks = [chunk] @@ -148,51 +148,62 @@ class TestMilvusDocEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_document_embeddings_none_chunk(self, processor): - """Test storing document embeddings with None chunk (should be skipped)""" + """Test storing document embeddings with None chunk_id""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + chunk = ChunkEmbeddings( - chunk=None, + chunk_id=None, vectors=[[0.1, 0.2, 0.3]] ) message.chunks = [chunk] - + await processor.store_document_embeddings(message) - - # Verify no insert was called for None chunk - processor.vecstore.insert.assert_not_called() + + # Note: Implementation passes through None chunk_ids (only skips empty string "") + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3], None, 'test_user', 'test_collection' + ) @pytest.mark.asyncio async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor): - """Test storing document embeddings with mix of valid and invalid chunks""" + """Test storing document embeddings with mix of valid and empty chunks""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + valid_chunk = ChunkEmbeddings( - chunk=b"Valid document content", + chunk_id="Valid document content", vectors=[[0.1, 0.2, 0.3]] ) empty_chunk = ChunkEmbeddings( - chunk=b"", + chunk_id="", vectors=[[0.4, 0.5, 0.6]] ) - none_chunk = ChunkEmbeddings( - chunk=None, + another_valid = ChunkEmbeddings( + chunk_id="Another valid chunk", vectors=[[0.7, 0.8, 0.9]] ) - message.chunks = [valid_chunk, empty_chunk, none_chunk] - + message.chunks = [valid_chunk, empty_chunk, another_valid] + await processor.store_document_embeddings(message) - - # Verify only valid chunk was inserted with user/collection parameters - processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection' - ) + + # Verify valid chunks were inserted, empty string chunk was skipped + expected_calls = [ + ([0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection'), + ([0.7, 0.8, 0.9], "Another valid chunk", 'test_user', 'test_collection'), + ] + + assert processor.vecstore.insert.call_count == 2 + for i, (expected_vec, expected_chunk_id, expected_user, expected_collection) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_chunk_id + assert actual_call[0][2] == expected_user + assert actual_call[0][3] == expected_collection @pytest.mark.asyncio async def test_store_document_embeddings_empty_chunks_list(self, processor): @@ -217,7 +228,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( - chunk=b"Document with no vectors", + chunk_id="Document with no vectors", vectors=[] ) message.chunks = [chunk] @@ -236,7 +247,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( - chunk=b"Document with mixed dimensions", + chunk_id="Document with mixed dimensions", vectors=[ [0.1, 0.2], # 2D vector [0.3, 0.4, 0.5, 0.6], # 4D vector @@ -264,46 +275,46 @@ class TestMilvusDocEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_document_embeddings_unicode_content(self, processor): - """Test storing document embeddings with Unicode content""" + """Test storing document embeddings with Unicode content in chunk_id""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + chunk = ChunkEmbeddings( - chunk="Document with Unicode: éñ中文🚀".encode('utf-8'), + chunk_id="chunk/doc/unicode-éñ中文🚀", vectors=[[0.1, 0.2, 0.3]] ) message.chunks = [chunk] - + await processor.store_document_embeddings(message) - - # Verify Unicode content was properly decoded and inserted with user/collection parameters + + # Verify Unicode chunk_id was stored correctly with user/collection parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀", 'test_user', 'test_collection' + [0.1, 0.2, 0.3], "chunk/doc/unicode-éñ中文🚀", 'test_user', 'test_collection' ) @pytest.mark.asyncio - async def test_store_document_embeddings_large_chunks(self, processor): - """Test storing document embeddings with large document chunks""" + async def test_store_document_embeddings_large_chunk_id(self, processor): + """Test storing document embeddings with long chunk_id""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - - # Create a large document chunk - large_content = "A" * 10000 # 10KB of content + + # Create a long chunk_id + long_chunk_id = "chunk/doc/" + "a" * 200 chunk = ChunkEmbeddings( - chunk=large_content.encode('utf-8'), + chunk_id=long_chunk_id, vectors=[[0.1, 0.2, 0.3]] ) message.chunks = [chunk] - + await processor.store_document_embeddings(message) - - # Verify large content was inserted with user/collection parameters + + # Verify long chunk_id was inserted with user/collection parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], large_content, 'test_user', 'test_collection' + [0.1, 0.2, 0.3], long_chunk_id, 'test_user', 'test_collection' ) @pytest.mark.asyncio @@ -315,7 +326,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( - chunk=b" \n\t ", + chunk_id=" \n\t ", vectors=[[0.1, 0.2, 0.3]] ) message.chunks = [chunk] @@ -346,7 +357,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = collection chunk = ChunkEmbeddings( - chunk=b"Test content", + chunk_id="Test content", vectors=[[0.1, 0.2, 0.3]] ) message.chunks = [chunk] @@ -367,7 +378,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message1.metadata.user = 'user1' message1.metadata.collection = 'collection1' chunk1 = ChunkEmbeddings( - chunk=b"User1 content", + chunk_id="User1 content", vectors=[[0.1, 0.2, 0.3]] ) message1.chunks = [chunk1] @@ -378,7 +389,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message2.metadata.user = 'user2' message2.metadata.collection = 'collection2' chunk2 = ChunkEmbeddings( - chunk=b"User2 content", + chunk_id="User2 content", vectors=[[0.4, 0.5, 0.6]] ) message2.chunks = [chunk2] @@ -409,7 +420,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata.collection = 'test-collection.v1' # Collection with special chars chunk = ChunkEmbeddings( - chunk=b"Special chars test", + chunk_id="Special chars test", vectors=[[0.1, 0.2, 0.3]] ) message.chunks = [chunk] diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index fc839482..e3498ce9 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -20,7 +20,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -34,7 +34,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Assert # Verify QdrantClient was created with correct parameters mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key') - + # Verify processor attributes assert hasattr(processor, 'qdrant') assert processor.qdrant == mock_qdrant_instance @@ -45,7 +45,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - + config = { 'taskgroup': AsyncMock(), 'id': 'test-doc-qdrant-processor' @@ -69,7 +69,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value = MagicMock() mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123') - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -86,13 +86,13 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' - + mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'test document chunk' + mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions - + mock_message.chunks = [mock_chunk] - + # Act await processor.store_document_embeddings(mock_message) @@ -100,18 +100,18 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Verify collection existence was checked (with dimension suffix) expected_collection = 'd_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3] mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) - + # Verify upsert was called mock_qdrant_instance.upsert.assert_called_once() - + # Verify upsert parameters upsert_call_args = mock_qdrant_instance.upsert.call_args assert upsert_call_args[1]['collection_name'] == 'd_test_user_test_collection_3' assert len(upsert_call_args[1]['points']) == 1 - + point = upsert_call_args[1]['points'][0] assert point.vector == [0.1, 0.2, 0.3] - assert point.payload['doc'] == 'test document chunk' + assert point.payload['chunk_id'] == 'doc/c1' @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @@ -123,7 +123,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value = MagicMock() mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid') - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -140,38 +140,38 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_message.metadata.user = 'multi_user' mock_message.metadata.collection = 'multi_collection' - + mock_chunk1 = MagicMock() - mock_chunk1.chunk.decode.return_value = 'first document chunk' + mock_chunk1.chunk_id = 'doc/c1' mock_chunk1.vectors = [[0.1, 0.2]] - + mock_chunk2 = MagicMock() - mock_chunk2.chunk.decode.return_value = 'second document chunk' + mock_chunk2.chunk_id = 'doc/c2' mock_chunk2.vectors = [[0.3, 0.4]] - + mock_message.chunks = [mock_chunk1, mock_chunk2] - + # Act await processor.store_document_embeddings(mock_message) # Assert # Should be called twice (once per chunk) assert mock_qdrant_instance.upsert.call_count == 2 - + # Verify both chunks were processed upsert_calls = mock_qdrant_instance.upsert.call_args_list - + # First chunk first_call = upsert_calls[0] first_point = first_call[1]['points'][0] assert first_point.vector == [0.1, 0.2] - assert first_point.payload['doc'] == 'first document chunk' - + assert first_point.payload['chunk_id'] == 'doc/c1' + # Second chunk second_call = upsert_calls[1] second_point = second_call[1]['points'][0] assert second_point.vector == [0.3, 0.4] - assert second_point.payload['doc'] == 'second document chunk' + assert second_point.payload['chunk_id'] == 'doc/c2' @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @@ -183,7 +183,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value = MagicMock() mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid') - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -200,41 +200,41 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_message.metadata.user = 'vector_user' mock_message.metadata.collection = 'vector_collection' - + mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'multi-vector document chunk' + mock_chunk.chunk_id = 'doc/multi-vector' mock_chunk.vectors = [ [0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9] ] - + mock_message.chunks = [mock_chunk] - + # Act await processor.store_document_embeddings(mock_message) # Assert # Should be called 3 times (once per vector) assert mock_qdrant_instance.upsert.call_count == 3 - + # Verify all vectors were processed upsert_calls = mock_qdrant_instance.upsert.call_args_list - + expected_vectors = [ [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], + [0.4, 0.5, 0.6], [0.7, 0.8, 0.9] ] - + for i, call in enumerate(upsert_calls): point = call[1]['points'][0] assert point.vector == expected_vectors[i] - assert point.payload['doc'] == 'multi-vector document chunk' + assert point.payload['chunk_id'] == 'doc/multi-vector' @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - async def test_store_document_embeddings_empty_chunk(self, mock_qdrant_client): - """Test storing document embeddings skips empty chunks""" + async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client): + """Test storing document embeddings skips empty chunk_ids""" # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True # Collection exists @@ -249,13 +249,13 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - # Create mock message with empty chunk + # Create mock message with empty chunk_id mock_message = MagicMock() mock_message.metadata.user = 'empty_user' mock_message.metadata.collection = 'empty_collection' mock_chunk_empty = MagicMock() - mock_chunk_empty.chunk.decode.return_value = "" # Empty string + mock_chunk_empty.chunk_id = "" # Empty chunk_id mock_chunk_empty.vectors = [[0.1, 0.2]] mock_message.chunks = [mock_chunk_empty] @@ -264,9 +264,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.store_document_embeddings(mock_message) # Assert - # Should not call upsert for empty chunks + # Should not call upsert for empty chunk_ids mock_qdrant_instance.upsert.assert_not_called() - # collection_exists should NOT be called since we return early for empty chunks + # collection_exists should NOT be called since we return early for empty chunk_ids mock_qdrant_instance.collection_exists.assert_not_called() @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @@ -298,7 +298,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.metadata.collection = 'new_collection' mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'test chunk' + mock_chunk.chunk_id = 'doc/test-chunk' mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions mock_message.chunks = [mock_chunk] @@ -350,7 +350,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.metadata.collection = 'error_collection' mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'test chunk' + mock_chunk.chunk_id = 'doc/test-chunk' mock_chunk.vectors = [[0.1, 0.2]] mock_message.chunks = [mock_chunk] @@ -388,7 +388,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message1.metadata.collection = 'cache_collection' mock_chunk1 = MagicMock() - mock_chunk1.chunk.decode.return_value = 'first chunk' + mock_chunk1.chunk_id = 'doc/c1' mock_chunk1.vectors = [[0.1, 0.2, 0.3]] mock_message1.chunks = [mock_chunk1] @@ -406,7 +406,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message2.metadata.collection = 'cache_collection' mock_chunk2 = MagicMock() - mock_chunk2.chunk.decode.return_value = 'second chunk' + mock_chunk2.chunk_id = 'doc/c2' mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3) mock_message2.chunks = [mock_chunk2] @@ -452,7 +452,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.metadata.collection = 'dim_collection' mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'dimension test chunk' + mock_chunk.chunk_id = 'doc/dim-test' mock_chunk.vectors = [ [0.1, 0.2], # 2 dimensions [0.3, 0.4, 0.5] # 3 dimensions @@ -485,28 +485,28 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Arrange mock_qdrant_client.return_value = MagicMock() mock_parser = MagicMock() - + # Act with patch('trustgraph.base.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args: Processor.add_args(mock_parser) # Assert mock_parent_add_args.assert_called_once_with(mock_parser) - + # Verify processor-specific arguments were added assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - async def test_utf8_decoding_handling(self, mock_uuid, mock_qdrant_client): - """Test proper UTF-8 decoding of chunk text""" + async def test_chunk_id_with_special_characters(self, mock_uuid, mock_qdrant_client): + """Test storing chunk_id with special characters (URIs)""" # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value = MagicMock() mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid') - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -517,65 +517,28 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) # Add collection to known_collections (simulates config push) - processor.known_collections[('utf8_user', 'utf8_collection')] = {} + processor.known_collections[('uri_user', 'uri_collection')] = {} - # Create mock message with UTF-8 encoded text + # Create mock message with URI-style chunk_id mock_message = MagicMock() - mock_message.metadata.user = 'utf8_user' - mock_message.metadata.collection = 'utf8_collection' - + mock_message.metadata.user = 'uri_user' + mock_message.metadata.collection = 'uri_collection' + mock_chunk = MagicMock() - mock_chunk.chunk.decode.return_value = 'UTF-8 text with special chars: café, naïve, résumé' + mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3' mock_chunk.vectors = [[0.1, 0.2]] - + mock_message.chunks = [mock_chunk] - + # Act await processor.store_document_embeddings(mock_message) # Assert - # Verify chunk.decode was called with 'utf-8' - mock_chunk.chunk.decode.assert_called_with('utf-8') - - # Verify the decoded text was stored in payload + # Verify the chunk_id was stored correctly upsert_call_args = mock_qdrant_instance.upsert.call_args point = upsert_call_args[1]['points'][0] - assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé' - - @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - async def test_chunk_decode_exception_handling(self, mock_qdrant_client): - """Test handling of chunk decode exceptions""" - # Arrange - mock_qdrant_instance = MagicMock() - mock_qdrant_client.return_value = mock_qdrant_instance - - config = { - 'store_uri': 'http://localhost:6333', - 'api_key': 'test-api-key', - 'taskgroup': AsyncMock(), - 'id': 'test-doc-qdrant-processor' - } - - processor = Processor(**config) - - # Add collection to known_collections (simulates config push) - processor.known_collections[('decode_user', 'decode_collection')] = {} - - # Create mock message with decode error - mock_message = MagicMock() - mock_message.metadata.user = 'decode_user' - mock_message.metadata.collection = 'decode_collection' - - mock_chunk = MagicMock() - mock_chunk.chunk.decode.side_effect = UnicodeDecodeError('utf-8', b'', 0, 1, 'invalid start byte') - mock_chunk.vectors = [[0.1, 0.2]] - - mock_message.chunks = [mock_chunk] - - # Act & Assert - with pytest.raises(UnicodeDecodeError): - await processor.store_document_embeddings(mock_message) + assert point.payload['chunk_id'] == 'https://trustgraph.ai/doc/my-document/p1/c3' if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__])