mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-10 15:25:14 +02:00
fix: structured data query and auth fixes
- Pass auth token to schema discovery and descriptor generation in tg-load-structured-data, fixing 401 errors with IAM enabled - Fix row query pagination: replace single-page async_execute with async_scan that streams pages and applies filters without materialising the full result set (OOM on large datasets) - Add missing filter operators (not, startsWith, endsWith, not_in) to row query post-filter matching - Fall back to scan path when an indexed field is queried with an empty string value, since empty index values are not stored - Revert top-level indexes array support — the current table schema overwrites rows with duplicate index values, so only primary_key fields are safe to index until the schema is redesigned
This commit is contained in:
parent
08bfec1539
commit
5bb00f34b9
4 changed files with 93 additions and 31 deletions
|
|
@ -78,7 +78,7 @@ def load_structured_data(
|
|||
logger.info("Step 1: Analyzing data to discover best matching schema...")
|
||||
|
||||
# Step 1: Auto-discover schema (reuse discover_schema logic)
|
||||
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
|
||||
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||
if not discovered_schema:
|
||||
logger.error("Failed to discover suitable schema automatically")
|
||||
print("❌ Could not automatically determine the best schema for your data.")
|
||||
|
|
@ -90,7 +90,7 @@ def load_structured_data(
|
|||
|
||||
# Step 2: Auto-generate descriptor
|
||||
logger.info("Step 2: Generating descriptor configuration...")
|
||||
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
|
||||
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||
if not auto_descriptor:
|
||||
logger.error("Failed to generate descriptor automatically")
|
||||
print("❌ Could not automatically generate descriptor configuration.")
|
||||
|
|
@ -172,7 +172,7 @@ def load_structured_data(
|
|||
logger.info(f"Sample chars: {sample_chars} characters")
|
||||
|
||||
# Use the helper function to discover schema (get raw response for display)
|
||||
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
|
||||
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)
|
||||
|
||||
if response:
|
||||
# Debug: print response type and content
|
||||
|
|
@ -203,7 +203,7 @@ def load_structured_data(
|
|||
# If no schema specified, discover it first
|
||||
if not schema_name:
|
||||
logger.info("No schema specified, auto-discovering...")
|
||||
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
|
||||
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||
if not schema_name:
|
||||
print("Error: Could not determine schema automatically.")
|
||||
print("Please specify a schema using --schema-name or run --discover-schema first.")
|
||||
|
|
@ -213,7 +213,7 @@ def load_structured_data(
|
|||
logger.info(f"Target schema: {schema_name}")
|
||||
|
||||
# Generate descriptor using helper function
|
||||
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
|
||||
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||
|
||||
if descriptor:
|
||||
# Output the generated descriptor
|
||||
|
|
@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
|
|||
|
||||
|
||||
# Helper functions for auto mode
|
||||
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
|
||||
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
|
||||
"""Auto-discover the best matching schema for the input data
|
||||
|
||||
Args:
|
||||
|
|
@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
|
|||
# Import API modules
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api.types import ConfigKey
|
||||
api = Api(api_url, workspace=workspace)
|
||||
api = Api(api_url, token=token, workspace=workspace)
|
||||
config_api = api.config()
|
||||
|
||||
# Get available schemas
|
||||
|
|
@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
|
|||
return None
|
||||
|
||||
|
||||
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
|
||||
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
|
||||
"""Auto-generate descriptor configuration for the discovered schema"""
|
||||
try:
|
||||
# Read sample data
|
||||
|
|
@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
|
|||
# Import API modules
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api.types import ConfigKey
|
||||
api = Api(api_url, workspace=workspace)
|
||||
api = Api(api_url, token=token, workspace=workspace)
|
||||
config_api = api.config()
|
||||
|
||||
# Get schema definition
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
|||
from .... schema import Error, RowSchema, Field as SchemaField
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
from .... tables.cassandra_async import async_execute
|
||||
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
|
||||
|
||||
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
||||
|
||||
|
|
@ -180,7 +180,7 @@ class Processor(FlowProcessor):
|
|||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
indexed=field_def.get("indexed", False),
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
|
|
@ -232,6 +232,8 @@ class Processor(FlowProcessor):
|
|||
for index_name in index_names:
|
||||
if index_name in filters:
|
||||
value = filters[index_name]
|
||||
if value == "" or value is None:
|
||||
continue
|
||||
# Single field index -> single element list
|
||||
index_value = [str(value)]
|
||||
return (index_name, index_value)
|
||||
|
|
@ -282,11 +284,13 @@ class Processor(FlowProcessor):
|
|||
query += f" LIMIT {limit}"
|
||||
|
||||
try:
|
||||
rows = await async_execute(self.session, query, params)
|
||||
for row in rows:
|
||||
# Convert data map to dict with proper field names
|
||||
row_dict = dict(row.data) if row.data else {}
|
||||
results.append(row_dict)
|
||||
pages = await async_execute_paged(
|
||||
self.session, query, params
|
||||
)
|
||||
for page in pages:
|
||||
for row in page:
|
||||
row_dict = dict(row.data) if row.data else {}
|
||||
results.append(row_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query rows: {e}", exc_info=True)
|
||||
raise
|
||||
|
|
@ -308,8 +312,6 @@ class Processor(FlowProcessor):
|
|||
# Query using the first index (arbitrary choice for scan)
|
||||
primary_index = index_names[0]
|
||||
|
||||
# We need to scan all values for this index
|
||||
# This requires ALLOW FILTERING or a different approach
|
||||
query = f"""
|
||||
SELECT data, source FROM {safe_keyspace}.rows
|
||||
WHERE collection = %s
|
||||
|
|
@ -320,17 +322,18 @@ class Processor(FlowProcessor):
|
|||
params = [collection, schema_name, primary_index]
|
||||
|
||||
try:
|
||||
rows = await async_execute(self.session, query, params)
|
||||
|
||||
for row in rows:
|
||||
def row_filter(row):
|
||||
row_dict = dict(row.data) if row.data else {}
|
||||
return self._matches_filters(row_dict, filters, row_schema)
|
||||
|
||||
# Apply post-filters
|
||||
if self._matches_filters(row_dict, filters, row_schema):
|
||||
results.append(row_dict)
|
||||
|
||||
if limit and len(results) >= limit:
|
||||
break
|
||||
matched_rows = await async_scan(
|
||||
self.session, query, params,
|
||||
row_filter=row_filter,
|
||||
limit=limit,
|
||||
)
|
||||
for row in matched_rows:
|
||||
row_dict = dict(row.data) if row.data else {}
|
||||
results.append(row_dict)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
||||
|
|
@ -363,7 +366,7 @@ class Processor(FlowProcessor):
|
|||
# Parse filter key for operator
|
||||
if '_' in filter_key:
|
||||
parts = filter_key.rsplit('_', 1)
|
||||
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
|
||||
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
|
||||
field_name = parts[0]
|
||||
operator = parts[1]
|
||||
else:
|
||||
|
|
@ -400,6 +403,18 @@ class Processor(FlowProcessor):
|
|||
elif operator == 'in':
|
||||
if str(row_value) not in [str(v) for v in filter_value]:
|
||||
return False
|
||||
elif operator == 'not':
|
||||
if str(row_value) == str(filter_value):
|
||||
return False
|
||||
elif operator == 'startsWith':
|
||||
if not str(row_value).startswith(str(filter_value)):
|
||||
return False
|
||||
elif operator == 'endsWith':
|
||||
if not str(row_value).endswith(str(filter_value)):
|
||||
return False
|
||||
elif operator == 'not_in':
|
||||
if str(row_value) in [str(v) for v in filter_value]:
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
indexed=field_def.get("indexed", False),
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
|
|
|
|||
|
|
@ -80,14 +80,14 @@ def _set_exception_if_pending(fut, exc):
|
|||
fut.set_exception(exc)
|
||||
|
||||
|
||||
async def async_execute_paged(session, query, parameters=None, fetch_size=100):
|
||||
async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
|
||||
"""Execute a CQL query with page-by-page iteration.
|
||||
|
||||
Uses synchronous session.execute() inside run_in_executor so that
|
||||
the driver's ResultSet paging works correctly without materialising
|
||||
the entire result set in memory.
|
||||
|
||||
Yields one page of rows at a time (as a list).
|
||||
Returns all pages as a list of lists.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
|
|
@ -111,3 +111,50 @@ async def async_execute_paged(session, query, parameters=None, fetch_size=100):
|
|||
return await loop.run_in_executor(
|
||||
None, _fetch_all_pages
|
||||
)
|
||||
|
||||
|
||||
async def async_scan(
|
||||
session, query, parameters=None, row_filter=None,
|
||||
limit=None, fetch_size=5000,
|
||||
):
|
||||
"""Scan a CQL query page-by-page, applying a filter and limit.
|
||||
|
||||
Only matching rows accumulate in memory. Each page is discarded
|
||||
after processing, so peak memory is bounded by fetch_size plus
|
||||
the number of matching rows (capped by limit).
|
||||
|
||||
Args:
|
||||
session: cassandra.cluster.Session
|
||||
query: CQL statement string
|
||||
parameters: bind params
|
||||
row_filter: callable(row) -> bool, or None to accept all
|
||||
limit: max results to return, or None for unlimited
|
||||
fetch_size: rows per Cassandra page fetch
|
||||
|
||||
Returns:
|
||||
List of matching rows.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if isinstance(query, str):
|
||||
stmt = SimpleStatement(query, fetch_size=fetch_size)
|
||||
else:
|
||||
stmt = query
|
||||
stmt.fetch_size = fetch_size
|
||||
|
||||
def _scan():
|
||||
results = []
|
||||
result_set = session.execute(stmt, parameters)
|
||||
while True:
|
||||
for row in result_set.current_rows:
|
||||
if row_filter is None or row_filter(row):
|
||||
results.append(row)
|
||||
if limit and len(results) >= limit:
|
||||
return results
|
||||
if result_set.has_more_pages:
|
||||
result_set.fetch_next_page()
|
||||
else:
|
||||
break
|
||||
return results
|
||||
|
||||
return await loop.run_in_executor(None, _scan)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue