Cassandra performance enhancement (#521)

* Tech spec

* Tech spec complete

* Cassandra multi-table for performance
This commit is contained in:
cybermaggedon 2025-09-18 19:52:05 +01:00 committed by GitHub
parent 13ff7d765d
commit d378db9370
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 1123 additions and 48 deletions

View file

@ -0,0 +1,582 @@
# Tech Spec: Cassandra Knowledge Base Performance Refactor
**Status:** Draft
**Author:** Assistant
**Date:** 2025-09-18
## Overview
This specification addresses performance issues in the TrustGraph Cassandra knowledge base implementation and proposes optimizations for RDF triple storage and querying.
## Current Implementation
### Schema Design
The current implementation uses a single table design in `trustgraph-flow/trustgraph/direct/cassandra_kg.py`:
```sql
CREATE TABLE triples (
collection text,
s text,
p text,
o text,
PRIMARY KEY (collection, s, p, o)
);
```
**Secondary Indexes:**
- `triples_s` ON `s` (subject)
- `triples_p` ON `p` (predicate)
- `triples_o` ON `o` (object)
### Query Patterns
The current implementation supports 8 distinct query patterns:
1. **get_all(collection, limit=50)** - Retrieve all triples for a collection
```sql
SELECT s, p, o FROM triples WHERE collection = ? LIMIT 50
```
2. **get_s(collection, s, limit=10)** - Query by subject
```sql
SELECT p, o FROM triples WHERE collection = ? AND s = ? LIMIT 10
```
3. **get_p(collection, p, limit=10)** - Query by predicate
```sql
SELECT s, o FROM triples WHERE collection = ? AND p = ? LIMIT 10
```
4. **get_o(collection, o, limit=10)** - Query by object
```sql
SELECT s, p FROM triples WHERE collection = ? AND o = ? LIMIT 10
```
5. **get_sp(collection, s, p, limit=10)** - Query by subject + predicate
```sql
SELECT o FROM triples WHERE collection = ? AND s = ? AND p = ? LIMIT 10
```
6. **get_po(collection, p, o, limit=10)** - Query by predicate + object ⚠️
```sql
SELECT s FROM triples WHERE collection = ? AND p = ? AND o = ? LIMIT 10 ALLOW FILTERING
```
7. **get_os(collection, o, s, limit=10)** - Query by object + subject ⚠️
```sql
SELECT p FROM triples WHERE collection = ? AND o = ? AND s = ? LIMIT 10 ALLOW FILTERING
```
8. **get_spo(collection, s, p, o, limit=10)** - Exact triple match
```sql
SELECT s as x FROM triples WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT 10
```
### Current Architecture
**File: `trustgraph-flow/trustgraph/direct/cassandra_kg.py`**
- Single `KnowledgeGraph` class handling all operations
- Connection pooling through global `_active_clusters` list
- Fixed table name: `"triples"`
- Keyspace per user model
- SimpleStrategy replication with factor 1
**Integration Points:**
- **Write Path:** `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
- **Query Path:** `trustgraph-flow/trustgraph/query/triples/cassandra/service.py`
- **Knowledge Store:** `trustgraph-flow/trustgraph/tables/knowledge.py`
## Performance Issues Identified
### Schema-Level Issues
1. **Inefficient Primary Key Design**
- Current: `PRIMARY KEY (collection, s, p, o)`
- Results in poor clustering for common access patterns
- Forces expensive secondary index usage
2. **Secondary Index Overuse** ⚠️
- Three secondary indexes on high-cardinality columns (s, p, o)
- Secondary indexes in Cassandra are expensive and don't scale well
- Queries 6 & 7 require `ALLOW FILTERING` indicating poor data modeling
3. **Hot Partition Risk**
- Single partition key `collection` can create hot partitions
- Large collections will concentrate on single nodes
- No distribution strategy for load balancing
### Query-Level Issues
1. **ALLOW FILTERING Usage** ⚠️
- Two query types (get_po, get_os) require `ALLOW FILTERING`
- These queries scan multiple partitions and are extremely expensive
- Performance degrades linearly with data size
2. **Inefficient Access Patterns**
- No optimization for common RDF query patterns
- Missing compound indexes for frequent query combinations
- No consideration for graph traversal patterns
3. **Lack of Query Optimization**
- No prepared statements caching
- No query hints or optimization strategies
- No consideration for pagination beyond simple LIMIT
## Problem Statement
The current Cassandra knowledge base implementation has two critical performance bottlenecks:
### 1. Inefficient get_po Query Performance
The `get_po(collection, p, o)` query is extremely inefficient due to requiring `ALLOW FILTERING`:
```sql
SELECT s FROM triples WHERE collection = ? AND p = ? AND o = ? LIMIT 10 ALLOW FILTERING
```
**Why this is problematic:**
- `ALLOW FILTERING` forces Cassandra to scan all partitions within the collection
- Performance degrades linearly with data size
- This is a common RDF query pattern (finding subjects that have a specific predicate-object relationship)
- Creates significant load on the cluster as data grows
### 2. Poor Clustering Strategy
The current primary key `PRIMARY KEY (collection, s, p, o)` provides minimal clustering benefits:
**Issues with current clustering:**
- `collection` as partition key doesn't distribute data effectively
- Most collections contain diverse data making clustering ineffective
- No consideration for common access patterns in RDF queries
- Large collections create hot partitions on single nodes
- Clustering columns (s, p, o) don't optimize for typical graph traversal patterns
**Impact:**
- Queries don't benefit from data locality
- Poor cache utilization
- Uneven load distribution across cluster nodes
- Scalability bottlenecks as collections grow
## Proposed Solution: Multi-Table Denormalization Strategy
### Overview
Replace the single `triples` table with three purpose-built tables, each optimized for specific query patterns. This eliminates the need for secondary indexes and ALLOW FILTERING while providing optimal performance for all query types.
### New Schema Design
**Table 1: Subject-Centric Queries**
```sql
CREATE TABLE triples_by_subject (
collection text,
s text,
p text,
o text,
PRIMARY KEY ((collection, s), p, o)
);
```
- **Optimizes:** get_s, get_sp, get_spo, get_os
- **Partition Key:** (collection, s) - Better distribution than collection alone
- **Clustering:** (p, o) - Enables efficient predicate/object lookups for a subject
**Table 2: Predicate-Object Queries**
```sql
CREATE TABLE triples_by_po (
collection text,
p text,
o text,
s text,
PRIMARY KEY ((collection, p), o, s)
);
```
- **Optimizes:** get_p, get_po (eliminates ALLOW FILTERING!)
- **Partition Key:** (collection, p) - Direct access by predicate
- **Clustering:** (o, s) - Efficient object-subject traversal
**Table 3: Object-Centric Queries**
```sql
CREATE TABLE triples_by_object (
collection text,
o text,
s text,
p text,
PRIMARY KEY ((collection, o), s, p)
);
```
- **Optimizes:** get_o, get_os
- **Partition Key:** (collection, o) - Direct access by object
- **Clustering:** (s, p) - Efficient subject-predicate traversal
### Query Mapping
| Original Query | Target Table | Performance Improvement |
|----------------|-------------|------------------------|
| get_all(collection) | triples_by_subject | Token-based pagination |
| get_s(collection, s) | triples_by_subject | Direct partition access |
| get_p(collection, p) | triples_by_po | Direct partition access |
| get_o(collection, o) | triples_by_object | Direct partition access |
| get_sp(collection, s, p) | triples_by_subject | Partition + clustering |
| get_po(collection, p, o) | triples_by_po | **No more ALLOW FILTERING!** |
| get_os(collection, o, s) | triples_by_subject | Partition + clustering |
| get_spo(collection, s, p, o) | triples_by_subject | Exact key lookup |
### Benefits
1. **Eliminates ALLOW FILTERING** - Every query has an optimal access path
2. **No Secondary Indexes** - Each table IS the index for its query pattern
3. **Better Data Distribution** - Composite partition keys spread load effectively
4. **Predictable Performance** - Query time proportional to result size, not total data
5. **Leverages Cassandra Strengths** - Designed for Cassandra's architecture
## Implementation Plan
### Files Requiring Changes
#### Primary Implementation File
**`trustgraph-flow/trustgraph/direct/cassandra_kg.py`** - Complete rewrite required
**Current Methods to Refactor:**
```python
# Schema initialization
def init(self) -> None # Replace single table with three tables
# Insert operations
def insert(self, collection, s, p, o) -> None # Write to all three tables
# Query operations (API unchanged, implementation optimized)
def get_all(self, collection, limit=50) # Use triples_by_subject
def get_s(self, collection, s, limit=10) # Use triples_by_subject
def get_p(self, collection, p, limit=10) # Use triples_by_po
def get_o(self, collection, o, limit=10) # Use triples_by_object
def get_sp(self, collection, s, p, limit=10) # Use triples_by_subject
def get_po(self, collection, p, o, limit=10) # Use triples_by_po (NO ALLOW FILTERING!)
def get_os(self, collection, o, s, limit=10) # Use triples_by_subject
def get_spo(self, collection, s, p, o, limit=10) # Use triples_by_subject
# Collection management
def delete_collection(self, collection) -> None # Delete from all three tables
```
#### Integration Files (No Logic Changes Required)
**`trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`**
- No changes needed - uses existing KnowledgeGraph API
- Benefits automatically from performance improvements
**`trustgraph-flow/trustgraph/query/triples/cassandra/service.py`**
- No changes needed - uses existing KnowledgeGraph API
- Benefits automatically from performance improvements
### Test Files Requiring Updates
#### Unit Tests
**`tests/unit/test_storage/test_triples_cassandra_storage.py`**
- Update test expectations for schema changes
- Add tests for multi-table consistency
- Verify no ALLOW FILTERING in query plans
**`tests/unit/test_query/test_triples_cassandra_query.py`**
- Update performance assertions
- Test all 8 query patterns against new tables
- Verify query routing to correct tables
#### Integration Tests
**`tests/integration/test_cassandra_integration.py`**
- End-to-end testing with new schema
- Performance benchmarking comparisons
- Data consistency verification across tables
**`tests/unit/test_storage/test_cassandra_config_integration.py`**
- Update schema validation tests
- Test migration scenarios
### Implementation Strategy
#### Phase 1: Schema and Core Methods
1. **Rewrite `init()` method** - Create three tables instead of one
2. **Rewrite `insert()` method** - Batch writes to all three tables
3. **Implement prepared statements** - For optimal performance
4. **Add table routing logic** - Direct queries to optimal tables
#### Phase 2: Query Method Optimization
1. **Rewrite each get_* method** to use optimal table
2. **Remove all ALLOW FILTERING** usage
3. **Implement efficient clustering key usage**
4. **Add query performance logging**
#### Phase 3: Collection Management
1. **Update `delete_collection()`** - Remove from all three tables
2. **Add consistency verification** - Ensure all tables stay in sync
3. **Implement batch operations** - For atomic multi-table operations
### Key Implementation Details
#### Batch Write Strategy
```python
def insert(self, collection, s, p, o):
batch = BatchStatement()
# Insert into all three tables
batch.add(SimpleStatement(
"INSERT INTO triples_by_subject (collection, s, p, o) VALUES (?, ?, ?, ?)"
), (collection, s, p, o))
batch.add(SimpleStatement(
"INSERT INTO triples_by_po (collection, p, o, s) VALUES (?, ?, ?, ?)"
), (collection, p, o, s))
batch.add(SimpleStatement(
"INSERT INTO triples_by_object (collection, o, s, p) VALUES (?, ?, ?, ?)"
), (collection, o, s, p))
self.session.execute(batch)
```
#### Query Routing Logic
```python
def get_po(self, collection, p, o, limit=10):
# Route to triples_by_po table - NO ALLOW FILTERING!
return self.session.execute(
"SELECT s FROM triples_by_po WHERE collection = ? AND p = ? AND o = ? LIMIT ?",
(collection, p, o, limit)
)
```
#### Prepared Statement Optimization
```python
def prepare_statements(self):
# Cache prepared statements for better performance
self.insert_subject_stmt = self.session.prepare(
"INSERT INTO triples_by_subject (collection, s, p, o) VALUES (?, ?, ?, ?)"
)
self.insert_po_stmt = self.session.prepare(
"INSERT INTO triples_by_po (collection, p, o, s) VALUES (?, ?, ?, ?)"
)
# ... etc for all tables and queries
```
## Migration Strategy
### Data Migration Approach
#### Option 1: Blue-Green Deployment (Recommended)
1. **Deploy new schema alongside existing** - Use different table names temporarily
2. **Dual-write period** - Write to both old and new schemas during transition
3. **Background migration** - Copy existing data to new tables
4. **Switch reads** - Route queries to new tables once data is migrated
5. **Drop old tables** - After verification period
#### Option 2: In-Place Migration
1. **Schema addition** - Create new tables in existing keyspace
2. **Data migration script** - Batch copy from old table to new tables
3. **Application update** - Deploy new code after migration completes
4. **Old table cleanup** - Remove old table and indexes
### Backward Compatibility
#### Deployment Strategy
```python
# Environment variable to control table usage during migration
USE_LEGACY_TABLES = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true'
class KnowledgeGraph:
def __init__(self, ...):
if USE_LEGACY_TABLES:
self.init_legacy_schema()
else:
self.init_optimized_schema()
```
#### Migration Script
```python
def migrate_data():
# Read from old table
old_triples = session.execute("SELECT collection, s, p, o FROM triples")
# Batch write to new tables
for batch in batched(old_triples, 100):
batch_stmt = BatchStatement()
for row in batch:
# Add to all three new tables
batch_stmt.add(insert_subject_stmt, row)
batch_stmt.add(insert_po_stmt, (row.collection, row.p, row.o, row.s))
batch_stmt.add(insert_object_stmt, (row.collection, row.o, row.s, row.p))
session.execute(batch_stmt)
```
### Validation Strategy
#### Data Consistency Checks
```python
def validate_migration():
# Count total records in old vs new tables
old_count = session.execute("SELECT COUNT(*) FROM triples WHERE collection = ?", (collection,))
new_count = session.execute("SELECT COUNT(*) FROM triples_by_subject WHERE collection = ?", (collection,))
assert old_count == new_count, f"Record count mismatch: {old_count} vs {new_count}"
# Spot check random samples
sample_queries = generate_test_queries()
for query in sample_queries:
old_result = execute_legacy_query(query)
new_result = execute_optimized_query(query)
assert old_result == new_result, f"Query results differ for {query}"
```
## Testing Strategy
### Performance Testing
#### Benchmark Scenarios
1. **Query Performance Comparison**
- Before/after performance metrics for all 8 query types
- Focus on get_po performance improvement (eliminate ALLOW FILTERING)
- Measure query latency under various data sizes
2. **Load Testing**
- Concurrent query execution
- Write throughput with batch operations
- Memory and CPU utilization
3. **Scalability Testing**
- Performance with increasing collection sizes
- Multi-collection query distribution
- Cluster node utilization
#### Test Data Sets
- **Small:** 10K triples per collection
- **Medium:** 100K triples per collection
- **Large:** 1M+ triples per collection
- **Multiple collections:** Test partition distribution
### Functional Testing
#### Unit Test Updates
```python
# Example test structure for new implementation
class TestCassandraKGPerformance:
def test_get_po_no_allow_filtering(self):
# Verify get_po queries don't use ALLOW FILTERING
with patch('cassandra.cluster.Session.execute') as mock_execute:
kg.get_po('test_collection', 'predicate', 'object')
executed_query = mock_execute.call_args[0][0]
assert 'ALLOW FILTERING' not in executed_query
def test_multi_table_consistency(self):
# Verify all tables stay in sync
kg.insert('test', 's1', 'p1', 'o1')
# Check all tables contain the triple
assert_triple_exists('triples_by_subject', 'test', 's1', 'p1', 'o1')
assert_triple_exists('triples_by_po', 'test', 'p1', 'o1', 's1')
assert_triple_exists('triples_by_object', 'test', 'o1', 's1', 'p1')
```
#### Integration Test Updates
```python
class TestCassandraIntegration:
def test_query_performance_regression(self):
# Ensure new implementation is faster than old
old_time = benchmark_legacy_get_po()
new_time = benchmark_optimized_get_po()
assert new_time < old_time * 0.5 # At least 50% improvement
def test_end_to_end_workflow(self):
# Test complete write -> query -> delete cycle
# Verify no performance degradation in integration
```
### Rollback Plan
#### Quick Rollback Strategy
1. **Environment variable toggle** - Switch back to legacy tables immediately
2. **Keep legacy tables** - Don't drop until performance is proven
3. **Monitoring alerts** - Automated rollback triggers based on error rates/latency
#### Rollback Validation
```python
def rollback_to_legacy():
# Set environment variable
os.environ['CASSANDRA_USE_LEGACY'] = 'true'
# Restart services to pick up change
restart_cassandra_services()
# Validate functionality
run_smoke_tests()
```
## Risks and Considerations
### Performance Risks
- **Write latency increase** - 3x write operations per insert
- **Storage overhead** - 3x storage requirement
- **Batch write failures** - Need proper error handling
### Operational Risks
- **Migration complexity** - Data migration for large datasets
- **Consistency challenges** - Ensuring all tables stay synchronized
- **Monitoring gaps** - Need new metrics for multi-table operations
### Mitigation Strategies
1. **Gradual rollout** - Start with small collections
2. **Comprehensive monitoring** - Track all performance metrics
3. **Automated validation** - Continuous consistency checking
4. **Quick rollback capability** - Environment-based table selection
## Success Criteria
### Performance Improvements
- [ ] **Eliminate ALLOW FILTERING** - get_po and get_os queries run without filtering
- [ ] **Query latency reduction** - 50%+ improvement in query response times
- [ ] **Better load distribution** - No hot partitions, even load across cluster nodes
- [ ] **Scalable performance** - Query time proportional to result size, not total data
### Functional Requirements
- [ ] **API compatibility** - All existing code continues to work unchanged
- [ ] **Data consistency** - All three tables remain synchronized
- [ ] **Zero data loss** - Migration preserves all existing triples
- [ ] **Backward compatibility** - Ability to rollback to legacy schema
### Operational Requirements
- [ ] **Safe migration** - Blue-green deployment with rollback capability
- [ ] **Monitoring coverage** - Comprehensive metrics for multi-table operations
- [ ] **Test coverage** - All query patterns tested with performance benchmarks
- [ ] **Documentation** - Updated deployment and operational procedures
## Timeline
### Phase 1: Implementation
- [ ] Rewrite `cassandra_kg.py` with multi-table schema
- [ ] Implement batch write operations
- [ ] Add prepared statement optimization
- [ ] Update unit tests
### Phase 2: Integration Testing
- [ ] Update integration tests
- [ ] Performance benchmarking
- [ ] Load testing with realistic data volumes
- [ ] Validation scripts for data consistency
### Phase 3: Migration Planning
- [ ] Blue-green deployment scripts
- [ ] Data migration tools
- [ ] Monitoring dashboard updates
- [ ] Rollback procedures
### Phase 4: Production Deployment
- [ ] Staged rollout to production
- [ ] Performance monitoring and validation
- [ ] Legacy table cleanup
- [ ] Documentation updates
## Conclusion
This multi-table denormalization strategy directly addresses the two critical performance bottlenecks:
1. **Eliminates expensive ALLOW FILTERING** by providing optimal table structures for each query pattern
2. **Improves clustering effectiveness** through composite partition keys that distribute load properly
The approach leverages Cassandra's strengths while maintaining complete API compatibility, ensuring existing code benefits automatically from the performance improvements.

View file

@ -534,4 +534,203 @@ class TestCassandraQueryProcessor:
assert len(result) == 2
assert result[0].o.value == 'object1'
assert result[1].o.value == 'object2'
assert result[1].o.value == 'object2'
class TestCassandraQueryPerformanceOptimizations:
"""Test cases for multi-table performance optimizations in query service"""
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_get_po_query_optimization(self, mock_trustgraph):
"""Test that get_po queries use optimized table (no ALLOW FILTERING)"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_tg_instance.get_po.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
# PO query pattern (predicate + object, find subjects)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
limit=50
)
result = await processor.query_triples(query)
# Verify get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with(
'test_collection', 'test_predicate', 'test_object', limit=50
)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_get_os_query_optimization(self, mock_trustgraph):
"""Test that get_os queries use optimized table (no ALLOW FILTERING)"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.p = 'result_predicate'
mock_tg_instance.get_os.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
# OS query pattern (object + subject, find predicates)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=None,
o=Value(value='test_object', is_uri=False),
limit=25
)
result = await processor.query_triples(query)
# Verify get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with(
'test_collection', 'test_object', 'test_subject', limit=25
)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_all_query_patterns_use_correct_tables(self, mock_trustgraph):
"""Test that all query patterns route to their optimal tables"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
# Mock empty results for all queries
mock_tg_instance.get_all.return_value = []
mock_tg_instance.get_s.return_value = []
mock_tg_instance.get_p.return_value = []
mock_tg_instance.get_o.return_value = []
mock_tg_instance.get_sp.return_value = []
mock_tg_instance.get_po.return_value = []
mock_tg_instance.get_os.return_value = []
mock_tg_instance.get_spo.return_value = []
processor = Processor(taskgroup=MagicMock())
# Test each query pattern
test_patterns = [
# (s, p, o, expected_method)
(None, None, None, 'get_all'), # All triples
('s1', None, None, 'get_s'), # Subject only
(None, 'p1', None, 'get_p'), # Predicate only
(None, None, 'o1', 'get_o'), # Object only
('s1', 'p1', None, 'get_sp'), # Subject + Predicate
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
('s1', None, 'o1', 'get_os'), # Object + Subject
('s1', 'p1', 'o1', 'get_spo'), # All three
]
for s, p, o, expected_method in test_patterns:
# Reset mock call counts
mock_tg_instance.reset_mock()
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value=s, is_uri=False) if s else None,
p=Value(value=p, is_uri=False) if p else None,
o=Value(value=o, is_uri=False) if o else None,
limit=10
)
await processor.query_triples(query)
# Verify the correct method was called
method = getattr(mock_tg_instance, expected_method)
assert method.called, f"Expected {expected_method} to be called for pattern s={s}, p={p}, o={o}"
def test_legacy_vs_optimized_mode_configuration(self):
"""Test that environment variable controls query optimization mode"""
taskgroup_mock = MagicMock()
# Test optimized mode (default)
with patch.dict('os.environ', {}, clear=True):
processor = Processor(taskgroup=taskgroup_mock)
# Mode is determined in KnowledgeGraph initialization
# Test legacy mode
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
processor = Processor(taskgroup=taskgroup_mock)
# Mode is determined in KnowledgeGraph initialization
# Test explicit optimized mode
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
processor = Processor(taskgroup=taskgroup_mock)
# Mode is determined in KnowledgeGraph initialization
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_performance_critical_po_query_no_filtering(self, mock_trustgraph):
"""Test the performance-critical PO query that eliminates ALLOW FILTERING"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
# Mock multiple subjects for the same predicate-object pair
mock_results = []
for i in range(5):
mock_result = MagicMock()
mock_result.s = f'subject_{i}'
mock_results.append(mock_result)
mock_tg_instance.get_po.return_value = mock_results
processor = Processor(taskgroup=MagicMock())
# This is the query pattern that was slow with ALLOW FILTERING
query = TriplesQueryRequest(
user='large_dataset_user',
collection='massive_collection',
s=None,
p=Value(value='http://www.w3.org/1999/02/22-rdf-syntax-ns#type', is_uri=True),
o=Value(value='http://example.com/Person', is_uri=True),
limit=1000
)
result = await processor.query_triples(query)
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.get_po.assert_called_once_with(
'massive_collection',
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
'http://example.com/Person',
limit=1000
)
# Verify all results were returned
assert len(result) == 5
for i, triple in enumerate(result):
assert triple.s.value == f'subject_{i}'
assert triple.p.value == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
assert triple.p.is_uri is True
assert triple.o.value == 'http://example.com/Person'
assert triple.o.is_uri is True

View file

@ -415,4 +415,99 @@ class TestCassandraStorageProcessor:
# Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection')
# TrustGraph should be set to None though
assert processor.tg is None
assert processor.tg is None
class TestCassandraPerformanceOptimizations:
"""Test cases for multi-table performance optimizations"""
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_legacy_mode_uses_single_table(self, mock_trustgraph):
"""Test that legacy mode still works with single table"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify KnowledgeGraph instance uses legacy mode
kg_instance = mock_trustgraph.return_value
assert kg_instance is not None
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_optimized_mode_uses_multi_table(self, mock_trustgraph):
"""Test that optimized mode uses multi-table schema"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify KnowledgeGraph instance is in optimized mode
kg_instance = mock_trustgraph.return_value
assert kg_instance is not None
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_batch_write_consistency(self, mock_trustgraph):
"""Test that all tables stay consistent during batch writes"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create test triple
triple = MagicMock()
triple.s.value = 'test_subject'
triple.p.value = 'test_predicate'
triple.o.value = 'test_object'
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with(
'collection1', 'test_subject', 'test_predicate', 'test_object'
)
def test_environment_variable_controls_mode(self):
"""Test that CASSANDRA_USE_LEGACY environment variable controls operation mode"""
taskgroup_mock = MagicMock()
# Test legacy mode
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
processor = Processor(taskgroup=taskgroup_mock)
# Mode is determined in KnowledgeGraph initialization
# Test optimized mode
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
processor = Processor(taskgroup=taskgroup_mock)
# Mode is determined in KnowledgeGraph initialization
# Test default mode (optimized when env var not set)
with patch.dict('os.environ', {}, clear=True):
processor = Processor(taskgroup=taskgroup_mock)
# Mode is determined in KnowledgeGraph initialization

View file

@ -1,11 +1,16 @@
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
from ssl import SSLContext, PROTOCOL_TLSv1_2
import os
import logging
# Global list to track clusters for cleanup
_active_clusters = []
logger = logging.getLogger(__name__)
class KnowledgeGraph:
def __init__(
@ -17,9 +22,19 @@ class KnowledgeGraph:
hosts = ["localhost"]
self.keyspace = keyspace
self.table = "triples" # Fixed table name for unified schema
self.username = username
# Multi-table schema design for optimal performance
self.use_legacy = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true'
if self.use_legacy:
self.table = "triples" # Legacy single table
else:
# New optimized tables
self.subject_table = "triples_by_subject"
self.po_table = "triples_by_po"
self.object_table = "triples_by_object"
if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(username=username, password=password)
@ -27,12 +42,15 @@ class KnowledgeGraph:
else:
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
# Track this cluster globally
_active_clusters.append(self.cluster)
self.init()
if not self.use_legacy:
self.prepare_statements()
def clear(self):
self.session.execute(f"""
@ -45,14 +63,21 @@ class KnowledgeGraph:
self.session.execute(f"""
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
}};
""");
self.session.set_keyspace(self.keyspace)
if self.use_legacy:
self.init_legacy_schema()
else:
self.init_optimized_schema()
def init_legacy_schema(self):
"""Initialize legacy single-table schema for backward compatibility"""
self.session.execute(f"""
create table if not exists {self.table} (
collection text,
@ -78,67 +103,241 @@ class KnowledgeGraph:
ON {self.table} (o);
""");
def init_optimized_schema(self):
"""Initialize optimized multi-table schema for performance"""
# Table 1: Subject-centric queries (get_s, get_sp, get_spo, get_os)
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.subject_table} (
collection text,
s text,
p text,
o text,
PRIMARY KEY ((collection, s), p, o)
);
""");
# Table 2: Predicate-Object queries (get_p, get_po) - eliminates ALLOW FILTERING!
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.po_table} (
collection text,
p text,
o text,
s text,
PRIMARY KEY ((collection, p), o, s)
);
""");
# Table 3: Object-centric queries (get_o)
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.object_table} (
collection text,
o text,
s text,
p text,
PRIMARY KEY ((collection, o), s, p)
);
""");
logger.info("Optimized multi-table schema initialized")
def prepare_statements(self):
"""Prepare statements for optimal performance"""
# Insert statements for batch operations
self.insert_subject_stmt = self.session.prepare(
f"INSERT INTO {self.subject_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
)
self.insert_po_stmt = self.session.prepare(
f"INSERT INTO {self.po_table} (collection, p, o, s) VALUES (?, ?, ?, ?)"
)
self.insert_object_stmt = self.session.prepare(
f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)"
)
# Query statements for optimized access
self.get_all_stmt = self.session.prepare(
f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ?"
)
self.get_s_stmt = self.session.prepare(
f"SELECT p, o FROM {self.subject_table} WHERE collection = ? AND s = ? LIMIT ?"
)
self.get_p_stmt = self.session.prepare(
f"SELECT s, o FROM {self.po_table} WHERE collection = ? AND p = ? LIMIT ?"
)
self.get_o_stmt = self.session.prepare(
f"SELECT s, p FROM {self.object_table} WHERE collection = ? AND o = ? LIMIT ?"
)
self.get_sp_stmt = self.session.prepare(
f"SELECT o FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? LIMIT ?"
)
# The critical optimization: get_po without ALLOW FILTERING!
self.get_po_stmt = self.session.prepare(
f"SELECT s FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? LIMIT ?"
)
self.get_os_stmt = self.session.prepare(
f"SELECT p FROM {self.subject_table} WHERE collection = ? AND s = ? AND o = ? LIMIT ?"
)
self.get_spo_stmt = self.session.prepare(
f"SELECT s as x FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
)
logger.info("Prepared statements initialized for optimal performance")
def insert(self, collection, s, p, o):
self.session.execute(
f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)",
(collection, s, p, o)
)
if self.use_legacy:
self.session.execute(
f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)",
(collection, s, p, o)
)
else:
# Batch write to all three tables for consistency
batch = BatchStatement()
# Insert into subject table
batch.add(self.insert_subject_stmt, (collection, s, p, o))
# Insert into predicate-object table (column order: collection, p, o, s)
batch.add(self.insert_po_stmt, (collection, p, o, s))
# Insert into object table (column order: collection, o, s, p)
batch.add(self.insert_object_stmt, (collection, o, s, p))
self.session.execute(batch)
def get_all(self, collection, limit=50):
return self.session.execute(
f"select s, p, o from {self.table} where collection = %s limit {limit}",
(collection,)
)
if self.use_legacy:
return self.session.execute(
f"select s, p, o from {self.table} where collection = %s limit {limit}",
(collection,)
)
else:
# Use subject table for get_all queries
return self.session.execute(
self.get_all_stmt,
(collection, limit)
)
def get_s(self, collection, s, limit=10):
return self.session.execute(
f"select p, o from {self.table} where collection = %s and s = %s limit {limit}",
(collection, s)
)
if self.use_legacy:
return self.session.execute(
f"select p, o from {self.table} where collection = %s and s = %s limit {limit}",
(collection, s)
)
else:
# Optimized: Direct partition access with (collection, s)
return self.session.execute(
self.get_s_stmt,
(collection, s, limit)
)
def get_p(self, collection, p, limit=10):
return self.session.execute(
f"select s, o from {self.table} where collection = %s and p = %s limit {limit}",
(collection, p)
)
if self.use_legacy:
return self.session.execute(
f"select s, o from {self.table} where collection = %s and p = %s limit {limit}",
(collection, p)
)
else:
# Optimized: Use po_table for direct partition access
return self.session.execute(
self.get_p_stmt,
(collection, p, limit)
)
def get_o(self, collection, o, limit=10):
return self.session.execute(
f"select s, p from {self.table} where collection = %s and o = %s limit {limit}",
(collection, o)
)
if self.use_legacy:
return self.session.execute(
f"select s, p from {self.table} where collection = %s and o = %s limit {limit}",
(collection, o)
)
else:
# Optimized: Use object_table for direct partition access
return self.session.execute(
self.get_o_stmt,
(collection, o, limit)
)
def get_sp(self, collection, s, p, limit=10):
return self.session.execute(
f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}",
(collection, s, p)
)
if self.use_legacy:
return self.session.execute(
f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}",
(collection, s, p)
)
else:
# Optimized: Use subject_table with clustering key access
return self.session.execute(
self.get_sp_stmt,
(collection, s, p, limit)
)
def get_po(self, collection, p, o, limit=10):
return self.session.execute(
f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering",
(collection, p, o)
)
if self.use_legacy:
return self.session.execute(
f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering",
(collection, p, o)
)
else:
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
return self.session.execute(
self.get_po_stmt,
(collection, p, o, limit)
)
def get_os(self, collection, o, s, limit=10):
return self.session.execute(
f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering",
(collection, o, s)
)
if self.use_legacy:
return self.session.execute(
f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering",
(collection, o, s)
)
else:
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
return self.session.execute(
self.get_os_stmt,
(collection, s, o, limit)
)
def get_spo(self, collection, s, p, o, limit=10):
return self.session.execute(
f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""",
(collection, s, p, o)
)
if self.use_legacy:
return self.session.execute(
f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""",
(collection, s, p, o)
)
else:
# Optimized: Use subject_table for exact key lookup
return self.session.execute(
self.get_spo_stmt,
(collection, s, p, o, limit)
)
def delete_collection(self, collection):
"""Delete all triples for a specific collection"""
self.session.execute(
f"delete from {self.table} where collection = %s",
(collection,)
)
if self.use_legacy:
self.session.execute(
f"delete from {self.table} where collection = %s",
(collection,)
)
else:
# Delete from all three tables
self.session.execute(
f"delete from {self.subject_table} where collection = %s",
(collection,)
)
self.session.execute(
f"delete from {self.po_table} where collection = %s",
(collection,)
)
self.session.execute(
f"delete from {self.object_table} where collection = %s",
(collection,)
)
def close(self):
"""Close the Cassandra session and cluster connections properly"""