diff --git a/tests/unit/test_query/test_memgraph_user_collection_query.py b/tests/unit/test_query/test_memgraph_user_collection_query.py new file mode 100644 index 00000000..772d4f84 --- /dev/null +++ b/tests/unit/test_query/test_memgraph_user_collection_query.py @@ -0,0 +1,432 @@ +""" +Tests for Memgraph user/collection isolation in query service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.query.triples.memgraph.service import Processor +from trustgraph.schema import TriplesQueryRequest, Value + + +class TestMemgraphQueryUserCollectionIsolation: + """Test cases for Memgraph query service with user/collection isolation""" + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_spo_query_with_user_collection(self, mock_graph_db): + """Test SPO query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=Value(value="http://example.com/p", is_uri=True), + o=Value(value="test_object", is_uri=False), + limit=1000 + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify SPO query for literal includes user/collection + expected_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "RETURN $src as src " + "LIMIT 1000" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + src="http://example.com/s", + rel="http://example.com/p", + value="test_object", + user="test_user", + collection="test_collection", + database_='memgraph' + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_sp_query_with_user_collection(self, mock_graph_db): + """Test SP query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=Value(value="http://example.com/p", is_uri=True), + o=None, + limit=1000 + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify SP query for literals includes user/collection + expected_literal_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN dest.value as dest " + "LIMIT 1000" + ) + + mock_driver.execute_query.assert_any_call( + expected_literal_query, + src="http://example.com/s", + rel="http://example.com/p", + user="test_user", + collection="test_collection", + database_='memgraph' + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_so_query_with_user_collection(self, mock_graph_db): + """Test SO query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=None, + o=Value(value="http://example.com/o", is_uri=True), + limit=1000 + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify SO query for nodes includes user/collection + expected_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "RETURN rel.uri as rel " + "LIMIT 1000" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + src="http://example.com/s", + uri="http://example.com/o", + user="test_user", + collection="test_collection", + database_='memgraph' + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_s_only_query_with_user_collection(self, mock_graph_db): + """Test S-only query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=None, + o=None, + limit=1000 + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify S query includes user/collection + expected_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN rel.uri as rel, dest.value as dest " + "LIMIT 1000" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + src="http://example.com/s", + user="test_user", + collection="test_collection", + database_='memgraph' + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_po_query_with_user_collection(self, mock_graph_db): + """Test PO query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=None, + p=Value(value="http://example.com/p", is_uri=True), + o=Value(value="literal", is_uri=False), + limit=1000 + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify PO query for literals includes user/collection + expected_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "RETURN src.uri as src " + "LIMIT 1000" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + uri="http://example.com/p", + value="literal", + user="test_user", + collection="test_collection", + database_='memgraph' + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_p_only_query_with_user_collection(self, mock_graph_db): + """Test P-only query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=None, + p=Value(value="http://example.com/p", is_uri=True), + o=None, + limit=1000 + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify P query includes user/collection + expected_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN src.uri as src, dest.value as dest " + "LIMIT 1000" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + uri="http://example.com/p", + user="test_user", + collection="test_collection", + database_='memgraph' + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_o_only_query_with_user_collection(self, mock_graph_db): + """Test O-only query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=None, + p=None, + o=Value(value="test_value", is_uri=False), + limit=1000 + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify O query for literals includes user/collection + expected_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "RETURN src.uri as src, rel.uri as rel " + "LIMIT 1000" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + value="test_value", + user="test_user", + collection="test_collection", + database_='memgraph' + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_wildcard_query_with_user_collection(self, mock_graph_db): + """Test wildcard query (all None) includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=None, + p=None, + o=None, + limit=1000 + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify wildcard query for literals includes user/collection + expected_literal_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN src.uri as src, rel.uri as rel, dest.value as dest " + "LIMIT 1000" + ) + + mock_driver.execute_query.assert_any_call( + expected_literal_query, + user="test_user", + collection="test_collection", + database_='memgraph' + ) + + # Verify wildcard query for nodes includes user/collection + expected_node_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " + "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " + "LIMIT 1000" + ) + + mock_driver.execute_query.assert_any_call( + expected_node_query, + user="test_user", + collection="test_collection", + database_='memgraph' + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_with_defaults_when_not_provided(self, mock_graph_db): + """Test that defaults are used when user/collection not provided""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + # Query without user/collection fields + query = TriplesQueryRequest( + s=Value(value="http://example.com/s", is_uri=True), + p=None, + o=None, + limit=1000 + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify defaults were used + calls = mock_driver.execute_query.call_args_list + for call in calls: + if 'user' in call.kwargs: + assert call.kwargs['user'] == 'default' + if 'collection' in call.kwargs: + assert call.kwargs['collection'] == 'default' + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_results_properly_converted_to_triples(self, mock_graph_db): + """Test that query results are properly converted to Triple objects""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=None, + o=None, + limit=1000 + ) + + # Mock some results + mock_record1 = MagicMock() + mock_record1.data.return_value = { + "rel": "http://example.com/p1", + "dest": "literal_value" + } + + mock_record2 = MagicMock() + mock_record2.data.return_value = { + "rel": "http://example.com/p2", + "dest": "http://example.com/o" + } + + # Return results for literal query, empty for node query + mock_driver.execute_query.side_effect = [ + ([mock_record1], MagicMock(), MagicMock()), # Literal query + ([mock_record2], MagicMock(), MagicMock()) # Node query + ] + + result = await processor.query_triples(query) + + # Verify results are proper Triple objects + assert len(result) == 2 + + # First triple (literal object) + assert result[0].s.value == "http://example.com/s" + assert result[0].s.is_uri == True + assert result[0].p.value == "http://example.com/p1" + assert result[0].p.is_uri == True + assert result[0].o.value == "literal_value" + assert result[0].o.is_uri == False + + # Second triple (URI object) + assert result[1].s.value == "http://example.com/s" + assert result[1].s.is_uri == True + assert result[1].p.value == "http://example.com/p2" + assert result[1].p.is_uri == True + assert result[1].o.value == "http://example.com/o" + assert result[1].o.is_uri == True \ No newline at end of file diff --git a/tests/unit/test_storage/test_memgraph_user_collection_isolation.py b/tests/unit/test_storage/test_memgraph_user_collection_isolation.py new file mode 100644 index 00000000..fdc7fb4e --- /dev/null +++ b/tests/unit/test_storage/test_memgraph_user_collection_isolation.py @@ -0,0 +1,363 @@ +""" +Tests for Memgraph user/collection isolation in storage service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.storage.triples.memgraph.write import Processor + + +class TestMemgraphUserCollectionIsolation: + """Test cases for Memgraph storage service with user/collection isolation""" + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + def test_storage_creates_indexes_with_user_collection(self, mock_graph_db): + """Test that storage creates both legacy and user/collection indexes""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = Processor(taskgroup=MagicMock()) + + # Verify all indexes were attempted (4 legacy + 4 user/collection = 8 total) + assert mock_session.run.call_count == 8 + + # Check some specific index creation calls + expected_calls = [ + "CREATE INDEX ON :Node", + "CREATE INDEX ON :Node(uri)", + "CREATE INDEX ON :Literal", + "CREATE INDEX ON :Literal(value)", + "CREATE INDEX ON :Node(user)", + "CREATE INDEX ON :Node(collection)", + "CREATE INDEX ON :Literal(user)", + "CREATE INDEX ON :Literal(collection)" + ] + + for expected_call in expected_calls: + mock_session.run.assert_any_call(expected_call) + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + @pytest.mark.asyncio + async def test_store_triples_with_user_collection(self, mock_graph_db): + """Test that store_triples includes user/collection in all operations""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=MagicMock()) + + # Create mock triple with URI object + triple = MagicMock() + triple.s.value = "http://example.com/subject" + triple.p.value = "http://example.com/predicate" + triple.o.value = "http://example.com/object" + triple.o.is_uri = True + + # Create mock message with metadata + mock_message = MagicMock() + mock_message.triples = [triple] + mock_message.metadata.user = "test_user" + mock_message.metadata.collection = "test_collection" + + await processor.store_triples(mock_message) + + # Verify user/collection parameters were passed to all operations + # Should have: create_node (subject), create_node (object), relate_node = 3 calls + assert mock_driver.execute_query.call_count == 3 + + # Check that user and collection were included in all calls + for call in mock_driver.execute_query.call_args_list: + call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] + assert 'user' in call_kwargs + assert 'collection' in call_kwargs + assert call_kwargs['user'] == "test_user" + assert call_kwargs['collection'] == "test_collection" + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + @pytest.mark.asyncio + async def test_store_triples_with_default_user_collection(self, mock_graph_db): + """Test that defaults are used when user/collection not provided in metadata""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=MagicMock()) + + # Create mock triple + triple = MagicMock() + triple.s.value = "http://example.com/subject" + triple.p.value = "http://example.com/predicate" + triple.o.value = "literal_value" + triple.o.is_uri = False + + # Create mock message without user/collection metadata + mock_message = MagicMock() + mock_message.triples = [triple] + mock_message.metadata.user = None + mock_message.metadata.collection = None + + await processor.store_triples(mock_message) + + # Verify defaults were used + for call in mock_driver.execute_query.call_args_list: + call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] + assert call_kwargs['user'] == "default" + assert call_kwargs['collection'] == "default" + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + def test_create_node_includes_user_collection(self, mock_graph_db): + """Test that create_node includes user/collection properties""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=MagicMock()) + + processor.create_node("http://example.com/node", "test_user", "test_collection") + + mock_driver.execute_query.assert_called_with( + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri="http://example.com/node", + user="test_user", + collection="test_collection", + database_="memgraph" + ) + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + def test_create_literal_includes_user_collection(self, mock_graph_db): + """Test that create_literal includes user/collection properties""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=MagicMock()) + + processor.create_literal("test_value", "test_user", "test_collection") + + mock_driver.execute_query.assert_called_with( + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + value="test_value", + user="test_user", + collection="test_collection", + database_="memgraph" + ) + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + def test_relate_node_includes_user_collection(self, mock_graph_db): + """Test that relate_node includes user/collection properties""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 0 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=MagicMock()) + + processor.relate_node( + "http://example.com/subject", + "http://example.com/predicate", + "http://example.com/object", + "test_user", + "test_collection" + ) + + mock_driver.execute_query.assert_called_with( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + src="http://example.com/subject", + dest="http://example.com/object", + uri="http://example.com/predicate", + user="test_user", + collection="test_collection", + database_="memgraph" + ) + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + def test_relate_literal_includes_user_collection(self, mock_graph_db): + """Test that relate_literal includes user/collection properties""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 0 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=MagicMock()) + + processor.relate_literal( + "http://example.com/subject", + "http://example.com/predicate", + "literal_value", + "test_user", + "test_collection" + ) + + mock_driver.execute_query.assert_called_with( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + src="http://example.com/subject", + dest="literal_value", + uri="http://example.com/predicate", + user="test_user", + collection="test_collection", + database_="memgraph" + ) + + def test_add_args_includes_memgraph_parameters(self): + """Test that add_args properly configures Memgraph-specific parameters""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added with Memgraph defaults + args = parser.parse_args([]) + + assert hasattr(args, 'graph_host') + assert args.graph_host == 'bolt://memgraph:7687' + assert hasattr(args, 'username') + assert args.username == 'memgraph' + assert hasattr(args, 'password') + assert args.password == 'password' + assert hasattr(args, 'database') + assert args.database == 'memgraph' + + +class TestMemgraphUserCollectionRegression: + """Regression tests to ensure user/collection isolation prevents data leakage""" + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + @pytest.mark.asyncio + async def test_regression_no_cross_user_data_access(self, mock_graph_db): + """Regression test: Ensure users cannot access each other's data""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=MagicMock()) + + # Store data for user1 + triple = MagicMock() + triple.s.value = "http://example.com/subject" + triple.p.value = "http://example.com/predicate" + triple.o.value = "user1_data" + triple.o.is_uri = False + + message_user1 = MagicMock() + message_user1.triples = [triple] + message_user1.metadata.user = "user1" + message_user1.metadata.collection = "collection1" + + await processor.store_triples(message_user1) + + # Verify that all storage operations included user1/collection1 parameters + for call in mock_driver.execute_query.call_args_list: + call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] + if 'user' in call_kwargs: + assert call_kwargs['user'] == "user1" + assert call_kwargs['collection'] == "collection1" + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + @pytest.mark.asyncio + async def test_regression_same_uri_different_users(self, mock_graph_db): + """Regression test: Same URI can exist for different users without conflict""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=MagicMock()) + + # Same URI for different users should create separate nodes + processor.create_node("http://example.com/same-uri", "user1", "collection1") + processor.create_node("http://example.com/same-uri", "user2", "collection2") + + # Verify both calls were made with different user/collection parameters + calls = mock_driver.execute_query.call_args_list[-2:] # Get last 2 calls + + call1_kwargs = calls[0].kwargs if hasattr(calls[0], 'kwargs') else calls[0][1] + call2_kwargs = calls[1].kwargs if hasattr(calls[1], 'kwargs') else calls[1][1] + + assert call1_kwargs['user'] == "user1" and call1_kwargs['collection'] == "collection1" + assert call2_kwargs['user'] == "user2" and call2_kwargs['collection'] == "collection2" + + # Both should have the same URI but different user/collection + assert call1_kwargs['uri'] == call2_kwargs['uri'] == "http://example.com/same-uri" \ No newline at end of file diff --git a/tests/unit/test_storage/test_triples_memgraph_storage.py b/tests/unit/test_storage/test_triples_memgraph_storage.py index 83dfdbc4..4cced655 100644 --- a/tests/unit/test_storage/test_triples_memgraph_storage.py +++ b/tests/unit/test_storage/test_triples_memgraph_storage.py @@ -99,12 +99,16 @@ class TestMemgraphStorageProcessor: processor = Processor(taskgroup=taskgroup_mock) - # Verify index creation calls + # Verify index creation calls (now includes user/collection indexes) expected_calls = [ "CREATE INDEX ON :Node", "CREATE INDEX ON :Node(uri)", "CREATE INDEX ON :Literal", - "CREATE INDEX ON :Literal(value)" + "CREATE INDEX ON :Literal(value)", + "CREATE INDEX ON :Node(user)", + "CREATE INDEX ON :Node(collection)", + "CREATE INDEX ON :Literal(user)", + "CREATE INDEX ON :Literal(collection)" ] assert mock_session.run.call_count == len(expected_calls) @@ -127,8 +131,8 @@ class TestMemgraphStorageProcessor: # Should not raise an exception processor = Processor(taskgroup=taskgroup_mock) - # Verify all index creation calls were attempted - assert mock_session.run.call_count == 4 + # Verify all index creation calls were attempted (8 total) + assert mock_session.run.call_count == 8 def test_create_node(self, processor): """Test node creation""" @@ -141,11 +145,13 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.create_node(test_uri) + processor.create_node(test_uri, "test_user", "test_collection") processor.io.execute_query.assert_called_once_with( - "MERGE (n:Node {uri: $uri})", + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", uri=test_uri, + user="test_user", + collection="test_collection", database_=processor.db ) @@ -160,11 +166,13 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.create_literal(test_value) + processor.create_literal(test_value, "test_user", "test_collection") processor.io.execute_query.assert_called_once_with( - "MERGE (n:Literal {value: $value})", + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", value=test_value, + user="test_user", + collection="test_collection", database_=processor.db ) @@ -182,13 +190,14 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.relate_node(src_uri, pred_uri, dest_uri) + processor.relate_node(src_uri, pred_uri, dest_uri, "test_user", "test_collection") processor.io.execute_query.assert_called_once_with( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", src=src_uri, dest=dest_uri, uri=pred_uri, + user="test_user", collection="test_collection", database_=processor.db ) @@ -206,13 +215,14 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.relate_literal(src_uri, pred_uri, literal_value) + processor.relate_literal(src_uri, pred_uri, literal_value, "test_user", "test_collection") processor.io.execute_query.assert_called_once_with( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", src=src_uri, dest=literal_value, uri=pred_uri, + user="test_user", collection="test_collection", database_=processor.db ) @@ -226,19 +236,22 @@ class TestMemgraphStorageProcessor: o=Value(value='http://example.com/object', is_uri=True) ) - processor.create_triple(mock_tx, triple) + processor.create_triple(mock_tx, triple, "test_user", "test_collection") # Verify transaction calls expected_calls = [ # Create subject node - ("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}), + ("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + {'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}), # Create object node - ("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/object'}), + ("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + {'uri': 'http://example.com/object', 'user': 'test_user', 'collection': 'test_collection'}), # Create relationship - ("MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", - {'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate'}) + ("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + {'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate', + 'user': 'test_user', 'collection': 'test_collection'}) ] assert mock_tx.run.call_count == 3 @@ -257,19 +270,22 @@ class TestMemgraphStorageProcessor: o=Value(value='literal object', is_uri=False) ) - processor.create_triple(mock_tx, triple) + processor.create_triple(mock_tx, triple, "test_user", "test_collection") # Verify transaction calls expected_calls = [ # Create subject node - ("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}), + ("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + {'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}), # Create literal object - ("MERGE (n:Literal {value: $value})", {'value': 'literal object'}), + ("MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + {'value': 'literal object', 'user': 'test_user', 'collection': 'test_collection'}), # Create relationship - ("MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", - {'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate'}) + ("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + {'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate', + 'user': 'test_user', 'collection': 'test_collection'}) ] assert mock_tx.run.call_count == 3 @@ -281,33 +297,42 @@ class TestMemgraphStorageProcessor: @pytest.mark.asyncio async def test_store_triples_single_triple(self, processor, mock_message): """Test storing a single triple""" - mock_session = MagicMock() - processor.io.session.return_value.__enter__.return_value = mock_session + # Mock the execute_query method used by the direct methods + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + processor.io.execute_query.return_value = mock_result - # Reset the mock to clear the initialization call - processor.io.session.reset_mock() + # Reset the mock to clear initialization calls + processor.io.execute_query.reset_mock() await processor.store_triples(mock_message) - # Verify session was created with correct database - processor.io.session.assert_called_once_with(database=processor.db) + # Verify execute_query was called for create_node, create_literal, and relate_literal + # (since mock_message has a literal object) + assert processor.io.execute_query.call_count == 3 - # Verify execute_write was called once per triple - mock_session.execute_write.assert_called_once() - - # Verify the triple was passed to create_triple - call_args = mock_session.execute_write.call_args - assert call_args[0][0] == processor.create_triple - assert call_args[0][1] == mock_message.triples[0] + # Verify user/collection parameters were included + for call in processor.io.execute_query.call_args_list: + call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] + assert 'user' in call_kwargs + assert 'collection' in call_kwargs @pytest.mark.asyncio async def test_store_triples_multiple_triples(self, processor): """Test storing multiple triples""" - mock_session = MagicMock() - processor.io.session.return_value.__enter__.return_value = mock_session + # Mock the execute_query method used by the direct methods + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + processor.io.execute_query.return_value = mock_result - # Reset the mock to clear the initialization call - processor.io.session.reset_mock() + # Reset the mock to clear initialization calls + processor.io.execute_query.reset_mock() # Create message with multiple triples message = MagicMock() @@ -329,16 +354,17 @@ class TestMemgraphStorageProcessor: await processor.store_triples(message) - # Verify session was called twice (once per triple) - assert processor.io.session.call_count == 2 + # Verify execute_query was called: + # Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls + # Triple2: create_node(s) + create_node(o) + relate_node = 3 calls + # Total: 6 calls + assert processor.io.execute_query.call_count == 6 - # Verify execute_write was called once per triple - assert mock_session.execute_write.call_count == 2 - - # Verify each triple was processed - call_args_list = mock_session.execute_write.call_args_list - assert call_args_list[0][0][1] == triple1 - assert call_args_list[1][0][1] == triple2 + # Verify user/collection parameters were included in all calls + for call in processor.io.execute_query.call_args_list: + call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] + assert call_kwargs['user'] == 'test_user' + assert call_kwargs['collection'] == 'test_collection' @pytest.mark.asyncio async def test_store_triples_empty_list(self, processor): diff --git a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py index dcf00281..262f89ab 100755 --- a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py +++ b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py @@ -55,6 +55,10 @@ class Processor(TriplesQueryService): try: + # Extract user and collection, use defaults if not provided + user = query.user if query.user else "default" + collection = query.collection if query.collection else "default" + triples = [] if query.s is not None: @@ -64,10 +68,13 @@ class Processor(TriplesQueryService): # SPO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=query.s.value, rel=query.p.value, value=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -75,10 +82,13 @@ class Processor(TriplesQueryService): triples.append((query.s.value, query.p.value, query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Node {uri: $uri, user: $user, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=query.s.value, rel=query.p.value, uri=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -90,10 +100,13 @@ class Processor(TriplesQueryService): # SP records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " "RETURN dest.value as dest " "LIMIT " + str(query.limit), src=query.s.value, rel=query.p.value, + user=user, collection=collection, database_=self.db, ) @@ -102,10 +115,13 @@ class Processor(TriplesQueryService): triples.append((query.s.value, query.p.value, data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " "RETURN dest.uri as dest " "LIMIT " + str(query.limit), src=query.s.value, rel=query.p.value, + user=user, collection=collection, database_=self.db, ) @@ -120,10 +136,13 @@ class Processor(TriplesQueryService): # SO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=query.s.value, value=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -132,10 +151,13 @@ class Processor(TriplesQueryService): triples.append((query.s.value, data["rel"], query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {uri: $uri, user: $user, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=query.s.value, uri=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -148,10 +170,13 @@ class Processor(TriplesQueryService): # S records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), src=query.s.value, + user=user, collection=collection, database_=self.db, ) @@ -160,10 +185,13 @@ class Processor(TriplesQueryService): triples.append((query.s.value, data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " "RETURN rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), src=query.s.value, + user=user, collection=collection, database_=self.db, ) @@ -181,10 +209,13 @@ class Processor(TriplesQueryService): # PO records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=query.p.value, value=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -193,10 +224,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], query.p.value, query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Node {uri: $dest, user: $user, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=query.p.value, dest=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -209,10 +243,13 @@ class Processor(TriplesQueryService): # P records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT " + str(query.limit), uri=query.p.value, + user=user, collection=collection, database_=self.db, ) @@ -221,10 +258,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], query.p.value, data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " "RETURN src.uri as src, dest.uri as dest " "LIMIT " + str(query.limit), uri=query.p.value, + user=user, collection=collection, database_=self.db, ) @@ -239,10 +279,13 @@ class Processor(TriplesQueryService): # O records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), value=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -251,10 +294,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {uri: $uri, user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), uri=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -267,9 +313,12 @@ class Processor(TriplesQueryService): # * records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Literal) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), + user=user, collection=collection, database_=self.db, ) @@ -278,9 +327,12 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Node) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), + user=user, collection=collection, database_=self.db, ) diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index fa0260ac..0996111d 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -61,6 +61,7 @@ class Processor(TriplesStoreService): logger.info("Create indexes...") + # Legacy indexes for backwards compatibility try: session.run( "CREATE INDEX ON :Node", @@ -97,15 +98,48 @@ class Processor(TriplesStoreService): # Maybe index already exists logger.warning("Index create failure ignored") + # New indexes for user/collection filtering + try: + session.run( + "CREATE INDEX ON :Node(user)" + ) + except Exception as e: + logger.warning(f"User index create failure: {e}") + logger.warning("Index create failure ignored") + + try: + session.run( + "CREATE INDEX ON :Node(collection)" + ) + except Exception as e: + logger.warning(f"Collection index create failure: {e}") + logger.warning("Index create failure ignored") + + try: + session.run( + "CREATE INDEX ON :Literal(user)" + ) + except Exception as e: + logger.warning(f"User index create failure: {e}") + logger.warning("Index create failure ignored") + + try: + session.run( + "CREATE INDEX ON :Literal(collection)" + ) + except Exception as e: + logger.warning(f"Collection index create failure: {e}") + logger.warning("Index create failure ignored") + logger.info("Index creation done") - def create_node(self, uri): + def create_node(self, uri, user, collection): - logger.debug(f"Create node {uri}") + logger.debug(f"Create node {uri} for user={user}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Node {uri: $uri})", - uri=uri, + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri=uri, user=user, collection=collection, database_=self.db, ).summary @@ -114,13 +148,13 @@ class Processor(TriplesStoreService): time=summary.result_available_after )) - def create_literal(self, value): + def create_literal(self, value, user, collection): - logger.debug(f"Create literal {value}") + logger.debug(f"Create literal {value} for user={user}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Literal {value: $value})", - value=value, + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + value=value, user=user, collection=collection, database_=self.db, ).summary @@ -129,15 +163,15 @@ class Processor(TriplesStoreService): time=summary.result_available_after )) - def relate_node(self, src, uri, dest): + def relate_node(self, src, uri, dest, user, collection): - logger.debug(f"Create node rel {src} {uri} {dest}") + logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", - src=src, dest=dest, uri=uri, + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, user=user, collection=collection, database_=self.db, ).summary @@ -146,15 +180,15 @@ class Processor(TriplesStoreService): time=summary.result_available_after )) - def relate_literal(self, src, uri, dest): + def relate_literal(self, src, uri, dest, user, collection): - logger.debug(f"Create literal rel {src} {uri} {dest}") + logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", - src=src, dest=dest, uri=uri, + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, user=user, collection=collection, database_=self.db, ).summary @@ -163,59 +197,64 @@ class Processor(TriplesStoreService): time=summary.result_available_after )) - def create_triple(self, tx, t): + def create_triple(self, tx, t, user, collection): # Create new s node with given uri, if not exists result = tx.run( - "MERGE (n:Node {uri: $uri})", - uri=t.s.value + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri=t.s.value, user=user, collection=collection ) if t.o.is_uri: # Create new o node with given uri, if not exists result = tx.run( - "MERGE (n:Node {uri: $uri})", - uri=t.o.value + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri=t.o.value, user=user, collection=collection ) result = tx.run( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", - src=t.s.value, dest=t.o.value, uri=t.p.value, + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection, ) else: # Create new o literal with given uri, if not exists result = tx.run( - "MERGE (n:Literal {value: $value})", - value=t.o.value + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + value=t.o.value, user=user, collection=collection ) result = tx.run( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", - src=t.s.value, dest=t.o.value, uri=t.p.value, + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection, ) async def store_triples(self, message): + # Extract user and collection from metadata + user = message.metadata.user if message.metadata.user else "default" + collection = message.metadata.collection if message.metadata.collection else "default" + for t in message.triples: - # self.create_node(t.s.value) + self.create_node(t.s.value, user, collection) - # if t.o.is_uri: - # self.create_node(t.o.value) - # self.relate_node(t.s.value, t.p.value, t.o.value) - # else: - # self.create_literal(t.o.value) - # self.relate_literal(t.s.value, t.p.value, t.o.value) + if t.o.is_uri: + self.create_node(t.o.value, user, collection) + self.relate_node(t.s.value, t.p.value, t.o.value, user, collection) + else: + self.create_literal(t.o.value, user, collection) + self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection) - with self.io.session(database=self.db) as session: - session.execute_write(self.create_triple, t) + # Alternative implementation using transactions + # with self.io.session(database=self.db) as session: + # session.execute_write(self.create_triple, t, user, collection) @staticmethod def add_args(parser):