mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-13 01:02:37 +02:00
* Fix publisher resource leak in librarian submit_document (#883) Wrap pub.start()/pub.send() in try/finally to guarantee pub.stop() is called on error. Remove unnecessary asyncio.sleep(1) kludge. * Make Cassandra replication factor configurable (issue #787) (#887) Add CASSANDRA_REPLICATION_FACTOR environment variable and --cassandra-replication-factor CLI argument to cassandra_config.py. Update all four table store constructors (ConfigTableStore, KnowledgeTableStore, LibraryTableStore, IamTableStore) to accept an optional replication_factor parameter and use it in keyspace creation CQL queries. Thread the replication factor through all service constructors: Configuration, KnowledgeManager, Librarian, IamService, and knowledge store Processor. * Update tests --------- Co-authored-by: gittihub-jpg <rico@springer-mail.net>
554 lines
18 KiB
Python
554 lines
18 KiB
Python
"""
|
|
Row query service using GraphQL. Input is a GraphQL query with variables.
|
|
Output is GraphQL response data with any errors.
|
|
|
|
Queries against the unified 'rows' table with schema:
|
|
- collection: text
|
|
- schema_name: text
|
|
- index_name: text
|
|
- index_value: frozen<list<text>>
|
|
- data: map<text, text>
|
|
- source: text
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from typing import Dict, Any, Optional, List, Set
|
|
|
|
from cassandra.cluster import Cluster
|
|
from cassandra.auth import PlainTextAuthProvider
|
|
|
|
from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
|
from .... schema import Error, RowSchema, Field as SchemaField
|
|
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
|
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
|
from .... tables.cassandra_async import async_execute
|
|
|
|
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
|
|
|
# Module logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
default_ident = "rows-query"
|
|
default_concurrency = 10
|
|
|
|
|
|
class Processor(FlowProcessor):
|
|
|
|
def __init__(self, **params):
|
|
|
|
id = params.get("id", default_ident)
|
|
concurrency = params.get("concurrency", default_concurrency)
|
|
|
|
# Get Cassandra parameters
|
|
cassandra_host = params.get("cassandra_host")
|
|
cassandra_username = params.get("cassandra_username")
|
|
cassandra_password = params.get("cassandra_password")
|
|
|
|
# Resolve configuration with environment variable fallback
|
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
|
host=cassandra_host,
|
|
username=cassandra_username,
|
|
password=cassandra_password
|
|
)
|
|
|
|
# Store resolved configuration with proper names
|
|
self.cassandra_host = hosts # Store as list
|
|
self.cassandra_username = username
|
|
self.cassandra_password = password
|
|
|
|
# Config key for schemas
|
|
self.config_key = params.get("config_type", "schema")
|
|
|
|
super(Processor, self).__init__(
|
|
**params | {
|
|
"id": id,
|
|
"config_type": self.config_key,
|
|
}
|
|
)
|
|
|
|
self.register_specification(
|
|
ConsumerSpec(
|
|
name="request",
|
|
schema=RowsQueryRequest,
|
|
handler=self.on_message,
|
|
concurrency=concurrency,
|
|
)
|
|
)
|
|
|
|
self.register_specification(
|
|
ProducerSpec(
|
|
name="response",
|
|
schema=RowsQueryResponse,
|
|
)
|
|
)
|
|
|
|
# Register config handler for schema updates
|
|
self.register_config_handler(self.on_schema_config, types=["schema"])
|
|
|
|
# Per-workspace schema storage: {workspace: {name: RowSchema}}
|
|
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
|
|
|
|
# Per-workspace GraphQL schema builders and compiled schemas
|
|
self.schema_builders: Dict[str, GraphQLSchemaBuilder] = {}
|
|
self.graphql_schemas: Dict[str, Any] = {}
|
|
|
|
# Cassandra session
|
|
self.cluster = None
|
|
self.session = None
|
|
|
|
# Known keyspaces
|
|
self.known_keyspaces: Set[str] = set()
|
|
|
|
def connect_cassandra(self):
|
|
"""Connect to Cassandra cluster"""
|
|
if self.session:
|
|
return
|
|
|
|
try:
|
|
if self.cassandra_username and self.cassandra_password:
|
|
auth_provider = PlainTextAuthProvider(
|
|
username=self.cassandra_username,
|
|
password=self.cassandra_password
|
|
)
|
|
self.cluster = Cluster(
|
|
contact_points=self.cassandra_host,
|
|
auth_provider=auth_provider
|
|
)
|
|
else:
|
|
self.cluster = Cluster(contact_points=self.cassandra_host)
|
|
|
|
self.session = self.cluster.connect()
|
|
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
|
raise
|
|
|
|
def sanitize_name(self, name: str) -> str:
|
|
"""Sanitize names for Cassandra compatibility"""
|
|
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
|
if safe_name and not safe_name[0].isalpha():
|
|
safe_name = 'r_' + safe_name
|
|
return safe_name.lower()
|
|
|
|
async def on_schema_config(self, workspace, config, version):
|
|
"""Handle schema configuration updates"""
|
|
logger.info(
|
|
f"Loading schema configuration version {version} "
|
|
f"for workspace {workspace}"
|
|
)
|
|
|
|
# Replace existing schemas for this workspace
|
|
ws_schemas: Dict[str, RowSchema] = {}
|
|
self.schemas[workspace] = ws_schemas
|
|
|
|
builder = GraphQLSchemaBuilder()
|
|
self.schema_builders[workspace] = builder
|
|
|
|
# Check if our config type exists
|
|
if self.config_key not in config:
|
|
logger.warning(
|
|
f"No '{self.config_key}' type in configuration "
|
|
f"for {workspace}"
|
|
)
|
|
self.graphql_schemas[workspace] = None
|
|
return
|
|
|
|
# Get the schemas dictionary for our type
|
|
schemas_config = config[self.config_key]
|
|
|
|
# Process each schema in the schemas config
|
|
for schema_name, schema_json in schemas_config.items():
|
|
try:
|
|
# Parse the JSON schema definition
|
|
schema_def = json.loads(schema_json)
|
|
|
|
# Create Field objects
|
|
fields = []
|
|
for field_def in schema_def.get("fields", []):
|
|
field = SchemaField(
|
|
name=field_def["name"],
|
|
type=field_def["type"],
|
|
size=field_def.get("size", 0),
|
|
primary=field_def.get("primary_key", False),
|
|
description=field_def.get("description", ""),
|
|
required=field_def.get("required", False),
|
|
enum_values=field_def.get("enum", []),
|
|
indexed=field_def.get("indexed", False)
|
|
)
|
|
fields.append(field)
|
|
|
|
# Create RowSchema
|
|
row_schema = RowSchema(
|
|
name=schema_def.get("name", schema_name),
|
|
description=schema_def.get("description", ""),
|
|
fields=fields
|
|
)
|
|
|
|
ws_schemas[schema_name] = row_schema
|
|
builder.add_schema(schema_name, row_schema)
|
|
logger.info(
|
|
f"Loaded schema: {schema_name} with "
|
|
f"{len(fields)} fields for {workspace}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
|
|
|
logger.info(
|
|
f"Schema configuration loaded for {workspace}: "
|
|
f"{len(ws_schemas)} schemas"
|
|
)
|
|
|
|
# Regenerate GraphQL schema for this workspace
|
|
self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
|
|
|
|
def get_index_names(self, schema: RowSchema) -> List[str]:
|
|
"""Get all index names for a schema."""
|
|
index_names = []
|
|
for field in schema.fields:
|
|
if field.primary or field.indexed:
|
|
index_names.append(field.name)
|
|
return index_names
|
|
|
|
def find_matching_index(
|
|
self,
|
|
schema: RowSchema,
|
|
filters: Dict[str, Any]
|
|
) -> Optional[tuple]:
|
|
"""
|
|
Find an index that can satisfy the query filters.
|
|
Returns (index_name, index_value) if found, None otherwise.
|
|
|
|
For exact match queries, we need a filter on an indexed field.
|
|
"""
|
|
index_names = self.get_index_names(schema)
|
|
|
|
# Look for an exact match filter on an indexed field
|
|
for index_name in index_names:
|
|
if index_name in filters:
|
|
value = filters[index_name]
|
|
# Single field index -> single element list
|
|
index_value = [str(value)]
|
|
return (index_name, index_value)
|
|
|
|
return None
|
|
|
|
async def query_cassandra(
|
|
self,
|
|
workspace: str,
|
|
collection: str,
|
|
schema_name: str,
|
|
row_schema: RowSchema,
|
|
filters: Dict[str, Any],
|
|
limit: int,
|
|
order_by: Optional[str] = None,
|
|
direction: Optional[SortDirection] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Execute a query against the unified Cassandra rows table.
|
|
|
|
For exact match queries on indexed fields, we can query directly.
|
|
For other queries, we need to scan and post-filter.
|
|
"""
|
|
# Connect if needed
|
|
self.connect_cassandra()
|
|
|
|
safe_keyspace = self.sanitize_name(workspace)
|
|
|
|
# Try to find an index that matches the filters
|
|
index_match = self.find_matching_index(row_schema, filters)
|
|
|
|
results = []
|
|
|
|
if index_match:
|
|
# Direct query using index
|
|
index_name, index_value = index_match
|
|
|
|
query = f"""
|
|
SELECT data, source FROM {safe_keyspace}.rows
|
|
WHERE collection = %s
|
|
AND schema_name = %s
|
|
AND index_name = %s
|
|
AND index_value = %s
|
|
"""
|
|
params = [collection, schema_name, index_name, index_value]
|
|
|
|
if limit:
|
|
query += f" LIMIT {limit}"
|
|
|
|
try:
|
|
rows = await async_execute(self.session, query, params)
|
|
for row in rows:
|
|
# Convert data map to dict with proper field names
|
|
row_dict = dict(row.data) if row.data else {}
|
|
results.append(row_dict)
|
|
except Exception as e:
|
|
logger.error(f"Failed to query rows: {e}", exc_info=True)
|
|
raise
|
|
|
|
else:
|
|
# No direct index match - scan all rows for this schema
|
|
# This is less efficient but necessary for non-indexed queries
|
|
logger.warning(
|
|
f"No index match for filters {filters} - scanning all indexes"
|
|
)
|
|
|
|
# Get all index names for this schema
|
|
index_names = self.get_index_names(row_schema)
|
|
|
|
if not index_names:
|
|
logger.warning(f"Schema {schema_name} has no indexes")
|
|
return []
|
|
|
|
# Query using the first index (arbitrary choice for scan)
|
|
primary_index = index_names[0]
|
|
|
|
# We need to scan all values for this index
|
|
# This requires ALLOW FILTERING or a different approach
|
|
query = f"""
|
|
SELECT data, source FROM {safe_keyspace}.rows
|
|
WHERE collection = %s
|
|
AND schema_name = %s
|
|
AND index_name = %s
|
|
ALLOW FILTERING
|
|
"""
|
|
params = [collection, schema_name, primary_index]
|
|
|
|
try:
|
|
rows = await async_execute(self.session, query, params)
|
|
|
|
for row in rows:
|
|
row_dict = dict(row.data) if row.data else {}
|
|
|
|
# Apply post-filters
|
|
if self._matches_filters(row_dict, filters, row_schema):
|
|
results.append(row_dict)
|
|
|
|
if limit and len(results) >= limit:
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
|
raise
|
|
|
|
# Post-query sorting if requested
|
|
if order_by and results:
|
|
reverse_order = direction and direction.value == "desc"
|
|
try:
|
|
results.sort(
|
|
key=lambda x: x.get(order_by, ""),
|
|
reverse=reverse_order
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to sort results by {order_by}: {e}")
|
|
|
|
return results
|
|
|
|
def _matches_filters(
|
|
self,
|
|
row_dict: Dict[str, Any],
|
|
filters: Dict[str, Any],
|
|
row_schema: RowSchema
|
|
) -> bool:
|
|
"""Check if a row matches the given filters."""
|
|
for filter_key, filter_value in filters.items():
|
|
if filter_value is None:
|
|
continue
|
|
|
|
# Parse filter key for operator
|
|
if '_' in filter_key:
|
|
parts = filter_key.rsplit('_', 1)
|
|
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
|
|
field_name = parts[0]
|
|
operator = parts[1]
|
|
else:
|
|
field_name = filter_key
|
|
operator = 'eq'
|
|
else:
|
|
field_name = filter_key
|
|
operator = 'eq'
|
|
|
|
row_value = row_dict.get(field_name)
|
|
if row_value is None:
|
|
return False
|
|
|
|
# Convert types for comparison
|
|
try:
|
|
if operator == 'eq':
|
|
if str(row_value) != str(filter_value):
|
|
return False
|
|
elif operator == 'gt':
|
|
if float(row_value) <= float(filter_value):
|
|
return False
|
|
elif operator == 'gte':
|
|
if float(row_value) < float(filter_value):
|
|
return False
|
|
elif operator == 'lt':
|
|
if float(row_value) >= float(filter_value):
|
|
return False
|
|
elif operator == 'lte':
|
|
if float(row_value) > float(filter_value):
|
|
return False
|
|
elif operator == 'contains':
|
|
if str(filter_value) not in str(row_value):
|
|
return False
|
|
elif operator == 'in':
|
|
if str(row_value) not in [str(v) for v in filter_value]:
|
|
return False
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
return True
|
|
|
|
async def execute_graphql_query(
|
|
self,
|
|
workspace: str,
|
|
query: str,
|
|
variables: Dict[str, Any],
|
|
operation_name: Optional[str],
|
|
collection: str
|
|
) -> Dict[str, Any]:
|
|
"""Execute a GraphQL query against the workspace's schema"""
|
|
|
|
graphql_schema = self.graphql_schemas.get(workspace)
|
|
if not graphql_schema:
|
|
raise RuntimeError(
|
|
f"No GraphQL schema available for workspace {workspace} "
|
|
f"- no schemas loaded"
|
|
)
|
|
|
|
# Create context for the query
|
|
context = {
|
|
"processor": self,
|
|
"workspace": workspace,
|
|
"collection": collection
|
|
}
|
|
|
|
# Execute the query
|
|
result = await graphql_schema.execute(
|
|
query,
|
|
variable_values=variables,
|
|
operation_name=operation_name,
|
|
context_value=context
|
|
)
|
|
|
|
# Build response
|
|
response = {}
|
|
|
|
if result.data:
|
|
response["data"] = result.data
|
|
else:
|
|
response["data"] = None
|
|
|
|
if result.errors:
|
|
response["errors"] = [
|
|
{
|
|
"message": str(error),
|
|
"path": getattr(error, "path", []),
|
|
"extensions": getattr(error, "extensions", {})
|
|
}
|
|
for error in result.errors
|
|
]
|
|
else:
|
|
response["errors"] = []
|
|
|
|
# Add extensions if any
|
|
if hasattr(result, "extensions") and result.extensions:
|
|
response["extensions"] = result.extensions
|
|
|
|
return response
|
|
|
|
async def on_message(self, msg, consumer, flow):
|
|
"""Handle incoming query request"""
|
|
|
|
try:
|
|
request = msg.value()
|
|
|
|
# Sender-produced ID
|
|
id = msg.properties()["id"]
|
|
|
|
logger.debug(f"Handling objects query request {id}...")
|
|
|
|
# Execute GraphQL query
|
|
result = await self.execute_graphql_query(
|
|
workspace=flow.workspace,
|
|
query=request.query,
|
|
variables=dict(request.variables) if request.variables else {},
|
|
operation_name=request.operation_name,
|
|
collection=request.collection
|
|
)
|
|
|
|
# Create response
|
|
graphql_errors = []
|
|
if "errors" in result and result["errors"]:
|
|
for err in result["errors"]:
|
|
graphql_error = GraphQLError(
|
|
message=err.get("message", ""),
|
|
path=err.get("path", []),
|
|
extensions=err.get("extensions", {})
|
|
)
|
|
graphql_errors.append(graphql_error)
|
|
|
|
response = RowsQueryResponse(
|
|
error=None,
|
|
data=json.dumps(result.get("data")) if result.get("data") else "null",
|
|
errors=graphql_errors,
|
|
extensions=result.get("extensions", {})
|
|
)
|
|
|
|
logger.debug("Sending objects query response...")
|
|
await flow("response").send(response, properties={"id": id})
|
|
|
|
logger.debug("Objects query request completed")
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Exception in rows query service: {e}", exc_info=True)
|
|
|
|
logger.info("Sending error response...")
|
|
|
|
response = RowsQueryResponse(
|
|
error=Error(
|
|
type="rows-query-error",
|
|
message=str(e),
|
|
),
|
|
data=None,
|
|
errors=[],
|
|
extensions={}
|
|
)
|
|
|
|
await flow("response").send(response, properties={"id": id})
|
|
|
|
def close(self):
|
|
"""Clean up Cassandra connections"""
|
|
if self.cluster:
|
|
self.cluster.shutdown()
|
|
logger.info("Closed Cassandra connection")
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
"""Add command-line arguments"""
|
|
|
|
FlowProcessor.add_args(parser)
|
|
add_cassandra_args(parser)
|
|
|
|
parser.add_argument(
|
|
'--config-type',
|
|
default='schema',
|
|
help='Configuration type prefix for schemas (default: schema)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'-c', '--concurrency',
|
|
type=int,
|
|
default=default_concurrency,
|
|
help=f'Number of concurrent requests (default: {default_concurrency})'
|
|
)
|
|
|
|
|
|
def run():
|
|
"""Entry point for rows-query-cassandra command"""
|
|
Processor.launch(default_ident, __doc__)
|