From 4781c2a08b74549bcbc0d92f49482a23b31cf8b9 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Tue, 21 Apr 2026 11:37:28 +0100 Subject: [PATCH] Fixing tests --- .../test_doc_embeddings_milvus_query.py | 26 ++--- .../test_doc_embeddings_pinecone_query.py | 28 +++--- .../test_doc_embeddings_qdrant_query.py | 20 ++-- .../test_graph_embeddings_milvus_query.py | 22 ++--- .../test_graph_embeddings_pinecone_query.py | 22 ++--- .../test_graph_embeddings_qdrant_query.py | 16 +-- ...st_memgraph_workspace_collection_query.py} | 38 +++---- ... test_neo4j_workspace_collection_query.py} | 98 +++++++++---------- .../test_query/test_rows_cassandra_query.py | 2 +- 9 files changed, 136 insertions(+), 136 deletions(-) rename tests/unit/test_query/{test_memgraph_user_collection_query.py => test_memgraph_workspace_collection_query.py} (92%) rename tests/unit/test_query/{test_neo4j_user_collection_query.py => test_neo4j_workspace_collection_query.py} (78%) diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py index 1cddce97..4e129748 100644 --- a/tests/unit/test_query/test_doc_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -83,7 +83,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called with correct parameters including user/collection processor.vecstore.search.assert_called_once_with( @@ -115,7 +115,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called once with the full vector processor.vecstore.search.assert_called_once_with( @@ -148,7 +148,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called with the specified limit processor.vecstore.search.assert_called_once_with( @@ -168,7 +168,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: limit=5 ) - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify no search was called processor.vecstore.search.assert_not_called() @@ -189,7 +189,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock empty search results processor.vecstore.search.return_value = [] - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called processor.vecstore.search.assert_called_once_with( @@ -217,7 +217,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify Unicode content is preserved in ChunkMatch objects assert len(result) == 3 @@ -244,7 +244,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify large content is preserved in ChunkMatch objects assert len(result) == 2 @@ -270,7 +270,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify special characters are preserved in ChunkMatch objects assert len(result) == 3 @@ -289,7 +289,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: limit=0 ) - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify no search was called (optimization for zero limit) processor.vecstore.search.assert_not_called() @@ -307,7 +307,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: limit=-1 ) - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify no search was called (optimization for negative limit) processor.vecstore.search.assert_not_called() @@ -330,7 +330,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Milvus connection failed"): - await processor.query_document_embeddings(query) + await processor.query_document_embeddings('test_user', query) @pytest.mark.asyncio async def test_query_document_embeddings_different_vector_dimensions(self, processor): @@ -349,7 +349,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called with the vector processor.vecstore.search.assert_called_once() @@ -378,7 +378,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify results are ChunkMatch objects assert len(result) == 3 diff --git a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py index 397bdf1b..fd090b0a 100644 --- a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py @@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify index was accessed correctly (with dimension suffix) expected_index_name = "d-test_user-test_collection-3" # 3 dimensions @@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = [mock_results1, mock_results2] - chunks = await processor.query_document_embeddings(mock_query_message) + chunks = await processor.query_document_embeddings('default', mock_query_message) # Verify both queries were made assert mock_index.query.call_count == 2 @@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify limit is passed to query mock_index.query.assert_called_once() @@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = [mock_results_2d, mock_results_4d] - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify different indexes used for different dimensions assert processor.pinecone.Index.call_count == 2 @@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify no queries were made and empty result returned processor.pinecone.Index.assert_not_called() @@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_results.matches = [] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify empty results assert chunks == [] @@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify Unicode content is properly handled assert len(chunks) == 2 @@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify large content is properly handled assert len(chunks) == 1 @@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify all content types are properly handled assert len(chunks) == 5 @@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = Exception("Query failed") with pytest.raises(Exception, match="Query failed"): - await processor.query_document_embeddings(message) + await processor.query_document_embeddings('default', message) @pytest.mark.asyncio async def test_query_document_embeddings_index_access_failure(self, processor): @@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: processor.pinecone.Index.side_effect = Exception("Index access failed") with pytest.raises(Exception, match="Index access failed"): - await processor.query_document_embeddings(message) + await processor.query_document_embeddings('default', message) @pytest.mark.asyncio async def test_query_document_embeddings_vector_accumulation(self, processor): @@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3] - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('default', message) # Verify all queries were made assert mock_index.query.call_count == 3 diff --git a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py index 1d2f0e6d..614efe74 100644 --- a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py @@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'test_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('default', mock_message) # Assert # Verify query was called with correct parameters (with dimension suffix) @@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'multi_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('default', mock_message) # Assert # Verify query was called once @@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'limit_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('default', mock_message) # Assert # Verify query was called with exact limit (no multiplication) @@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'empty_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('default', mock_message) # Assert assert result == [] @@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'dim_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('default', mock_message) # Assert # Verify query was called once with correct collection @@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'utf8_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('default', mock_message) # Assert assert len(result) == 2 @@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Act & Assert with pytest.raises(Exception, match="Qdrant connection failed"): - await processor.query_document_embeddings(mock_message) + await processor.query_document_embeddings('default', mock_message) @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'zero_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('default', mock_message) # Assert # Should still query (with limit 0) @@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'large_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('default', mock_message) # Assert # Should query with full limit @@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Act & Assert # This should raise a KeyError when trying to access payload['chunk_id'] with pytest.raises(KeyError): - await processor.query_document_embeddings(mock_message) + await processor.query_document_embeddings('default', mock_message) @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index f2b8be7e..2bf7d7e9 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -131,7 +131,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called with correct parameters including user/collection processor.vecstore.search.assert_called_once_with( @@ -168,7 +168,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called once with the full vector processor.vecstore.search.assert_called_once_with( @@ -201,7 +201,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called with 2*limit for better deduplication processor.vecstore.search.assert_called_once_with( @@ -229,7 +229,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify results are in the same order as returned by the store assert len(result) == 3 @@ -255,7 +255,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called with the full vector processor.vecstore.search.assert_called_once_with( @@ -275,7 +275,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: limit=5 ) - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify no search was called processor.vecstore.search.assert_not_called() @@ -296,7 +296,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Mock empty search results processor.vecstore.search.return_value = [] - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called processor.vecstore.search.assert_called_once_with( @@ -325,7 +325,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify all results are properly typed assert len(result) == 4 @@ -359,7 +359,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Milvus connection failed"): - await processor.query_graph_embeddings(query) + await processor.query_graph_embeddings('test_user', query) def test_add_args_method(self): """Test that add_args properly configures argument parser""" @@ -436,7 +436,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: limit=0 ) - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify no search was called (optimization for zero limit) processor.vecstore.search.assert_not_called() @@ -461,7 +461,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called once with the full vector processor.vecstore.search.assert_called_once() diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py index 2c1a673a..eb64ec8b 100644 --- a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('default', message) # Verify index was accessed correctly (with dimension suffix) expected_index_name = "t-test_user-test_collection-3" # 3 dimensions @@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(mock_query_message) + entities = await processor.query_graph_embeddings('default', mock_query_message) # Verify query was made once assert mock_index.query.call_count == 1 @@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('default', message) # Verify limit is respected assert len(entities) == 2 @@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('default', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('default', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('default', message) # Verify correct index used for 2D vector processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2") @@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('default', message) # Verify no queries were made and empty result returned processor.pinecone.Index.assert_not_called() @@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_results.matches = [] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('default', message) # Verify empty results assert entities == [] @@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('default', message) # Should get exactly 3 unique entities (respecting limit) assert len(entities) == 3 @@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('default', message) # Should only return 2 entities (respecting limit) mock_index.query.assert_called_once() @@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.side_effect = Exception("Query failed") with pytest.raises(Exception, match="Query failed"): - await processor.query_graph_embeddings(message) + await processor.query_graph_embeddings('default', message) def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py index 9362a8dd..13924bef 100644 --- a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py @@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'test_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('default', mock_message) # Assert # Verify query was called with correct parameters (with dimension suffix) @@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'multi_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('default', mock_message) # Assert # Verify query was called once @@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'limit_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('default', mock_message) # Assert # Verify query was called with limit * 2 @@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'empty_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('default', mock_message) # Assert assert result == [] @@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'dim_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('default', mock_message) # Assert # Verify query was called once @@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'uri_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('default', mock_message) # Assert assert len(result) == 3 @@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Act & Assert with pytest.raises(Exception, match="Qdrant connection failed"): - await processor.query_graph_embeddings(mock_message) + await processor.query_graph_embeddings('default', mock_message) @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'zero_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('default', mock_message) # Assert # Should still query (with limit 0) diff --git a/tests/unit/test_query/test_memgraph_user_collection_query.py b/tests/unit/test_query/test_memgraph_workspace_collection_query.py similarity index 92% rename from tests/unit/test_query/test_memgraph_user_collection_query.py rename to tests/unit/test_query/test_memgraph_workspace_collection_query.py index 038fb438..ea902d05 100644 --- a/tests/unit/test_query/test_memgraph_user_collection_query.py +++ b/tests/unit/test_query/test_memgraph_workspace_collection_query.py @@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL -class TestMemgraphQueryUserCollectionIsolation: +class TestMemgraphQueryWorkspaceCollectionIsolation: """Test cases for Memgraph query service with user/collection isolation""" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_spo_query_with_user_collection(self, mock_graph_db): + async def test_spo_query_with_workspace_collection(self, mock_graph_db): """Test SPO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -32,7 +32,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SPO query for literal includes user/collection expected_query = ( @@ -55,7 +55,7 @@ class TestMemgraphQueryUserCollectionIsolation: @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_sp_query_with_user_collection(self, mock_graph_db): + async def test_sp_query_with_workspace_collection(self, mock_graph_db): """Test SP query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -73,7 +73,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SP query for literals includes user/collection expected_literal_query = ( @@ -95,7 +95,7 @@ class TestMemgraphQueryUserCollectionIsolation: @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_so_query_with_user_collection(self, mock_graph_db): + async def test_so_query_with_workspace_collection(self, mock_graph_db): """Test SO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -113,7 +113,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SO query for nodes includes user/collection expected_query = ( @@ -135,7 +135,7 @@ class TestMemgraphQueryUserCollectionIsolation: @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_s_only_query_with_user_collection(self, mock_graph_db): + async def test_s_only_query_with_workspace_collection(self, mock_graph_db): """Test S-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -153,7 +153,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify S query includes user/collection expected_query = ( @@ -174,7 +174,7 @@ class TestMemgraphQueryUserCollectionIsolation: @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_po_query_with_user_collection(self, mock_graph_db): + async def test_po_query_with_workspace_collection(self, mock_graph_db): """Test PO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -192,7 +192,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify PO query for literals includes user/collection expected_query = ( @@ -214,7 +214,7 @@ class TestMemgraphQueryUserCollectionIsolation: @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_p_only_query_with_user_collection(self, mock_graph_db): + async def test_p_only_query_with_workspace_collection(self, mock_graph_db): """Test P-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -232,7 +232,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify P query includes user/collection expected_query = ( @@ -253,7 +253,7 @@ class TestMemgraphQueryUserCollectionIsolation: @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_o_only_query_with_user_collection(self, mock_graph_db): + async def test_o_only_query_with_workspace_collection(self, mock_graph_db): """Test O-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -271,7 +271,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify O query for literals includes user/collection expected_query = ( @@ -292,7 +292,7 @@ class TestMemgraphQueryUserCollectionIsolation: @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_wildcard_query_with_user_collection(self, mock_graph_db): + async def test_wildcard_query_with_workspace_collection(self, mock_graph_db): """Test wildcard query (all None) includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -310,7 +310,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify wildcard query for literals includes user/collection expected_literal_query = ( @@ -363,7 +363,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples('default', query) # Verify defaults were used calls = mock_driver.execute_query.call_args_list @@ -410,7 +410,7 @@ class TestMemgraphQueryUserCollectionIsolation: ([mock_record2], MagicMock(), MagicMock()) # Node query ] - result = await processor.query_triples(query) + result = await processor.query_triples("test_user", query) # Verify results are proper Triple objects assert len(result) == 2 diff --git a/tests/unit/test_query/test_neo4j_user_collection_query.py b/tests/unit/test_query/test_neo4j_workspace_collection_query.py similarity index 78% rename from tests/unit/test_query/test_neo4j_user_collection_query.py rename to tests/unit/test_query/test_neo4j_workspace_collection_query.py index 12beb714..8c8c1e22 100644 --- a/tests/unit/test_query/test_neo4j_user_collection_query.py +++ b/tests/unit/test_query/test_neo4j_workspace_collection_query.py @@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL -class TestNeo4jQueryUserCollectionIsolation: +class TestNeo4jQueryWorkspaceCollectionIsolation: """Test cases for Neo4j query service with user/collection isolation""" @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_spo_query_with_user_collection(self, mock_graph_db): + async def test_spo_query_with_workspace_collection(self, mock_graph_db): """Test SPO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -36,9 +36,9 @@ class TestNeo4jQueryUserCollectionIsolation: # Verify SPO query for literal includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT 10" ) @@ -48,14 +48,14 @@ class TestNeo4jQueryUserCollectionIsolation: src="http://example.com/s", rel="http://example.com/p", value="test_object", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_sp_query_with_user_collection(self, mock_graph_db): + async def test_sp_query_with_workspace_collection(self, mock_graph_db): """Test SP query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -77,9 +77,9 @@ class TestNeo4jQueryUserCollectionIsolation: # Verify SP query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest " "LIMIT 10" ) @@ -88,16 +88,16 @@ class TestNeo4jQueryUserCollectionIsolation: expected_literal_query, src="http://example.com/s", rel="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) # Verify SP query for nodes includes user/collection expected_node_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN dest.uri as dest " "LIMIT 10" ) @@ -106,14 +106,14 @@ class TestNeo4jQueryUserCollectionIsolation: expected_node_query, src="http://example.com/s", rel="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_so_query_with_user_collection(self, mock_graph_db): + async def test_so_query_with_workspace_collection(self, mock_graph_db): """Test SO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -135,9 +135,9 @@ class TestNeo4jQueryUserCollectionIsolation: # Verify SO query for nodes includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT 10" ) @@ -146,14 +146,14 @@ class TestNeo4jQueryUserCollectionIsolation: expected_query, src="http://example.com/s", uri="http://example.com/o", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_s_only_query_with_user_collection(self, mock_graph_db): + async def test_s_only_query_with_workspace_collection(self, mock_graph_db): """Test S-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -175,9 +175,9 @@ class TestNeo4jQueryUserCollectionIsolation: # Verify S query includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT 10" ) @@ -185,14 +185,14 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, src="http://example.com/s", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_po_query_with_user_collection(self, mock_graph_db): + async def test_po_query_with_workspace_collection(self, mock_graph_db): """Test PO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -214,9 +214,9 @@ class TestNeo4jQueryUserCollectionIsolation: # Verify PO query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT 10" ) @@ -225,14 +225,14 @@ class TestNeo4jQueryUserCollectionIsolation: expected_query, uri="http://example.com/p", value="literal", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_p_only_query_with_user_collection(self, mock_graph_db): + async def test_p_only_query_with_workspace_collection(self, mock_graph_db): """Test P-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -254,9 +254,9 @@ class TestNeo4jQueryUserCollectionIsolation: # Verify P query includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT 10" ) @@ -264,14 +264,14 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, uri="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_o_only_query_with_user_collection(self, mock_graph_db): + async def test_o_only_query_with_workspace_collection(self, mock_graph_db): """Test O-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -293,9 +293,9 @@ class TestNeo4jQueryUserCollectionIsolation: # Verify O query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT 10" ) @@ -303,14 +303,14 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, value="test_value", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_wildcard_query_with_user_collection(self, mock_graph_db): + async def test_wildcard_query_with_workspace_collection(self, mock_graph_db): """Test wildcard query (all None) includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -332,32 +332,32 @@ class TestNeo4jQueryUserCollectionIsolation: # Verify wildcard query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT 10" ) mock_driver.execute_query.assert_any_call( expected_literal_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) # Verify wildcard query for nodes includes user/collection expected_node_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT 10" ) mock_driver.execute_query.assert_any_call( expected_node_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py index 6933c29c..bb6bbe84 100644 --- a/tests/unit/test_query/test_rows_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -374,7 +374,7 @@ class TestUnifiedTableQueries: query = call_args[0][1] params = call_args[0][2] - assert "SELECT data, source FROM test_user.rows" in query + assert "SELECT data, source FROM test_workspace.rows" in query assert "collection = %s" in query assert "schema_name = %s" in query assert "index_name = %s" in query