Fixing tests

This commit is contained in:
Cyber MacGeddon 2026-04-21 11:37:28 +01:00
parent affdb06a1e
commit 4781c2a08b
9 changed files with 136 additions and 136 deletions

View file

@ -83,7 +83,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
@ -115,7 +115,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
@ -148,7 +148,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with the specified limit
processor.vecstore.search.assert_called_once_with(
@ -168,7 +168,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
limit=5
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
@ -189,7 +189,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called
processor.vecstore.search.assert_called_once_with(
@ -217,7 +217,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify Unicode content is preserved in ChunkMatch objects
assert len(result) == 3
@ -244,7 +244,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify large content is preserved in ChunkMatch objects
assert len(result) == 2
@ -270,7 +270,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify special characters are preserved in ChunkMatch objects
assert len(result) == 3
@ -289,7 +289,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
limit=0
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
@ -307,7 +307,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
limit=-1
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called (optimization for negative limit)
processor.vecstore.search.assert_not_called()
@ -330,7 +330,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_document_embeddings(query)
await processor.query_document_embeddings('test_user', query)
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
@ -349,7 +349,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with the vector
processor.vecstore.search.assert_called_once()
@ -378,7 +378,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify results are ChunkMatch objects
assert len(result) == 3

View file

@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results1, mock_results2]
chunks = await processor.query_document_embeddings(mock_query_message)
chunks = await processor.query_document_embeddings('default', mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify limit is passed to query
mock_index.query.assert_called_once()
@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify different indexes used for different dimensions
assert processor.pinecone.Index.call_count == 2
@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_results.matches = []
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify empty results
assert chunks == []
@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify Unicode content is properly handled
assert len(chunks) == 2
@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify large content is properly handled
assert len(chunks) == 1
@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify all content types are properly handled
assert len(chunks) == 5
@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_document_embeddings(message)
await processor.query_document_embeddings('default', message)
@pytest.mark.asyncio
async def test_query_document_embeddings_index_access_failure(self, processor):
@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
processor.pinecone.Index.side_effect = Exception("Index access failed")
with pytest.raises(Exception, match="Index access failed"):
await processor.query_document_embeddings(message)
await processor.query_document_embeddings('default', message)
@pytest.mark.asyncio
async def test_query_document_embeddings_vector_accumulation(self, processor):
@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('default', message)
# Verify all queries were made
assert mock_index.query.call_count == 3

View file

@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'test_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('default', mock_message)
# Assert
# Verify query was called with correct parameters (with dimension suffix)
@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('default', mock_message)
# Assert
# Verify query was called once
@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('default', mock_message)
# Assert
# Verify query was called with exact limit (no multiplication)
@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('default', mock_message)
# Assert
assert result == []
@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('default', mock_message)
# Assert
# Verify query was called once with correct collection
@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'utf8_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('default', mock_message)
# Assert
assert len(result) == 2
@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_document_embeddings(mock_message)
await processor.query_document_embeddings('default', mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('default', mock_message)
# Assert
# Should still query (with limit 0)
@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'large_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('default', mock_message)
# Assert
# Should query with full limit
@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
# This should raise a KeyError when trying to access payload['chunk_id']
with pytest.raises(KeyError):
await processor.query_document_embeddings(mock_message)
await processor.query_document_embeddings('default', mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')

View file

@ -131,7 +131,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
@ -168,7 +168,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
@ -201,7 +201,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with 2*limit for better deduplication
processor.vecstore.search.assert_called_once_with(
@ -229,7 +229,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify results are in the same order as returned by the store
assert len(result) == 3
@ -255,7 +255,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with the full vector
processor.vecstore.search.assert_called_once_with(
@ -275,7 +275,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
limit=5
)
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
@ -296,7 +296,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called
processor.vecstore.search.assert_called_once_with(
@ -325,7 +325,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify all results are properly typed
assert len(result) == 4
@ -359,7 +359,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_graph_embeddings(query)
await processor.query_graph_embeddings('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
@ -436,7 +436,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
limit=0
)
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
@ -461,7 +461,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once()

View file

@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('default', message)
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(mock_query_message)
entities = await processor.query_graph_embeddings('default', mock_query_message)
# Verify query was made once
assert mock_index.query.call_count == 1
@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('default', message)
# Verify limit is respected
assert len(entities) == 2
@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('default', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('default', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('default', message)
# Verify correct index used for 2D vector
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('default', message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_results.matches = []
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('default', message)
# Verify empty results
assert entities == []
@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('default', message)
# Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3
@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('default', message)
# Should only return 2 entities (respecting limit)
mock_index.query.assert_called_once()
@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_graph_embeddings(message)
await processor.query_graph_embeddings('default', message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'test_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('default', mock_message)
# Assert
# Verify query was called with correct parameters (with dimension suffix)
@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('default', mock_message)
# Assert
# Verify query was called once
@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('default', mock_message)
# Assert
# Verify query was called with limit * 2
@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('default', mock_message)
# Assert
assert result == []
@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('default', mock_message)
# Assert
# Verify query was called once
@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'uri_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('default', mock_message)
# Assert
assert len(result) == 3
@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_graph_embeddings(mock_message)
await processor.query_graph_embeddings('default', mock_message)
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('default', mock_message)
# Assert
# Should still query (with limit 0)

View file

@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestMemgraphQueryUserCollectionIsolation:
class TestMemgraphQueryWorkspaceCollectionIsolation:
"""Test cases for Memgraph query service with user/collection isolation"""
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -32,7 +32,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SPO query for literal includes user/collection
expected_query = (
@ -55,7 +55,7 @@ class TestMemgraphQueryUserCollectionIsolation:
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -73,7 +73,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
@ -95,7 +95,7 @@ class TestMemgraphQueryUserCollectionIsolation:
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
async def test_so_query_with_workspace_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -113,7 +113,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SO query for nodes includes user/collection
expected_query = (
@ -135,7 +135,7 @@ class TestMemgraphQueryUserCollectionIsolation:
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -153,7 +153,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify S query includes user/collection
expected_query = (
@ -174,7 +174,7 @@ class TestMemgraphQueryUserCollectionIsolation:
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
async def test_po_query_with_workspace_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -192,7 +192,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify PO query for literals includes user/collection
expected_query = (
@ -214,7 +214,7 @@ class TestMemgraphQueryUserCollectionIsolation:
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -232,7 +232,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify P query includes user/collection
expected_query = (
@ -253,7 +253,7 @@ class TestMemgraphQueryUserCollectionIsolation:
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -271,7 +271,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify O query for literals includes user/collection
expected_query = (
@ -292,7 +292,7 @@ class TestMemgraphQueryUserCollectionIsolation:
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -310,7 +310,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
@ -363,7 +363,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples('default', query)
# Verify defaults were used
calls = mock_driver.execute_query.call_args_list
@ -410,7 +410,7 @@ class TestMemgraphQueryUserCollectionIsolation:
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
result = await processor.query_triples("test_user", query)
# Verify results are proper Triple objects
assert len(result) == 2

View file

@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestNeo4jQueryUserCollectionIsolation:
class TestNeo4jQueryWorkspaceCollectionIsolation:
"""Test cases for Neo4j query service with user/collection isolation"""
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -36,9 +36,9 @@ class TestNeo4jQueryUserCollectionIsolation:
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT 10"
)
@ -48,14 +48,14 @@ class TestNeo4jQueryUserCollectionIsolation:
src="http://example.com/s",
rel="http://example.com/p",
value="test_object",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -77,9 +77,9 @@ class TestNeo4jQueryUserCollectionIsolation:
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT 10"
)
@ -88,16 +88,16 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_literal_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify SP query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT 10"
)
@ -106,14 +106,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_node_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
async def test_so_query_with_workspace_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -135,9 +135,9 @@ class TestNeo4jQueryUserCollectionIsolation:
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT 10"
)
@ -146,14 +146,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_query,
src="http://example.com/s",
uri="http://example.com/o",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -175,9 +175,9 @@ class TestNeo4jQueryUserCollectionIsolation:
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
@ -185,14 +185,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
async def test_po_query_with_workspace_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -214,9 +214,9 @@ class TestNeo4jQueryUserCollectionIsolation:
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT 10"
)
@ -225,14 +225,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_query,
uri="http://example.com/p",
value="literal",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -254,9 +254,9 @@ class TestNeo4jQueryUserCollectionIsolation:
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT 10"
)
@ -264,14 +264,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -293,9 +293,9 @@ class TestNeo4jQueryUserCollectionIsolation:
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 10"
)
@ -303,14 +303,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
value="test_value",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -332,32 +332,32 @@ class TestNeo4jQueryUserCollectionIsolation:
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)

View file

@ -374,7 +374,7 @@ class TestUnifiedTableQueries:
query = call_args[0][1]
params = call_args[0][2]
assert "SELECT data, source FROM test_user.rows" in query
assert "SELECT data, source FROM test_workspace.rows" in query
assert "collection = %s" in query
assert "schema_name = %s" in query
assert "index_name = %s" in query