diff --git a/docs/tech-specs/cassandra-performance-refactor.md b/docs/tech-specs/cassandra-performance-refactor.md new file mode 100644 index 00000000..4ae49a68 --- /dev/null +++ b/docs/tech-specs/cassandra-performance-refactor.md @@ -0,0 +1,582 @@ +# Tech Spec: Cassandra Knowledge Base Performance Refactor + +**Status:** Draft +**Author:** Assistant +**Date:** 2025-09-18 + +## Overview + +This specification addresses performance issues in the TrustGraph Cassandra knowledge base implementation and proposes optimizations for RDF triple storage and querying. + +## Current Implementation + +### Schema Design + +The current implementation uses a single table design in `trustgraph-flow/trustgraph/direct/cassandra_kg.py`: + +```sql +CREATE TABLE triples ( + collection text, + s text, + p text, + o text, + PRIMARY KEY (collection, s, p, o) +); +``` + +**Secondary Indexes:** +- `triples_s` ON `s` (subject) +- `triples_p` ON `p` (predicate) +- `triples_o` ON `o` (object) + +### Query Patterns + +The current implementation supports 8 distinct query patterns: + +1. **get_all(collection, limit=50)** - Retrieve all triples for a collection + ```sql + SELECT s, p, o FROM triples WHERE collection = ? LIMIT 50 + ``` + +2. **get_s(collection, s, limit=10)** - Query by subject + ```sql + SELECT p, o FROM triples WHERE collection = ? AND s = ? LIMIT 10 + ``` + +3. **get_p(collection, p, limit=10)** - Query by predicate + ```sql + SELECT s, o FROM triples WHERE collection = ? AND p = ? LIMIT 10 + ``` + +4. **get_o(collection, o, limit=10)** - Query by object + ```sql + SELECT s, p FROM triples WHERE collection = ? AND o = ? LIMIT 10 + ``` + +5. **get_sp(collection, s, p, limit=10)** - Query by subject + predicate + ```sql + SELECT o FROM triples WHERE collection = ? AND s = ? AND p = ? LIMIT 10 + ``` + +6. **get_po(collection, p, o, limit=10)** - Query by predicate + object ⚠️ + ```sql + SELECT s FROM triples WHERE collection = ? AND p = ? AND o = ? LIMIT 10 ALLOW FILTERING + ``` + +7. **get_os(collection, o, s, limit=10)** - Query by object + subject ⚠️ + ```sql + SELECT p FROM triples WHERE collection = ? AND o = ? AND s = ? LIMIT 10 ALLOW FILTERING + ``` + +8. **get_spo(collection, s, p, o, limit=10)** - Exact triple match + ```sql + SELECT s as x FROM triples WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT 10 + ``` + +### Current Architecture + +**File: `trustgraph-flow/trustgraph/direct/cassandra_kg.py`** +- Single `KnowledgeGraph` class handling all operations +- Connection pooling through global `_active_clusters` list +- Fixed table name: `"triples"` +- Keyspace per user model +- SimpleStrategy replication with factor 1 + +**Integration Points:** +- **Write Path:** `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py` +- **Query Path:** `trustgraph-flow/trustgraph/query/triples/cassandra/service.py` +- **Knowledge Store:** `trustgraph-flow/trustgraph/tables/knowledge.py` + +## Performance Issues Identified + +### Schema-Level Issues + +1. **Inefficient Primary Key Design** + - Current: `PRIMARY KEY (collection, s, p, o)` + - Results in poor clustering for common access patterns + - Forces expensive secondary index usage + +2. **Secondary Index Overuse** ⚠️ + - Three secondary indexes on high-cardinality columns (s, p, o) + - Secondary indexes in Cassandra are expensive and don't scale well + - Queries 6 & 7 require `ALLOW FILTERING` indicating poor data modeling + +3. **Hot Partition Risk** + - Single partition key `collection` can create hot partitions + - Large collections will concentrate on single nodes + - No distribution strategy for load balancing + +### Query-Level Issues + +1. **ALLOW FILTERING Usage** ⚠️ + - Two query types (get_po, get_os) require `ALLOW FILTERING` + - These queries scan multiple partitions and are extremely expensive + - Performance degrades linearly with data size + +2. **Inefficient Access Patterns** + - No optimization for common RDF query patterns + - Missing compound indexes for frequent query combinations + - No consideration for graph traversal patterns + +3. **Lack of Query Optimization** + - No prepared statements caching + - No query hints or optimization strategies + - No consideration for pagination beyond simple LIMIT + +## Problem Statement + +The current Cassandra knowledge base implementation has two critical performance bottlenecks: + +### 1. Inefficient get_po Query Performance + +The `get_po(collection, p, o)` query is extremely inefficient due to requiring `ALLOW FILTERING`: + +```sql +SELECT s FROM triples WHERE collection = ? AND p = ? AND o = ? LIMIT 10 ALLOW FILTERING +``` + +**Why this is problematic:** +- `ALLOW FILTERING` forces Cassandra to scan all partitions within the collection +- Performance degrades linearly with data size +- This is a common RDF query pattern (finding subjects that have a specific predicate-object relationship) +- Creates significant load on the cluster as data grows + +### 2. Poor Clustering Strategy + +The current primary key `PRIMARY KEY (collection, s, p, o)` provides minimal clustering benefits: + +**Issues with current clustering:** +- `collection` as partition key doesn't distribute data effectively +- Most collections contain diverse data making clustering ineffective +- No consideration for common access patterns in RDF queries +- Large collections create hot partitions on single nodes +- Clustering columns (s, p, o) don't optimize for typical graph traversal patterns + +**Impact:** +- Queries don't benefit from data locality +- Poor cache utilization +- Uneven load distribution across cluster nodes +- Scalability bottlenecks as collections grow + +## Proposed Solution: Multi-Table Denormalization Strategy + +### Overview + +Replace the single `triples` table with three purpose-built tables, each optimized for specific query patterns. This eliminates the need for secondary indexes and ALLOW FILTERING while providing optimal performance for all query types. + +### New Schema Design + +**Table 1: Subject-Centric Queries** +```sql +CREATE TABLE triples_by_subject ( + collection text, + s text, + p text, + o text, + PRIMARY KEY ((collection, s), p, o) +); +``` +- **Optimizes:** get_s, get_sp, get_spo, get_os +- **Partition Key:** (collection, s) - Better distribution than collection alone +- **Clustering:** (p, o) - Enables efficient predicate/object lookups for a subject + +**Table 2: Predicate-Object Queries** +```sql +CREATE TABLE triples_by_po ( + collection text, + p text, + o text, + s text, + PRIMARY KEY ((collection, p), o, s) +); +``` +- **Optimizes:** get_p, get_po (eliminates ALLOW FILTERING!) +- **Partition Key:** (collection, p) - Direct access by predicate +- **Clustering:** (o, s) - Efficient object-subject traversal + +**Table 3: Object-Centric Queries** +```sql +CREATE TABLE triples_by_object ( + collection text, + o text, + s text, + p text, + PRIMARY KEY ((collection, o), s, p) +); +``` +- **Optimizes:** get_o, get_os +- **Partition Key:** (collection, o) - Direct access by object +- **Clustering:** (s, p) - Efficient subject-predicate traversal + +### Query Mapping + +| Original Query | Target Table | Performance Improvement | +|----------------|-------------|------------------------| +| get_all(collection) | triples_by_subject | Token-based pagination | +| get_s(collection, s) | triples_by_subject | Direct partition access | +| get_p(collection, p) | triples_by_po | Direct partition access | +| get_o(collection, o) | triples_by_object | Direct partition access | +| get_sp(collection, s, p) | triples_by_subject | Partition + clustering | +| get_po(collection, p, o) | triples_by_po | **No more ALLOW FILTERING!** | +| get_os(collection, o, s) | triples_by_subject | Partition + clustering | +| get_spo(collection, s, p, o) | triples_by_subject | Exact key lookup | + +### Benefits + +1. **Eliminates ALLOW FILTERING** - Every query has an optimal access path +2. **No Secondary Indexes** - Each table IS the index for its query pattern +3. **Better Data Distribution** - Composite partition keys spread load effectively +4. **Predictable Performance** - Query time proportional to result size, not total data +5. **Leverages Cassandra Strengths** - Designed for Cassandra's architecture + +## Implementation Plan + +### Files Requiring Changes + +#### Primary Implementation File + +**`trustgraph-flow/trustgraph/direct/cassandra_kg.py`** - Complete rewrite required + +**Current Methods to Refactor:** +```python +# Schema initialization +def init(self) -> None # Replace single table with three tables + +# Insert operations +def insert(self, collection, s, p, o) -> None # Write to all three tables + +# Query operations (API unchanged, implementation optimized) +def get_all(self, collection, limit=50) # Use triples_by_subject +def get_s(self, collection, s, limit=10) # Use triples_by_subject +def get_p(self, collection, p, limit=10) # Use triples_by_po +def get_o(self, collection, o, limit=10) # Use triples_by_object +def get_sp(self, collection, s, p, limit=10) # Use triples_by_subject +def get_po(self, collection, p, o, limit=10) # Use triples_by_po (NO ALLOW FILTERING!) +def get_os(self, collection, o, s, limit=10) # Use triples_by_subject +def get_spo(self, collection, s, p, o, limit=10) # Use triples_by_subject + +# Collection management +def delete_collection(self, collection) -> None # Delete from all three tables +``` + +#### Integration Files (No Logic Changes Required) + +**`trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`** +- No changes needed - uses existing KnowledgeGraph API +- Benefits automatically from performance improvements + +**`trustgraph-flow/trustgraph/query/triples/cassandra/service.py`** +- No changes needed - uses existing KnowledgeGraph API +- Benefits automatically from performance improvements + +### Test Files Requiring Updates + +#### Unit Tests +**`tests/unit/test_storage/test_triples_cassandra_storage.py`** +- Update test expectations for schema changes +- Add tests for multi-table consistency +- Verify no ALLOW FILTERING in query plans + +**`tests/unit/test_query/test_triples_cassandra_query.py`** +- Update performance assertions +- Test all 8 query patterns against new tables +- Verify query routing to correct tables + +#### Integration Tests +**`tests/integration/test_cassandra_integration.py`** +- End-to-end testing with new schema +- Performance benchmarking comparisons +- Data consistency verification across tables + +**`tests/unit/test_storage/test_cassandra_config_integration.py`** +- Update schema validation tests +- Test migration scenarios + +### Implementation Strategy + +#### Phase 1: Schema and Core Methods +1. **Rewrite `init()` method** - Create three tables instead of one +2. **Rewrite `insert()` method** - Batch writes to all three tables +3. **Implement prepared statements** - For optimal performance +4. **Add table routing logic** - Direct queries to optimal tables + +#### Phase 2: Query Method Optimization +1. **Rewrite each get_* method** to use optimal table +2. **Remove all ALLOW FILTERING** usage +3. **Implement efficient clustering key usage** +4. **Add query performance logging** + +#### Phase 3: Collection Management +1. **Update `delete_collection()`** - Remove from all three tables +2. **Add consistency verification** - Ensure all tables stay in sync +3. **Implement batch operations** - For atomic multi-table operations + +### Key Implementation Details + +#### Batch Write Strategy +```python +def insert(self, collection, s, p, o): + batch = BatchStatement() + + # Insert into all three tables + batch.add(SimpleStatement( + "INSERT INTO triples_by_subject (collection, s, p, o) VALUES (?, ?, ?, ?)" + ), (collection, s, p, o)) + + batch.add(SimpleStatement( + "INSERT INTO triples_by_po (collection, p, o, s) VALUES (?, ?, ?, ?)" + ), (collection, p, o, s)) + + batch.add(SimpleStatement( + "INSERT INTO triples_by_object (collection, o, s, p) VALUES (?, ?, ?, ?)" + ), (collection, o, s, p)) + + self.session.execute(batch) +``` + +#### Query Routing Logic +```python +def get_po(self, collection, p, o, limit=10): + # Route to triples_by_po table - NO ALLOW FILTERING! + return self.session.execute( + "SELECT s FROM triples_by_po WHERE collection = ? AND p = ? AND o = ? LIMIT ?", + (collection, p, o, limit) + ) +``` + +#### Prepared Statement Optimization +```python +def prepare_statements(self): + # Cache prepared statements for better performance + self.insert_subject_stmt = self.session.prepare( + "INSERT INTO triples_by_subject (collection, s, p, o) VALUES (?, ?, ?, ?)" + ) + self.insert_po_stmt = self.session.prepare( + "INSERT INTO triples_by_po (collection, p, o, s) VALUES (?, ?, ?, ?)" + ) + # ... etc for all tables and queries +``` + +## Migration Strategy + +### Data Migration Approach + +#### Option 1: Blue-Green Deployment (Recommended) +1. **Deploy new schema alongside existing** - Use different table names temporarily +2. **Dual-write period** - Write to both old and new schemas during transition +3. **Background migration** - Copy existing data to new tables +4. **Switch reads** - Route queries to new tables once data is migrated +5. **Drop old tables** - After verification period + +#### Option 2: In-Place Migration +1. **Schema addition** - Create new tables in existing keyspace +2. **Data migration script** - Batch copy from old table to new tables +3. **Application update** - Deploy new code after migration completes +4. **Old table cleanup** - Remove old table and indexes + +### Backward Compatibility + +#### Deployment Strategy +```python +# Environment variable to control table usage during migration +USE_LEGACY_TABLES = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true' + +class KnowledgeGraph: + def __init__(self, ...): + if USE_LEGACY_TABLES: + self.init_legacy_schema() + else: + self.init_optimized_schema() +``` + +#### Migration Script +```python +def migrate_data(): + # Read from old table + old_triples = session.execute("SELECT collection, s, p, o FROM triples") + + # Batch write to new tables + for batch in batched(old_triples, 100): + batch_stmt = BatchStatement() + for row in batch: + # Add to all three new tables + batch_stmt.add(insert_subject_stmt, row) + batch_stmt.add(insert_po_stmt, (row.collection, row.p, row.o, row.s)) + batch_stmt.add(insert_object_stmt, (row.collection, row.o, row.s, row.p)) + session.execute(batch_stmt) +``` + +### Validation Strategy + +#### Data Consistency Checks +```python +def validate_migration(): + # Count total records in old vs new tables + old_count = session.execute("SELECT COUNT(*) FROM triples WHERE collection = ?", (collection,)) + new_count = session.execute("SELECT COUNT(*) FROM triples_by_subject WHERE collection = ?", (collection,)) + + assert old_count == new_count, f"Record count mismatch: {old_count} vs {new_count}" + + # Spot check random samples + sample_queries = generate_test_queries() + for query in sample_queries: + old_result = execute_legacy_query(query) + new_result = execute_optimized_query(query) + assert old_result == new_result, f"Query results differ for {query}" +``` + +## Testing Strategy + +### Performance Testing + +#### Benchmark Scenarios +1. **Query Performance Comparison** + - Before/after performance metrics for all 8 query types + - Focus on get_po performance improvement (eliminate ALLOW FILTERING) + - Measure query latency under various data sizes + +2. **Load Testing** + - Concurrent query execution + - Write throughput with batch operations + - Memory and CPU utilization + +3. **Scalability Testing** + - Performance with increasing collection sizes + - Multi-collection query distribution + - Cluster node utilization + +#### Test Data Sets +- **Small:** 10K triples per collection +- **Medium:** 100K triples per collection +- **Large:** 1M+ triples per collection +- **Multiple collections:** Test partition distribution + +### Functional Testing + +#### Unit Test Updates +```python +# Example test structure for new implementation +class TestCassandraKGPerformance: + def test_get_po_no_allow_filtering(self): + # Verify get_po queries don't use ALLOW FILTERING + with patch('cassandra.cluster.Session.execute') as mock_execute: + kg.get_po('test_collection', 'predicate', 'object') + executed_query = mock_execute.call_args[0][0] + assert 'ALLOW FILTERING' not in executed_query + + def test_multi_table_consistency(self): + # Verify all tables stay in sync + kg.insert('test', 's1', 'p1', 'o1') + + # Check all tables contain the triple + assert_triple_exists('triples_by_subject', 'test', 's1', 'p1', 'o1') + assert_triple_exists('triples_by_po', 'test', 'p1', 'o1', 's1') + assert_triple_exists('triples_by_object', 'test', 'o1', 's1', 'p1') +``` + +#### Integration Test Updates +```python +class TestCassandraIntegration: + def test_query_performance_regression(self): + # Ensure new implementation is faster than old + old_time = benchmark_legacy_get_po() + new_time = benchmark_optimized_get_po() + assert new_time < old_time * 0.5 # At least 50% improvement + + def test_end_to_end_workflow(self): + # Test complete write -> query -> delete cycle + # Verify no performance degradation in integration +``` + +### Rollback Plan + +#### Quick Rollback Strategy +1. **Environment variable toggle** - Switch back to legacy tables immediately +2. **Keep legacy tables** - Don't drop until performance is proven +3. **Monitoring alerts** - Automated rollback triggers based on error rates/latency + +#### Rollback Validation +```python +def rollback_to_legacy(): + # Set environment variable + os.environ['CASSANDRA_USE_LEGACY'] = 'true' + + # Restart services to pick up change + restart_cassandra_services() + + # Validate functionality + run_smoke_tests() +``` + +## Risks and Considerations + +### Performance Risks +- **Write latency increase** - 3x write operations per insert +- **Storage overhead** - 3x storage requirement +- **Batch write failures** - Need proper error handling + +### Operational Risks +- **Migration complexity** - Data migration for large datasets +- **Consistency challenges** - Ensuring all tables stay synchronized +- **Monitoring gaps** - Need new metrics for multi-table operations + +### Mitigation Strategies +1. **Gradual rollout** - Start with small collections +2. **Comprehensive monitoring** - Track all performance metrics +3. **Automated validation** - Continuous consistency checking +4. **Quick rollback capability** - Environment-based table selection + +## Success Criteria + +### Performance Improvements +- [ ] **Eliminate ALLOW FILTERING** - get_po and get_os queries run without filtering +- [ ] **Query latency reduction** - 50%+ improvement in query response times +- [ ] **Better load distribution** - No hot partitions, even load across cluster nodes +- [ ] **Scalable performance** - Query time proportional to result size, not total data + +### Functional Requirements +- [ ] **API compatibility** - All existing code continues to work unchanged +- [ ] **Data consistency** - All three tables remain synchronized +- [ ] **Zero data loss** - Migration preserves all existing triples +- [ ] **Backward compatibility** - Ability to rollback to legacy schema + +### Operational Requirements +- [ ] **Safe migration** - Blue-green deployment with rollback capability +- [ ] **Monitoring coverage** - Comprehensive metrics for multi-table operations +- [ ] **Test coverage** - All query patterns tested with performance benchmarks +- [ ] **Documentation** - Updated deployment and operational procedures + +## Timeline + +### Phase 1: Implementation +- [ ] Rewrite `cassandra_kg.py` with multi-table schema +- [ ] Implement batch write operations +- [ ] Add prepared statement optimization +- [ ] Update unit tests + +### Phase 2: Integration Testing +- [ ] Update integration tests +- [ ] Performance benchmarking +- [ ] Load testing with realistic data volumes +- [ ] Validation scripts for data consistency + +### Phase 3: Migration Planning +- [ ] Blue-green deployment scripts +- [ ] Data migration tools +- [ ] Monitoring dashboard updates +- [ ] Rollback procedures + +### Phase 4: Production Deployment +- [ ] Staged rollout to production +- [ ] Performance monitoring and validation +- [ ] Legacy table cleanup +- [ ] Documentation updates + +## Conclusion + +This multi-table denormalization strategy directly addresses the two critical performance bottlenecks: + +1. **Eliminates expensive ALLOW FILTERING** by providing optimal table structures for each query pattern +2. **Improves clustering effectiveness** through composite partition keys that distribute load properly + +The approach leverages Cassandra's strengths while maintaining complete API compatibility, ensuring existing code benefits automatically from the performance improvements. diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index 72871456..f5be4961 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -534,4 +534,203 @@ class TestCassandraQueryProcessor: assert len(result) == 2 assert result[0].o.value == 'object1' - assert result[1].o.value == 'object2' \ No newline at end of file + assert result[1].o.value == 'object2' + + +class TestCassandraQueryPerformanceOptimizations: + """Test cases for multi-table performance optimizations in query service""" + + @pytest.mark.asyncio + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') + async def test_get_po_query_optimization(self, mock_trustgraph): + """Test that get_po queries use optimized table (no ALLOW FILTERING)""" + from trustgraph.schema import TriplesQueryRequest, Value + + mock_tg_instance = MagicMock() + mock_trustgraph.return_value = mock_tg_instance + + mock_result = MagicMock() + mock_result.s = 'result_subject' + mock_tg_instance.get_po.return_value = [mock_result] + + processor = Processor(taskgroup=MagicMock()) + + # PO query pattern (predicate + object, find subjects) + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=Value(value='test_predicate', is_uri=False), + o=Value(value='test_object', is_uri=False), + limit=50 + ) + + result = await processor.query_triples(query) + + # Verify get_po was called (should use optimized po_table) + mock_tg_instance.get_po.assert_called_once_with( + 'test_collection', 'test_predicate', 'test_object', limit=50 + ) + + assert len(result) == 1 + assert result[0].s.value == 'result_subject' + assert result[0].p.value == 'test_predicate' + assert result[0].o.value == 'test_object' + + @pytest.mark.asyncio + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') + async def test_get_os_query_optimization(self, mock_trustgraph): + """Test that get_os queries use optimized table (no ALLOW FILTERING)""" + from trustgraph.schema import TriplesQueryRequest, Value + + mock_tg_instance = MagicMock() + mock_trustgraph.return_value = mock_tg_instance + + mock_result = MagicMock() + mock_result.p = 'result_predicate' + mock_tg_instance.get_os.return_value = [mock_result] + + processor = Processor(taskgroup=MagicMock()) + + # OS query pattern (object + subject, find predicates) + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value='test_subject', is_uri=False), + p=None, + o=Value(value='test_object', is_uri=False), + limit=25 + ) + + result = await processor.query_triples(query) + + # Verify get_os was called (should use optimized subject_table with clustering) + mock_tg_instance.get_os.assert_called_once_with( + 'test_collection', 'test_object', 'test_subject', limit=25 + ) + + assert len(result) == 1 + assert result[0].s.value == 'test_subject' + assert result[0].p.value == 'result_predicate' + assert result[0].o.value == 'test_object' + + @pytest.mark.asyncio + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') + async def test_all_query_patterns_use_correct_tables(self, mock_trustgraph): + """Test that all query patterns route to their optimal tables""" + from trustgraph.schema import TriplesQueryRequest, Value + + mock_tg_instance = MagicMock() + mock_trustgraph.return_value = mock_tg_instance + + # Mock empty results for all queries + mock_tg_instance.get_all.return_value = [] + mock_tg_instance.get_s.return_value = [] + mock_tg_instance.get_p.return_value = [] + mock_tg_instance.get_o.return_value = [] + mock_tg_instance.get_sp.return_value = [] + mock_tg_instance.get_po.return_value = [] + mock_tg_instance.get_os.return_value = [] + mock_tg_instance.get_spo.return_value = [] + + processor = Processor(taskgroup=MagicMock()) + + # Test each query pattern + test_patterns = [ + # (s, p, o, expected_method) + (None, None, None, 'get_all'), # All triples + ('s1', None, None, 'get_s'), # Subject only + (None, 'p1', None, 'get_p'), # Predicate only + (None, None, 'o1', 'get_o'), # Object only + ('s1', 'p1', None, 'get_sp'), # Subject + Predicate + (None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) + ('s1', None, 'o1', 'get_os'), # Object + Subject + ('s1', 'p1', 'o1', 'get_spo'), # All three + ] + + for s, p, o, expected_method in test_patterns: + # Reset mock call counts + mock_tg_instance.reset_mock() + + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value=s, is_uri=False) if s else None, + p=Value(value=p, is_uri=False) if p else None, + o=Value(value=o, is_uri=False) if o else None, + limit=10 + ) + + await processor.query_triples(query) + + # Verify the correct method was called + method = getattr(mock_tg_instance, expected_method) + assert method.called, f"Expected {expected_method} to be called for pattern s={s}, p={p}, o={o}" + + def test_legacy_vs_optimized_mode_configuration(self): + """Test that environment variable controls query optimization mode""" + taskgroup_mock = MagicMock() + + # Test optimized mode (default) + with patch.dict('os.environ', {}, clear=True): + processor = Processor(taskgroup=taskgroup_mock) + # Mode is determined in KnowledgeGraph initialization + + # Test legacy mode + with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}): + processor = Processor(taskgroup=taskgroup_mock) + # Mode is determined in KnowledgeGraph initialization + + # Test explicit optimized mode + with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}): + processor = Processor(taskgroup=taskgroup_mock) + # Mode is determined in KnowledgeGraph initialization + + @pytest.mark.asyncio + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') + async def test_performance_critical_po_query_no_filtering(self, mock_trustgraph): + """Test the performance-critical PO query that eliminates ALLOW FILTERING""" + from trustgraph.schema import TriplesQueryRequest, Value + + mock_tg_instance = MagicMock() + mock_trustgraph.return_value = mock_tg_instance + + # Mock multiple subjects for the same predicate-object pair + mock_results = [] + for i in range(5): + mock_result = MagicMock() + mock_result.s = f'subject_{i}' + mock_results.append(mock_result) + + mock_tg_instance.get_po.return_value = mock_results + + processor = Processor(taskgroup=MagicMock()) + + # This is the query pattern that was slow with ALLOW FILTERING + query = TriplesQueryRequest( + user='large_dataset_user', + collection='massive_collection', + s=None, + p=Value(value='http://www.w3.org/1999/02/22-rdf-syntax-ns#type', is_uri=True), + o=Value(value='http://example.com/Person', is_uri=True), + limit=1000 + ) + + result = await processor.query_triples(query) + + # Verify optimized get_po was used (no ALLOW FILTERING needed!) + mock_tg_instance.get_po.assert_called_once_with( + 'massive_collection', + 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type', + 'http://example.com/Person', + limit=1000 + ) + + # Verify all results were returned + assert len(result) == 5 + for i, triple in enumerate(result): + assert triple.s.value == f'subject_{i}' + assert triple.p.value == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' + assert triple.p.is_uri is True + assert triple.o.value == 'http://example.com/Person' + assert triple.o.is_uri is True \ No newline at end of file diff --git a/tests/unit/test_storage/test_triples_cassandra_storage.py b/tests/unit/test_storage/test_triples_cassandra_storage.py index a6a6a539..54ea1a95 100644 --- a/tests/unit/test_storage/test_triples_cassandra_storage.py +++ b/tests/unit/test_storage/test_triples_cassandra_storage.py @@ -415,4 +415,99 @@ class TestCassandraStorageProcessor: # Table should remain unchanged since self.table = table happens after try/except assert processor.table == ('old_user', 'old_collection') # TrustGraph should be set to None though - assert processor.tg is None \ No newline at end of file + assert processor.tg is None + + +class TestCassandraPerformanceOptimizations: + """Test cases for multi-table performance optimizations""" + + @pytest.mark.asyncio + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') + async def test_legacy_mode_uses_single_table(self, mock_trustgraph): + """Test that legacy mode still works with single table""" + taskgroup_mock = MagicMock() + mock_tg_instance = MagicMock() + mock_trustgraph.return_value = mock_tg_instance + + with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}): + processor = Processor(taskgroup=taskgroup_mock) + + mock_message = MagicMock() + mock_message.metadata.user = 'user1' + mock_message.metadata.collection = 'collection1' + mock_message.triples = [] + + await processor.store_triples(mock_message) + + # Verify KnowledgeGraph instance uses legacy mode + kg_instance = mock_trustgraph.return_value + assert kg_instance is not None + + @pytest.mark.asyncio + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') + async def test_optimized_mode_uses_multi_table(self, mock_trustgraph): + """Test that optimized mode uses multi-table schema""" + taskgroup_mock = MagicMock() + mock_tg_instance = MagicMock() + mock_trustgraph.return_value = mock_tg_instance + + with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}): + processor = Processor(taskgroup=taskgroup_mock) + + mock_message = MagicMock() + mock_message.metadata.user = 'user1' + mock_message.metadata.collection = 'collection1' + mock_message.triples = [] + + await processor.store_triples(mock_message) + + # Verify KnowledgeGraph instance is in optimized mode + kg_instance = mock_trustgraph.return_value + assert kg_instance is not None + + @pytest.mark.asyncio + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') + async def test_batch_write_consistency(self, mock_trustgraph): + """Test that all tables stay consistent during batch writes""" + taskgroup_mock = MagicMock() + mock_tg_instance = MagicMock() + mock_trustgraph.return_value = mock_tg_instance + + processor = Processor(taskgroup=taskgroup_mock) + + # Create test triple + triple = MagicMock() + triple.s.value = 'test_subject' + triple.p.value = 'test_predicate' + triple.o.value = 'test_object' + + mock_message = MagicMock() + mock_message.metadata.user = 'user1' + mock_message.metadata.collection = 'collection1' + mock_message.triples = [triple] + + await processor.store_triples(mock_message) + + # Verify insert was called for the triple (implementation details tested in KnowledgeGraph) + mock_tg_instance.insert.assert_called_once_with( + 'collection1', 'test_subject', 'test_predicate', 'test_object' + ) + + def test_environment_variable_controls_mode(self): + """Test that CASSANDRA_USE_LEGACY environment variable controls operation mode""" + taskgroup_mock = MagicMock() + + # Test legacy mode + with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}): + processor = Processor(taskgroup=taskgroup_mock) + # Mode is determined in KnowledgeGraph initialization + + # Test optimized mode + with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}): + processor = Processor(taskgroup=taskgroup_mock) + # Mode is determined in KnowledgeGraph initialization + + # Test default mode (optimized when env var not set) + with patch.dict('os.environ', {}, clear=True): + processor = Processor(taskgroup=taskgroup_mock) + # Mode is determined in KnowledgeGraph initialization \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index 93e41230..a4cf12b4 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -1,11 +1,16 @@ from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider +from cassandra.query import BatchStatement, SimpleStatement from ssl import SSLContext, PROTOCOL_TLSv1_2 +import os +import logging # Global list to track clusters for cleanup _active_clusters = [] +logger = logging.getLogger(__name__) + class KnowledgeGraph: def __init__( @@ -17,9 +22,19 @@ class KnowledgeGraph: hosts = ["localhost"] self.keyspace = keyspace - self.table = "triples" # Fixed table name for unified schema self.username = username - + + # Multi-table schema design for optimal performance + self.use_legacy = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true' + + if self.use_legacy: + self.table = "triples" # Legacy single table + else: + # New optimized tables + self.subject_table = "triples_by_subject" + self.po_table = "triples_by_po" + self.object_table = "triples_by_object" + if username and password: ssl_context = SSLContext(PROTOCOL_TLSv1_2) auth_provider = PlainTextAuthProvider(username=username, password=password) @@ -27,12 +42,15 @@ class KnowledgeGraph: else: self.cluster = Cluster(hosts) self.session = self.cluster.connect() - + # Track this cluster globally _active_clusters.append(self.cluster) self.init() + if not self.use_legacy: + self.prepare_statements() + def clear(self): self.session.execute(f""" @@ -45,14 +63,21 @@ class KnowledgeGraph: self.session.execute(f""" create keyspace if not exists {self.keyspace} - with replication = {{ - 'class' : 'SimpleStrategy', - 'replication_factor' : 1 + with replication = {{ + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 }}; """); self.session.set_keyspace(self.keyspace) + if self.use_legacy: + self.init_legacy_schema() + else: + self.init_optimized_schema() + + def init_legacy_schema(self): + """Initialize legacy single-table schema for backward compatibility""" self.session.execute(f""" create table if not exists {self.table} ( collection text, @@ -78,67 +103,241 @@ class KnowledgeGraph: ON {self.table} (o); """); + def init_optimized_schema(self): + """Initialize optimized multi-table schema for performance""" + # Table 1: Subject-centric queries (get_s, get_sp, get_spo, get_os) + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.subject_table} ( + collection text, + s text, + p text, + o text, + PRIMARY KEY ((collection, s), p, o) + ); + """); + + # Table 2: Predicate-Object queries (get_p, get_po) - eliminates ALLOW FILTERING! + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.po_table} ( + collection text, + p text, + o text, + s text, + PRIMARY KEY ((collection, p), o, s) + ); + """); + + # Table 3: Object-centric queries (get_o) + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.object_table} ( + collection text, + o text, + s text, + p text, + PRIMARY KEY ((collection, o), s, p) + ); + """); + + logger.info("Optimized multi-table schema initialized") + + def prepare_statements(self): + """Prepare statements for optimal performance""" + # Insert statements for batch operations + self.insert_subject_stmt = self.session.prepare( + f"INSERT INTO {self.subject_table} (collection, s, p, o) VALUES (?, ?, ?, ?)" + ) + + self.insert_po_stmt = self.session.prepare( + f"INSERT INTO {self.po_table} (collection, p, o, s) VALUES (?, ?, ?, ?)" + ) + + self.insert_object_stmt = self.session.prepare( + f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)" + ) + + # Query statements for optimized access + self.get_all_stmt = self.session.prepare( + f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ?" + ) + + self.get_s_stmt = self.session.prepare( + f"SELECT p, o FROM {self.subject_table} WHERE collection = ? AND s = ? LIMIT ?" + ) + + self.get_p_stmt = self.session.prepare( + f"SELECT s, o FROM {self.po_table} WHERE collection = ? AND p = ? LIMIT ?" + ) + + self.get_o_stmt = self.session.prepare( + f"SELECT s, p FROM {self.object_table} WHERE collection = ? AND o = ? LIMIT ?" + ) + + self.get_sp_stmt = self.session.prepare( + f"SELECT o FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? LIMIT ?" + ) + + # The critical optimization: get_po without ALLOW FILTERING! + self.get_po_stmt = self.session.prepare( + f"SELECT s FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? LIMIT ?" + ) + + self.get_os_stmt = self.session.prepare( + f"SELECT p FROM {self.subject_table} WHERE collection = ? AND s = ? AND o = ? LIMIT ?" + ) + + self.get_spo_stmt = self.session.prepare( + f"SELECT s as x FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?" + ) + + logger.info("Prepared statements initialized for optimal performance") + def insert(self, collection, s, p, o): - self.session.execute( - f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)", - (collection, s, p, o) - ) + if self.use_legacy: + self.session.execute( + f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)", + (collection, s, p, o) + ) + else: + # Batch write to all three tables for consistency + batch = BatchStatement() + + # Insert into subject table + batch.add(self.insert_subject_stmt, (collection, s, p, o)) + + # Insert into predicate-object table (column order: collection, p, o, s) + batch.add(self.insert_po_stmt, (collection, p, o, s)) + + # Insert into object table (column order: collection, o, s, p) + batch.add(self.insert_object_stmt, (collection, o, s, p)) + + self.session.execute(batch) def get_all(self, collection, limit=50): - return self.session.execute( - f"select s, p, o from {self.table} where collection = %s limit {limit}", - (collection,) - ) + if self.use_legacy: + return self.session.execute( + f"select s, p, o from {self.table} where collection = %s limit {limit}", + (collection,) + ) + else: + # Use subject table for get_all queries + return self.session.execute( + self.get_all_stmt, + (collection, limit) + ) def get_s(self, collection, s, limit=10): - return self.session.execute( - f"select p, o from {self.table} where collection = %s and s = %s limit {limit}", - (collection, s) - ) + if self.use_legacy: + return self.session.execute( + f"select p, o from {self.table} where collection = %s and s = %s limit {limit}", + (collection, s) + ) + else: + # Optimized: Direct partition access with (collection, s) + return self.session.execute( + self.get_s_stmt, + (collection, s, limit) + ) def get_p(self, collection, p, limit=10): - return self.session.execute( - f"select s, o from {self.table} where collection = %s and p = %s limit {limit}", - (collection, p) - ) + if self.use_legacy: + return self.session.execute( + f"select s, o from {self.table} where collection = %s and p = %s limit {limit}", + (collection, p) + ) + else: + # Optimized: Use po_table for direct partition access + return self.session.execute( + self.get_p_stmt, + (collection, p, limit) + ) def get_o(self, collection, o, limit=10): - return self.session.execute( - f"select s, p from {self.table} where collection = %s and o = %s limit {limit}", - (collection, o) - ) + if self.use_legacy: + return self.session.execute( + f"select s, p from {self.table} where collection = %s and o = %s limit {limit}", + (collection, o) + ) + else: + # Optimized: Use object_table for direct partition access + return self.session.execute( + self.get_o_stmt, + (collection, o, limit) + ) def get_sp(self, collection, s, p, limit=10): - return self.session.execute( - f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}", - (collection, s, p) - ) + if self.use_legacy: + return self.session.execute( + f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}", + (collection, s, p) + ) + else: + # Optimized: Use subject_table with clustering key access + return self.session.execute( + self.get_sp_stmt, + (collection, s, p, limit) + ) def get_po(self, collection, p, o, limit=10): - return self.session.execute( - f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering", - (collection, p, o) - ) + if self.use_legacy: + return self.session.execute( + f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering", + (collection, p, o) + ) + else: + # CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING! + return self.session.execute( + self.get_po_stmt, + (collection, p, o, limit) + ) def get_os(self, collection, o, s, limit=10): - return self.session.execute( - f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering", - (collection, o, s) - ) + if self.use_legacy: + return self.session.execute( + f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering", + (collection, o, s) + ) + else: + # Optimized: Use subject_table with clustering access (no more ALLOW FILTERING) + return self.session.execute( + self.get_os_stmt, + (collection, s, o, limit) + ) def get_spo(self, collection, s, p, o, limit=10): - return self.session.execute( - f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""", - (collection, s, p, o) - ) + if self.use_legacy: + return self.session.execute( + f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""", + (collection, s, p, o) + ) + else: + # Optimized: Use subject_table for exact key lookup + return self.session.execute( + self.get_spo_stmt, + (collection, s, p, o, limit) + ) def delete_collection(self, collection): """Delete all triples for a specific collection""" - self.session.execute( - f"delete from {self.table} where collection = %s", - (collection,) - ) + if self.use_legacy: + self.session.execute( + f"delete from {self.table} where collection = %s", + (collection,) + ) + else: + # Delete from all three tables + self.session.execute( + f"delete from {self.subject_table} where collection = %s", + (collection,) + ) + self.session.execute( + f"delete from {self.po_table} where collection = %s", + (collection,) + ) + self.session.execute( + f"delete from {self.object_table} where collection = %s", + (collection,) + ) def close(self): """Close the Cassandra session and cluster connections properly"""