mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-12 00:05:13 +02:00
- Add centralised qdrant_config.py helper with env-var fallback for QDRANT_URL, QDRANT_API_KEY, QDRANT_REPLICATION_FACTOR, QDRANT_SHARD_NUMBER - Update all 6 Qdrant processors to use the helper; writers pass replication_factor and shard_number to create_collection - Fix hardcoded Cassandra replication_factor=1 in cassandra_kg.py, write.py, and sparql_cassandra.py to respect CASSANDRA_REPLICATION_FACTOR - Upgrade Cassandra TLS from deprecated PROTOCOL_TLSv1_2 to ssl.create_default_context() across all connectors
644 lines
22 KiB
Python
Executable file
644 lines
22 KiB
Python
Executable file
"""
|
|
Row writer for Cassandra. Input is ExtractedObject.
|
|
Writes structured rows to a unified Cassandra table with multi-index support.
|
|
|
|
Uses a single 'rows' table with the schema:
|
|
- collection: text
|
|
- schema_name: text
|
|
- index_name: text
|
|
- index_value: frozen<list<text>>
|
|
- data: map<text, text>
|
|
- source: text
|
|
|
|
Each row is written multiple times - once per indexed field defined in the schema.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
from typing import Dict, Set, Optional, Any, List, Tuple
|
|
|
|
from cassandra.cluster import Cluster
|
|
from cassandra.auth import PlainTextAuthProvider
|
|
|
|
from .... schema import ExtractedObject
|
|
from .... schema import RowSchema, Field
|
|
from .... base import FlowProcessor, ConsumerSpec
|
|
from .... base import CollectionConfigHandler
|
|
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
|
from .... tables.cassandra_async import async_execute
|
|
|
|
# Module logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
default_ident = "rows-write"
|
|
|
|
|
|
class Processor(CollectionConfigHandler, FlowProcessor):
|
|
|
|
def __init__(self, **params):
|
|
|
|
id = params.get("id", default_ident)
|
|
|
|
# Get Cassandra parameters
|
|
cassandra_host = params.get("cassandra_host")
|
|
cassandra_username = params.get("cassandra_username")
|
|
cassandra_password = params.get("cassandra_password")
|
|
|
|
# Resolve configuration with environment variable fallback
|
|
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
|
|
host=cassandra_host,
|
|
username=cassandra_username,
|
|
password=cassandra_password
|
|
)
|
|
|
|
# Store resolved configuration with proper names
|
|
self.cassandra_host = hosts # Store as list
|
|
self.cassandra_username = username
|
|
self.cassandra_password = password
|
|
self.replication_factor = replication_factor
|
|
|
|
# Config key for schemas
|
|
self.config_key = params.get("config_type", "schema")
|
|
|
|
super(Processor, self).__init__(
|
|
**params | {
|
|
"id": id,
|
|
"config_type": self.config_key,
|
|
}
|
|
)
|
|
|
|
self.register_specification(
|
|
ConsumerSpec(
|
|
name="input",
|
|
schema=ExtractedObject,
|
|
handler=self.on_object
|
|
)
|
|
)
|
|
|
|
# Register config handlers
|
|
self.register_config_handler(self.on_schema_config, types=["schema"])
|
|
self.register_config_handler(self.on_collection_config, types=["collection"])
|
|
|
|
# Cache of known keyspaces and whether tables exist
|
|
self.known_keyspaces: Set[str] = set()
|
|
self.tables_initialized: Set[str] = set()
|
|
|
|
# Cache of registered (collection, schema_name) pairs
|
|
self.registered_partitions: Set[Tuple[str, str]] = set()
|
|
|
|
# Schema storage: name -> RowSchema
|
|
self.schemas: Dict[str, RowSchema] = {}
|
|
|
|
# Cassandra session
|
|
self.cluster = None
|
|
self.session = None
|
|
|
|
# Protects connection setup and cache mutations
|
|
self._setup_lock = asyncio.Lock()
|
|
|
|
def connect_cassandra(self):
|
|
"""Connect to Cassandra cluster"""
|
|
if self.session:
|
|
return
|
|
|
|
try:
|
|
if self.cassandra_username and self.cassandra_password:
|
|
auth_provider = PlainTextAuthProvider(
|
|
username=self.cassandra_username,
|
|
password=self.cassandra_password
|
|
)
|
|
self.cluster = Cluster(
|
|
contact_points=self.cassandra_host,
|
|
auth_provider=auth_provider
|
|
)
|
|
else:
|
|
self.cluster = Cluster(contact_points=self.cassandra_host)
|
|
|
|
self.session = self.cluster.connect()
|
|
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def on_schema_config(self, workspace, config, version):
|
|
"""Handle schema configuration updates"""
|
|
logger.info(
|
|
f"Loading schema configuration version {version} "
|
|
f"for workspace {workspace}"
|
|
)
|
|
|
|
async with self._setup_lock:
|
|
return await self._apply_schema_config(workspace, config, version)
|
|
|
|
async def _apply_schema_config(self, workspace, config, version):
|
|
|
|
# Track which schemas changed in this workspace
|
|
old_schemas = self.schemas.get(workspace, {})
|
|
old_schema_names = set(old_schemas.keys())
|
|
|
|
# Replace existing schemas for this workspace
|
|
ws_schemas: Dict[str, RowSchema] = {}
|
|
self.schemas[workspace] = ws_schemas
|
|
|
|
# Check if our config type exists
|
|
if self.config_key not in config:
|
|
logger.warning(
|
|
f"No '{self.config_key}' type in configuration "
|
|
f"for {workspace}"
|
|
)
|
|
return
|
|
|
|
# Get the schemas dictionary for our type
|
|
schemas_config = config[self.config_key]
|
|
|
|
# Process each schema in the schemas config
|
|
for schema_name, schema_json in schemas_config.items():
|
|
try:
|
|
# Parse the JSON schema definition
|
|
schema_def = json.loads(schema_json)
|
|
|
|
# Create Field objects
|
|
fields = []
|
|
for field_def in schema_def.get("fields", []):
|
|
field = Field(
|
|
name=field_def["name"],
|
|
type=field_def["type"],
|
|
size=field_def.get("size", 0),
|
|
primary=field_def.get("primary_key", False),
|
|
description=field_def.get("description", ""),
|
|
required=field_def.get("required", False),
|
|
enum_values=field_def.get("enum", []),
|
|
indexed=field_def.get("indexed", False)
|
|
)
|
|
fields.append(field)
|
|
|
|
# Create RowSchema
|
|
row_schema = RowSchema(
|
|
name=schema_def.get("name", schema_name),
|
|
description=schema_def.get("description", ""),
|
|
fields=fields
|
|
)
|
|
|
|
ws_schemas[schema_name] = row_schema
|
|
logger.info(
|
|
f"Loaded schema: {schema_name} with "
|
|
f"{len(fields)} fields for {workspace}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
|
|
|
logger.info(
|
|
f"Schema configuration loaded for {workspace}: "
|
|
f"{len(ws_schemas)} schemas"
|
|
)
|
|
|
|
# Clear partition cache for schemas that changed in this workspace
|
|
new_schema_names = set(ws_schemas.keys())
|
|
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
|
|
if changed_schemas:
|
|
self.registered_partitions = {
|
|
(col, sch) for col, sch in self.registered_partitions
|
|
if sch not in changed_schemas
|
|
}
|
|
logger.info(
|
|
f"Cleared partition cache for changed schemas "
|
|
f"in {workspace}: {changed_schemas}"
|
|
)
|
|
|
|
def sanitize_name(self, name: str) -> str:
|
|
"""Sanitize names for Cassandra compatibility"""
|
|
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
|
# Ensure it starts with a letter
|
|
if safe_name and not safe_name[0].isalpha():
|
|
safe_name = 'r_' + safe_name
|
|
return safe_name.lower()
|
|
|
|
def ensure_keyspace(self, keyspace: str):
|
|
"""Ensure keyspace exists in Cassandra"""
|
|
if keyspace in self.known_keyspaces:
|
|
return
|
|
|
|
# Connect if needed
|
|
self.connect_cassandra()
|
|
|
|
# Sanitize keyspace name
|
|
safe_keyspace = self.sanitize_name(keyspace)
|
|
|
|
# Create keyspace if not exists
|
|
create_keyspace_cql = f"""
|
|
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
|
WITH REPLICATION = {{
|
|
'class': 'SimpleStrategy',
|
|
'replication_factor': {self.replication_factor}
|
|
}}
|
|
"""
|
|
|
|
try:
|
|
self.session.execute(create_keyspace_cql)
|
|
self.known_keyspaces.add(keyspace)
|
|
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
|
|
raise
|
|
|
|
def ensure_tables(self, keyspace: str):
|
|
"""Ensure unified rows and row_partitions tables exist"""
|
|
if keyspace in self.tables_initialized:
|
|
return
|
|
|
|
# Ensure keyspace exists first
|
|
self.ensure_keyspace(keyspace)
|
|
|
|
safe_keyspace = self.sanitize_name(keyspace)
|
|
|
|
# Create unified rows table
|
|
create_rows_cql = f"""
|
|
CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows (
|
|
collection text,
|
|
schema_name text,
|
|
index_name text,
|
|
index_value frozen<list<text>>,
|
|
data map<text, text>,
|
|
source text,
|
|
PRIMARY KEY ((collection, schema_name, index_name), index_value)
|
|
)
|
|
"""
|
|
|
|
# Create row_partitions tracking table
|
|
create_partitions_cql = f"""
|
|
CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions (
|
|
collection text,
|
|
schema_name text,
|
|
index_name text,
|
|
PRIMARY KEY ((collection), schema_name, index_name)
|
|
)
|
|
"""
|
|
|
|
try:
|
|
self.session.execute(create_rows_cql)
|
|
logger.info(f"Ensured rows table exists: {safe_keyspace}.rows")
|
|
|
|
self.session.execute(create_partitions_cql)
|
|
logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions")
|
|
|
|
self.tables_initialized.add(keyspace)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True)
|
|
raise
|
|
|
|
def get_index_names(self, schema: RowSchema) -> List[str]:
|
|
"""
|
|
Get all index names for a schema.
|
|
Returns list of index_name strings (single field names or comma-joined composites).
|
|
"""
|
|
index_names = []
|
|
|
|
for field in schema.fields:
|
|
# Primary key fields are treated as indexes
|
|
if field.primary:
|
|
index_names.append(field.name)
|
|
# Indexed fields
|
|
elif field.indexed:
|
|
index_names.append(field.name)
|
|
|
|
# TODO: Support composite indexes in the future
|
|
# For now, each indexed field is a single-field index
|
|
|
|
return index_names
|
|
|
|
def register_partitions(
|
|
self, keyspace: str, collection: str, schema_name: str,
|
|
workspace: str,
|
|
):
|
|
"""
|
|
Register partition entries for a (collection, schema_name) pair.
|
|
Called once on first row for each pair.
|
|
"""
|
|
cache_key = (collection, schema_name)
|
|
if cache_key in self.registered_partitions:
|
|
return
|
|
|
|
ws_schemas = self.schemas.get(workspace, {})
|
|
schema = ws_schemas.get(schema_name)
|
|
if not schema:
|
|
logger.warning(
|
|
f"Cannot register partitions - schema {schema_name} "
|
|
f"not found in workspace {workspace}"
|
|
)
|
|
return
|
|
|
|
safe_keyspace = self.sanitize_name(keyspace)
|
|
index_names = self.get_index_names(schema)
|
|
|
|
# Insert partition entries for each index
|
|
insert_cql = f"""
|
|
INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name)
|
|
VALUES (%s, %s, %s)
|
|
"""
|
|
|
|
for index_name in index_names:
|
|
try:
|
|
self.session.execute(insert_cql, (collection, schema_name, index_name))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}")
|
|
|
|
self.registered_partitions.add(cache_key)
|
|
logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}")
|
|
|
|
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
|
|
"""
|
|
Build the index_value list for a given index.
|
|
For single-field indexes, returns a single-element list.
|
|
For composite indexes (comma-separated), returns multiple elements.
|
|
"""
|
|
field_names = [f.strip() for f in index_name.split(',')]
|
|
values = []
|
|
|
|
for field_name in field_names:
|
|
value = value_map.get(field_name)
|
|
# Convert to string for storage
|
|
values.append(str(value) if value is not None else "")
|
|
|
|
return values
|
|
|
|
async def on_object(self, msg, consumer, flow):
|
|
"""Process incoming ExtractedObject and store in Cassandra"""
|
|
|
|
obj = msg.value()
|
|
workspace = flow.workspace
|
|
logger.info(
|
|
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
|
|
f"from {obj.metadata.id} (workspace {workspace})"
|
|
)
|
|
|
|
# Validate collection exists before accepting writes
|
|
if not self.collection_exists(workspace, obj.metadata.collection):
|
|
error_msg = (
|
|
f"Collection {obj.metadata.collection} does not exist. "
|
|
f"Create it first via collection management API."
|
|
)
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
# Get schema definition for this workspace
|
|
ws_schemas = self.schemas.get(workspace, {})
|
|
schema = ws_schemas.get(obj.schema_name)
|
|
if not schema:
|
|
logger.warning(
|
|
f"No schema found for {obj.schema_name} in "
|
|
f"workspace {workspace} - skipping"
|
|
)
|
|
return
|
|
|
|
keyspace = workspace
|
|
collection = obj.metadata.collection
|
|
schema_name = obj.schema_name
|
|
source = getattr(obj.metadata, 'source', '') or ''
|
|
|
|
async with self._setup_lock:
|
|
await asyncio.to_thread(self.ensure_tables, keyspace)
|
|
await asyncio.to_thread(
|
|
self.register_partitions,
|
|
keyspace, collection, schema_name, workspace,
|
|
)
|
|
|
|
safe_keyspace = self.sanitize_name(keyspace)
|
|
|
|
# Get all index names for this schema
|
|
index_names = self.get_index_names(schema)
|
|
|
|
if not index_names:
|
|
logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable")
|
|
return
|
|
|
|
# Prepare insert statement
|
|
insert_cql = f"""
|
|
INSERT INTO {safe_keyspace}.rows
|
|
(collection, schema_name, index_name, index_value, data, source)
|
|
VALUES (%s, %s, %s, %s, %s, %s)
|
|
"""
|
|
|
|
# Process each row in the batch
|
|
rows_written = 0
|
|
for row_index, value_map in enumerate(obj.values):
|
|
# Convert all values to strings for the data map
|
|
data_map = {}
|
|
for field in schema.fields:
|
|
raw_value = value_map.get(field.name)
|
|
if raw_value is not None:
|
|
data_map[field.name] = str(raw_value)
|
|
|
|
# Write one copy per index
|
|
for index_name in index_names:
|
|
index_value = self.build_index_value(value_map, index_name)
|
|
|
|
# Skip if index value is empty/null
|
|
if not index_value or all(v == "" for v in index_value):
|
|
logger.debug(
|
|
f"Skipping index {index_name} for row {row_index} - "
|
|
f"empty index value"
|
|
)
|
|
continue
|
|
|
|
try:
|
|
await async_execute(
|
|
self.session,
|
|
insert_cql,
|
|
(collection, schema_name, index_name, index_value, data_map, source),
|
|
)
|
|
rows_written += 1
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to insert row {row_index} for index {index_name}: {e}",
|
|
exc_info=True
|
|
)
|
|
raise
|
|
|
|
logger.info(
|
|
f"Wrote {rows_written} index entries for {len(obj.values)} rows "
|
|
f"({len(index_names)} indexes per row)"
|
|
)
|
|
|
|
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
|
"""Create/verify collection exists in Cassandra row store"""
|
|
async with self._setup_lock:
|
|
await asyncio.to_thread(self.connect_cassandra)
|
|
await asyncio.to_thread(self.ensure_tables, workspace)
|
|
|
|
logger.info(f"Collection {collection} ready for workspace {workspace}")
|
|
|
|
async def delete_collection(self, workspace: str, collection: str):
|
|
"""Delete all data for a specific collection using partition tracking"""
|
|
async with self._setup_lock:
|
|
await asyncio.to_thread(self.connect_cassandra)
|
|
if workspace not in self.known_keyspaces:
|
|
safe_ks = self.sanitize_name(workspace)
|
|
check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s"
|
|
result = await async_execute(self.session, check_cql, (safe_ks,))
|
|
if not result:
|
|
logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete")
|
|
return
|
|
self.known_keyspaces.add(workspace)
|
|
|
|
safe_keyspace = self.sanitize_name(workspace)
|
|
|
|
# Discover all partitions for this collection
|
|
select_partitions_cql = f"""
|
|
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
|
|
WHERE collection = %s
|
|
"""
|
|
|
|
try:
|
|
partition_list = await async_execute(
|
|
self.session, select_partitions_cql, (collection,)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to query partitions for collection {collection}: {e}")
|
|
raise
|
|
|
|
# Delete each partition from rows table
|
|
delete_rows_cql = f"""
|
|
DELETE FROM {safe_keyspace}.rows
|
|
WHERE collection = %s AND schema_name = %s AND index_name = %s
|
|
"""
|
|
|
|
partitions_deleted = 0
|
|
for partition in partition_list:
|
|
try:
|
|
await async_execute(
|
|
self.session,
|
|
delete_rows_cql,
|
|
(collection, partition.schema_name, partition.index_name),
|
|
)
|
|
partitions_deleted += 1
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to delete partition {collection}/{partition.schema_name}/"
|
|
f"{partition.index_name}: {e}"
|
|
)
|
|
raise
|
|
|
|
# Clean up row_partitions entries
|
|
delete_partitions_cql = f"""
|
|
DELETE FROM {safe_keyspace}.row_partitions
|
|
WHERE collection = %s
|
|
"""
|
|
|
|
try:
|
|
await async_execute(
|
|
self.session, delete_partitions_cql, (collection,)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
|
|
raise
|
|
|
|
async with self._setup_lock:
|
|
self.registered_partitions = {
|
|
(col, sch) for col, sch in self.registered_partitions
|
|
if col != collection
|
|
}
|
|
|
|
logger.info(
|
|
f"Deleted collection {collection}: {partitions_deleted} partitions "
|
|
f"from keyspace {safe_keyspace}"
|
|
)
|
|
|
|
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
|
|
"""Delete all data for a specific collection + schema combination"""
|
|
async with self._setup_lock:
|
|
await asyncio.to_thread(self.connect_cassandra)
|
|
|
|
safe_keyspace = self.sanitize_name(workspace)
|
|
|
|
# Discover partitions for this collection + schema
|
|
select_partitions_cql = f"""
|
|
SELECT index_name FROM {safe_keyspace}.row_partitions
|
|
WHERE collection = %s AND schema_name = %s
|
|
"""
|
|
|
|
try:
|
|
partition_list = await async_execute(
|
|
self.session, select_partitions_cql, (collection, schema_name)
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to query partitions for {collection}/{schema_name}: {e}"
|
|
)
|
|
raise
|
|
|
|
# Delete each partition from rows table
|
|
delete_rows_cql = f"""
|
|
DELETE FROM {safe_keyspace}.rows
|
|
WHERE collection = %s AND schema_name = %s AND index_name = %s
|
|
"""
|
|
|
|
partitions_deleted = 0
|
|
for partition in partition_list:
|
|
try:
|
|
await async_execute(
|
|
self.session,
|
|
delete_rows_cql,
|
|
(collection, schema_name, partition.index_name),
|
|
)
|
|
partitions_deleted += 1
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to delete partition {collection}/{schema_name}/"
|
|
f"{partition.index_name}: {e}"
|
|
)
|
|
raise
|
|
|
|
# Clean up row_partitions entries for this schema
|
|
delete_partitions_cql = f"""
|
|
DELETE FROM {safe_keyspace}.row_partitions
|
|
WHERE collection = %s AND schema_name = %s
|
|
"""
|
|
|
|
try:
|
|
await async_execute(
|
|
self.session,
|
|
delete_partitions_cql,
|
|
(collection, schema_name),
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"
|
|
)
|
|
raise
|
|
|
|
async with self._setup_lock:
|
|
self.registered_partitions.discard((collection, schema_name))
|
|
|
|
logger.info(
|
|
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
|
|
f"from keyspace {safe_keyspace}"
|
|
)
|
|
|
|
def close(self):
|
|
"""Clean up Cassandra connections"""
|
|
if self.cluster:
|
|
self.cluster.shutdown()
|
|
logger.info("Closed Cassandra connection")
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
"""Add command-line arguments"""
|
|
|
|
FlowProcessor.add_args(parser)
|
|
add_cassandra_args(parser)
|
|
|
|
parser.add_argument(
|
|
'--config-type',
|
|
default='schema',
|
|
help='Configuration type prefix for schemas (default: schema)'
|
|
)
|
|
|
|
|
|
def run():
|
|
"""Entry point for rows-write-cassandra command"""
|
|
Processor.launch(default_ident, __doc__)
|