trustgraph/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
Cyber MacGeddon 5bb00f34b9 fix: structured data query and auth fixes
- Pass auth token to schema discovery and descriptor generation in
  tg-load-structured-data, fixing 401 errors with IAM enabled
- Fix row query pagination: replace single-page async_execute with
  async_scan that streams pages and applies filters without
  materialising the full result set (OOM on large datasets)
- Add missing filter operators (not, startsWith, endsWith, not_in)
  to row query post-filter matching
- Fall back to scan path when an indexed field is queried with an
  empty string value, since empty index values are not stored
- Revert top-level indexes array support — the current table schema
  overwrites rows with duplicate index values, so only primary_key
  fields are safe to index until the schema is redesigned
2026-06-08 15:21:12 +01:00

571 lines
19 KiB
Python

"""
Row query service using GraphQL. Input is a GraphQL query with variables.
Output is GraphQL response data with any errors.
Queries against the unified 'rows' table with schema:
- collection: text
- schema_name: text
- index_name: text
- index_value: frozen<list<text>>
- data: map<text, text>
- source: text
"""
import asyncio
import json
import logging
import re
from typing import Dict, Any, Optional, List, Set
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
from ... graphql import GraphQLSchemaBuilder, SortDirection
# Module logger
logger = logging.getLogger(__name__)
default_ident = "rows-query"
default_concurrency = 10
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace, _ = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name="request",
schema=RowsQueryRequest,
handler=self.on_message,
concurrency=concurrency,
)
)
self.register_specification(
ProducerSpec(
name="response",
schema=RowsQueryResponse,
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config, types=["schema"])
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
# Per-workspace GraphQL schema builders and compiled schemas
self.schema_builders: Dict[str, GraphQLSchemaBuilder] = {}
self.graphql_schemas: Dict[str, Any] = {}
# Cassandra session
self.cluster = None
self.session = None
self._setup_lock = asyncio.Lock()
# Known keyspaces
self.known_keyspaces: Set[str] = set()
async def connect_cassandra(self):
"""Connect to Cassandra cluster"""
async with self._setup_lock:
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
cluster = Cluster(contact_points=self.cassandra_host)
session = await asyncio.to_thread(cluster.connect)
self.cluster = cluster
self.session = session
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
async with self._setup_lock:
await self._apply_schema_config(workspace, config)
async def _apply_schema_config(self, workspace, config):
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
builder = GraphQLSchemaBuilder()
self.schema_builders[workspace] = builder
if self.config_key not in config:
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
self.graphql_schemas[workspace] = None
return
schemas_config = config[self.config_key]
for schema_name, schema_json in schemas_config.items():
try:
schema_def = json.loads(schema_json)
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False),
)
fields.append(field)
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
ws_schemas[schema_name] = row_schema
builder.add_schema(schema_name, row_schema)
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
index_names = []
for field in schema.fields:
if field.primary or field.indexed:
index_names.append(field.name)
return index_names
def find_matching_index(
self,
schema: RowSchema,
filters: Dict[str, Any]
) -> Optional[tuple]:
"""
Find an index that can satisfy the query filters.
Returns (index_name, index_value) if found, None otherwise.
For exact match queries, we need a filter on an indexed field.
"""
index_names = self.get_index_names(schema)
# Look for an exact match filter on an indexed field
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
if value == "" or value is None:
continue
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
return None
async def query_cassandra(
self,
workspace: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
filters: Dict[str, Any],
limit: int,
order_by: Optional[str] = None,
direction: Optional[SortDirection] = None
) -> List[Dict[str, Any]]:
"""
Execute a query against the unified Cassandra rows table.
For exact match queries on indexed fields, we can query directly.
For other queries, we need to scan and post-filter.
"""
# Connect if needed
await self.connect_cassandra()
safe_keyspace = self.sanitize_name(workspace)
# Try to find an index that matches the filters
index_match = self.find_matching_index(row_schema, filters)
results = []
if index_match:
# Direct query using index
index_name, index_value = index_match
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
AND schema_name = %s
AND index_name = %s
AND index_value = %s
"""
params = [collection, schema_name, index_name, index_value]
if limit:
query += f" LIMIT {limit}"
try:
pages = await async_execute_paged(
self.session, query, params
)
for page in pages:
for row in page:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True)
raise
else:
# No direct index match - scan all rows for this schema
# This is less efficient but necessary for non-indexed queries
logger.warning(
f"No index match for filters {filters} - scanning all indexes"
)
# Get all index names for this schema
index_names = self.get_index_names(row_schema)
if not index_names:
logger.warning(f"Schema {schema_name} has no indexes")
return []
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
AND schema_name = %s
AND index_name = %s
ALLOW FILTERING
"""
params = [collection, schema_name, primary_index]
try:
def row_filter(row):
row_dict = dict(row.data) if row.data else {}
return self._matches_filters(row_dict, filters, row_schema)
matched_rows = await async_scan(
self.session, query, params,
row_filter=row_filter,
limit=limit,
)
for row in matched_rows:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
raise
# Post-query sorting if requested
if order_by and results:
reverse_order = direction and direction.value == "desc"
try:
results.sort(
key=lambda x: x.get(order_by, ""),
reverse=reverse_order
)
except Exception as e:
logger.warning(f"Failed to sort results by {order_by}: {e}")
return results
def _matches_filters(
self,
row_dict: Dict[str, Any],
filters: Dict[str, Any],
row_schema: RowSchema
) -> bool:
"""Check if a row matches the given filters."""
for filter_key, filter_value in filters.items():
if filter_value is None:
continue
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
field_name = parts[0]
operator = parts[1]
else:
field_name = filter_key
operator = 'eq'
else:
field_name = filter_key
operator = 'eq'
row_value = row_dict.get(field_name)
if row_value is None:
return False
# Convert types for comparison
try:
if operator == 'eq':
if str(row_value) != str(filter_value):
return False
elif operator == 'gt':
if float(row_value) <= float(filter_value):
return False
elif operator == 'gte':
if float(row_value) < float(filter_value):
return False
elif operator == 'lt':
if float(row_value) >= float(filter_value):
return False
elif operator == 'lte':
if float(row_value) > float(filter_value):
return False
elif operator == 'contains':
if str(filter_value) not in str(row_value):
return False
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
elif operator == 'not':
if str(row_value) == str(filter_value):
return False
elif operator == 'startsWith':
if not str(row_value).startswith(str(filter_value)):
return False
elif operator == 'endsWith':
if not str(row_value).endswith(str(filter_value)):
return False
elif operator == 'not_in':
if str(row_value) in [str(v) for v in filter_value]:
return False
except (ValueError, TypeError):
return False
return True
async def execute_graphql_query(
self,
workspace: str,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query against the workspace's schema"""
graphql_schema = self.graphql_schemas.get(workspace)
if not graphql_schema:
raise RuntimeError(
f"No GraphQL schema available for workspace {workspace} "
f"- no schemas loaded"
)
# Create context for the query
context = {
"processor": self,
"workspace": workspace,
"collection": collection
}
# Execute the query
result = await graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
context_value=context
)
# Build response
response = {}
if result.data:
response["data"] = result.data
else:
response["data"] = None
if result.errors:
response["errors"] = [
{
"message": str(error),
"path": getattr(error, "path", []),
"extensions": getattr(error, "extensions", {})
}
for error in result.errors
]
else:
response["errors"] = []
# Add extensions if any
if hasattr(result, "extensions") and result.extensions:
response["extensions"] = result.extensions
return response
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(f"Handling objects query request {id}...")
# Execute GraphQL query
result = await self.execute_graphql_query(
workspace=flow.workspace,
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
collection=request.collection
)
# Create response
graphql_errors = []
if "errors" in result and result["errors"]:
for err in result["errors"]:
graphql_error = GraphQLError(
message=err.get("message", ""),
path=err.get("path", []),
extensions=err.get("extensions", {})
)
graphql_errors.append(graphql_error)
response = RowsQueryResponse(
error=None,
data=json.dumps(result.get("data")) if result.get("data") else "null",
errors=graphql_errors,
extensions=result.get("extensions", {})
)
logger.debug("Sending objects query response...")
await flow("response").send(response, properties={"id": id})
logger.debug("Objects query request completed")
except Exception as e:
logger.error(f"Exception in rows query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = RowsQueryResponse(
error=Error(
type="rows-query-error",
message=str(e),
),
data=None,
errors=[],
extensions={}
)
await flow("response").send(response, properties={"id": id})
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Number of concurrent requests (default: {default_concurrency})'
)
def run():
"""Entry point for rows-query-cassandra command"""
Processor.launch(default_ident, __doc__)