diff --git a/docs/tech-specs/neo4j-user-collection-isolation.md b/docs/tech-specs/neo4j-user-collection-isolation.md new file mode 100644 index 00000000..62623c07 --- /dev/null +++ b/docs/tech-specs/neo4j-user-collection-isolation.md @@ -0,0 +1,359 @@ +# Neo4j User/Collection Isolation Support + +## Problem Statement + +The Neo4j triples storage and query implementation currently lacks user/collection isolation, which creates a multi-tenancy security issue. All triples are stored in the same graph space without any mechanism to prevent users from accessing other users' data or mixing collections. + +Unlike other storage backends in TrustGraph: +- **Cassandra**: Uses separate keyspaces per user and tables per collection +- **Vector stores** (Milvus, Qdrant, Pinecone): Use collection-specific namespaces +- **Neo4j**: Currently shares all data in a single graph (security vulnerability) + +## Current Architecture + +### Data Model +- **Nodes**: `:Node` label with `uri` property, `:Literal` label with `value` property +- **Relationships**: `:Rel` label with `uri` property +- **Indexes**: `Node.uri`, `Literal.value`, `Rel.uri` + +### Message Flow +- `Triples` messages contain `metadata.user` and `metadata.collection` fields +- Storage service receives user/collection info but ignores it +- Query service expects `user` and `collection` in `TriplesQueryRequest` but ignores them + +### Current Security Issue +```cypher +# Any user can query any data - no isolation +MATCH (src:Node)-[rel:Rel]->(dest:Node) +RETURN src.uri, rel.uri, dest.uri +``` + +## Proposed Solution: Property-Based Filtering (Recommended) + +### Overview +Add `user` and `collection` properties to all nodes and relationships, then filter all operations by these properties. This approach provides strong isolation while maintaining query flexibility and backwards compatibility. + +### Data Model Changes + +#### Enhanced Node Structure +```cypher +// Node entities +CREATE (n:Node { + uri: "http://example.com/entity1", + user: "john_doe", + collection: "production_v1" +}) + +// Literal entities +CREATE (n:Literal { + value: "literal value", + user: "john_doe", + collection: "production_v1" +}) +``` + +#### Enhanced Relationship Structure +```cypher +// Relationships with user/collection properties +CREATE (src)-[:Rel { + uri: "http://example.com/predicate1", + user: "john_doe", + collection: "production_v1" +}]->(dest) +``` + +#### Updated Indexes +```cypher +// Compound indexes for efficient filtering +CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri); +CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value); +CREATE INDEX rel_user_collection_uri FOR ()-[r:Rel]-() ON (r.user, r.collection, r.uri); + +// Maintain existing indexes for backwards compatibility (optional) +CREATE INDEX Node_uri FOR (n:Node) ON (n.uri); +CREATE INDEX Literal_value FOR (n:Literal) ON (n.value); +CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri); +``` + +### Implementation Changes + +#### Storage Service (`write.py`) + +**Current Code:** +```python +def create_node(self, uri): + summary = self.io.execute_query( + "MERGE (n:Node {uri: $uri})", + uri=uri, database_=self.db, + ).summary +``` + +**Updated Code:** +```python +def create_node(self, uri, user, collection): + summary = self.io.execute_query( + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri=uri, user=user, collection=collection, database_=self.db, + ).summary +``` + +**Enhanced store_triples Method:** +```python +async def store_triples(self, message): + user = message.metadata.user + collection = message.metadata.collection + + for t in message.triples: + self.create_node(t.s.value, user, collection) + + if t.o.is_uri: + self.create_node(t.o.value, user, collection) + self.relate_node(t.s.value, t.p.value, t.o.value, user, collection) + else: + self.create_literal(t.o.value, user, collection) + self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection) +``` + +#### Query Service (`service.py`) + +**Current Code:** +```python +records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " + "RETURN dest.uri as dest", + src=query.s.value, rel=query.p.value, database_=self.db, +) +``` + +**Updated Code:** +```python +records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " + "RETURN dest.uri as dest", + src=query.s.value, rel=query.p.value, + user=query.user, collection=query.collection, + database_=self.db, +) +``` + +### Migration Strategy + +#### Phase 1: Add Properties to New Data +1. Update storage service to add user/collection properties to new triples +2. Maintain backwards compatibility by not requiring properties in queries +3. Existing data remains accessible but not isolated + +#### Phase 2: Migrate Existing Data +```cypher +// Migrate existing nodes (requires default user/collection assignment) +MATCH (n:Node) WHERE n.user IS NULL +SET n.user = 'legacy_user', n.collection = 'default_collection'; + +MATCH (n:Literal) WHERE n.user IS NULL +SET n.user = 'legacy_user', n.collection = 'default_collection'; + +MATCH ()-[r:Rel]->() WHERE r.user IS NULL +SET r.user = 'legacy_user', r.collection = 'default_collection'; +``` + +#### Phase 3: Enforce Isolation +1. Update query service to require user/collection filtering +2. Add validation to reject queries without proper user/collection context +3. Remove legacy data access paths + +### Security Considerations + +#### Query Validation +```python +async def query_triples(self, query): + # Validate user/collection parameters + if not query.user or not query.collection: + raise ValueError("User and collection must be specified") + + # All queries must include user/collection filters + # ... rest of implementation +``` + +#### Preventing Parameter Injection +- Use parameterized queries exclusively +- Validate user/collection values against allowed patterns +- Consider sanitization for Neo4j property name requirements + +#### Audit Trail +```python +logger.info(f"Query executed - User: {query.user}, Collection: {query.collection}, " + f"Pattern: {query.s}/{query.p}/{query.o}") +``` + +## Alternative Approaches Considered + +### Option 2: Label-Based Isolation + +**Approach**: Use dynamic labels like `User_john_Collection_prod` + +**Pros:** +- Strong isolation through label filtering +- Efficient query performance with label indexes +- Clear data separation + +**Cons:** +- Neo4j has practical limits on number of labels (~1000s) +- Complex label name generation and sanitization +- Difficult to query across collections when needed + +**Implementation Example:** +```cypher +CREATE (n:Node:User_john_Collection_prod {uri: "http://example.com/entity"}) +MATCH (n:User_john_Collection_prod) WHERE n:Node RETURN n +``` + +### Option 3: Database-Per-User + +**Approach**: Create separate Neo4j databases for each user or user/collection combination + +**Pros:** +- Complete data isolation +- No risk of cross-contamination +- Independent scaling per user + +**Cons:** +- Resource overhead (each database consumes memory) +- Complex database lifecycle management +- Neo4j Community Edition database limits +- Difficult cross-user analytics + +### Option 4: Composite Key Strategy + +**Approach**: Prefix all URIs and values with user/collection information + +**Pros:** +- Backwards compatible with existing queries +- Simple implementation +- No schema changes required + +**Cons:** +- URI pollution affects data semantics +- Less efficient queries (string prefix matching) +- Breaks RDF/semantic web standards + +**Implementation Example:** +```python +def make_composite_uri(uri, user, collection): + return f"usr:{user}:col:{collection}:uri:{uri}" +``` + +## Implementation Plan + +### Phase 1: Foundation (Week 1) +1. [ ] Update storage service to accept and store user/collection properties +2. [ ] Add compound indexes for efficient querying +3. [ ] Implement backwards compatibility layer +4. [ ] Create unit tests for new functionality + +### Phase 2: Query Updates (Week 2) +1. [ ] Update all query patterns to include user/collection filters +2. [ ] Add query validation and security checks +3. [ ] Update integration tests +4. [ ] Performance testing with filtered queries + +### Phase 3: Migration & Deployment (Week 3) +1. [ ] Create data migration scripts for existing Neo4j instances +2. [ ] Deployment documentation and runbooks +3. [ ] Monitoring and alerting for isolation violations +4. [ ] End-to-end testing with multiple users/collections + +### Phase 4: Hardening (Week 4) +1. [ ] Remove legacy compatibility mode +2. [ ] Add comprehensive audit logging +3. [ ] Security review and penetration testing +4. [ ] Performance optimization + +## Testing Strategy + +### Unit Tests +```python +def test_user_collection_isolation(): + # Store triples for user1/collection1 + processor.store_triples(triples_user1_coll1) + + # Store triples for user2/collection2 + processor.store_triples(triples_user2_coll2) + + # Query as user1 should only return user1's data + results = processor.query_triples(query_user1_coll1) + assert all_results_belong_to_user1_coll1(results) + + # Query as user2 should only return user2's data + results = processor.query_triples(query_user2_coll2) + assert all_results_belong_to_user2_coll2(results) +``` + +### Integration Tests +- Multi-user scenarios with overlapping data +- Cross-collection queries (should fail) +- Migration testing with existing data +- Performance benchmarks with large datasets + +### Security Tests +- Attempt to query other users' data +- SQL injection style attacks on user/collection parameters +- Verify complete isolation under various query patterns + +## Performance Considerations + +### Index Strategy +- Compound indexes on `(user, collection, uri)` for optimal filtering +- Consider partial indexes if some collections are much larger +- Monitor index usage and query performance + +### Query Optimization +- Use EXPLAIN to verify index usage in filtered queries +- Consider query result caching for frequently accessed data +- Profile memory usage with large numbers of users/collections + +### Scalability +- Each user/collection combination creates separate data islands +- Monitor database size and connection pool usage +- Consider horizontal scaling strategies if needed + +## Security & Compliance + +### Data Isolation Guarantees +- **Physical**: All user data stored with explicit user/collection properties +- **Logical**: All queries filtered by user/collection context +- **Access Control**: Service-level validation prevents unauthorized access + +### Audit Requirements +- Log all data access with user/collection context +- Track migration activities and data movements +- Monitor for isolation violation attempts + +### Compliance Considerations +- GDPR: Enhanced ability to locate and delete user-specific data +- SOC2: Clear data isolation and access controls +- HIPAA: Strong tenant isolation for healthcare data + +## Risks & Mitigations + +| Risk | Impact | Likelihood | Mitigation | +|------|--------|------------|------------| +| Query missing user/collection filter | High | Medium | Mandatory validation, comprehensive testing | +| Performance degradation | Medium | Low | Index optimization, query profiling | +| Migration data corruption | High | Low | Backup strategy, rollback procedures | +| Complex multi-collection queries | Medium | Medium | Document query patterns, provide examples | + +## Success Criteria + +1. **Security**: Zero cross-user data access in production +2. **Performance**: <10% query performance impact vs unfiltered queries +3. **Migration**: 100% existing data successfully migrated with zero loss +4. **Usability**: All existing query patterns work with user/collection context +5. **Compliance**: Full audit trail of user/collection data access + +## Conclusion + +The property-based filtering approach provides the best balance of security, performance, and maintainability for adding user/collection isolation to Neo4j. It aligns with TrustGraph's existing multi-tenancy patterns while leveraging Neo4j's strengths in graph querying and indexing. + +This solution ensures TrustGraph's Neo4j backend meets the same security standards as other storage backends, preventing data isolation vulnerabilities while maintaining the flexibility and power of graph queries. \ No newline at end of file diff --git a/tests/unit/test_direct/test_milvus_collection_naming.py b/tests/unit/test_direct/test_milvus_collection_naming.py new file mode 100644 index 00000000..9c6b0a90 --- /dev/null +++ b/tests/unit/test_direct/test_milvus_collection_naming.py @@ -0,0 +1,231 @@ +""" +Unit tests for Milvus collection name sanitization functionality +""" + +import pytest +from trustgraph.direct.milvus_doc_embeddings import make_safe_collection_name + + +class TestMilvusCollectionNaming: + """Test cases for Milvus collection name generation and sanitization""" + + def test_make_safe_collection_name_basic(self): + """Test basic collection name creation""" + result = make_safe_collection_name( + user="test_user", + collection="test_collection", + dimension=384, + prefix="doc" + ) + assert result == "doc_test_user_test_collection_384" + + def test_make_safe_collection_name_with_special_characters(self): + """Test collection name creation with special characters that need sanitization""" + result = make_safe_collection_name( + user="user@domain.com", + collection="test-collection.v2", + dimension=768, + prefix="entity" + ) + assert result == "entity_user_domain_com_test_collection_v2_768" + + def test_make_safe_collection_name_with_unicode(self): + """Test collection name creation with Unicode characters""" + result = make_safe_collection_name( + user="测试用户", + collection="colección_española", + dimension=512, + prefix="doc" + ) + assert result == "doc_default_colecci_n_espa_ola_512" + + def test_make_safe_collection_name_with_spaces(self): + """Test collection name creation with spaces""" + result = make_safe_collection_name( + user="test user", + collection="my test collection", + dimension=256, + prefix="entity" + ) + assert result == "entity_test_user_my_test_collection_256" + + def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self): + """Test collection name creation with multiple consecutive special characters""" + result = make_safe_collection_name( + user="user@@@domain!!!", + collection="test---collection...v2", + dimension=384, + prefix="doc" + ) + assert result == "doc_user_domain_test_collection_v2_384" + + def test_make_safe_collection_name_with_leading_trailing_underscores(self): + """Test collection name creation with leading/trailing special characters""" + result = make_safe_collection_name( + user="__test_user__", + collection="@@test_collection##", + dimension=128, + prefix="entity" + ) + assert result == "entity_test_user_test_collection_128" + + def test_make_safe_collection_name_empty_user(self): + """Test collection name creation with empty user (should fallback to 'default')""" + result = make_safe_collection_name( + user="", + collection="test_collection", + dimension=384, + prefix="doc" + ) + assert result == "doc_default_test_collection_384" + + def test_make_safe_collection_name_empty_collection(self): + """Test collection name creation with empty collection (should fallback to 'default')""" + result = make_safe_collection_name( + user="test_user", + collection="", + dimension=384, + prefix="doc" + ) + assert result == "doc_test_user_default_384" + + def test_make_safe_collection_name_both_empty(self): + """Test collection name creation with both user and collection empty""" + result = make_safe_collection_name( + user="", + collection="", + dimension=384, + prefix="doc" + ) + assert result == "doc_default_default_384" + + def test_make_safe_collection_name_only_special_characters(self): + """Test collection name creation with only special characters (should fallback to 'default')""" + result = make_safe_collection_name( + user="@@@!!!", + collection="---###", + dimension=512, + prefix="entity" + ) + assert result == "entity_default_default_512" + + def test_make_safe_collection_name_whitespace_only(self): + """Test collection name creation with whitespace-only strings""" + result = make_safe_collection_name( + user=" \n\t ", + collection=" \r\n ", + dimension=256, + prefix="doc" + ) + assert result == "doc_default_default_256" + + def test_make_safe_collection_name_mixed_valid_invalid_chars(self): + """Test collection name creation with mixed valid and invalid characters""" + result = make_safe_collection_name( + user="user123@test", + collection="coll_2023.v1", + dimension=384, + prefix="entity" + ) + assert result == "entity_user123_test_coll_2023_v1_384" + + def test_make_safe_collection_name_different_prefixes(self): + """Test collection name creation with different prefixes""" + user = "test_user" + collection = "test_collection" + dimension = 384 + + doc_result = make_safe_collection_name(user, collection, dimension, "doc") + entity_result = make_safe_collection_name(user, collection, dimension, "entity") + custom_result = make_safe_collection_name(user, collection, dimension, "custom") + + assert doc_result == "doc_test_user_test_collection_384" + assert entity_result == "entity_test_user_test_collection_384" + assert custom_result == "custom_test_user_test_collection_384" + + def test_make_safe_collection_name_different_dimensions(self): + """Test collection name creation with different dimensions""" + user = "test_user" + collection = "test_collection" + prefix = "doc" + + result_128 = make_safe_collection_name(user, collection, 128, prefix) + result_384 = make_safe_collection_name(user, collection, 384, prefix) + result_768 = make_safe_collection_name(user, collection, 768, prefix) + + assert result_128 == "doc_test_user_test_collection_128" + assert result_384 == "doc_test_user_test_collection_384" + assert result_768 == "doc_test_user_test_collection_768" + + def test_make_safe_collection_name_long_names(self): + """Test collection name creation with very long user/collection names""" + long_user = "a" * 100 + long_collection = "b" * 100 + + result = make_safe_collection_name( + user=long_user, + collection=long_collection, + dimension=384, + prefix="doc" + ) + + expected = f"doc_{long_user}_{long_collection}_384" + assert result == expected + assert len(result) > 200 # Verify it handles long names + + def test_make_safe_collection_name_numeric_values(self): + """Test collection name creation with numeric user/collection values""" + result = make_safe_collection_name( + user="user123", + collection="collection456", + dimension=384, + prefix="doc" + ) + assert result == "doc_user123_collection456_384" + + def test_make_safe_collection_name_case_sensitivity(self): + """Test that collection name creation preserves case""" + result = make_safe_collection_name( + user="TestUser", + collection="TestCollection", + dimension=384, + prefix="Doc" + ) + assert result == "Doc_TestUser_TestCollection_384" + + def test_make_safe_collection_name_realistic_examples(self): + """Test collection name creation with realistic user/collection combinations""" + test_cases = [ + # (user, collection, expected_safe_user, expected_safe_collection) + ("john.doe", "production-2024", "john_doe", "production_2024"), + ("team@company.com", "ml_models.v1", "team_company_com", "ml_models_v1"), + ("user_123", "test_collection", "user_123", "test_collection"), + ("αβγ-user", "测试集合", "user", "default"), + ] + + for user, collection, expected_user, expected_collection in test_cases: + result = make_safe_collection_name(user, collection, 384, "doc") + assert result == f"doc_{expected_user}_{expected_collection}_384" + + def test_make_safe_collection_name_matches_qdrant_pattern(self): + """Test that Milvus collection names follow similar pattern to Qdrant""" + # Qdrant uses: "d_{user}_{collection}_{dimension}" and "t_{user}_{collection}_{dimension}" + # Milvus should use: "{prefix}_{safe_user}_{safe_collection}_{dimension}" + + user = "test.user@domain.com" + collection = "test-collection.v2" + dimension = 384 + + doc_result = make_safe_collection_name(user, collection, dimension, "doc") + entity_result = make_safe_collection_name(user, collection, dimension, "entity") + + # Should follow the pattern but with sanitized names + assert doc_result == "doc_test_user_domain_com_test_collection_v2_384" + assert entity_result == "entity_test_user_domain_com_test_collection_v2_384" + + # Verify structure matches expected pattern (may have more underscores due to sanitization) + # The important thing is that it follows prefix_user_collection_dimension structure + assert doc_result.startswith("doc_") + assert doc_result.endswith("_384") + assert entity_result.startswith("entity_") + assert entity_result.endswith("_384") \ No newline at end of file diff --git a/tests/unit/test_direct/test_milvus_user_collection_integration.py b/tests/unit/test_direct/test_milvus_user_collection_integration.py new file mode 100644 index 00000000..931332e4 --- /dev/null +++ b/tests/unit/test_direct/test_milvus_user_collection_integration.py @@ -0,0 +1,309 @@ +""" +Integration tests for Milvus user/collection functionality +Tests the complete flow of the new user/collection parameter handling +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.direct.milvus_doc_embeddings import DocVectors, make_safe_collection_name +from trustgraph.direct.milvus_graph_embeddings import EntityVectors + + +class TestMilvusUserCollectionIntegration: + """Test cases for Milvus user/collection integration functionality""" + + @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') + def test_doc_vectors_collection_creation_with_user_collection(self, mock_milvus_client): + """Test DocVectors creates collections with proper user/collection names""" + mock_client = MagicMock() + mock_milvus_client.return_value = mock_client + + doc_vectors = DocVectors(uri="http://test:19530", prefix="doc") + + # Test collection creation for different user/collection combinations + test_cases = [ + ("user1", "collection1", [0.1, 0.2, 0.3]), + ("user2", "collection2", [0.1, 0.2, 0.3, 0.4]), + ("user@domain.com", "test-collection.v1", [0.1, 0.2, 0.3]), + ] + + for user, collection, vector in test_cases: + doc_vectors.insert(vector, "test document", user, collection) + + expected_collection_name = make_safe_collection_name( + user, collection, len(vector), "doc" + ) + + # Verify collection was created with correct name + assert (len(vector), user, collection) in doc_vectors.collections + assert doc_vectors.collections[(len(vector), user, collection)] == expected_collection_name + + @patch('trustgraph.direct.milvus_graph_embeddings.MilvusClient') + def test_entity_vectors_collection_creation_with_user_collection(self, mock_milvus_client): + """Test EntityVectors creates collections with proper user/collection names""" + mock_client = MagicMock() + mock_milvus_client.return_value = mock_client + + entity_vectors = EntityVectors(uri="http://test:19530", prefix="entity") + + # Test collection creation for different user/collection combinations + test_cases = [ + ("user1", "collection1", [0.1, 0.2, 0.3]), + ("user2", "collection2", [0.1, 0.2, 0.3, 0.4]), + ("user@domain.com", "test-collection.v1", [0.1, 0.2, 0.3]), + ] + + for user, collection, vector in test_cases: + entity_vectors.insert(vector, "test entity", user, collection) + + expected_collection_name = make_safe_collection_name( + user, collection, len(vector), "entity" + ) + + # Verify collection was created with correct name + assert (len(vector), user, collection) in entity_vectors.collections + assert entity_vectors.collections[(len(vector), user, collection)] == expected_collection_name + + @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') + def test_doc_vectors_search_uses_correct_collection(self, mock_milvus_client): + """Test DocVectors search uses the correct collection for user/collection""" + mock_client = MagicMock() + mock_milvus_client.return_value = mock_client + + # Mock search results + mock_client.search.return_value = [ + {"entity": {"doc": "test document"}} + ] + + doc_vectors = DocVectors(uri="http://test:19530", prefix="doc") + + # First insert to create collection + vector = [0.1, 0.2, 0.3] + user = "test_user" + collection = "test_collection" + + doc_vectors.insert(vector, "test doc", user, collection) + + # Now search + result = doc_vectors.search(vector, user, collection, limit=5) + + # Verify search was called with correct collection name + expected_collection_name = make_safe_collection_name(user, collection, 3, "doc") + mock_client.search.assert_called_once() + search_call = mock_client.search.call_args + assert search_call[1]["collection_name"] == expected_collection_name + + @patch('trustgraph.direct.milvus_graph_embeddings.MilvusClient') + def test_entity_vectors_search_uses_correct_collection(self, mock_milvus_client): + """Test EntityVectors search uses the correct collection for user/collection""" + mock_client = MagicMock() + mock_milvus_client.return_value = mock_client + + # Mock search results + mock_client.search.return_value = [ + {"entity": {"entity": "test entity"}} + ] + + entity_vectors = EntityVectors(uri="http://test:19530", prefix="entity") + + # First insert to create collection + vector = [0.1, 0.2, 0.3] + user = "test_user" + collection = "test_collection" + + entity_vectors.insert(vector, "test entity", user, collection) + + # Now search + result = entity_vectors.search(vector, user, collection, limit=5) + + # Verify search was called with correct collection name + expected_collection_name = make_safe_collection_name(user, collection, 3, "entity") + mock_client.search.assert_called_once() + search_call = mock_client.search.call_args + assert search_call[1]["collection_name"] == expected_collection_name + + @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') + def test_doc_vectors_collection_isolation(self, mock_milvus_client): + """Test that different user/collection combinations create separate collections""" + mock_client = MagicMock() + mock_milvus_client.return_value = mock_client + + doc_vectors = DocVectors(uri="http://test:19530", prefix="doc") + + # Insert same vector for different user/collection combinations + vector = [0.1, 0.2, 0.3] + doc_vectors.insert(vector, "user1 doc", "user1", "collection1") + doc_vectors.insert(vector, "user2 doc", "user2", "collection2") + doc_vectors.insert(vector, "user1 doc2", "user1", "collection2") + + # Verify three separate collections were created + assert len(doc_vectors.collections) == 3 + + collection_names = set(doc_vectors.collections.values()) + expected_names = { + "doc_user1_collection1_3", + "doc_user2_collection2_3", + "doc_user1_collection2_3" + } + assert collection_names == expected_names + + @patch('trustgraph.direct.milvus_graph_embeddings.MilvusClient') + def test_entity_vectors_collection_isolation(self, mock_milvus_client): + """Test that different user/collection combinations create separate collections""" + mock_client = MagicMock() + mock_milvus_client.return_value = mock_client + + entity_vectors = EntityVectors(uri="http://test:19530", prefix="entity") + + # Insert same vector for different user/collection combinations + vector = [0.1, 0.2, 0.3] + entity_vectors.insert(vector, "user1 entity", "user1", "collection1") + entity_vectors.insert(vector, "user2 entity", "user2", "collection2") + entity_vectors.insert(vector, "user1 entity2", "user1", "collection2") + + # Verify three separate collections were created + assert len(entity_vectors.collections) == 3 + + collection_names = set(entity_vectors.collections.values()) + expected_names = { + "entity_user1_collection1_3", + "entity_user2_collection2_3", + "entity_user1_collection2_3" + } + assert collection_names == expected_names + + @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') + def test_doc_vectors_dimension_isolation(self, mock_milvus_client): + """Test that different dimensions create separate collections even with same user/collection""" + mock_client = MagicMock() + mock_milvus_client.return_value = mock_client + + doc_vectors = DocVectors(uri="http://test:19530", prefix="doc") + + user = "test_user" + collection = "test_collection" + + # Insert vectors with different dimensions + doc_vectors.insert([0.1, 0.2, 0.3], "3D doc", user, collection) # 3D + doc_vectors.insert([0.1, 0.2, 0.3, 0.4], "4D doc", user, collection) # 4D + doc_vectors.insert([0.1, 0.2], "2D doc", user, collection) # 2D + + # Verify three separate collections were created for different dimensions + assert len(doc_vectors.collections) == 3 + + collection_names = set(doc_vectors.collections.values()) + expected_names = { + "doc_test_user_test_collection_3", # 3D + "doc_test_user_test_collection_4", # 4D + "doc_test_user_test_collection_2" # 2D + } + assert collection_names == expected_names + + @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') + def test_doc_vectors_collection_reuse(self, mock_milvus_client): + """Test that same user/collection/dimension reuses existing collection""" + mock_client = MagicMock() + mock_milvus_client.return_value = mock_client + + doc_vectors = DocVectors(uri="http://test:19530", prefix="doc") + + user = "test_user" + collection = "test_collection" + vector = [0.1, 0.2, 0.3] + + # Insert multiple documents with same user/collection/dimension + doc_vectors.insert(vector, "doc1", user, collection) + doc_vectors.insert(vector, "doc2", user, collection) + doc_vectors.insert(vector, "doc3", user, collection) + + # Verify only one collection was created + assert len(doc_vectors.collections) == 1 + + expected_collection_name = "doc_test_user_test_collection_3" + assert doc_vectors.collections[(3, user, collection)] == expected_collection_name + + @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') + def test_doc_vectors_special_characters_handling(self, mock_milvus_client): + """Test that special characters in user/collection names are handled correctly""" + mock_client = MagicMock() + mock_milvus_client.return_value = mock_client + + doc_vectors = DocVectors(uri="http://test:19530", prefix="doc") + + # Test various special character combinations + test_cases = [ + ("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1_3"), + ("user_123", "collection_456", "doc_user_123_collection_456_3"), + ("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces_3"), + ("user@@@test", "collection---test", "doc_user_test_collection_test_3"), + ] + + vector = [0.1, 0.2, 0.3] + + for user, collection, expected_name in test_cases: + doc_vectors_instance = DocVectors(uri="http://test:19530", prefix="doc") + doc_vectors_instance.insert(vector, "test doc", user, collection) + + assert doc_vectors_instance.collections[(3, user, collection)] == expected_name + + def test_collection_name_backward_compatibility(self): + """Test that new collection names don't conflict with old pattern""" + # Old pattern was: {prefix}_{dimension} + # New pattern is: {prefix}_{safe_user}_{safe_collection}_{dimension} + + # The new pattern should never generate names that match the old pattern + old_pattern_examples = ["doc_384", "entity_768", "doc_512"] + + test_cases = [ + ("user", "collection", 384, "doc"), + ("test", "test", 768, "entity"), + ("a", "b", 512, "doc"), + ] + + for user, collection, dimension, prefix in test_cases: + new_name = make_safe_collection_name(user, collection, dimension, prefix) + + # New names should have at least 4 underscores (prefix_user_collection_dimension) + # Old names had only 1 underscore (prefix_dimension) + assert new_name.count('_') >= 3, f"New name {new_name} doesn't have enough underscores" + + # New names should not match old pattern + assert new_name not in old_pattern_examples, f"New name {new_name} conflicts with old pattern" + + def test_user_collection_isolation_regression(self): + """ + Regression test to ensure user/collection parameters prevent data mixing. + + This test guards against the bug where all users shared the same Milvus + collections, causing data contamination between users/collections. + """ + + # Test the specific case that was broken before the fix + user1, collection1 = "my_user", "test_coll_1" + user2, collection2 = "other_user", "production_data" + + dimension = 384 + + # Generate collection names + doc_name1 = make_safe_collection_name(user1, collection1, dimension, "doc") + doc_name2 = make_safe_collection_name(user2, collection2, dimension, "doc") + + entity_name1 = make_safe_collection_name(user1, collection1, dimension, "entity") + entity_name2 = make_safe_collection_name(user2, collection2, dimension, "entity") + + # Verify complete isolation + assert doc_name1 != doc_name2, "Document collections should be isolated" + assert entity_name1 != entity_name2, "Entity collections should be isolated" + + # Verify names match expected pattern from Qdrant + # Qdrant uses: d_{user}_{collection}_{dimension}, t_{user}_{collection}_{dimension} + # Milvus uses: doc_{safe_user}_{safe_collection}_{dimension}, entity_{safe_user}_{safe_collection}_{dimension} + assert doc_name1 == "doc_my_user_test_coll_1_384" + assert doc_name2 == "doc_other_user_production_data_384" + assert entity_name1 == "entity_my_user_test_coll_1_384" + assert entity_name2 == "entity_other_user_production_data_384" + + # This test would have FAILED with the old implementation that used: + # - doc_384 for all document embeddings (no user/collection differentiation) + # - entity_384 for all graph embeddings (no user/collection differentiation) \ No newline at end of file diff --git a/tests/unit/test_query/test_neo4j_user_collection_query.py b/tests/unit/test_query/test_neo4j_user_collection_query.py new file mode 100644 index 00000000..bf23680c --- /dev/null +++ b/tests/unit/test_query/test_neo4j_user_collection_query.py @@ -0,0 +1,430 @@ +""" +Tests for Neo4j user/collection isolation in query service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.query.triples.neo4j.service import Processor +from trustgraph.schema import TriplesQueryRequest, Value + + +class TestNeo4jQueryUserCollectionIsolation: + """Test cases for Neo4j query service with user/collection isolation""" + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_spo_query_with_user_collection(self, mock_graph_db): + """Test SPO query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=Value(value="http://example.com/p", is_uri=True), + o=Value(value="test_object", is_uri=False) + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify SPO query for literal includes user/collection + expected_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "RETURN $src as src" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + src="http://example.com/s", + rel="http://example.com/p", + value="test_object", + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_sp_query_with_user_collection(self, mock_graph_db): + """Test SP query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=Value(value="http://example.com/p", is_uri=True), + o=None + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify SP query for literals includes user/collection + expected_literal_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN dest.value as dest" + ) + + mock_driver.execute_query.assert_any_call( + expected_literal_query, + src="http://example.com/s", + rel="http://example.com/p", + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + # Verify SP query for nodes includes user/collection + expected_node_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " + "RETURN dest.uri as dest" + ) + + mock_driver.execute_query.assert_any_call( + expected_node_query, + src="http://example.com/s", + rel="http://example.com/p", + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_so_query_with_user_collection(self, mock_graph_db): + """Test SO query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=None, + o=Value(value="http://example.com/o", is_uri=True) + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify SO query for nodes includes user/collection + expected_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "RETURN rel.uri as rel" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + src="http://example.com/s", + uri="http://example.com/o", + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_s_only_query_with_user_collection(self, mock_graph_db): + """Test S-only query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=None, + o=None + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify S query includes user/collection + expected_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN rel.uri as rel, dest.value as dest" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + src="http://example.com/s", + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_po_query_with_user_collection(self, mock_graph_db): + """Test PO query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=None, + p=Value(value="http://example.com/p", is_uri=True), + o=Value(value="literal", is_uri=False) + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify PO query for literals includes user/collection + expected_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "RETURN src.uri as src" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + uri="http://example.com/p", + value="literal", + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_p_only_query_with_user_collection(self, mock_graph_db): + """Test P-only query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=None, + p=Value(value="http://example.com/p", is_uri=True), + o=None + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify P query includes user/collection + expected_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN src.uri as src, dest.value as dest" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + uri="http://example.com/p", + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_o_only_query_with_user_collection(self, mock_graph_db): + """Test O-only query pattern includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=None, + p=None, + o=Value(value="test_value", is_uri=False) + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify O query for literals includes user/collection + expected_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "RETURN src.uri as src, rel.uri as rel" + ) + + mock_driver.execute_query.assert_any_call( + expected_query, + value="test_value", + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_wildcard_query_with_user_collection(self, mock_graph_db): + """Test wildcard query (all None) includes user/collection filtering""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=None, + p=None, + o=None + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify wildcard query for literals includes user/collection + expected_literal_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN src.uri as src, rel.uri as rel, dest.value as dest" + ) + + mock_driver.execute_query.assert_any_call( + expected_literal_query, + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + # Verify wildcard query for nodes includes user/collection + expected_node_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " + "RETURN src.uri as src, rel.uri as rel, dest.uri as dest" + ) + + mock_driver.execute_query.assert_any_call( + expected_node_query, + user="test_user", + collection="test_collection", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_with_defaults_when_not_provided(self, mock_graph_db): + """Test that defaults are used when user/collection not provided""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + # Query without user/collection fields + query = TriplesQueryRequest( + s=Value(value="http://example.com/s", is_uri=True), + p=None, + o=None + ) + + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + await processor.query_triples(query) + + # Verify defaults were used + calls = mock_driver.execute_query.call_args_list + for call in calls: + if 'user' in call.kwargs: + assert call.kwargs['user'] == 'default' + if 'collection' in call.kwargs: + assert call.kwargs['collection'] == 'default' + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_results_properly_converted_to_triples(self, mock_graph_db): + """Test that query results are properly converted to Triple objects""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=MagicMock()) + + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/s", is_uri=True), + p=None, + o=None + ) + + # Mock some results + mock_record1 = MagicMock() + mock_record1.data.return_value = { + "rel": "http://example.com/p1", + "dest": "literal_value" + } + + mock_record2 = MagicMock() + mock_record2.data.return_value = { + "rel": "http://example.com/p2", + "dest": "http://example.com/o" + } + + # Return results for literal query, empty for node query + mock_driver.execute_query.side_effect = [ + ([mock_record1], MagicMock(), MagicMock()), # Literal query + ([mock_record2], MagicMock(), MagicMock()) # Node query + ] + + result = await processor.query_triples(query) + + # Verify results are proper Triple objects + assert len(result) == 2 + + # First triple (literal object) + assert result[0].s.value == "http://example.com/s" + assert result[0].s.is_uri == True + assert result[0].p.value == "http://example.com/p1" + assert result[0].p.is_uri == True + assert result[0].o.value == "literal_value" + assert result[0].o.is_uri == False + + # Second triple (URI object) + assert result[1].s.value == "http://example.com/s" + assert result[1].s.is_uri == True + assert result[1].p.value == "http://example.com/p2" + assert result[1].p.is_uri == True + assert result[1].o.value == "http://example.com/o" + assert result[1].o.is_uri == True \ No newline at end of file diff --git a/tests/unit/test_storage/test_neo4j_user_collection_isolation.py b/tests/unit/test_storage/test_neo4j_user_collection_isolation.py new file mode 100644 index 00000000..b3d5c79a --- /dev/null +++ b/tests/unit/test_storage/test_neo4j_user_collection_isolation.py @@ -0,0 +1,470 @@ +""" +Tests for Neo4j user/collection isolation in triples storage and query +""" + +import pytest +from unittest.mock import MagicMock, patch, call + +from trustgraph.storage.triples.neo4j.write import Processor as StorageProcessor +from trustgraph.query.triples.neo4j.service import Processor as QueryProcessor +from trustgraph.schema import Triples, Triple, Value, Metadata +from trustgraph.schema import TriplesQueryRequest + + +class TestNeo4jUserCollectionIsolation: + """Test cases for Neo4j user/collection isolation functionality""" + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + def test_storage_creates_indexes_with_user_collection(self, mock_graph_db): + """Test that storage service creates compound indexes for user/collection""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = StorageProcessor(taskgroup=taskgroup_mock) + + # Verify both legacy and new compound indexes are created + expected_indexes = [ + "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", + "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", + "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)", + "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", + "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", + "CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", + "CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)" + ] + + # Check that all expected indexes were created + for expected_query in expected_indexes: + mock_session.run.assert_any_call(expected_query) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + @pytest.mark.asyncio + async def test_store_triples_with_user_collection(self, mock_graph_db): + """Test that triples are stored with user/collection properties""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = StorageProcessor(taskgroup=taskgroup_mock) + + # Create test message with user/collection metadata + metadata = Metadata( + id="test-id", + user="test_user", + collection="test_collection" + ) + + triple = Triple( + s=Value(value="http://example.com/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=Value(value="literal_value", is_uri=False) + ) + + message = Triples( + metadata=metadata, + triples=[triple] + ) + + # Mock execute_query to return summaries + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_driver.execute_query.return_value.summary = mock_summary + + await processor.store_triples(message) + + # Verify nodes and relationships were created with user/collection properties + expected_calls = [ + call( + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri="http://example.com/subject", + user="test_user", + collection="test_collection", + database_='neo4j' + ), + call( + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + value="literal_value", + user="test_user", + collection="test_collection", + database_='neo4j' + ), + call( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + src="http://example.com/subject", + dest="literal_value", + uri="http://example.com/predicate", + user="test_user", + collection="test_collection", + database_='neo4j' + ) + ] + + for expected_call in expected_calls: + mock_driver.execute_query.assert_any_call(*expected_call.args, **expected_call.kwargs) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + @pytest.mark.asyncio + async def test_store_triples_with_default_user_collection(self, mock_graph_db): + """Test that default user/collection are used when not provided""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = StorageProcessor(taskgroup=taskgroup_mock) + + # Create test message without user/collection + metadata = Metadata(id="test-id") + + triple = Triple( + s=Value(value="http://example.com/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=Value(value="http://example.com/object", is_uri=True) + ) + + message = Triples( + metadata=metadata, + triples=[triple] + ) + + # Mock execute_query + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_driver.execute_query.return_value.summary = mock_summary + + await processor.store_triples(message) + + # Verify defaults were used + mock_driver.execute_query.assert_any_call( + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri="http://example.com/subject", + user="default", + collection="default", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_filters_by_user_collection(self, mock_graph_db): + """Test that query service filters results by user/collection""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = QueryProcessor(taskgroup=MagicMock()) + + # Create test query + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=Value(value="http://example.com/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=None + ) + + # Mock query results + mock_records = [ + MagicMock(data=lambda: {"dest": "http://example.com/object1"}), + MagicMock(data=lambda: {"dest": "literal_value"}) + ] + + mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock()) + + result = await processor.query_triples(query) + + # Verify queries include user/collection filters + expected_literal_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN dest.value as dest" + ) + + expected_node_query = ( + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " + "RETURN dest.uri as dest" + ) + + # Check that queries were executed with user/collection parameters + calls = mock_driver.execute_query.call_args_list + assert any( + expected_literal_query in str(call) and + "user='test_user'" in str(call) and + "collection='test_collection'" in str(call) + for call in calls + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_with_default_user_collection(self, mock_graph_db): + """Test that query service uses defaults when user/collection not provided""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = QueryProcessor(taskgroup=MagicMock()) + + # Create test query without user/collection + query = TriplesQueryRequest( + s=None, + p=None, + o=None + ) + + # Mock empty results + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + result = await processor.query_triples(query) + + # Verify defaults were used in queries + calls = mock_driver.execute_query.call_args_list + assert any( + "user='default'" in str(call) and "collection='default'" in str(call) + for call in calls + ) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + @pytest.mark.asyncio + async def test_data_isolation_between_users(self, mock_graph_db): + """Test that data from different users is properly isolated""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = StorageProcessor(taskgroup=taskgroup_mock) + + # Create messages for different users + message_user1 = Triples( + metadata=Metadata(user="user1", collection="coll1"), + triples=[ + Triple( + s=Value(value="http://example.com/user1/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=Value(value="user1_data", is_uri=False) + ) + ] + ) + + message_user2 = Triples( + metadata=Metadata(user="user2", collection="coll2"), + triples=[ + Triple( + s=Value(value="http://example.com/user2/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=Value(value="user2_data", is_uri=False) + ) + ] + ) + + # Mock execute_query + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_driver.execute_query.return_value.summary = mock_summary + + # Store data for both users + await processor.store_triples(message_user1) + await processor.store_triples(message_user2) + + # Verify user1 data was stored with user1/coll1 + mock_driver.execute_query.assert_any_call( + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + value="user1_data", + user="user1", + collection="coll1", + database_='neo4j' + ) + + # Verify user2 data was stored with user2/coll2 + mock_driver.execute_query.assert_any_call( + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + value="user2_data", + user="user2", + collection="coll2", + database_='neo4j' + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_wildcard_query_respects_user_collection(self, mock_graph_db): + """Test that wildcard queries still filter by user/collection""" + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = QueryProcessor(taskgroup=MagicMock()) + + # Create wildcard query (all nulls) with user/collection + query = TriplesQueryRequest( + user="test_user", + collection="test_collection", + s=None, + p=None, + o=None + ) + + # Mock results + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + result = await processor.query_triples(query) + + # Verify wildcard queries include user/collection filters + wildcard_query = ( + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " + "RETURN src.uri as src, rel.uri as rel, dest.value as dest" + ) + + calls = mock_driver.execute_query.call_args_list + assert any( + wildcard_query in str(call) and + "user='test_user'" in str(call) and + "collection='test_collection'" in str(call) + for call in calls + ) + + def test_add_args_includes_neo4j_parameters(self): + """Test that add_args includes Neo4j-specific parameters""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'): + StorageProcessor.add_args(parser) + + args = parser.parse_args([]) + + assert hasattr(args, 'graph_host') + assert hasattr(args, 'username') + assert hasattr(args, 'password') + assert hasattr(args, 'database') + + # Check defaults + assert args.graph_host == 'bolt://neo4j:7687' + assert args.username == 'neo4j' + assert args.password == 'password' + assert args.database == 'neo4j' + + +class TestNeo4jUserCollectionRegression: + """Regression tests to ensure user/collection isolation prevents data leaks""" + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_regression_no_cross_user_data_access(self, mock_graph_db): + """ + Regression test: Ensure user1 cannot access user2's data + + This test guards against the bug where all users shared the same + Neo4j graph space, causing data contamination between users. + """ + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = QueryProcessor(taskgroup=MagicMock()) + + # User1 queries for all triples + query_user1 = TriplesQueryRequest( + user="user1", + collection="collection1", + s=None, p=None, o=None + ) + + # Mock that the database has data but none matching user1/collection1 + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) + + result = await processor.query_triples(query_user1) + + # Verify empty results (user1 cannot see other users' data) + assert len(result) == 0 + + # Verify the query included user/collection filters + calls = mock_driver.execute_query.call_args_list + for call in calls: + query_str = str(call) + if "MATCH" in query_str: + assert "user: $user" in query_str or "user='user1'" in query_str + assert "collection: $collection" in query_str or "collection='collection1'" in query_str + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + @pytest.mark.asyncio + async def test_regression_same_uri_different_users(self, mock_graph_db): + """ + Regression test: Same URI in different user contexts should create separate nodes + + This ensures that http://example.com/entity for user1 is completely separate + from http://example.com/entity for user2. + """ + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = StorageProcessor(taskgroup=taskgroup_mock) + + # Same URI for different users + shared_uri = "http://example.com/shared_entity" + + message_user1 = Triples( + metadata=Metadata(user="user1", collection="coll1"), + triples=[ + Triple( + s=Value(value=shared_uri, is_uri=True), + p=Value(value="http://example.com/p", is_uri=True), + o=Value(value="user1_value", is_uri=False) + ) + ] + ) + + message_user2 = Triples( + metadata=Metadata(user="user2", collection="coll2"), + triples=[ + Triple( + s=Value(value=shared_uri, is_uri=True), + p=Value(value="http://example.com/p", is_uri=True), + o=Value(value="user2_value", is_uri=False) + ) + ] + ) + + # Mock execute_query + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_driver.execute_query.return_value.summary = mock_summary + + await processor.store_triples(message_user1) + await processor.store_triples(message_user2) + + # Verify two separate nodes were created with same URI but different user/collection + user1_node_call = call( + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri=shared_uri, + user="user1", + collection="coll1", + database_='neo4j' + ) + + user2_node_call = call( + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri=shared_uri, + user="user2", + collection="coll2", + database_='neo4j' + ) + + mock_driver.execute_query.assert_has_calls([user1_node_call, user2_node_call], any_order=True) \ No newline at end of file diff --git a/tests/unit/test_storage/test_triples_neo4j_storage.py b/tests/unit/test_storage/test_triples_neo4j_storage.py index a84706ee..e600d227 100644 --- a/tests/unit/test_storage/test_triples_neo4j_storage.py +++ b/tests/unit/test_storage/test_triples_neo4j_storage.py @@ -62,14 +62,18 @@ class TestNeo4jStorageProcessor: processor = Processor(taskgroup=taskgroup_mock) - # Verify index creation queries were executed + # Verify index creation queries were executed (now includes 7 indexes) expected_calls = [ "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", - "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)" + "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)", + "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", + "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", + "CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", + "CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)" ] - assert mock_session.run.call_count == 3 + assert mock_session.run.call_count == 7 for expected_query in expected_calls: mock_session.run.assert_any_call(expected_query) @@ -88,8 +92,8 @@ class TestNeo4jStorageProcessor: # Should not raise exception - they should be caught and ignored processor = Processor(taskgroup=taskgroup_mock) - # Should have tried to create all 3 indexes despite exceptions - assert mock_session.run.call_count == 3 + # Should have tried to create all 7 indexes despite exceptions + assert mock_session.run.call_count == 7 @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') def test_create_node(self, mock_graph_db): @@ -111,11 +115,13 @@ class TestNeo4jStorageProcessor: processor = Processor(taskgroup=taskgroup_mock) # Test create_node - processor.create_node("http://example.com/node") + processor.create_node("http://example.com/node", "test_user", "test_collection") mock_driver.execute_query.assert_called_with( - "MERGE (n:Node {uri: $uri})", + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", uri="http://example.com/node", + user="test_user", + collection="test_collection", database_="neo4j" ) @@ -139,11 +145,13 @@ class TestNeo4jStorageProcessor: processor = Processor(taskgroup=taskgroup_mock) # Test create_literal - processor.create_literal("literal value") + processor.create_literal("literal value", "test_user", "test_collection") mock_driver.execute_query.assert_called_with( - "MERGE (n:Literal {value: $value})", + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", value="literal value", + user="test_user", + collection="test_collection", database_="neo4j" ) @@ -170,16 +178,20 @@ class TestNeo4jStorageProcessor: processor.relate_node( "http://example.com/subject", "http://example.com/predicate", - "http://example.com/object" + "http://example.com/object", + "test_user", + "test_collection" ) mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", src="http://example.com/subject", dest="http://example.com/object", uri="http://example.com/predicate", + user="test_user", + collection="test_collection", database_="neo4j" ) @@ -206,16 +218,20 @@ class TestNeo4jStorageProcessor: processor.relate_literal( "http://example.com/subject", "http://example.com/predicate", - "literal value" + "literal value", + "test_user", + "test_collection" ) mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", src="http://example.com/subject", dest="literal value", uri="http://example.com/predicate", + user="test_user", + collection="test_collection", database_="neo4j" ) @@ -246,9 +262,11 @@ class TestNeo4jStorageProcessor: triple.o.value = "http://example.com/object" triple.o.is_uri = True - # Create mock message + # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple] + mock_message.metadata.user = "test_user" + mock_message.metadata.collection = "test_collection" await processor.store_triples(mock_message) @@ -257,23 +275,25 @@ class TestNeo4jStorageProcessor: expected_calls = [ # Subject node creation ( - "MERGE (n:Node {uri: $uri})", - {"uri": "http://example.com/subject", "database_": "neo4j"} + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} ), # Object node creation ( - "MERGE (n:Node {uri: $uri})", - {"uri": "http://example.com/object", "database_": "neo4j"} + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} ), # Relationship creation ( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", { "src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", + "user": "test_user", + "collection": "test_collection", "database_": "neo4j" } ) @@ -310,9 +330,11 @@ class TestNeo4jStorageProcessor: triple.o.value = "literal value" triple.o.is_uri = False - # Create mock message + # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple] + mock_message.metadata.user = "test_user" + mock_message.metadata.collection = "test_collection" await processor.store_triples(mock_message) @@ -322,23 +344,25 @@ class TestNeo4jStorageProcessor: expected_calls = [ # Subject node creation ( - "MERGE (n:Node {uri: $uri})", - {"uri": "http://example.com/subject", "database_": "neo4j"} + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} ), # Literal creation ( - "MERGE (n:Literal {value: $value})", - {"value": "literal value", "database_": "neo4j"} + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + {"value": "literal value", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} ), # Relationship creation ( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", { "src": "http://example.com/subject", "dest": "literal value", "uri": "http://example.com/predicate", + "user": "test_user", + "collection": "test_collection", "database_": "neo4j" } ) @@ -381,9 +405,11 @@ class TestNeo4jStorageProcessor: triple2.o.value = "literal value" triple2.o.is_uri = False - # Create mock message + # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple1, triple2] + mock_message.metadata.user = "test_user" + mock_message.metadata.collection = "test_collection" await processor.store_triples(mock_message) @@ -405,9 +431,11 @@ class TestNeo4jStorageProcessor: processor = Processor(taskgroup=taskgroup_mock) - # Create mock message with empty triples + # Create mock message with empty triples and metadata mock_message = MagicMock() mock_message.triples = [] + mock_message.metadata.user = "test_user" + mock_message.metadata.collection = "test_collection" await processor.store_triples(mock_message) @@ -521,28 +549,36 @@ class TestNeo4jStorageProcessor: mock_message = MagicMock() mock_message.triples = [triple] + mock_message.metadata.user = "test_user" + mock_message.metadata.collection = "test_collection" await processor.store_triples(mock_message) # Verify the triple was processed with special characters preserved mock_driver.execute_query.assert_any_call( - "MERGE (n:Node {uri: $uri})", + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", uri="http://example.com/subject with spaces", + user="test_user", + collection="test_collection", database_="neo4j" ) mock_driver.execute_query.assert_any_call( - "MERGE (n:Literal {value: $value})", + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", value='literal with "quotes" and unicode: ñáéíóú', + user="test_user", + collection="test_collection", database_="neo4j" ) mock_driver.execute_query.assert_any_call( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", src="http://example.com/subject with spaces", dest='literal with "quotes" and unicode: ñáéíóú', uri="http://example.com/predicate:with/symbols", + user="test_user", + collection="test_collection", database_="neo4j" ) diff --git a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py index 69e10d62..0e84d733 100755 --- a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py +++ b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py @@ -55,6 +55,10 @@ class Processor(TriplesQueryService): try: + # Extract user and collection, use defaults if not provided + user = query.user if query.user else "default" + collection = query.collection if query.collection else "default" + triples = [] if query.s is not None: @@ -64,9 +68,12 @@ class Processor(TriplesQueryService): # SPO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN $src as src", src=query.s.value, rel=query.p.value, value=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -74,9 +81,12 @@ class Processor(TriplesQueryService): triples.append((query.s.value, query.p.value, query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Node {uri: $uri, user: $user, collection: $collection}) " "RETURN $src as src", src=query.s.value, rel=query.p.value, uri=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -88,9 +98,12 @@ class Processor(TriplesQueryService): # SP records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " "RETURN dest.value as dest", src=query.s.value, rel=query.p.value, + user=user, collection=collection, database_=self.db, ) @@ -99,9 +112,12 @@ class Processor(TriplesQueryService): triples.append((query.s.value, query.p.value, data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " "RETURN dest.uri as dest", src=query.s.value, rel=query.p.value, + user=user, collection=collection, database_=self.db, ) @@ -116,9 +132,12 @@ class Processor(TriplesQueryService): # SO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN rel.uri as rel", src=query.s.value, value=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -127,9 +146,12 @@ class Processor(TriplesQueryService): triples.append((query.s.value, data["rel"], query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {uri: $uri, user: $user, collection: $collection}) " "RETURN rel.uri as rel", src=query.s.value, uri=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -142,9 +164,12 @@ class Processor(TriplesQueryService): # S records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest", src=query.s.value, + user=user, collection=collection, database_=self.db, ) @@ -153,9 +178,12 @@ class Processor(TriplesQueryService): triples.append((query.s.value, data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) " + "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " "RETURN rel.uri as rel, dest.uri as dest", src=query.s.value, + user=user, collection=collection, database_=self.db, ) @@ -173,9 +201,12 @@ class Processor(TriplesQueryService): # PO records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN src.uri as src", uri=query.p.value, value=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -184,9 +215,12 @@ class Processor(TriplesQueryService): triples.append((data["src"], query.p.value, query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Node {uri: $dest, user: $user, collection: $collection}) " "RETURN src.uri as src", uri=query.p.value, dest=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -199,9 +233,12 @@ class Processor(TriplesQueryService): # P records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " "RETURN src.uri as src, dest.value as dest", uri=query.p.value, + user=user, collection=collection, database_=self.db, ) @@ -210,9 +247,12 @@ class Processor(TriplesQueryService): triples.append((data["src"], query.p.value, data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " "RETURN src.uri as src, dest.uri as dest", uri=query.p.value, + user=user, collection=collection, database_=self.db, ) @@ -227,9 +267,12 @@ class Processor(TriplesQueryService): # O records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {value: $value, user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel", value=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -238,9 +281,12 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {uri: $uri, user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel", uri=query.o.value, + user=user, collection=collection, database_=self.db, ) @@ -253,8 +299,11 @@ class Processor(TriplesQueryService): # * records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Literal) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Literal {user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest", + user=user, collection=collection, database_=self.db, ) @@ -263,8 +312,11 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Node) " + "MATCH (src:Node {user: $user, collection: $collection})-" + "[rel:Rel {user: $user, collection: $collection}]->" + "(dest:Node {user: $user, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest", + user=user, collection=collection, database_=self.db, ) diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index e1913c14..c33478eb 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -61,6 +61,7 @@ class Processor(TriplesStoreService): logger.info("Create indexes...") + # Legacy indexes for backwards compatibility try: session.run( "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", @@ -88,15 +89,50 @@ class Processor(TriplesStoreService): # Maybe index already exists logger.warning("Index create failure ignored") + # New compound indexes for user/collection filtering + try: + session.run( + "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", + ) + except Exception as e: + logger.warning(f"Compound index create failure: {e}") + logger.warning("Index create failure ignored") + + try: + session.run( + "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", + ) + except Exception as e: + logger.warning(f"Compound index create failure: {e}") + logger.warning("Index create failure ignored") + + # Note: Neo4j doesn't support compound indexes on relationships in all versions + # Try to create individual indexes on relationship properties + try: + session.run( + "CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", + ) + except Exception as e: + logger.warning(f"Relationship index create failure: {e}") + logger.warning("Index create failure ignored") + + try: + session.run( + "CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)", + ) + except Exception as e: + logger.warning(f"Relationship index create failure: {e}") + logger.warning("Index create failure ignored") + logger.info("Index creation done") - def create_node(self, uri): + def create_node(self, uri, user, collection): - logger.debug(f"Create node {uri}") + logger.debug(f"Create node {uri} for user={user}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Node {uri: $uri})", - uri=uri, + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + uri=uri, user=user, collection=collection, database_=self.db, ).summary @@ -105,13 +141,13 @@ class Processor(TriplesStoreService): time=summary.result_available_after )) - def create_literal(self, value): + def create_literal(self, value, user, collection): - logger.debug(f"Create literal {value}") + logger.debug(f"Create literal {value} for user={user}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Literal {value: $value})", - value=value, + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + value=value, user=user, collection=collection, database_=self.db, ).summary @@ -120,15 +156,15 @@ class Processor(TriplesStoreService): time=summary.result_available_after )) - def relate_node(self, src, uri, dest): + def relate_node(self, src, uri, dest, user, collection): - logger.debug(f"Create node rel {src} {uri} {dest}") + logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", - src=src, dest=dest, uri=uri, + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, user=user, collection=collection, database_=self.db, ).summary @@ -137,15 +173,15 @@ class Processor(TriplesStoreService): time=summary.result_available_after )) - def relate_literal(self, src, uri, dest): + def relate_literal(self, src, uri, dest, user, collection): - logger.debug(f"Create literal rel {src} {uri} {dest}") + logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", - src=src, dest=dest, uri=uri, + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, user=user, collection=collection, database_=self.db, ).summary @@ -156,16 +192,20 @@ class Processor(TriplesStoreService): async def store_triples(self, message): + # Extract user and collection from metadata + user = message.metadata.user if message.metadata.user else "default" + collection = message.metadata.collection if message.metadata.collection else "default" + for t in message.triples: - self.create_node(t.s.value) + self.create_node(t.s.value, user, collection) if t.o.is_uri: - self.create_node(t.o.value) - self.relate_node(t.s.value, t.p.value, t.o.value) + self.create_node(t.o.value, user, collection) + self.relate_node(t.s.value, t.p.value, t.o.value, user, collection) else: - self.create_literal(t.o.value) - self.relate_literal(t.s.value, t.p.value, t.o.value) + self.create_literal(t.o.value, user, collection) + self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection) @staticmethod def add_args(parser):