mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-28 07:59:37 +02:00
fix: structured data query and auth fixes (#978)
- Pass auth token to schema discovery and descriptor generation in tg-load-structured-data, fixing 401 errors with IAM enabled - Fix row query pagination: replace single-page async_execute with async_scan that streams pages and applies filters without materialising the full result set (OOM on large datasets) - Add missing filter operators (not, startsWith, endsWith, not_in) to row query post-filter matching - Fall back to scan path when an indexed field is queried with an empty string value, since empty index values are not stored - Revert top-level indexes array support — the current table schema overwrites rows with duplicate index values, so only primary_key fields are safe to index until the schema is redesigned
This commit is contained in:
parent
08bfec1539
commit
dbc21c0bb9
4 changed files with 93 additions and 31 deletions
|
|
@ -78,7 +78,7 @@ def load_structured_data(
|
||||||
logger.info("Step 1: Analyzing data to discover best matching schema...")
|
logger.info("Step 1: Analyzing data to discover best matching schema...")
|
||||||
|
|
||||||
# Step 1: Auto-discover schema (reuse discover_schema logic)
|
# Step 1: Auto-discover schema (reuse discover_schema logic)
|
||||||
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
|
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
if not discovered_schema:
|
if not discovered_schema:
|
||||||
logger.error("Failed to discover suitable schema automatically")
|
logger.error("Failed to discover suitable schema automatically")
|
||||||
print("❌ Could not automatically determine the best schema for your data.")
|
print("❌ Could not automatically determine the best schema for your data.")
|
||||||
|
|
@ -90,7 +90,7 @@ def load_structured_data(
|
||||||
|
|
||||||
# Step 2: Auto-generate descriptor
|
# Step 2: Auto-generate descriptor
|
||||||
logger.info("Step 2: Generating descriptor configuration...")
|
logger.info("Step 2: Generating descriptor configuration...")
|
||||||
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
|
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
if not auto_descriptor:
|
if not auto_descriptor:
|
||||||
logger.error("Failed to generate descriptor automatically")
|
logger.error("Failed to generate descriptor automatically")
|
||||||
print("❌ Could not automatically generate descriptor configuration.")
|
print("❌ Could not automatically generate descriptor configuration.")
|
||||||
|
|
@ -172,7 +172,7 @@ def load_structured_data(
|
||||||
logger.info(f"Sample chars: {sample_chars} characters")
|
logger.info(f"Sample chars: {sample_chars} characters")
|
||||||
|
|
||||||
# Use the helper function to discover schema (get raw response for display)
|
# Use the helper function to discover schema (get raw response for display)
|
||||||
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
|
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
# Debug: print response type and content
|
# Debug: print response type and content
|
||||||
|
|
@ -203,7 +203,7 @@ def load_structured_data(
|
||||||
# If no schema specified, discover it first
|
# If no schema specified, discover it first
|
||||||
if not schema_name:
|
if not schema_name:
|
||||||
logger.info("No schema specified, auto-discovering...")
|
logger.info("No schema specified, auto-discovering...")
|
||||||
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
|
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
if not schema_name:
|
if not schema_name:
|
||||||
print("Error: Could not determine schema automatically.")
|
print("Error: Could not determine schema automatically.")
|
||||||
print("Please specify a schema using --schema-name or run --discover-schema first.")
|
print("Please specify a schema using --schema-name or run --discover-schema first.")
|
||||||
|
|
@ -213,7 +213,7 @@ def load_structured_data(
|
||||||
logger.info(f"Target schema: {schema_name}")
|
logger.info(f"Target schema: {schema_name}")
|
||||||
|
|
||||||
# Generate descriptor using helper function
|
# Generate descriptor using helper function
|
||||||
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
|
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
|
|
||||||
if descriptor:
|
if descriptor:
|
||||||
# Output the generated descriptor
|
# Output the generated descriptor
|
||||||
|
|
@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for auto mode
|
# Helper functions for auto mode
|
||||||
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
|
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
|
||||||
"""Auto-discover the best matching schema for the input data
|
"""Auto-discover the best matching schema for the input data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
|
||||||
# Import API modules
|
# Import API modules
|
||||||
from trustgraph.api import Api
|
from trustgraph.api import Api
|
||||||
from trustgraph.api.types import ConfigKey
|
from trustgraph.api.types import ConfigKey
|
||||||
api = Api(api_url, workspace=workspace)
|
api = Api(api_url, token=token, workspace=workspace)
|
||||||
config_api = api.config()
|
config_api = api.config()
|
||||||
|
|
||||||
# Get available schemas
|
# Get available schemas
|
||||||
|
|
@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
|
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
|
||||||
"""Auto-generate descriptor configuration for the discovered schema"""
|
"""Auto-generate descriptor configuration for the discovered schema"""
|
||||||
try:
|
try:
|
||||||
# Read sample data
|
# Read sample data
|
||||||
|
|
@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
|
||||||
# Import API modules
|
# Import API modules
|
||||||
from trustgraph.api import Api
|
from trustgraph.api import Api
|
||||||
from trustgraph.api.types import ConfigKey
|
from trustgraph.api.types import ConfigKey
|
||||||
api = Api(api_url, workspace=workspace)
|
api = Api(api_url, token=token, workspace=workspace)
|
||||||
config_api = api.config()
|
config_api = api.config()
|
||||||
|
|
||||||
# Get schema definition
|
# Get schema definition
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||||
from .... schema import Error, RowSchema, Field as SchemaField
|
from .... schema import Error, RowSchema, Field as SchemaField
|
||||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
from .... tables.cassandra_async import async_execute
|
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
|
||||||
|
|
||||||
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
||||||
|
|
||||||
|
|
@ -180,7 +180,7 @@ class Processor(FlowProcessor):
|
||||||
description=field_def.get("description", ""),
|
description=field_def.get("description", ""),
|
||||||
required=field_def.get("required", False),
|
required=field_def.get("required", False),
|
||||||
enum_values=field_def.get("enum", []),
|
enum_values=field_def.get("enum", []),
|
||||||
indexed=field_def.get("indexed", False)
|
indexed=field_def.get("indexed", False),
|
||||||
)
|
)
|
||||||
fields.append(field)
|
fields.append(field)
|
||||||
|
|
||||||
|
|
@ -232,6 +232,8 @@ class Processor(FlowProcessor):
|
||||||
for index_name in index_names:
|
for index_name in index_names:
|
||||||
if index_name in filters:
|
if index_name in filters:
|
||||||
value = filters[index_name]
|
value = filters[index_name]
|
||||||
|
if value == "" or value is None:
|
||||||
|
continue
|
||||||
# Single field index -> single element list
|
# Single field index -> single element list
|
||||||
index_value = [str(value)]
|
index_value = [str(value)]
|
||||||
return (index_name, index_value)
|
return (index_name, index_value)
|
||||||
|
|
@ -282,11 +284,13 @@ class Processor(FlowProcessor):
|
||||||
query += f" LIMIT {limit}"
|
query += f" LIMIT {limit}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = await async_execute(self.session, query, params)
|
pages = await async_execute_paged(
|
||||||
for row in rows:
|
self.session, query, params
|
||||||
# Convert data map to dict with proper field names
|
)
|
||||||
row_dict = dict(row.data) if row.data else {}
|
for page in pages:
|
||||||
results.append(row_dict)
|
for row in page:
|
||||||
|
row_dict = dict(row.data) if row.data else {}
|
||||||
|
results.append(row_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to query rows: {e}", exc_info=True)
|
logger.error(f"Failed to query rows: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
@ -308,8 +312,6 @@ class Processor(FlowProcessor):
|
||||||
# Query using the first index (arbitrary choice for scan)
|
# Query using the first index (arbitrary choice for scan)
|
||||||
primary_index = index_names[0]
|
primary_index = index_names[0]
|
||||||
|
|
||||||
# We need to scan all values for this index
|
|
||||||
# This requires ALLOW FILTERING or a different approach
|
|
||||||
query = f"""
|
query = f"""
|
||||||
SELECT data, source FROM {safe_keyspace}.rows
|
SELECT data, source FROM {safe_keyspace}.rows
|
||||||
WHERE collection = %s
|
WHERE collection = %s
|
||||||
|
|
@ -320,17 +322,18 @@ class Processor(FlowProcessor):
|
||||||
params = [collection, schema_name, primary_index]
|
params = [collection, schema_name, primary_index]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = await async_execute(self.session, query, params)
|
def row_filter(row):
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
row_dict = dict(row.data) if row.data else {}
|
row_dict = dict(row.data) if row.data else {}
|
||||||
|
return self._matches_filters(row_dict, filters, row_schema)
|
||||||
|
|
||||||
# Apply post-filters
|
matched_rows = await async_scan(
|
||||||
if self._matches_filters(row_dict, filters, row_schema):
|
self.session, query, params,
|
||||||
results.append(row_dict)
|
row_filter=row_filter,
|
||||||
|
limit=limit,
|
||||||
if limit and len(results) >= limit:
|
)
|
||||||
break
|
for row in matched_rows:
|
||||||
|
row_dict = dict(row.data) if row.data else {}
|
||||||
|
results.append(row_dict)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
||||||
|
|
@ -363,7 +366,7 @@ class Processor(FlowProcessor):
|
||||||
# Parse filter key for operator
|
# Parse filter key for operator
|
||||||
if '_' in filter_key:
|
if '_' in filter_key:
|
||||||
parts = filter_key.rsplit('_', 1)
|
parts = filter_key.rsplit('_', 1)
|
||||||
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
|
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
|
||||||
field_name = parts[0]
|
field_name = parts[0]
|
||||||
operator = parts[1]
|
operator = parts[1]
|
||||||
else:
|
else:
|
||||||
|
|
@ -400,6 +403,18 @@ class Processor(FlowProcessor):
|
||||||
elif operator == 'in':
|
elif operator == 'in':
|
||||||
if str(row_value) not in [str(v) for v in filter_value]:
|
if str(row_value) not in [str(v) for v in filter_value]:
|
||||||
return False
|
return False
|
||||||
|
elif operator == 'not':
|
||||||
|
if str(row_value) == str(filter_value):
|
||||||
|
return False
|
||||||
|
elif operator == 'startsWith':
|
||||||
|
if not str(row_value).startswith(str(filter_value)):
|
||||||
|
return False
|
||||||
|
elif operator == 'endsWith':
|
||||||
|
if not str(row_value).endswith(str(filter_value)):
|
||||||
|
return False
|
||||||
|
elif operator == 'not_in':
|
||||||
|
if str(row_value) in [str(v) for v in filter_value]:
|
||||||
|
return False
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -172,7 +172,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
description=field_def.get("description", ""),
|
description=field_def.get("description", ""),
|
||||||
required=field_def.get("required", False),
|
required=field_def.get("required", False),
|
||||||
enum_values=field_def.get("enum", []),
|
enum_values=field_def.get("enum", []),
|
||||||
indexed=field_def.get("indexed", False)
|
indexed=field_def.get("indexed", False),
|
||||||
)
|
)
|
||||||
fields.append(field)
|
fields.append(field)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,14 +80,14 @@ def _set_exception_if_pending(fut, exc):
|
||||||
fut.set_exception(exc)
|
fut.set_exception(exc)
|
||||||
|
|
||||||
|
|
||||||
async def async_execute_paged(session, query, parameters=None, fetch_size=100):
|
async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
|
||||||
"""Execute a CQL query with page-by-page iteration.
|
"""Execute a CQL query with page-by-page iteration.
|
||||||
|
|
||||||
Uses synchronous session.execute() inside run_in_executor so that
|
Uses synchronous session.execute() inside run_in_executor so that
|
||||||
the driver's ResultSet paging works correctly without materialising
|
the driver's ResultSet paging works correctly without materialising
|
||||||
the entire result set in memory.
|
the entire result set in memory.
|
||||||
|
|
||||||
Yields one page of rows at a time (as a list).
|
Returns all pages as a list of lists.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
|
@ -111,3 +111,50 @@ async def async_execute_paged(session, query, parameters=None, fetch_size=100):
|
||||||
return await loop.run_in_executor(
|
return await loop.run_in_executor(
|
||||||
None, _fetch_all_pages
|
None, _fetch_all_pages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_scan(
|
||||||
|
session, query, parameters=None, row_filter=None,
|
||||||
|
limit=None, fetch_size=5000,
|
||||||
|
):
|
||||||
|
"""Scan a CQL query page-by-page, applying a filter and limit.
|
||||||
|
|
||||||
|
Only matching rows accumulate in memory. Each page is discarded
|
||||||
|
after processing, so peak memory is bounded by fetch_size plus
|
||||||
|
the number of matching rows (capped by limit).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: cassandra.cluster.Session
|
||||||
|
query: CQL statement string
|
||||||
|
parameters: bind params
|
||||||
|
row_filter: callable(row) -> bool, or None to accept all
|
||||||
|
limit: max results to return, or None for unlimited
|
||||||
|
fetch_size: rows per Cassandra page fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching rows.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
if isinstance(query, str):
|
||||||
|
stmt = SimpleStatement(query, fetch_size=fetch_size)
|
||||||
|
else:
|
||||||
|
stmt = query
|
||||||
|
stmt.fetch_size = fetch_size
|
||||||
|
|
||||||
|
def _scan():
|
||||||
|
results = []
|
||||||
|
result_set = session.execute(stmt, parameters)
|
||||||
|
while True:
|
||||||
|
for row in result_set.current_rows:
|
||||||
|
if row_filter is None or row_filter(row):
|
||||||
|
results.append(row)
|
||||||
|
if limit and len(results) >= limit:
|
||||||
|
return results
|
||||||
|
if result_set.has_more_pages:
|
||||||
|
result_set.fetch_next_page()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return results
|
||||||
|
|
||||||
|
return await loop.run_in_executor(None, _scan)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue