diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 51e41524..92a62222 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -175,9 +175,14 @@ class TestQuery: test_vectors = [[0.1, 0.2, 0.3]] mock_embeddings_client.embed.return_value = [test_vectors] - # Mock document embeddings returns chunk_ids - test_chunk_ids = ["doc/c1", "doc/c2"] - mock_doc_embeddings_client.query.return_value = test_chunk_ids + # Mock document embeddings returns ChunkMatch objects + mock_match1 = MagicMock() + mock_match1.chunk_id = "doc/c1" + mock_match1.score = 0.95 + mock_match2 = MagicMock() + mock_match2.chunk_id = "doc/c2" + mock_match2.score = 0.85 + mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] # Initialize Query query = Query( @@ -195,9 +200,9 @@ class TestQuery: # Verify embeddings client was called (now expects list) mock_embeddings_client.embed.assert_called_once_with([test_query]) - # Verify doc embeddings client was called correctly (with extracted vectors) + # Verify doc embeddings client was called correctly (with extracted vector) mock_doc_embeddings_client.query.assert_called_once_with( - test_vectors, + vector=test_vectors, limit=15, user="test_user", collection="test_collection" @@ -218,11 +223,16 @@ class TestQuery: # Mock embeddings and document embeddings responses # New batch format: [[[vectors]]] - get_vector extracts [0] test_vectors = [[0.1, 0.2, 0.3]] - test_chunk_ids = ["doc/c3", "doc/c4"] + mock_match1 = MagicMock() + mock_match1.chunk_id = "doc/c3" + mock_match1.score = 0.9 + mock_match2 = MagicMock() + mock_match2.chunk_id = "doc/c4" + mock_match2.score = 0.8 expected_response = "This is the document RAG response" mock_embeddings_client.embed.return_value = [test_vectors] - mock_doc_embeddings_client.query.return_value = test_chunk_ids + mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] mock_prompt_client.document_prompt.return_value = expected_response # Initialize DocumentRag @@ -245,9 +255,9 @@ class TestQuery: # Verify embeddings client was called (now expects list) mock_embeddings_client.embed.assert_called_once_with(["test query"]) - # Verify doc embeddings client was called (with extracted vectors) + # Verify doc embeddings client was called (with extracted vector) mock_doc_embeddings_client.query.assert_called_once_with( - test_vectors, + vector=test_vectors, limit=10, user="test_user", collection="test_collection" @@ -275,7 +285,10 @@ class TestQuery: # Mock responses (batch format) mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] - mock_doc_embeddings_client.query.return_value = ["doc/c5"] + mock_match = MagicMock() + mock_match.chunk_id = "doc/c5" + mock_match.score = 0.9 + mock_doc_embeddings_client.query.return_value = [mock_match] mock_prompt_client.document_prompt.return_value = "Default response" # Initialize DocumentRag @@ -289,9 +302,9 @@ class TestQuery: # Call DocumentRag.query with minimal parameters result = await document_rag.query("simple query") - # Verify default parameters were used (vectors extracted from batch) + # Verify default parameters were used (vector extracted from batch) mock_doc_embeddings_client.query.assert_called_once_with( - [[0.1, 0.2]], + vector=[[0.1, 0.2]], limit=20, # Default doc_limit user="trustgraph", # Default user collection="default" # Default collection @@ -316,7 +329,10 @@ class TestQuery: # Mock responses (batch format) mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]] - mock_doc_embeddings_client.query.return_value = ["doc/c6"] + mock_match = MagicMock() + mock_match.chunk_id = "doc/c6" + mock_match.score = 0.88 + mock_doc_embeddings_client.query.return_value = [mock_match] # Initialize Query with verbose=True query = Query( @@ -347,7 +363,10 @@ class TestQuery: # Mock responses (batch format) mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] - mock_doc_embeddings_client.query.return_value = ["doc/c7"] + mock_match = MagicMock() + mock_match.chunk_id = "doc/c7" + mock_match.score = 0.92 + mock_doc_embeddings_client.query.return_value = [mock_match] mock_prompt_client.document_prompt.return_value = "Verbose RAG response" # Initialize DocumentRag with verbose=True @@ -487,7 +506,13 @@ class TestQuery: final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." mock_embeddings_client.embed.return_value = [query_vectors] - mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids + mock_matches = [] + for chunk_id in retrieved_chunk_ids: + mock_match = MagicMock() + mock_match.chunk_id = chunk_id + mock_match.score = 0.9 + mock_matches.append(mock_match) + mock_doc_embeddings_client.query.return_value = mock_matches mock_prompt_client.document_prompt.return_value = final_response # Initialize DocumentRag @@ -511,7 +536,7 @@ class TestQuery: mock_embeddings_client.embed.assert_called_once_with([query_text]) mock_doc_embeddings_client.query.assert_called_once_with( - query_vectors, + vector=query_vectors, limit=25, user="research_user", collection="ml_knowledge"