trustgraph/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
cybermaggedon 4913f8c2eb
feat: data store replication configuration and TLS upgrade (#975)
- Add centralised qdrant_config.py helper with env-var fallback for
  QDRANT_URL, QDRANT_API_KEY, QDRANT_REPLICATION_FACTOR, QDRANT_SHARD_NUMBER
- Update all 6 Qdrant processors to use the helper; writers pass
  replication_factor and shard_number to create_collection
- Fix hardcoded Cassandra replication_factor=1 in cassandra_kg.py,
  write.py, and sparql_cassandra.py to respect CASSANDRA_REPLICATION_FACTOR
- Upgrade Cassandra TLS from deprecated PROTOCOL_TLSv1_2 to
  ssl.create_default_context() across all connectors
2026-06-04 11:49:29 +01:00

644 lines
22 KiB
Python
Executable file

"""
Row writer for Cassandra. Input is ExtractedObject.
Writes structured rows to a unified Cassandra table with multi-index support.
Uses a single 'rows' table with the schema:
- collection: text
- schema_name: text
- index_name: text
- index_value: frozen<list<text>>
- data: map<text, text>
- source: text
Each row is written multiple times - once per indexed field defined in the schema.
"""
import asyncio
import json
import logging
import re
from typing import Dict, Set, Optional, Any, List, Tuple
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from .... schema import ExtractedObject
from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute
# Module logger
logger = logging.getLogger(__name__)
default_ident = "rows-write"
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
self.replication_factor = replication_factor
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name="input",
schema=ExtractedObject,
handler=self.on_object
)
)
# Register config handlers
self.register_config_handler(self.on_schema_config, types=["schema"])
self.register_config_handler(self.on_collection_config, types=["collection"])
# Cache of known keyspaces and whether tables exist
self.known_keyspaces: Set[str] = set()
self.tables_initialized: Set[str] = set()
# Cache of registered (collection, schema_name) pairs
self.registered_partitions: Set[Tuple[str, str]] = set()
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Cassandra session
self.cluster = None
self.session = None
# Protects connection setup and cache mutations
self._setup_lock = asyncio.Lock()
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
async with self._setup_lock:
return await self._apply_schema_config(workspace, config, version)
async def _apply_schema_config(self, workspace, config, version):
# Track which schemas changed in this workspace
old_schemas = self.schemas.get(workspace, {})
old_schema_names = set(old_schemas.keys())
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
# Clear partition cache for schemas that changed in this workspace
new_schema_names = set(ws_schemas.keys())
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
if changed_schemas:
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if sch not in changed_schemas
}
logger.info(
f"Cleared partition cache for changed schemas "
f"in {workspace}: {changed_schemas}"
)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def ensure_keyspace(self, keyspace: str):
"""Ensure keyspace exists in Cassandra"""
if keyspace in self.known_keyspaces:
return
# Connect if needed
self.connect_cassandra()
# Sanitize keyspace name
safe_keyspace = self.sanitize_name(keyspace)
# Create keyspace if not exists
create_keyspace_cql = f"""
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
'replication_factor': {self.replication_factor}
}}
"""
try:
self.session.execute(create_keyspace_cql)
self.known_keyspaces.add(keyspace)
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
except Exception as e:
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
raise
def ensure_tables(self, keyspace: str):
"""Ensure unified rows and row_partitions tables exist"""
if keyspace in self.tables_initialized:
return
# Ensure keyspace exists first
self.ensure_keyspace(keyspace)
safe_keyspace = self.sanitize_name(keyspace)
# Create unified rows table
create_rows_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows (
collection text,
schema_name text,
index_name text,
index_value frozen<list<text>>,
data map<text, text>,
source text,
PRIMARY KEY ((collection, schema_name, index_name), index_value)
)
"""
# Create row_partitions tracking table
create_partitions_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions (
collection text,
schema_name text,
index_name text,
PRIMARY KEY ((collection), schema_name, index_name)
)
"""
try:
self.session.execute(create_rows_cql)
logger.info(f"Ensured rows table exists: {safe_keyspace}.rows")
self.session.execute(create_partitions_cql)
logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions")
self.tables_initialized.add(keyspace)
except Exception as e:
logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True)
raise
def get_index_names(self, schema: RowSchema) -> List[str]:
"""
Get all index names for a schema.
Returns list of index_name strings (single field names or comma-joined composites).
"""
index_names = []
for field in schema.fields:
# Primary key fields are treated as indexes
if field.primary:
index_names.append(field.name)
# Indexed fields
elif field.indexed:
index_names.append(field.name)
# TODO: Support composite indexes in the future
# For now, each indexed field is a single-field index
return index_names
def register_partitions(
self, keyspace: str, collection: str, schema_name: str,
workspace: str,
):
"""
Register partition entries for a (collection, schema_name) pair.
Called once on first row for each pair.
"""
cache_key = (collection, schema_name)
if cache_key in self.registered_partitions:
return
ws_schemas = self.schemas.get(workspace, {})
schema = ws_schemas.get(schema_name)
if not schema:
logger.warning(
f"Cannot register partitions - schema {schema_name} "
f"not found in workspace {workspace}"
)
return
safe_keyspace = self.sanitize_name(keyspace)
index_names = self.get_index_names(schema)
# Insert partition entries for each index
insert_cql = f"""
INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name)
VALUES (%s, %s, %s)
"""
for index_name in index_names:
try:
self.session.execute(insert_cql, (collection, schema_name, index_name))
except Exception as e:
logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}")
self.registered_partitions.add(cache_key)
logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}")
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
"""
Build the index_value list for a given index.
For single-field indexes, returns a single-element list.
For composite indexes (comma-separated), returns multiple elements.
"""
field_names = [f.strip() for f in index_name.split(',')]
values = []
for field_name in field_names:
value = value_map.get(field_name)
# Convert to string for storage
values.append(str(value) if value is not None else "")
return values
async def on_object(self, msg, consumer, flow):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
workspace = flow.workspace
logger.info(
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
f"from {obj.metadata.id} (workspace {workspace})"
)
# Validate collection exists before accepting writes
if not self.collection_exists(workspace, obj.metadata.collection):
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first via collection management API."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Get schema definition for this workspace
ws_schemas = self.schemas.get(workspace, {})
schema = ws_schemas.get(obj.schema_name)
if not schema:
logger.warning(
f"No schema found for {obj.schema_name} in "
f"workspace {workspace} - skipping"
)
return
keyspace = workspace
collection = obj.metadata.collection
schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or ''
async with self._setup_lock:
await asyncio.to_thread(self.ensure_tables, keyspace)
await asyncio.to_thread(
self.register_partitions,
keyspace, collection, schema_name, workspace,
)
safe_keyspace = self.sanitize_name(keyspace)
# Get all index names for this schema
index_names = self.get_index_names(schema)
if not index_names:
logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable")
return
# Prepare insert statement
insert_cql = f"""
INSERT INTO {safe_keyspace}.rows
(collection, schema_name, index_name, index_value, data, source)
VALUES (%s, %s, %s, %s, %s, %s)
"""
# Process each row in the batch
rows_written = 0
for row_index, value_map in enumerate(obj.values):
# Convert all values to strings for the data map
data_map = {}
for field in schema.fields:
raw_value = value_map.get(field.name)
if raw_value is not None:
data_map[field.name] = str(raw_value)
# Write one copy per index
for index_name in index_names:
index_value = self.build_index_value(value_map, index_name)
# Skip if index value is empty/null
if not index_value or all(v == "" for v in index_value):
logger.debug(
f"Skipping index {index_name} for row {row_index} - "
f"empty index value"
)
continue
try:
await async_execute(
self.session,
insert_cql,
(collection, schema_name, index_name, index_value, data_map, source),
)
rows_written += 1
except Exception as e:
logger.error(
f"Failed to insert row {row_index} for index {index_name}: {e}",
exc_info=True
)
raise
logger.info(
f"Wrote {rows_written} index entries for {len(obj.values)} rows "
f"({len(index_names)} indexes per row)"
)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra)
await asyncio.to_thread(self.ensure_tables, workspace)
logger.info(f"Collection {collection} ready for workspace {workspace}")
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra)
if workspace not in self.known_keyspaces:
safe_ks = self.sanitize_name(workspace)
check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s"
result = await async_execute(self.session, check_cql, (safe_ks,))
if not result:
logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete")
return
self.known_keyspaces.add(workspace)
safe_keyspace = self.sanitize_name(workspace)
# Discover all partitions for this collection
select_partitions_cql = f"""
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
WHERE collection = %s
"""
try:
partition_list = await async_execute(
self.session, select_partitions_cql, (collection,)
)
except Exception as e:
logger.error(f"Failed to query partitions for collection {collection}: {e}")
raise
# Delete each partition from rows table
delete_rows_cql = f"""
DELETE FROM {safe_keyspace}.rows
WHERE collection = %s AND schema_name = %s AND index_name = %s
"""
partitions_deleted = 0
for partition in partition_list:
try:
await async_execute(
self.session,
delete_rows_cql,
(collection, partition.schema_name, partition.index_name),
)
partitions_deleted += 1
except Exception as e:
logger.error(
f"Failed to delete partition {collection}/{partition.schema_name}/"
f"{partition.index_name}: {e}"
)
raise
# Clean up row_partitions entries
delete_partitions_cql = f"""
DELETE FROM {safe_keyspace}.row_partitions
WHERE collection = %s
"""
try:
await async_execute(
self.session, delete_partitions_cql, (collection,)
)
except Exception as e:
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
raise
async with self._setup_lock:
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if col != collection
}
logger.info(
f"Deleted collection {collection}: {partitions_deleted} partitions "
f"from keyspace {safe_keyspace}"
)
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(workspace)
# Discover partitions for this collection + schema
select_partitions_cql = f"""
SELECT index_name FROM {safe_keyspace}.row_partitions
WHERE collection = %s AND schema_name = %s
"""
try:
partition_list = await async_execute(
self.session, select_partitions_cql, (collection, schema_name)
)
except Exception as e:
logger.error(
f"Failed to query partitions for {collection}/{schema_name}: {e}"
)
raise
# Delete each partition from rows table
delete_rows_cql = f"""
DELETE FROM {safe_keyspace}.rows
WHERE collection = %s AND schema_name = %s AND index_name = %s
"""
partitions_deleted = 0
for partition in partition_list:
try:
await async_execute(
self.session,
delete_rows_cql,
(collection, schema_name, partition.index_name),
)
partitions_deleted += 1
except Exception as e:
logger.error(
f"Failed to delete partition {collection}/{schema_name}/"
f"{partition.index_name}: {e}"
)
raise
# Clean up row_partitions entries for this schema
delete_partitions_cql = f"""
DELETE FROM {safe_keyspace}.row_partitions
WHERE collection = %s AND schema_name = %s
"""
try:
await async_execute(
self.session,
delete_partitions_cql,
(collection, schema_name),
)
except Exception as e:
logger.error(
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"
)
raise
async with self._setup_lock:
self.registered_partitions.discard((collection, schema_name))
logger.info(
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
f"from keyspace {safe_keyspace}"
)
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for rows-write-cassandra command"""
Processor.launch(default_ident, __doc__)