fix: structured data query and auth fixes

- Pass auth token to schema discovery and descriptor generation in
  tg-load-structured-data, fixing 401 errors with IAM enabled
- Fix row query pagination: replace single-page async_execute with
  async_scan that streams pages and applies filters without
  materialising the full result set (OOM on large datasets)
- Add missing filter operators (not, startsWith, endsWith, not_in)
  to row query post-filter matching
- Fall back to scan path when an indexed field is queried with an
  empty string value, since empty index values are not stored
- Revert top-level indexes array support — the current table schema
  overwrites rows with duplicate index values, so only primary_key
  fields are safe to index until the schema is redesigned
This commit is contained in:
Cyber MacGeddon 2026-06-08 13:49:09 +01:00
parent 08bfec1539
commit 5bb00f34b9
4 changed files with 93 additions and 31 deletions

View file

@ -78,7 +78,7 @@ def load_structured_data(
logger.info("Step 1: Analyzing data to discover best matching schema...")
# Step 1: Auto-discover schema (reuse discover_schema logic)
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not discovered_schema:
logger.error("Failed to discover suitable schema automatically")
print("❌ Could not automatically determine the best schema for your data.")
@ -90,7 +90,7 @@ def load_structured_data(
# Step 2: Auto-generate descriptor
logger.info("Step 2: Generating descriptor configuration...")
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
if not auto_descriptor:
logger.error("Failed to generate descriptor automatically")
print("❌ Could not automatically generate descriptor configuration.")
@ -172,7 +172,7 @@ def load_structured_data(
logger.info(f"Sample chars: {sample_chars} characters")
# Use the helper function to discover schema (get raw response for display)
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)
if response:
# Debug: print response type and content
@ -203,7 +203,7 @@ def load_structured_data(
# If no schema specified, discover it first
if not schema_name:
logger.info("No schema specified, auto-discovering...")
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not schema_name:
print("Error: Could not determine schema automatically.")
print("Please specify a schema using --schema-name or run --discover-schema first.")
@ -213,7 +213,7 @@ def load_structured_data(
logger.info(f"Target schema: {schema_name}")
# Generate descriptor using helper function
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)
if descriptor:
# Output the generated descriptor
@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
# Helper functions for auto mode
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
"""Auto-discover the best matching schema for the input data
Args:
@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
api = Api(api_url, workspace=workspace)
api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()
# Get available schemas
@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
return None
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
"""Auto-generate descriptor configuration for the discovered schema"""
try:
# Read sample data
@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
api = Api(api_url, workspace=workspace)
api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()
# Get schema definition

View file

@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
from ... graphql import GraphQLSchemaBuilder, SortDirection
@ -180,7 +180,7 @@ class Processor(FlowProcessor):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
indexed=field_def.get("indexed", False),
)
fields.append(field)
@ -232,6 +232,8 @@ class Processor(FlowProcessor):
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
if value == "" or value is None:
continue
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
@ -282,11 +284,13 @@ class Processor(FlowProcessor):
query += f" LIMIT {limit}"
try:
rows = await async_execute(self.session, query, params)
for row in rows:
# Convert data map to dict with proper field names
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
pages = await async_execute_paged(
self.session, query, params
)
for page in pages:
for row in page:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True)
raise
@ -308,8 +312,6 @@ class Processor(FlowProcessor):
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]
# We need to scan all values for this index
# This requires ALLOW FILTERING or a different approach
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
@ -320,17 +322,18 @@ class Processor(FlowProcessor):
params = [collection, schema_name, primary_index]
try:
rows = await async_execute(self.session, query, params)
for row in rows:
def row_filter(row):
row_dict = dict(row.data) if row.data else {}
return self._matches_filters(row_dict, filters, row_schema)
# Apply post-filters
if self._matches_filters(row_dict, filters, row_schema):
results.append(row_dict)
if limit and len(results) >= limit:
break
matched_rows = await async_scan(
self.session, query, params,
row_filter=row_filter,
limit=limit,
)
for row in matched_rows:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
@ -363,7 +366,7 @@ class Processor(FlowProcessor):
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
field_name = parts[0]
operator = parts[1]
else:
@ -400,6 +403,18 @@ class Processor(FlowProcessor):
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
elif operator == 'not':
if str(row_value) == str(filter_value):
return False
elif operator == 'startsWith':
if not str(row_value).startswith(str(filter_value)):
return False
elif operator == 'endsWith':
if not str(row_value).endswith(str(filter_value)):
return False
elif operator == 'not_in':
if str(row_value) in [str(v) for v in filter_value]:
return False
except (ValueError, TypeError):
return False

View file

@ -172,7 +172,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
indexed=field_def.get("indexed", False),
)
fields.append(field)

View file

@ -80,14 +80,14 @@ def _set_exception_if_pending(fut, exc):
fut.set_exception(exc)
async def async_execute_paged(session, query, parameters=None, fetch_size=100):
async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
"""Execute a CQL query with page-by-page iteration.
Uses synchronous session.execute() inside run_in_executor so that
the driver's ResultSet paging works correctly without materialising
the entire result set in memory.
Yields one page of rows at a time (as a list).
Returns all pages as a list of lists.
"""
loop = asyncio.get_running_loop()
@ -111,3 +111,50 @@ async def async_execute_paged(session, query, parameters=None, fetch_size=100):
return await loop.run_in_executor(
None, _fetch_all_pages
)
async def async_scan(
session, query, parameters=None, row_filter=None,
limit=None, fetch_size=5000,
):
"""Scan a CQL query page-by-page, applying a filter and limit.
Only matching rows accumulate in memory. Each page is discarded
after processing, so peak memory is bounded by fetch_size plus
the number of matching rows (capped by limit).
Args:
session: cassandra.cluster.Session
query: CQL statement string
parameters: bind params
row_filter: callable(row) -> bool, or None to accept all
limit: max results to return, or None for unlimited
fetch_size: rows per Cassandra page fetch
Returns:
List of matching rows.
"""
loop = asyncio.get_running_loop()
if isinstance(query, str):
stmt = SimpleStatement(query, fetch_size=fetch_size)
else:
stmt = query
stmt.fetch_size = fetch_size
def _scan():
results = []
result_set = session.execute(stmt, parameters)
while True:
for row in result_set.current_rows:
if row_filter is None or row_filter(row):
results.append(row)
if limit and len(results) >= limit:
return results
if result_set.has_more_pages:
result_set.fetch_next_page()
else:
break
return results
return await loop.run_in_executor(None, _scan)