Structured data 2 (#645)

* Structured data refactor - multi-index tables, remove need for manual mods to the Cassandra tables

* Tech spec updated to track implementation
This commit is contained in:
cybermaggedon 2026-02-23 15:56:29 +00:00 committed by GitHub
parent 5ffad92345
commit 1809c1f56d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
87 changed files with 5233 additions and 3235 deletions

View file

@ -60,27 +60,27 @@ api-gateway = "trustgraph.gateway:run"
chunker-recursive = "trustgraph.chunking.recursive:run"
chunker-token = "trustgraph.chunking.token:run"
config-svc = "trustgraph.config.service:run"
de-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
de-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
de-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
de-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
de-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
de-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
doc-embeddings-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
doc-embeddings-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
doc-embeddings-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
document-embeddings = "trustgraph.embeddings.document_embeddings:run"
document-rag = "trustgraph.retrieval.document_rag:run"
embeddings-fastembed = "trustgraph.embeddings.fastembed:run"
embeddings-ollama = "trustgraph.embeddings.ollama:run"
ge-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
ge-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
ge-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
ge-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
ge-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
ge-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
graph-embeddings-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
graph-embeddings-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
graph-embeddings-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
graph-rag = "trustgraph.retrieval.graph_rag:run"
kg-extract-agent = "trustgraph.extract.kg.agent:run"
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
kg-extract-objects = "trustgraph.extract.kg.objects:run"
kg-extract-rows = "trustgraph.extract.kg.rows:run"
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
kg-extract-topics = "trustgraph.extract.kg.topics:run"
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
@ -90,8 +90,11 @@ librarian = "trustgraph.librarian:run"
mcp-tool = "trustgraph.agent.mcp_tool:run"
metering = "trustgraph.metering:run"
nlp-query = "trustgraph.retrieval.nlp_query:run"
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
rows-write-cassandra = "trustgraph.storage.rows.cassandra:run"
rows-query-cassandra = "trustgraph.query.rows.cassandra:run"
row-embeddings = "trustgraph.embeddings.row_embeddings:run"
row-embeddings-write-qdrant = "trustgraph.storage.row_embeddings.qdrant:run"
row-embeddings-query-qdrant = "trustgraph.query.row_embeddings.qdrant:run"
pdf-decoder = "trustgraph.decoding.pdf:run"
pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run"
prompt-template = "trustgraph.prompt.template:run"

View file

@ -0,0 +1,3 @@
from . embeddings import *

View file

@ -0,0 +1,6 @@
from . embeddings import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,263 @@
"""
Row embeddings processor. Calls the embeddings service to compute embeddings
for indexed field values in extracted row data.
Input is ExtractedObject (structured row data with schema).
Output is RowEmbeddings (row data with embeddings for indexed fields).
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (this processor): Compute embeddings
Stage 2 (row-embeddings-write-*): Store embeddings
"""
import json
import logging
from typing import Dict, List, Set
from ... schema import ExtractedObject, RowEmbeddings, RowIndexEmbedding
from ... schema import RowSchema, Field
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
from ... base import ProducerSpec, CollectionConfigHandler
logger = logging.getLogger(__name__)
default_ident = "row-embeddings"
default_batch_size = 10
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
self.batch_size = params.get("batch_size", default_batch_size)
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name="input",
schema=ExtractedObject,
handler=self.on_message,
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name="embeddings-request",
response_name="embeddings-response",
)
)
self.register_specification(
ProducerSpec(
name="output",
schema=RowEmbeddings
)
)
# Register config handlers
self.register_config_handler(self.on_schema_config)
self.register_config_handler(self.on_collection_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
index_names = []
for field in schema.fields:
if field.primary or field.indexed:
index_names.append(field.name)
return index_names
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
"""Build the index_value list for a given index."""
field_names = [f.strip() for f in index_name.split(',')]
values = []
for field_name in field_names:
value = value_map.get(field_name)
values.append(str(value) if value is not None else "")
return values
def build_text_for_embedding(self, index_value: List[str]) -> str:
"""Build text representation for embedding from index values."""
# Space-join the values for composite indexes
return " ".join(index_value)
async def on_message(self, msg, consumer, flow):
"""Process incoming ExtractedObject and compute embeddings"""
obj = msg.value()
logger.info(
f"Computing embeddings for {len(obj.values)} rows, "
f"schema {obj.schema_name}, doc {obj.metadata.id}"
)
# Validate collection exists before processing
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
logger.warning(
f"Collection {obj.metadata.collection} for user {obj.metadata.user} "
f"does not exist in config. Dropping message."
)
return
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
# Get all index names for this schema
index_names = self.get_index_names(schema)
if not index_names:
logger.warning(f"Schema {obj.schema_name} has no indexed fields - skipping")
return
# Track unique texts to avoid duplicate embeddings
# text -> (index_name, index_value)
texts_to_embed: Dict[str, tuple] = {}
# Collect all texts that need embeddings
for value_map in obj.values:
for index_name in index_names:
index_value = self.build_index_value(value_map, index_name)
# Skip empty values
if not index_value or all(v == "" for v in index_value):
continue
text = self.build_text_for_embedding(index_value)
if text and text not in texts_to_embed:
texts_to_embed[text] = (index_name, index_value)
if not texts_to_embed:
logger.info("No texts to embed")
return
# Compute embeddings
embeddings_list = []
try:
for text, (index_name, index_value) in texts_to_embed.items():
vectors = await flow("embeddings-request").embed(text=text)
embeddings_list.append(
RowIndexEmbedding(
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors
)
)
# Send in batches to avoid oversized messages
for i in range(0, len(embeddings_list), self.batch_size):
batch = embeddings_list[i:i + self.batch_size]
result = RowEmbeddings(
metadata=obj.metadata,
schema_name=obj.schema_name,
embeddings=batch,
)
await flow("output").send(result)
logger.info(
f"Computed {len(embeddings_list)} embeddings for "
f"{len(obj.values)} rows ({len(index_names)} indexes)"
)
except Exception as e:
logger.error("Exception during embedding computation", exc_info=True)
raise e
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Collection creation notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection notification for {user}/{collection}")
async def delete_collection(self, user: str, collection: str):
"""Collection deletion notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection delete notification for {user}/{collection}")
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
parser.add_argument(
'--batch-size',
type=int,
default=default_batch_size,
help=f'Maximum embeddings per output message (default: {default_batch_size})'
)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -1,5 +1,5 @@
"""
Object extraction service - extracts structured objects from text chunks
Row extraction service - extracts structured rows from text chunks
based on configured schemas.
"""
@ -18,7 +18,7 @@ from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
from .... messaging.translators import row_schema_translator
default_ident = "kg-extract-objects"
default_ident = "kg-extract-rows"
def convert_values_to_strings(obj: Dict[str, Any]) -> Dict[str, str]:
@ -310,5 +310,5 @@ class Processor(FlowProcessor):
FlowProcessor.add_args(parser)
def run():
"""Entry point for kg-extract-objects command"""
"""Entry point for kg-extract-rows command"""
Processor.launch(default_ident, __doc__)

View file

@ -20,7 +20,7 @@ from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor
from . objects_query import ObjectsQueryRequestor
from . rows_query import RowsQueryRequestor
from . nlp_query import NLPQueryRequestor
from . structured_query import StructuredQueryRequestor
from . structured_diag import StructuredDiagRequestor
@ -40,7 +40,7 @@ from . triples_import import TriplesImport
from . graph_embeddings_import import GraphEmbeddingsImport
from . document_embeddings_import import DocumentEmbeddingsImport
from . entity_contexts_import import EntityContextsImport
from . objects_import import ObjectsImport
from . rows_import import RowsImport
from . core_export import CoreExport
from . core_import import CoreImport
@ -58,7 +58,7 @@ request_response_dispatchers = {
"graph-embeddings": GraphEmbeddingsQueryRequestor,
"document-embeddings": DocumentEmbeddingsQueryRequestor,
"triples": TriplesQueryRequestor,
"objects": ObjectsQueryRequestor,
"rows": RowsQueryRequestor,
"nlp-query": NLPQueryRequestor,
"structured-query": StructuredQueryRequestor,
"structured-diag": StructuredDiagRequestor,
@ -89,7 +89,7 @@ import_dispatchers = {
"graph-embeddings": GraphEmbeddingsImport,
"document-embeddings": DocumentEmbeddingsImport,
"entity-contexts": EntityContextsImport,
"objects": ObjectsImport,
"rows": RowsImport,
}
class DispatcherWrapper:

View file

@ -12,7 +12,7 @@ from . serialize import to_subgraph
# Module logger
logger = logging.getLogger(__name__)
class ObjectsImport:
class RowsImport:
def __init__(
self, ws, running, backend, queue
@ -20,7 +20,7 @@ class ObjectsImport:
self.ws = ws
self.running = running
self.publisher = Publisher(
backend, topic = queue, schema = ExtractedObject
)
@ -73,4 +73,4 @@ class ObjectsImport:
if self.ws:
await self.ws.close()
self.ws = None
self.ws = None

View file

@ -1,30 +1,30 @@
from ... schema import ObjectsQueryRequest, ObjectsQueryResponse
from ... schema import RowsQueryRequest, RowsQueryResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class ObjectsQueryRequestor(ServiceRequestor):
class RowsQueryRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(ObjectsQueryRequestor, self).__init__(
super(RowsQueryRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=ObjectsQueryRequest,
response_schema=ObjectsQueryResponse,
request_schema=RowsQueryRequest,
response_schema=RowsQueryResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("objects-query")
self.response_translator = TranslatorRegistry.get_response_translator("objects-query")
self.request_translator = TranslatorRegistry.get_request_translator("rows-query")
self.response_translator = TranslatorRegistry.get_response_translator("rows-query")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)
return self.response_translator.from_response_with_completion(message)

View file

@ -0,0 +1,22 @@
"""
Shared GraphQL utilities for row query services.
This module provides reusable GraphQL components including:
- Filter types (IntFilter, StringFilter, FloatFilter)
- Dynamic schema generation from RowSchema definitions
- Filter parsing utilities
"""
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
from .schema import GraphQLSchemaBuilder
from .filters import parse_filter_key, parse_where_clause
__all__ = [
"IntFilter",
"StringFilter",
"FloatFilter",
"SortDirection",
"GraphQLSchemaBuilder",
"parse_filter_key",
"parse_where_clause",
]

View file

@ -0,0 +1,104 @@
"""
Filter parsing utilities for GraphQL row queries.
Provides functions to parse GraphQL filter objects into a normalized
format that can be used by different query backends.
"""
import logging
from typing import Dict, Any, Tuple
logger = logging.getLogger(__name__)
def parse_filter_key(filter_key: str) -> Tuple[str, str]:
"""
Parse GraphQL filter key into field name and operator.
Supports common GraphQL filter patterns:
- field_name -> (field_name, "eq")
- field_name_gt -> (field_name, "gt")
- field_name_gte -> (field_name, "gte")
- field_name_lt -> (field_name, "lt")
- field_name_lte -> (field_name, "lte")
- field_name_in -> (field_name, "in")
Args:
filter_key: The filter key string from GraphQL
Returns:
Tuple of (field_name, operator)
"""
if not filter_key:
return ("", "eq")
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
for op_suffix in operators:
if filter_key.endswith(op_suffix):
field_name = filter_key[:-len(op_suffix)]
operator = op_suffix[1:] # Remove the leading underscore
return (field_name, operator)
# Default to equality if no operator suffix found
return (filter_key, "eq")
def parse_where_clause(where_obj) -> Dict[str, Any]:
"""
Parse the idiomatic nested GraphQL filter structure into a flat dict.
Converts Strawberry filter objects (StringFilter, IntFilter, etc.)
into a dictionary mapping field names with operators to values.
Example:
Input: where_obj with email.eq = "foo@bar.com"
Output: {"email": "foo@bar.com"}
Input: where_obj with age.gt = 21
Output: {"age_gt": 21}
Args:
where_obj: The GraphQL where clause object
Returns:
Dictionary mapping field_operator keys to values
"""
if not where_obj:
return {}
conditions = {}
logger.debug(f"Parsing where clause: {where_obj}")
for field_name, filter_obj in where_obj.__dict__.items():
if filter_obj is None:
continue
logger.debug(f"Processing field {field_name} with filter_obj: {filter_obj}")
if hasattr(filter_obj, '__dict__'):
# This is a filter object (StringFilter, IntFilter, etc.)
for operator, value in filter_obj.__dict__.items():
if value is not None:
logger.debug(f"Found operator {operator} with value {value}")
# Map GraphQL operators to our internal format
if operator == "eq":
conditions[field_name] = value
elif operator in ["gt", "gte", "lt", "lte"]:
conditions[f"{field_name}_{operator}"] = value
elif operator == "in_":
conditions[f"{field_name}_in"] = value
elif operator == "contains":
conditions[f"{field_name}_contains"] = value
elif operator == "startsWith":
conditions[f"{field_name}_startsWith"] = value
elif operator == "endsWith":
conditions[f"{field_name}_endsWith"] = value
elif operator == "not_":
conditions[f"{field_name}_not"] = value
elif operator == "not_in":
conditions[f"{field_name}_not_in"] = value
logger.debug(f"Final parsed conditions: {conditions}")
return conditions

View file

@ -0,0 +1,251 @@
"""
Dynamic GraphQL schema generation from RowSchema definitions.
Provides a builder class that creates Strawberry GraphQL schemas
from TrustGraph RowSchema definitions, with pluggable query backends.
"""
import logging
from typing import Dict, Any, Optional, List, Callable, Awaitable
import strawberry
from strawberry import Schema
from strawberry.types import Info
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
logger = logging.getLogger(__name__)
# Type alias for query callback function
QueryCallback = Callable[
[str, str, str, Any, Dict[str, Any], int, Optional[str], Optional[SortDirection]],
Awaitable[List[Dict[str, Any]]]
]
class GraphQLSchemaBuilder:
"""
Builds GraphQL schemas from RowSchema definitions.
This class extracts the GraphQL schema generation logic so it can be
reused across different query backends (Cassandra, etc.).
Usage:
builder = GraphQLSchemaBuilder()
# Add schemas
for name, row_schema in schemas.items():
builder.add_schema(name, row_schema)
# Build with a query callback
schema = builder.build(query_callback)
"""
def __init__(self):
self.schemas: Dict[str, Any] = {} # name -> RowSchema
self.graphql_types: Dict[str, type] = {}
self.filter_types: Dict[str, type] = {}
def add_schema(self, name: str, row_schema) -> None:
"""
Add a RowSchema to the builder.
Args:
name: The schema name (used as the GraphQL query field name)
row_schema: The RowSchema object defining fields
"""
self.schemas[name] = row_schema
self.graphql_types[name] = self._create_graphql_type(name, row_schema)
self.filter_types[name] = self._create_filter_type(name, row_schema)
logger.debug(f"Added schema {name} with {len(row_schema.fields)} fields")
def clear(self) -> None:
"""Clear all schemas from the builder."""
self.schemas = {}
self.graphql_types = {}
self.filter_types = {}
def build(self, query_callback: QueryCallback) -> Optional[Schema]:
"""
Build the GraphQL schema with the provided query callback.
The query callback will be invoked when resolving queries, with:
- user: str
- collection: str
- schema_name: str
- row_schema: RowSchema
- filters: Dict[str, Any]
- limit: int
- order_by: Optional[str]
- direction: Optional[SortDirection]
It should return a list of row dictionaries.
Args:
query_callback: Async function to execute queries
Returns:
Strawberry Schema, or None if no schemas are loaded
"""
if not self.schemas:
logger.warning("No schemas loaded, cannot generate GraphQL schema")
return None
# Create the Query class with resolvers
query_dict = {'__annotations__': {}}
for schema_name, row_schema in self.schemas.items():
graphql_type = self.graphql_types[schema_name]
filter_type = self.filter_types[schema_name]
# Create resolver function for this schema
resolver_func = self._make_resolver(
schema_name, row_schema, graphql_type, filter_type, query_callback
)
# Add field to query dictionary
query_dict[schema_name] = strawberry.field(resolver=resolver_func)
query_dict['__annotations__'][schema_name] = List[graphql_type]
# Create the Query class
Query = type('Query', (), query_dict)
Query = strawberry.type(Query)
# Create the schema with auto_camel_case disabled to keep snake_case field names
schema = strawberry.Schema(
query=Query,
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
)
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
return schema
def _get_python_type(self, field_type: str):
"""Convert schema field type to Python type for GraphQL."""
type_mapping = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"timestamp": str, # Use string for timestamps in GraphQL
"date": str,
"time": str,
"uuid": str
}
return type_mapping.get(field_type, str)
def _create_graphql_type(self, schema_name: str, row_schema) -> type:
"""Create a GraphQL output type from a RowSchema."""
# Create annotations for the GraphQL type
annotations = {}
defaults = {}
for field in row_schema.fields:
python_type = self._get_python_type(field.type)
# Make field optional if not required
if not field.required and not field.primary:
annotations[field.name] = Optional[python_type]
defaults[field.name] = None
else:
annotations[field.name] = python_type
# Create the class dynamically
type_name = f"{schema_name.capitalize()}Type"
graphql_class = type(
type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry decorator
return strawberry.type(graphql_class)
def _create_filter_type(self, schema_name: str, row_schema) -> type:
"""Create a dynamic filter input type for a schema."""
filter_type_name = f"{schema_name.capitalize()}Filter"
# Add __annotations__ and defaults for the fields
annotations = {}
defaults = {}
logger.debug(f"Creating filter type {filter_type_name} for schema {schema_name}")
for field in row_schema.fields:
logger.debug(
f"Field {field.name}: type={field.type}, "
f"indexed={field.indexed}, primary={field.primary}"
)
# Allow filtering on any field
if field.type == "integer":
annotations[field.name] = Optional[IntFilter]
defaults[field.name] = None
elif field.type == "float":
annotations[field.name] = Optional[FloatFilter]
defaults[field.name] = None
elif field.type == "string":
annotations[field.name] = Optional[StringFilter]
defaults[field.name] = None
logger.debug(
f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}"
)
# Create the class dynamically
FilterType = type(
filter_type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry input decorator
FilterType = strawberry.input(FilterType)
return FilterType
def _make_resolver(
self,
schema_name: str,
row_schema,
graphql_type: type,
filter_type: type,
query_callback: QueryCallback
):
"""Create a resolver function for a schema."""
from .filters import parse_where_clause
async def resolver(
info: Info,
where: Optional[filter_type] = None,
order_by: Optional[str] = None,
direction: Optional[SortDirection] = None,
limit: Optional[int] = 100
) -> List[graphql_type]:
# Get context values
user = info.context["user"]
collection = info.context["collection"]
# Parse the where clause
filters = parse_where_clause(where)
# Call the query backend
results = await query_callback(
user, collection, schema_name, row_schema,
filters, limit, order_by, direction
)
# Convert to GraphQL types
graphql_results = []
for row in results:
graphql_obj = graphql_type(**row)
graphql_results.append(graphql_obj)
return graphql_results
return resolver

View file

@ -0,0 +1,56 @@
"""
GraphQL filter and sort types for row queries.
These types are used to build dynamic GraphQL schemas for querying
structured row data.
"""
from typing import Optional, List
from enum import Enum
import strawberry
@strawberry.input
class IntFilter:
"""Filter type for integer fields."""
eq: Optional[int] = None
gt: Optional[int] = None
gte: Optional[int] = None
lt: Optional[int] = None
lte: Optional[int] = None
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
not_: Optional[int] = strawberry.field(name="not", default=None)
not_in: Optional[List[int]] = None
@strawberry.input
class StringFilter:
"""Filter type for string fields."""
eq: Optional[str] = None
contains: Optional[str] = None
startsWith: Optional[str] = None
endsWith: Optional[str] = None
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
not_: Optional[str] = strawberry.field(name="not", default=None)
not_in: Optional[List[str]] = None
@strawberry.input
class FloatFilter:
"""Filter type for float fields."""
eq: Optional[float] = None
gt: Optional[float] = None
gte: Optional[float] = None
lt: Optional[float] = None
lte: Optional[float] = None
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
not_: Optional[float] = strawberry.field(name="not", default=None)
not_in: Optional[List[float]] = None
@strawberry.enum
class SortDirection(Enum):
"""Sort direction for query results."""
ASC = "asc"
DESC = "desc"

View file

@ -1,738 +0,0 @@
"""
Objects query service using GraphQL. Input is a GraphQL query with variables.
Output is GraphQL response data with any errors.
"""
import json
import logging
import asyncio
from typing import Dict, Any, Optional, List, Set
from enum import Enum
from dataclasses import dataclass, field
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
import strawberry
from strawberry import Schema
from strawberry.types import Info
from strawberry.scalars import JSON
from strawberry.tools import create_type
from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "objects-query"
# GraphQL filter input types
@strawberry.input
class IntFilter:
eq: Optional[int] = None
gt: Optional[int] = None
gte: Optional[int] = None
lt: Optional[int] = None
lte: Optional[int] = None
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
not_: Optional[int] = strawberry.field(name="not", default=None)
not_in: Optional[List[int]] = None
@strawberry.input
class StringFilter:
eq: Optional[str] = None
contains: Optional[str] = None
startsWith: Optional[str] = None
endsWith: Optional[str] = None
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
not_: Optional[str] = strawberry.field(name="not", default=None)
not_in: Optional[List[str]] = None
@strawberry.input
class FloatFilter:
eq: Optional[float] = None
gt: Optional[float] = None
gte: Optional[float] = None
lt: Optional[float] = None
lte: Optional[float] = None
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
not_: Optional[float] = strawberry.field(name="not", default=None)
not_in: Optional[List[float]] = None
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = ObjectsQueryRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = ObjectsQueryResponse,
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# GraphQL schema
self.graphql_schema: Optional[Schema] = None
# GraphQL types cache
self.graphql_types: Dict[str, type] = {}
# Cassandra session
self.cluster = None
self.session = None
# Known keyspaces and tables
self.known_keyspaces: Set[str] = set()
self.known_tables: Dict[str, Set[str]] = {}
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'o_' + safe_name
return safe_name.lower()
def sanitize_table(self, name: str) -> str:
"""Sanitize table names for Cassandra compatibility"""
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
safe_name = 'o_' + safe_name
return safe_name.lower()
def parse_filter_key(self, filter_key: str) -> tuple[str, str]:
"""Parse GraphQL filter key into field name and operator"""
if not filter_key:
return ("", "eq")
# Support common GraphQL filter patterns:
# field_name -> (field_name, "eq")
# field_name_gt -> (field_name, "gt")
# field_name_gte -> (field_name, "gte")
# field_name_lt -> (field_name, "lt")
# field_name_lte -> (field_name, "lte")
# field_name_in -> (field_name, "in")
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
for op_suffix in operators:
if filter_key.endswith(op_suffix):
field_name = filter_key[:-len(op_suffix)]
operator = op_suffix[1:] # Remove the leading underscore
return (field_name, operator)
# Default to equality if no operator suffix found
return (filter_key, "eq")
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
self.graphql_types = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Regenerate GraphQL schema
self.generate_graphql_schema()
def get_python_type(self, field_type: str):
"""Convert schema field type to Python type for GraphQL"""
type_mapping = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"timestamp": str, # Use string for timestamps in GraphQL
"date": str,
"time": str,
"uuid": str
}
return type_mapping.get(field_type, str)
def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type:
"""Create a GraphQL type from a RowSchema"""
# Create annotations for the GraphQL type
annotations = {}
defaults = {}
for field in row_schema.fields:
python_type = self.get_python_type(field.type)
# Make field optional if not required
if not field.required and not field.primary:
annotations[field.name] = Optional[python_type]
defaults[field.name] = None
else:
annotations[field.name] = python_type
# Create the class dynamically
type_name = f"{schema_name.capitalize()}Type"
graphql_class = type(
type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry decorator
return strawberry.type(graphql_class)
def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema):
"""Create a dynamic filter input type for a schema"""
# Create the filter type dynamically
filter_type_name = f"{schema_name.capitalize()}Filter"
# Add __annotations__ and defaults for the fields
annotations = {}
defaults = {}
logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}")
for field in row_schema.fields:
logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}")
# Allow filtering on any field for now, not just indexed/primary
# if field.indexed or field.primary:
if field.type == "integer":
annotations[field.name] = Optional[IntFilter]
defaults[field.name] = None
logger.info(f"Added IntFilter for {field.name}")
elif field.type == "float":
annotations[field.name] = Optional[FloatFilter]
defaults[field.name] = None
logger.info(f"Added FloatFilter for {field.name}")
elif field.type == "string":
annotations[field.name] = Optional[StringFilter]
defaults[field.name] = None
logger.info(f"Added StringFilter for {field.name}")
logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}")
# Create the class dynamically
FilterType = type(
filter_type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry input decorator
FilterType = strawberry.input(FilterType)
return FilterType
def create_sort_direction_enum(self):
"""Create sort direction enum"""
@strawberry.enum
class SortDirection(Enum):
ASC = "asc"
DESC = "desc"
return SortDirection
def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]:
"""Parse the idiomatic nested filter structure"""
if not where_obj:
return {}
conditions = {}
logger.info(f"Parsing where clause: {where_obj}")
for field_name, filter_obj in where_obj.__dict__.items():
if filter_obj is None:
continue
logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}")
if hasattr(filter_obj, '__dict__'):
# This is a filter object (StringFilter, IntFilter, etc.)
for operator, value in filter_obj.__dict__.items():
if value is not None:
logger.info(f"Found operator {operator} with value {value}")
# Map GraphQL operators to our internal format
if operator == "eq":
conditions[field_name] = value
elif operator in ["gt", "gte", "lt", "lte"]:
conditions[f"{field_name}_{operator}"] = value
elif operator == "in_":
conditions[f"{field_name}_in"] = value
elif operator == "contains":
conditions[f"{field_name}_contains"] = value
logger.info(f"Final parsed conditions: {conditions}")
return conditions
def generate_graphql_schema(self):
"""Generate GraphQL schema from loaded schemas using dynamic filter types"""
if not self.schemas:
logger.warning("No schemas loaded, cannot generate GraphQL schema")
self.graphql_schema = None
return
# Create GraphQL types and filter types for each schema
filter_types = {}
sort_direction_enum = self.create_sort_direction_enum()
for schema_name, row_schema in self.schemas.items():
graphql_type = self.create_graphql_type(schema_name, row_schema)
filter_type = self.create_filter_type_for_schema(schema_name, row_schema)
self.graphql_types[schema_name] = graphql_type
filter_types[schema_name] = filter_type
# Create the Query class with resolvers
query_dict = {'__annotations__': {}}
for schema_name, row_schema in self.schemas.items():
graphql_type = self.graphql_types[schema_name]
filter_type = filter_types[schema_name]
# Create resolver function for this schema
def make_resolver(s_name, r_schema, g_type, f_type, sort_enum):
async def resolver(
info: Info,
where: Optional[f_type] = None,
order_by: Optional[str] = None,
direction: Optional[sort_enum] = None,
limit: Optional[int] = 100
) -> List[g_type]:
# Get the processor instance from context
processor = info.context["processor"]
user = info.context["user"]
collection = info.context["collection"]
# Parse the idiomatic where clause
filters = processor.parse_idiomatic_where_clause(where)
# Query Cassandra
results = await processor.query_cassandra(
user, collection, s_name, r_schema,
filters, limit, order_by, direction
)
# Convert to GraphQL types
graphql_results = []
for row in results:
graphql_obj = g_type(**row)
graphql_results.append(graphql_obj)
return graphql_results
return resolver
# Add resolver to query
resolver_name = schema_name
resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum)
# Add field to query dictionary
query_dict[resolver_name] = strawberry.field(resolver=resolver_func)
query_dict['__annotations__'][resolver_name] = List[graphql_type]
# Create the Query class
Query = type('Query', (), query_dict)
Query = strawberry.type(Query)
# Create the schema with auto_camel_case disabled to keep snake_case field names
self.graphql_schema = strawberry.Schema(
query=Query,
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
)
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
async def query_cassandra(
self,
user: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
filters: Dict[str, Any],
limit: int,
order_by: Optional[str] = None,
direction: Optional[Any] = None
) -> List[Dict[str, Any]]:
"""Execute a query against Cassandra"""
# Connect if needed
self.connect_cassandra()
# Build the query
keyspace = self.sanitize_name(user)
table = self.sanitize_table(schema_name)
# Start with basic SELECT
query = f"SELECT * FROM {keyspace}.{table}"
# Add WHERE clauses
where_clauses = [f"collection = %s"]
params = [collection]
# Add filters for indexed or primary key fields
for filter_key, value in filters.items():
if value is not None:
# Parse field name and operator from filter key
logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})")
result = self.parse_filter_key(filter_key)
logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})")
if not result or len(result) != 2:
logger.error(f"parse_filter_key returned invalid result: {result}")
continue # Skip this filter
field_name, operator = result
# Find the field in schema
schema_field = None
for f in row_schema.fields:
if f.name == field_name:
schema_field = f
break
if schema_field:
safe_field = self.sanitize_name(field_name)
# Build WHERE clause based on operator
if operator == "eq":
where_clauses.append(f"{safe_field} = %s")
params.append(value)
elif operator == "gt":
where_clauses.append(f"{safe_field} > %s")
params.append(value)
elif operator == "gte":
where_clauses.append(f"{safe_field} >= %s")
params.append(value)
elif operator == "lt":
where_clauses.append(f"{safe_field} < %s")
params.append(value)
elif operator == "lte":
where_clauses.append(f"{safe_field} <= %s")
params.append(value)
elif operator == "in":
if isinstance(value, list):
placeholders = ",".join(["%s"] * len(value))
where_clauses.append(f"{safe_field} IN ({placeholders})")
params.extend(value)
else:
# Default to equality for unknown operators
where_clauses.append(f"{safe_field} = %s")
params.append(value)
if where_clauses:
query += " WHERE " + " AND ".join(where_clauses)
# Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort)
cassandra_order_by_added = False
if order_by and direction:
# Validate that order_by field exists in schema
order_field_exists = any(f.name == order_by for f in row_schema.fields)
if order_field_exists:
safe_order_field = self.sanitize_name(order_by)
direction_str = "ASC" if direction.value == "asc" else "DESC"
# Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution
query += f" ORDER BY {safe_order_field} {direction_str}"
# Add limit first (must come before ALLOW FILTERING)
if limit:
query += f" LIMIT {limit}"
# Add ALLOW FILTERING for now (should optimize with proper indexes later)
query += " ALLOW FILTERING"
# Execute query
try:
result = self.session.execute(query, params)
cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY
except Exception as e:
# If ORDER BY fails, try without it
if order_by and direction and "ORDER BY" in query:
logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}")
# Remove ORDER BY clause and retry
query_parts = query.split(" ORDER BY ")
if len(query_parts) == 2:
query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING"
result = self.session.execute(query_without_order, params)
cassandra_order_by_added = False
else:
raise
else:
raise
# Convert rows to dicts
results = []
for row in result:
row_dict = {}
for field in row_schema.fields:
safe_field = self.sanitize_name(field.name)
if hasattr(row, safe_field):
value = getattr(row, safe_field)
# Use original field name in result
row_dict[field.name] = value
results.append(row_dict)
# Post-query sorting if Cassandra didn't handle ORDER BY
if order_by and direction and not cassandra_order_by_added:
reverse_order = (direction.value == "desc")
try:
results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order)
except Exception as e:
logger.warning(f"Failed to sort results by {order_by}: {e}")
return results
async def execute_graphql_query(
self,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query"""
if not self.graphql_schema:
raise RuntimeError("No GraphQL schema available - no schemas loaded")
# Create context for the query
context = {
"processor": self,
"user": user,
"collection": collection
}
# Execute the query
result = await self.graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
context_value=context
)
# Build response
response = {}
if result.data:
response["data"] = result.data
else:
response["data"] = None
if result.errors:
response["errors"] = [
{
"message": str(error),
"path": getattr(error, "path", []),
"extensions": getattr(error, "extensions", {})
}
for error in result.errors
]
else:
response["errors"] = []
# Add extensions if any
if hasattr(result, "extensions") and result.extensions:
response["extensions"] = result.extensions
return response
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(f"Handling objects query request {id}...")
# Execute GraphQL query
result = await self.execute_graphql_query(
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)
# Create response
graphql_errors = []
if "errors" in result and result["errors"]:
for err in result["errors"]:
graphql_error = GraphQLError(
message=err.get("message", ""),
path=err.get("path", []),
extensions=err.get("extensions", {})
)
graphql_errors.append(graphql_error)
response = ObjectsQueryResponse(
error=None,
data=json.dumps(result.get("data")) if result.get("data") else "null",
errors=graphql_errors,
extensions=result.get("extensions", {})
)
logger.debug("Sending objects query response...")
await flow("response").send(response, properties={"id": id})
logger.debug("Objects query request completed")
except Exception as e:
logger.error(f"Exception in objects query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = ObjectsQueryResponse(
error = Error(
type = "objects-query-error",
message = str(e),
),
data = None,
errors = [],
extensions = {}
)
await flow("response").send(response, properties={"id": id})
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for objects-query-graphql-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,3 @@
"""
Row embeddings query modules.
"""

View file

@ -0,0 +1,5 @@
"""
Qdrant row embeddings query service.
"""
from .service import Processor, run, default_ident

View file

@ -0,0 +1,4 @@
from .service import run
run()

View file

@ -0,0 +1,209 @@
"""
Row embeddings query service for Qdrant.
Input is query vectors plus user/collection/schema context.
Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups.
"""
import logging
import re
from typing import Optional
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
from .... schema import (
RowEmbeddingsRequest, RowEmbeddingsResponse,
RowIndexMatch, Error
)
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
default_store_uri = 'http://localhost:6333'
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.register_specification(
ConsumerSpec(
name="request",
schema=RowEmbeddingsRequest,
handler=self.on_message
)
)
self.register_specification(
ProducerSpec(
name="response",
schema=RowEmbeddingsResponse
)
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given user/collection/schema"""
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
try:
all_collections = self.qdrant.get_collections().collections
matching = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if matching:
# Return first match (there should typically be only one per dimension)
return matching[0]
except Exception as e:
logger.error(f"Failed to list Qdrant collections: {e}", exc_info=True)
return None
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
"""Execute row embeddings query"""
matches = []
# Find the collection for this user/collection/schema
qdrant_collection = self.find_collection(
request.user, request.collection, request.schema_name
)
if not qdrant_collection:
logger.info(
f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}"
)
return matches
for vec in request.vectors:
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
return matches
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(
f"Handling row embeddings query for "
f"{request.user}/{request.collection}/{request.schema_name}..."
)
# Execute query
matches = await self.query_row_embeddings(request)
response = RowEmbeddingsResponse(
error=None,
matches=matches
)
logger.debug(f"Returning {len(matches)} matches")
await flow("response").send(response, properties={"id": id})
except Exception as e:
logger.error(f"Exception in row embeddings query: {e}", exc_info=True)
response = RowEmbeddingsResponse(
error=Error(
type="row-embeddings-query-error",
message=str(e)
),
matches=[]
)
await flow("response").send(response, properties={"id": id})
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='API key for Qdrant (default: None)'
)
def run():
"""Entry point for row-embeddings-query-qdrant command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,523 @@
"""
Row query service using GraphQL. Input is a GraphQL query with variables.
Output is GraphQL response data with any errors.
Queries against the unified 'rows' table with schema:
- collection: text
- schema_name: text
- index_name: text
- index_value: frozen<list<text>>
- data: map<text, text>
- source: text
"""
import json
import logging
import re
from typing import Dict, Any, Optional, List, Set
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from ... graphql import GraphQLSchemaBuilder, SortDirection
# Module logger
logger = logging.getLogger(__name__)
default_ident = "rows-query"
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name="request",
schema=RowsQueryRequest,
handler=self.on_message
)
)
self.register_specification(
ProducerSpec(
name="response",
schema=RowsQueryResponse,
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# GraphQL schema builder and generated schema
self.schema_builder = GraphQLSchemaBuilder()
self.graphql_schema = None
# Cassandra session
self.cluster = None
self.session = None
# Known keyspaces
self.known_keyspaces: Set[str] = set()
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
self.schema_builder.clear()
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
self.schema_builder.add_schema(schema_name, row_schema)
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Regenerate GraphQL schema
self.graphql_schema = self.schema_builder.build(self.query_cassandra)
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
index_names = []
for field in schema.fields:
if field.primary or field.indexed:
index_names.append(field.name)
return index_names
def find_matching_index(
self,
schema: RowSchema,
filters: Dict[str, Any]
) -> Optional[tuple]:
"""
Find an index that can satisfy the query filters.
Returns (index_name, index_value) if found, None otherwise.
For exact match queries, we need a filter on an indexed field.
"""
index_names = self.get_index_names(schema)
# Look for an exact match filter on an indexed field
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
return None
async def query_cassandra(
self,
user: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
filters: Dict[str, Any],
limit: int,
order_by: Optional[str] = None,
direction: Optional[SortDirection] = None
) -> List[Dict[str, Any]]:
"""
Execute a query against the unified Cassandra rows table.
For exact match queries on indexed fields, we can query directly.
For other queries, we need to scan and post-filter.
"""
# Connect if needed
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
# Try to find an index that matches the filters
index_match = self.find_matching_index(row_schema, filters)
results = []
if index_match:
# Direct query using index
index_name, index_value = index_match
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
AND schema_name = %s
AND index_name = %s
AND index_value = %s
"""
params = [collection, schema_name, index_name, index_value]
if limit:
query += f" LIMIT {limit}"
try:
rows = self.session.execute(query, params)
for row in rows:
# Convert data map to dict with proper field names
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True)
raise
else:
# No direct index match - scan all rows for this schema
# This is less efficient but necessary for non-indexed queries
logger.warning(
f"No index match for filters {filters} - scanning all indexes"
)
# Get all index names for this schema
index_names = self.get_index_names(row_schema)
if not index_names:
logger.warning(f"Schema {schema_name} has no indexes")
return []
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]
# We need to scan all values for this index
# This requires ALLOW FILTERING or a different approach
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
AND schema_name = %s
AND index_name = %s
ALLOW FILTERING
"""
params = [collection, schema_name, primary_index]
try:
rows = self.session.execute(query, params)
for row in rows:
row_dict = dict(row.data) if row.data else {}
# Apply post-filters
if self._matches_filters(row_dict, filters, row_schema):
results.append(row_dict)
if limit and len(results) >= limit:
break
except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
raise
# Post-query sorting if requested
if order_by and results:
reverse_order = direction and direction.value == "desc"
try:
results.sort(
key=lambda x: x.get(order_by, ""),
reverse=reverse_order
)
except Exception as e:
logger.warning(f"Failed to sort results by {order_by}: {e}")
return results
def _matches_filters(
self,
row_dict: Dict[str, Any],
filters: Dict[str, Any],
row_schema: RowSchema
) -> bool:
"""Check if a row matches the given filters."""
for filter_key, filter_value in filters.items():
if filter_value is None:
continue
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
field_name = parts[0]
operator = parts[1]
else:
field_name = filter_key
operator = 'eq'
else:
field_name = filter_key
operator = 'eq'
row_value = row_dict.get(field_name)
if row_value is None:
return False
# Convert types for comparison
try:
if operator == 'eq':
if str(row_value) != str(filter_value):
return False
elif operator == 'gt':
if float(row_value) <= float(filter_value):
return False
elif operator == 'gte':
if float(row_value) < float(filter_value):
return False
elif operator == 'lt':
if float(row_value) >= float(filter_value):
return False
elif operator == 'lte':
if float(row_value) > float(filter_value):
return False
elif operator == 'contains':
if str(filter_value) not in str(row_value):
return False
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
except (ValueError, TypeError):
return False
return True
async def execute_graphql_query(
self,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query"""
if not self.graphql_schema:
raise RuntimeError("No GraphQL schema available - no schemas loaded")
# Create context for the query
context = {
"processor": self,
"user": user,
"collection": collection
}
# Execute the query
result = await self.graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
context_value=context
)
# Build response
response = {}
if result.data:
response["data"] = result.data
else:
response["data"] = None
if result.errors:
response["errors"] = [
{
"message": str(error),
"path": getattr(error, "path", []),
"extensions": getattr(error, "extensions", {})
}
for error in result.errors
]
else:
response["errors"] = []
# Add extensions if any
if hasattr(result, "extensions") and result.extensions:
response["extensions"] = result.extensions
return response
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(f"Handling objects query request {id}...")
# Execute GraphQL query
result = await self.execute_graphql_query(
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)
# Create response
graphql_errors = []
if "errors" in result and result["errors"]:
for err in result["errors"]:
graphql_error = GraphQLError(
message=err.get("message", ""),
path=err.get("path", []),
extensions=err.get("extensions", {})
)
graphql_errors.append(graphql_error)
response = RowsQueryResponse(
error=None,
data=json.dumps(result.get("data")) if result.get("data") else "null",
errors=graphql_errors,
extensions=result.get("extensions", {})
)
logger.debug("Sending objects query response...")
await flow("response").send(response, properties={"id": id})
logger.debug("Objects query request completed")
except Exception as e:
logger.error(f"Exception in rows query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = RowsQueryResponse(
error=Error(
type="rows-query-error",
message=str(e),
),
data=None,
errors=[],
extensions={}
)
await flow("response").send(response, properties={"id": id})
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for rows-query-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -1,6 +1,6 @@
"""
Structured Query Service - orchestrates natural language question processing.
Takes a question, converts it to GraphQL via nlp-query, executes via objects-query,
Takes a question, converts it to GraphQL via nlp-query, executes via rows-query,
and returns the results.
"""
@ -10,7 +10,7 @@ from typing import Dict, Any, Optional
from ...schema import StructuredQueryRequest, StructuredQueryResponse
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
from ...schema import RowsQueryRequest, RowsQueryResponse
from ...schema import Error
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec
@ -57,13 +57,13 @@ class Processor(FlowProcessor):
)
)
# Client spec for calling objects query service
# Client spec for calling rows query service
self.register_specification(
RequestResponseSpec(
request_name = "objects-query-request",
response_name = "objects-query-response",
request_schema = ObjectsQueryRequest,
response_schema = ObjectsQueryResponse
request_name = "rows-query-request",
response_name = "rows-query-response",
request_schema = RowsQueryRequest,
response_schema = RowsQueryResponse
)
)
@ -112,7 +112,7 @@ class Processor(FlowProcessor):
variables_as_strings[key] = str(value)
# Use user/collection values from request
objects_request = ObjectsQueryRequest(
objects_request = RowsQueryRequest(
user=request.user,
collection=request.collection,
query=nlp_response.graphql_query,
@ -120,12 +120,12 @@ class Processor(FlowProcessor):
operation_name=None
)
objects_response = await flow("objects-query-request").request(objects_request)
objects_response = await flow("rows-query-request").request(objects_request)
if objects_response.error is not None:
raise Exception(f"Objects query service error: {objects_response.error.message}")
# Handle GraphQL errors from the objects query service
raise Exception(f"Rows query service error: {objects_response.error.message}")
# Handle GraphQL errors from the rows query service
graphql_errors = []
if objects_response.errors:
for gql_error in objects_response.errors:

View file

@ -13,7 +13,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_ident = "doc-embeddings-write"
default_store_uri = 'http://localhost:19530'
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):

View file

@ -18,7 +18,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_ident = "doc-embeddings-write"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"

View file

@ -16,7 +16,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_ident = "doc-embeddings-write"
default_store_uri = 'http://localhost:6333'

View file

@ -27,7 +27,7 @@ def get_term_value(term):
# For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "ge-write"
default_ident = "graph-embeddings-write"
default_store_uri = 'http://localhost:19530'
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):

View file

@ -32,7 +32,7 @@ def get_term_value(term):
# For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "ge-write"
default_ident = "graph-embeddings-write"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"

View file

@ -31,7 +31,7 @@ def get_term_value(term):
return term.id or term.value
default_ident = "ge-write"
default_ident = "graph-embeddings-write"
default_store_uri = 'http://localhost:6333'

View file

@ -1 +0,0 @@
# Objects storage module

View file

@ -1 +0,0 @@
from . write import *

View file

@ -1,3 +0,0 @@
from . write import run
run()

View file

@ -1,538 +0,0 @@
"""
Object writer for Cassandra. Input is ExtractedObject.
Writes structured objects to Cassandra tables based on schema definitions.
"""
import json
import logging
from typing import Dict, Set, Optional, Any
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.cqlengine import connection
from cassandra import ConsistencyLevel
from .... schema import ExtractedObject
from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import CollectionConfigHandler
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "objects-write"
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = ExtractedObject,
handler = self.on_object
)
)
# Register config handlers
self.register_config_handler(self.on_schema_config)
self.register_config_handler(self.on_collection_config)
# Cache of known keyspaces/tables
self.known_keyspaces: Set[str] = set()
self.known_tables: Dict[str, Set[str]] = {} # keyspace -> set of tables
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Cassandra session
self.cluster = None
self.session = None
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
def ensure_keyspace(self, keyspace: str):
"""Ensure keyspace exists in Cassandra"""
if keyspace in self.known_keyspaces:
return
# Connect if needed
self.connect_cassandra()
# Sanitize keyspace name
safe_keyspace = self.sanitize_name(keyspace)
# Create keyspace if not exists
create_keyspace_cql = f"""
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
'replication_factor': 1
}}
"""
try:
self.session.execute(create_keyspace_cql)
self.known_keyspaces.add(keyspace)
self.known_tables[keyspace] = set()
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
except Exception as e:
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
raise
def get_cassandra_type(self, field_type: str, size: int = 0) -> str:
"""Convert schema field type to Cassandra type"""
# Handle None size
if size is None:
size = 0
type_mapping = {
"string": "text",
"integer": "bigint" if size > 4 else "int",
"float": "double" if size > 4 else "float",
"boolean": "boolean",
"timestamp": "timestamp",
"date": "date",
"time": "time",
"uuid": "uuid"
}
return type_mapping.get(field_type, "text")
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
# Replace non-alphanumeric characters with underscore
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
if safe_name and not safe_name[0].isalpha():
safe_name = 'o_' + safe_name
return safe_name.lower()
def sanitize_table(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
# Replace non-alphanumeric characters with underscore
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
safe_name = 'o_' + safe_name
return safe_name.lower()
def ensure_table(self, keyspace: str, table_name: str, schema: RowSchema):
"""Ensure table exists with proper structure"""
table_key = f"{keyspace}.{table_name}"
if table_key in self.known_tables.get(keyspace, set()):
return
# Ensure keyspace exists first
self.ensure_keyspace(keyspace)
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
# Build column definitions
columns = ["collection text"] # Collection is always part of table
primary_key_fields = []
clustering_fields = []
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
cassandra_type = self.get_cassandra_type(field.type, field.size)
columns.append(f"{safe_field_name} {cassandra_type}")
if field.primary:
primary_key_fields.append(safe_field_name)
# Build primary key - collection is always first in partition key
if primary_key_fields:
primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))"
else:
# If no primary key defined, use collection and a synthetic id
columns.append("synthetic_id uuid")
primary_key = "PRIMARY KEY ((collection, synthetic_id))"
# Create table
create_table_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} (
{', '.join(columns)},
{primary_key}
)
"""
try:
self.session.execute(create_table_cql)
if keyspace not in self.known_tables:
self.known_tables[keyspace] = set()
self.known_tables[keyspace].add(table_key)
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
# Create secondary indexes for indexed fields
for field in schema.fields:
if field.indexed and not field.primary:
safe_field_name = self.sanitize_name(field.name)
index_name = f"{safe_table}_{safe_field_name}_idx"
create_index_cql = f"""
CREATE INDEX IF NOT EXISTS {index_name}
ON {safe_keyspace}.{safe_table} ({safe_field_name})
"""
try:
self.session.execute(create_index_cql)
logger.info(f"Created index: {index_name}")
except Exception as e:
logger.warning(f"Failed to create index {index_name}: {e}")
except Exception as e:
logger.error(f"Failed to create table {safe_keyspace}.{safe_table}: {e}", exc_info=True)
raise
def convert_value(self, value: Any, field_type: str) -> Any:
"""Convert value to appropriate type for Cassandra"""
if value is None:
return None
try:
if field_type == "integer":
return int(value)
elif field_type == "float":
return float(value)
elif field_type == "boolean":
if isinstance(value, str):
return value.lower() in ('true', '1', 'yes')
return bool(value)
elif field_type == "timestamp":
# Handle timestamp conversion if needed
return value
else:
return str(value)
except Exception as e:
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
return str(value)
async def on_object(self, msg, consumer, flow):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
# Validate collection exists before accepting writes
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first via collection management API."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
# Ensure table exists
keyspace = obj.metadata.user
table_name = obj.schema_name
self.ensure_table(keyspace, table_name, schema)
# Prepare data for insertion
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
# Process each object in the batch
for obj_index, value_map in enumerate(obj.values):
# Build column names and values for this object
columns = ["collection"]
values = [obj.metadata.collection]
placeholders = ["%s"]
# Check if we need a synthetic ID
has_primary_key = any(field.primary for field in schema.fields)
if not has_primary_key:
import uuid
columns.append("synthetic_id")
values.append(uuid.uuid4())
placeholders.append("%s")
# Process fields for this object
skip_object = False
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
raw_value = value_map.get(field.name)
# Handle required fields
if field.required and raw_value is None:
logger.warning(f"Required field {field.name} is missing in object {obj_index}")
# Continue anyway - Cassandra doesn't enforce NOT NULL
# Check if primary key field is NULL
if field.primary and raw_value is None:
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
skip_object = True
break
# Convert value to appropriate type
converted_value = self.convert_value(raw_value, field.type)
columns.append(safe_field_name)
values.append(converted_value)
placeholders.append("%s")
# Skip this object if primary key validation failed
if skip_object:
continue
# Build and execute insert query for this object
insert_cql = f"""
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
"""
# Debug: Show data being inserted
logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
if len(columns) != len(values) or len(columns) != len(placeholders):
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
try:
# Convert to tuple - Cassandra driver requires tuple for parameters
self.session.execute(insert_cql, tuple(values))
except Exception as e:
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
raise
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra object store"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Ensure keyspace exists
if safe_keyspace not in self.known_keyspaces:
self.ensure_keyspace(safe_keyspace)
self.known_keyspaces.add(safe_keyspace)
# For Cassandra objects, collection is just a property in rows
# No need to create separate tables per collection
# Just mark that we've seen this collection
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection using schema information"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Check if keyspace exists
if safe_keyspace not in self.known_keyspaces:
# Query to verify keyspace exists
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
if not result.one():
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(safe_keyspace)
# Iterate over schemas we manage to delete from relevant tables
tables_deleted = 0
for schema_name, schema in self.schemas.items():
safe_table = self.sanitize_table(schema_name)
# Check if table exists
table_key = f"{user}.{schema_name}"
if table_key not in self.known_tables.get(user, set()):
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
continue
try:
# Get primary key fields from schema
primary_key_fields = [field for field in schema.fields if field.primary]
if primary_key_fields:
# Schema has primary keys: need to query for partition keys first
# Build SELECT query for primary key fields
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
select_cql = f"""
SELECT {', '.join(pk_field_names)}
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
rows = self.session.execute(select_cql, (collection,))
# Delete each row using full partition key
for row in rows:
where_clauses = ["collection = %s"]
values = [collection]
for field_name in pk_field_names:
where_clauses.append(f"{field_name} = %s")
values.append(getattr(row, field_name))
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE {' AND '.join(where_clauses)}
"""
self.session.execute(delete_cql, tuple(values))
else:
# No primary keys, uses synthetic_id
# Need to query for synthetic_ids first
select_cql = f"""
SELECT synthetic_id
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
rows = self.session.execute(select_cql, (collection,))
# Delete each row using collection and synthetic_id
for row in rows:
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE collection = %s AND synthetic_id = %s
"""
self.session.execute(delete_cql, (collection, row.synthetic_id))
tables_deleted += 1
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
except Exception as e:
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
raise
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for objects-write-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,3 @@
"""
Row embeddings storage modules.
"""

View file

@ -0,0 +1,5 @@
"""
Qdrant storage for row embeddings.
"""
from .write import Processor, run, default_ident

View file

@ -0,0 +1,4 @@
from .write import run
run()

View file

@ -0,0 +1,264 @@
"""
Row embeddings writer for Qdrant (Stage 2).
Consumes RowEmbeddings messages (which already contain computed vectors)
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair.
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (row-embeddings): Compute embeddings
Stage 2 (this processor): Store embeddings
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension}
Payload structure:
- index_name: The indexed field(s) this embedding represents
- index_value: The original list of values (for Cassandra lookup)
- text: The text that was embedded (for debugging/display)
"""
import logging
import re
import uuid
from typing import Set, Tuple
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams
from .... schema import RowEmbeddings
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.register_specification(
ConsumerSpec(
name="input",
schema=RowEmbeddings,
handler=self.on_embeddings
)
)
# Register config handler for collection management
self.register_config_handler(self.on_collection_config)
# Cache of created Qdrant collections
self.created_collections: Set[str] = set()
# Qdrant client
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def get_collection_name(
self, user: str, collection: str, schema_name: str, dimension: int
) -> str:
"""Generate Qdrant collection name"""
safe_user = self.sanitize_name(user)
safe_collection = self.sanitize_name(collection)
safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
def ensure_collection(self, collection_name: str, dimension: int):
"""Create Qdrant collection if it doesn't exist"""
if collection_name in self.created_collections:
return
if not self.qdrant.collection_exists(collection_name):
logger.info(
f"Creating Qdrant collection {collection_name} "
f"with dimension {dimension}"
)
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=dimension,
distance=Distance.COSINE
)
)
self.created_collections.add(collection_name)
async def on_embeddings(self, msg, consumer, flow):
"""Process incoming RowEmbeddings and write to Qdrant"""
embeddings = msg.value()
logger.info(
f"Writing {len(embeddings.embeddings)} embeddings for schema "
f"{embeddings.schema_name} from {embeddings.metadata.id}"
)
# Validate collection exists in config before processing
if not self.collection_exists(
embeddings.metadata.user, embeddings.metadata.collection
):
logger.warning(
f"Collection {embeddings.metadata.collection} for user "
f"{embeddings.metadata.user} does not exist in config. "
f"Dropping message."
)
return
user = embeddings.metadata.user
collection = embeddings.metadata.collection
schema_name = embeddings.schema_name
embeddings_written = 0
qdrant_collection = None
for row_emb in embeddings.embeddings:
if not row_emb.vectors:
logger.warning(
f"No vectors for index {row_emb.index_name} - skipping"
)
continue
# Use first vector (there may be multiple from different models)
for vector in row_emb.vectors:
dimension = len(vector)
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
)
embeddings_written += 1
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Collection creation via config push - collections created lazily on first write"""
logger.info(
f"Row embeddings collection create request for {user}/{collection} - "
f"will be created lazily on first write"
)
async def delete_collection(self, user: str, collection: str):
"""Delete all Qdrant collections for a given user/collection"""
try:
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(
f"Deleted {len(matching_collections)} collection(s) "
f"for {user}/{collection}"
)
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}: {e}",
exc_info=True
)
raise
async def delete_collection_schema(
self, user: str, collection: str, schema_name: str
):
"""Delete Qdrant collection for a specific user/collection/schema"""
try:
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}",
exc_info=True
)
raise
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='Qdrant API key (default: None)'
)
def run():
"""Entry point for row-embeddings-write-qdrant command"""
Processor.launch(default_ident, __doc__)

View file

@ -1,46 +1,49 @@
"""
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
Row writer for Cassandra. Input is ExtractedObject.
Writes structured rows to a unified Cassandra table with multi-index support.
Uses a single 'rows' table with the schema:
- collection: text
- schema_name: text
- index_name: text
- index_value: frozen<list<text>>
- data: map<text, text>
- source: text
Each row is written multiple times - once per indexed field defined in the schema.
"""
raise RuntimeError("This code is no longer in use")
import pulsar
import base64
import os
import argparse
import time
import json
import logging
import re
from typing import Dict, Set, Optional, Any, List, Tuple
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
from .... schema import Rows
from .... log_level import LogLevel
from .... base import Consumer
from .... schema import ExtractedObject
from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
module = "rows-write"
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
default_ident = "rows-write"
default_input_queue = "rows-store" # Default queue name
default_subscriber = module
class Processor(Consumer):
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
@ -48,99 +51,549 @@ class Processor(Consumer):
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Rows,
"cassandra_host": ','.join(hosts),
"cassandra_username": username,
"cassandra_password": password,
"id": id,
"config_type": self.config_key,
}
)
if username and password:
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
self.tables = set()
self.register_specification(
ConsumerSpec(
name="input",
schema=ExtractedObject,
handler=self.on_object
)
)
self.session.execute("""
create keyspace if not exists trustgraph
with replication = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
};
""");
# Register config handlers
self.register_config_handler(self.on_schema_config)
self.register_config_handler(self.on_collection_config)
self.session.execute("use trustgraph");
# Cache of known keyspaces and whether tables exist
self.known_keyspaces: Set[str] = set()
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
async def handle(self, msg):
# Cache of registered (collection, schema_name) pairs
self.registered_partitions: Set[Tuple[str, str]] = set()
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Cassandra session
self.cluster = None
self.session = None
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
v = msg.value()
name = v.row_schema.name
if name not in self.tables:
# FIXME: SQL injection?
pkey = []
stmt = "create table if not exists " + name + " ( "
for field in v.row_schema.fields:
stmt += field.name + " text, "
if field.primary:
pkey.append(field.name)
stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));"
self.session.execute(stmt)
self.tables.add(name);
for row in v.rows:
field_names = []
values = []
for field in v.row_schema.fields:
field_names.append(field.name)
values.append(row[field.name])
# FIXME: SQL injection?
stmt = (
"insert into " + name + " (" + ", ".join(field_names) +
") values (" + ",".join(["%s"] * len(values)) + ")"
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session.execute(stmt, values)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
logger.error(f"Exception: {str(e)}", exc_info=True)
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# If there's an error make sure to do table creation etc.
self.tables.remove(name)
# Track which schemas changed so we can clear partition cache
old_schema_names = set(self.schemas.keys())
raise e
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Clear partition cache for schemas that changed
# This ensures next write will re-register partitions
new_schema_names = set(self.schemas.keys())
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
if changed_schemas:
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if sch not in changed_schemas
}
logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}")
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def ensure_keyspace(self, keyspace: str):
"""Ensure keyspace exists in Cassandra"""
if keyspace in self.known_keyspaces:
return
# Connect if needed
self.connect_cassandra()
# Sanitize keyspace name
safe_keyspace = self.sanitize_name(keyspace)
# Create keyspace if not exists
create_keyspace_cql = f"""
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
'replication_factor': 1
}}
"""
try:
self.session.execute(create_keyspace_cql)
self.known_keyspaces.add(keyspace)
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
except Exception as e:
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
raise
def ensure_tables(self, keyspace: str):
"""Ensure unified rows and row_partitions tables exist"""
if keyspace in self.tables_initialized:
return
# Ensure keyspace exists first
self.ensure_keyspace(keyspace)
safe_keyspace = self.sanitize_name(keyspace)
# Create unified rows table
create_rows_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows (
collection text,
schema_name text,
index_name text,
index_value frozen<list<text>>,
data map<text, text>,
source text,
PRIMARY KEY ((collection, schema_name, index_name), index_value)
)
"""
# Create row_partitions tracking table
create_partitions_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions (
collection text,
schema_name text,
index_name text,
PRIMARY KEY ((collection), schema_name, index_name)
)
"""
try:
self.session.execute(create_rows_cql)
logger.info(f"Ensured rows table exists: {safe_keyspace}.rows")
self.session.execute(create_partitions_cql)
logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions")
self.tables_initialized.add(keyspace)
except Exception as e:
logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True)
raise
def get_index_names(self, schema: RowSchema) -> List[str]:
"""
Get all index names for a schema.
Returns list of index_name strings (single field names or comma-joined composites).
"""
index_names = []
for field in schema.fields:
# Primary key fields are treated as indexes
if field.primary:
index_names.append(field.name)
# Indexed fields
elif field.indexed:
index_names.append(field.name)
# TODO: Support composite indexes in the future
# For now, each indexed field is a single-field index
return index_names
def register_partitions(self, keyspace: str, collection: str, schema_name: str):
"""
Register partition entries for a (collection, schema_name) pair.
Called once on first row for each pair.
"""
cache_key = (collection, schema_name)
if cache_key in self.registered_partitions:
return
schema = self.schemas.get(schema_name)
if not schema:
logger.warning(f"Cannot register partitions - schema {schema_name} not found")
return
safe_keyspace = self.sanitize_name(keyspace)
index_names = self.get_index_names(schema)
# Insert partition entries for each index
insert_cql = f"""
INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name)
VALUES (%s, %s, %s)
"""
for index_name in index_names:
try:
self.session.execute(insert_cql, (collection, schema_name, index_name))
except Exception as e:
logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}")
self.registered_partitions.add(cache_key)
logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}")
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
"""
Build the index_value list for a given index.
For single-field indexes, returns a single-element list.
For composite indexes (comma-separated), returns multiple elements.
"""
field_names = [f.strip() for f in index_name.split(',')]
values = []
for field_name in field_names:
value = value_map.get(field_name)
# Convert to string for storage
values.append(str(value) if value is not None else "")
return values
async def on_object(self, msg, consumer, flow):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
logger.info(
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
f"from {obj.metadata.id}"
)
# Validate collection exists before accepting writes
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first via collection management API."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
keyspace = obj.metadata.user
collection = obj.metadata.collection
schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or ''
# Ensure tables exist
self.ensure_tables(keyspace)
# Register partitions if first time seeing this (collection, schema_name)
self.register_partitions(keyspace, collection, schema_name)
safe_keyspace = self.sanitize_name(keyspace)
# Get all index names for this schema
index_names = self.get_index_names(schema)
if not index_names:
logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable")
return
# Prepare insert statement
insert_cql = f"""
INSERT INTO {safe_keyspace}.rows
(collection, schema_name, index_name, index_value, data, source)
VALUES (%s, %s, %s, %s, %s, %s)
"""
# Process each row in the batch
rows_written = 0
for row_index, value_map in enumerate(obj.values):
# Convert all values to strings for the data map
data_map = {}
for field in schema.fields:
raw_value = value_map.get(field.name)
if raw_value is not None:
data_map[field.name] = str(raw_value)
# Write one copy per index
for index_name in index_names:
index_value = self.build_index_value(value_map, index_name)
# Skip if index value is empty/null
if not index_value or all(v == "" for v in index_value):
logger.debug(
f"Skipping index {index_name} for row {row_index} - "
f"empty index value"
)
continue
try:
self.session.execute(
insert_cql,
(collection, schema_name, index_name, index_value, data_map, source)
)
rows_written += 1
except Exception as e:
logger.error(
f"Failed to insert row {row_index} for index {index_name}: {e}",
exc_info=True
)
raise
logger.info(
f"Wrote {rows_written} index entries for {len(obj.values)} rows "
f"({len(index_names)} indexes per row)"
)
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
# Connect if not already connected
self.connect_cassandra()
# Ensure tables exist
self.ensure_tables(user)
logger.info(f"Collection {collection} ready for user {user}")
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
# Connect if not already connected
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
# Check if keyspace exists
if user not in self.known_keyspaces:
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
if not result.one():
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(user)
# Discover all partitions for this collection
select_partitions_cql = f"""
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
WHERE collection = %s
"""
try:
partitions = self.session.execute(select_partitions_cql, (collection,))
partition_list = list(partitions)
except Exception as e:
logger.error(f"Failed to query partitions for collection {collection}: {e}")
raise
# Delete each partition from rows table
delete_rows_cql = f"""
DELETE FROM {safe_keyspace}.rows
WHERE collection = %s AND schema_name = %s AND index_name = %s
"""
partitions_deleted = 0
for partition in partition_list:
try:
self.session.execute(
delete_rows_cql,
(collection, partition.schema_name, partition.index_name)
)
partitions_deleted += 1
except Exception as e:
logger.error(
f"Failed to delete partition {collection}/{partition.schema_name}/"
f"{partition.index_name}: {e}"
)
raise
# Clean up row_partitions entries
delete_partitions_cql = f"""
DELETE FROM {safe_keyspace}.row_partitions
WHERE collection = %s
"""
try:
self.session.execute(delete_partitions_cql, (collection,))
except Exception as e:
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
raise
# Clear from local cache
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if col != collection
}
logger.info(
f"Deleted collection {collection}: {partitions_deleted} partitions "
f"from keyspace {safe_keyspace}"
)
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
# Connect if not already connected
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
# Discover partitions for this collection + schema
select_partitions_cql = f"""
SELECT index_name FROM {safe_keyspace}.row_partitions
WHERE collection = %s AND schema_name = %s
"""
try:
partitions = self.session.execute(select_partitions_cql, (collection, schema_name))
partition_list = list(partitions)
except Exception as e:
logger.error(
f"Failed to query partitions for {collection}/{schema_name}: {e}"
)
raise
# Delete each partition from rows table
delete_rows_cql = f"""
DELETE FROM {safe_keyspace}.rows
WHERE collection = %s AND schema_name = %s AND index_name = %s
"""
partitions_deleted = 0
for partition in partition_list:
try:
self.session.execute(
delete_rows_cql,
(collection, schema_name, partition.index_name)
)
partitions_deleted += 1
except Exception as e:
logger.error(
f"Failed to delete partition {collection}/{schema_name}/"
f"{partition.index_name}: {e}"
)
raise
# Clean up row_partitions entries for this schema
delete_partitions_cql = f"""
DELETE FROM {safe_keyspace}.row_partitions
WHERE collection = %s AND schema_name = %s
"""
try:
self.session.execute(delete_partitions_cql, (collection, schema_name))
except Exception as e:
logger.error(
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"
)
raise
# Clear from local cache
self.registered_partitions.discard((collection, schema_name))
logger.info(
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
f"from keyspace {safe_keyspace}"
)
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
Processor.launch(module, __doc__)
"""Entry point for rows-write-cassandra command"""
Processor.launch(default_ident, __doc__)