mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-09 06:45:13 +02:00
Structured data 2 (#645)
* Structured data refactor - multi-index tables, remove need for manual mods to the Cassandra tables * Tech spec updated to track implementation
This commit is contained in:
parent
5ffad92345
commit
1809c1f56d
87 changed files with 5233 additions and 3235 deletions
|
|
@ -60,27 +60,27 @@ api-gateway = "trustgraph.gateway:run"
|
|||
chunker-recursive = "trustgraph.chunking.recursive:run"
|
||||
chunker-token = "trustgraph.chunking.token:run"
|
||||
config-svc = "trustgraph.config.service:run"
|
||||
de-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
|
||||
de-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
|
||||
de-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
|
||||
de-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
|
||||
de-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
|
||||
de-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
|
||||
doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
|
||||
doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
|
||||
doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
|
||||
doc-embeddings-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
|
||||
doc-embeddings-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
|
||||
doc-embeddings-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
|
||||
document-embeddings = "trustgraph.embeddings.document_embeddings:run"
|
||||
document-rag = "trustgraph.retrieval.document_rag:run"
|
||||
embeddings-fastembed = "trustgraph.embeddings.fastembed:run"
|
||||
embeddings-ollama = "trustgraph.embeddings.ollama:run"
|
||||
ge-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
|
||||
ge-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
|
||||
ge-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
|
||||
ge-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
|
||||
ge-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
|
||||
ge-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
||||
graph-embeddings-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
|
||||
graph-embeddings-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
|
||||
graph-embeddings-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
|
||||
graph-embeddings-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
|
||||
graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
|
||||
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
||||
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
|
||||
graph-rag = "trustgraph.retrieval.graph_rag:run"
|
||||
kg-extract-agent = "trustgraph.extract.kg.agent:run"
|
||||
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
|
||||
kg-extract-objects = "trustgraph.extract.kg.objects:run"
|
||||
kg-extract-rows = "trustgraph.extract.kg.rows:run"
|
||||
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
|
||||
kg-extract-topics = "trustgraph.extract.kg.topics:run"
|
||||
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
|
||||
|
|
@ -90,8 +90,11 @@ librarian = "trustgraph.librarian:run"
|
|||
mcp-tool = "trustgraph.agent.mcp_tool:run"
|
||||
metering = "trustgraph.metering:run"
|
||||
nlp-query = "trustgraph.retrieval.nlp_query:run"
|
||||
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
|
||||
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
|
||||
rows-write-cassandra = "trustgraph.storage.rows.cassandra:run"
|
||||
rows-query-cassandra = "trustgraph.query.rows.cassandra:run"
|
||||
row-embeddings = "trustgraph.embeddings.row_embeddings:run"
|
||||
row-embeddings-write-qdrant = "trustgraph.storage.row_embeddings.qdrant:run"
|
||||
row-embeddings-query-qdrant = "trustgraph.query.row_embeddings.qdrant:run"
|
||||
pdf-decoder = "trustgraph.decoding.pdf:run"
|
||||
pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run"
|
||||
prompt-template = "trustgraph.prompt.template:run"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . embeddings import *
|
||||
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
|
||||
from . embeddings import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
|
||||
"""
|
||||
Row embeddings processor. Calls the embeddings service to compute embeddings
|
||||
for indexed field values in extracted row data.
|
||||
|
||||
Input is ExtractedObject (structured row data with schema).
|
||||
Output is RowEmbeddings (row data with embeddings for indexed fields).
|
||||
|
||||
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
|
||||
Stage 1 (this processor): Compute embeddings
|
||||
Stage 2 (row-embeddings-write-*): Store embeddings
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from ... schema import ExtractedObject, RowEmbeddings, RowIndexEmbedding
|
||||
from ... schema import RowSchema, Field
|
||||
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
|
||||
from ... base import ProducerSpec, CollectionConfigHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "row-embeddings"
|
||||
default_batch_size = 10
|
||||
|
||||
|
||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
self.batch_size = params.get("batch_size", default_batch_size)
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="input",
|
||||
schema=ExtractedObject,
|
||||
handler=self.on_message,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name="embeddings-request",
|
||||
response_name="embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="output",
|
||||
schema=RowEmbeddings
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handlers
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
self.register_config_handler(self.on_collection_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = Field(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||
"""Get all index names for a schema."""
|
||||
index_names = []
|
||||
for field in schema.fields:
|
||||
if field.primary or field.indexed:
|
||||
index_names.append(field.name)
|
||||
return index_names
|
||||
|
||||
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
|
||||
"""Build the index_value list for a given index."""
|
||||
field_names = [f.strip() for f in index_name.split(',')]
|
||||
values = []
|
||||
for field_name in field_names:
|
||||
value = value_map.get(field_name)
|
||||
values.append(str(value) if value is not None else "")
|
||||
return values
|
||||
|
||||
def build_text_for_embedding(self, index_value: List[str]) -> str:
|
||||
"""Build text representation for embedding from index values."""
|
||||
# Space-join the values for composite indexes
|
||||
return " ".join(index_value)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Process incoming ExtractedObject and compute embeddings"""
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(obj.values)} rows, "
|
||||
f"schema {obj.schema_name}, doc {obj.metadata.id}"
|
||||
)
|
||||
|
||||
# Validate collection exists before processing
|
||||
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
|
||||
logger.warning(
|
||||
f"Collection {obj.metadata.collection} for user {obj.metadata.user} "
|
||||
f"does not exist in config. Dropping message."
|
||||
)
|
||||
return
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||
return
|
||||
|
||||
# Get all index names for this schema
|
||||
index_names = self.get_index_names(schema)
|
||||
|
||||
if not index_names:
|
||||
logger.warning(f"Schema {obj.schema_name} has no indexed fields - skipping")
|
||||
return
|
||||
|
||||
# Track unique texts to avoid duplicate embeddings
|
||||
# text -> (index_name, index_value)
|
||||
texts_to_embed: Dict[str, tuple] = {}
|
||||
|
||||
# Collect all texts that need embeddings
|
||||
for value_map in obj.values:
|
||||
for index_name in index_names:
|
||||
index_value = self.build_index_value(value_map, index_name)
|
||||
|
||||
# Skip empty values
|
||||
if not index_value or all(v == "" for v in index_value):
|
||||
continue
|
||||
|
||||
text = self.build_text_for_embedding(index_value)
|
||||
if text and text not in texts_to_embed:
|
||||
texts_to_embed[text] = (index_name, index_value)
|
||||
|
||||
if not texts_to_embed:
|
||||
logger.info("No texts to embed")
|
||||
return
|
||||
|
||||
# Compute embeddings
|
||||
embeddings_list = []
|
||||
|
||||
try:
|
||||
for text, (index_name, index_value) in texts_to_embed.items():
|
||||
vectors = await flow("embeddings-request").embed(text=text)
|
||||
|
||||
embeddings_list.append(
|
||||
RowIndexEmbedding(
|
||||
index_name=index_name,
|
||||
index_value=index_value,
|
||||
text=text,
|
||||
vectors=vectors
|
||||
)
|
||||
)
|
||||
|
||||
# Send in batches to avoid oversized messages
|
||||
for i in range(0, len(embeddings_list), self.batch_size):
|
||||
batch = embeddings_list[i:i + self.batch_size]
|
||||
result = RowEmbeddings(
|
||||
metadata=obj.metadata,
|
||||
schema_name=obj.schema_name,
|
||||
embeddings=batch,
|
||||
)
|
||||
await flow("output").send(result)
|
||||
|
||||
logger.info(
|
||||
f"Computed {len(embeddings_list)} embeddings for "
|
||||
f"{len(obj.values)} rows ({len(index_names)} indexes)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Exception during embedding computation", exc_info=True)
|
||||
raise e
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
"""Collection creation notification - no action needed for embedding stage"""
|
||||
logger.debug(f"Row embeddings collection notification for {user}/{collection}")
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Collection deletion notification - no action needed for embedding stage"""
|
||||
logger.debug(f"Row embeddings collection delete notification for {user}/{collection}")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=default_batch_size,
|
||||
help=f'Maximum embeddings per output message (default: {default_batch_size})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Object extraction service - extracts structured objects from text chunks
|
||||
Row extraction service - extracts structured rows from text chunks
|
||||
based on configured schemas.
|
||||
"""
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
|||
from .... base import PromptClientSpec
|
||||
from .... messaging.translators import row_schema_translator
|
||||
|
||||
default_ident = "kg-extract-objects"
|
||||
default_ident = "kg-extract-rows"
|
||||
|
||||
|
||||
def convert_values_to_strings(obj: Dict[str, Any]) -> Dict[str, str]:
|
||||
|
|
@ -310,5 +310,5 @@ class Processor(FlowProcessor):
|
|||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
"""Entry point for kg-extract-objects command"""
|
||||
"""Entry point for kg-extract-rows command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -20,7 +20,7 @@ from . prompt import PromptRequestor
|
|||
from . graph_rag import GraphRagRequestor
|
||||
from . document_rag import DocumentRagRequestor
|
||||
from . triples_query import TriplesQueryRequestor
|
||||
from . objects_query import ObjectsQueryRequestor
|
||||
from . rows_query import RowsQueryRequestor
|
||||
from . nlp_query import NLPQueryRequestor
|
||||
from . structured_query import StructuredQueryRequestor
|
||||
from . structured_diag import StructuredDiagRequestor
|
||||
|
|
@ -40,7 +40,7 @@ from . triples_import import TriplesImport
|
|||
from . graph_embeddings_import import GraphEmbeddingsImport
|
||||
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||
from . entity_contexts_import import EntityContextsImport
|
||||
from . objects_import import ObjectsImport
|
||||
from . rows_import import RowsImport
|
||||
|
||||
from . core_export import CoreExport
|
||||
from . core_import import CoreImport
|
||||
|
|
@ -58,7 +58,7 @@ request_response_dispatchers = {
|
|||
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
||||
"document-embeddings": DocumentEmbeddingsQueryRequestor,
|
||||
"triples": TriplesQueryRequestor,
|
||||
"objects": ObjectsQueryRequestor,
|
||||
"rows": RowsQueryRequestor,
|
||||
"nlp-query": NLPQueryRequestor,
|
||||
"structured-query": StructuredQueryRequestor,
|
||||
"structured-diag": StructuredDiagRequestor,
|
||||
|
|
@ -89,7 +89,7 @@ import_dispatchers = {
|
|||
"graph-embeddings": GraphEmbeddingsImport,
|
||||
"document-embeddings": DocumentEmbeddingsImport,
|
||||
"entity-contexts": EntityContextsImport,
|
||||
"objects": ObjectsImport,
|
||||
"rows": RowsImport,
|
||||
}
|
||||
|
||||
class DispatcherWrapper:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from . serialize import to_subgraph
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ObjectsImport:
|
||||
class RowsImport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, backend, queue
|
||||
|
|
@ -20,7 +20,7 @@ class ObjectsImport:
|
|||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
|
||||
|
||||
self.publisher = Publisher(
|
||||
backend, topic = queue, schema = ExtractedObject
|
||||
)
|
||||
|
|
@ -73,4 +73,4 @@ class ObjectsImport:
|
|||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
self.ws = None
|
||||
self.ws = None
|
||||
|
|
@ -1,30 +1,30 @@
|
|||
from ... schema import ObjectsQueryRequest, ObjectsQueryResponse
|
||||
from ... schema import RowsQueryRequest, RowsQueryResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class ObjectsQueryRequestor(ServiceRequestor):
|
||||
class RowsQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(ObjectsQueryRequestor, self).__init__(
|
||||
super(RowsQueryRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=ObjectsQueryRequest,
|
||||
response_schema=ObjectsQueryResponse,
|
||||
request_schema=RowsQueryRequest,
|
||||
response_schema=RowsQueryResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("objects-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("objects-query")
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("rows-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("rows-query")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
22
trustgraph-flow/trustgraph/query/graphql/__init__.py
Normal file
22
trustgraph-flow/trustgraph/query/graphql/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""
|
||||
Shared GraphQL utilities for row query services.
|
||||
|
||||
This module provides reusable GraphQL components including:
|
||||
- Filter types (IntFilter, StringFilter, FloatFilter)
|
||||
- Dynamic schema generation from RowSchema definitions
|
||||
- Filter parsing utilities
|
||||
"""
|
||||
|
||||
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
|
||||
from .schema import GraphQLSchemaBuilder
|
||||
from .filters import parse_filter_key, parse_where_clause
|
||||
|
||||
__all__ = [
|
||||
"IntFilter",
|
||||
"StringFilter",
|
||||
"FloatFilter",
|
||||
"SortDirection",
|
||||
"GraphQLSchemaBuilder",
|
||||
"parse_filter_key",
|
||||
"parse_where_clause",
|
||||
]
|
||||
104
trustgraph-flow/trustgraph/query/graphql/filters.py
Normal file
104
trustgraph-flow/trustgraph/query/graphql/filters.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""
|
||||
Filter parsing utilities for GraphQL row queries.
|
||||
|
||||
Provides functions to parse GraphQL filter objects into a normalized
|
||||
format that can be used by different query backends.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_filter_key(filter_key: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse GraphQL filter key into field name and operator.
|
||||
|
||||
Supports common GraphQL filter patterns:
|
||||
- field_name -> (field_name, "eq")
|
||||
- field_name_gt -> (field_name, "gt")
|
||||
- field_name_gte -> (field_name, "gte")
|
||||
- field_name_lt -> (field_name, "lt")
|
||||
- field_name_lte -> (field_name, "lte")
|
||||
- field_name_in -> (field_name, "in")
|
||||
|
||||
Args:
|
||||
filter_key: The filter key string from GraphQL
|
||||
|
||||
Returns:
|
||||
Tuple of (field_name, operator)
|
||||
"""
|
||||
if not filter_key:
|
||||
return ("", "eq")
|
||||
|
||||
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
|
||||
|
||||
for op_suffix in operators:
|
||||
if filter_key.endswith(op_suffix):
|
||||
field_name = filter_key[:-len(op_suffix)]
|
||||
operator = op_suffix[1:] # Remove the leading underscore
|
||||
return (field_name, operator)
|
||||
|
||||
# Default to equality if no operator suffix found
|
||||
return (filter_key, "eq")
|
||||
|
||||
|
||||
def parse_where_clause(where_obj) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse the idiomatic nested GraphQL filter structure into a flat dict.
|
||||
|
||||
Converts Strawberry filter objects (StringFilter, IntFilter, etc.)
|
||||
into a dictionary mapping field names with operators to values.
|
||||
|
||||
Example:
|
||||
Input: where_obj with email.eq = "foo@bar.com"
|
||||
Output: {"email": "foo@bar.com"}
|
||||
|
||||
Input: where_obj with age.gt = 21
|
||||
Output: {"age_gt": 21}
|
||||
|
||||
Args:
|
||||
where_obj: The GraphQL where clause object
|
||||
|
||||
Returns:
|
||||
Dictionary mapping field_operator keys to values
|
||||
"""
|
||||
if not where_obj:
|
||||
return {}
|
||||
|
||||
conditions = {}
|
||||
|
||||
logger.debug(f"Parsing where clause: {where_obj}")
|
||||
|
||||
for field_name, filter_obj in where_obj.__dict__.items():
|
||||
if filter_obj is None:
|
||||
continue
|
||||
|
||||
logger.debug(f"Processing field {field_name} with filter_obj: {filter_obj}")
|
||||
|
||||
if hasattr(filter_obj, '__dict__'):
|
||||
# This is a filter object (StringFilter, IntFilter, etc.)
|
||||
for operator, value in filter_obj.__dict__.items():
|
||||
if value is not None:
|
||||
logger.debug(f"Found operator {operator} with value {value}")
|
||||
# Map GraphQL operators to our internal format
|
||||
if operator == "eq":
|
||||
conditions[field_name] = value
|
||||
elif operator in ["gt", "gte", "lt", "lte"]:
|
||||
conditions[f"{field_name}_{operator}"] = value
|
||||
elif operator == "in_":
|
||||
conditions[f"{field_name}_in"] = value
|
||||
elif operator == "contains":
|
||||
conditions[f"{field_name}_contains"] = value
|
||||
elif operator == "startsWith":
|
||||
conditions[f"{field_name}_startsWith"] = value
|
||||
elif operator == "endsWith":
|
||||
conditions[f"{field_name}_endsWith"] = value
|
||||
elif operator == "not_":
|
||||
conditions[f"{field_name}_not"] = value
|
||||
elif operator == "not_in":
|
||||
conditions[f"{field_name}_not_in"] = value
|
||||
|
||||
logger.debug(f"Final parsed conditions: {conditions}")
|
||||
return conditions
|
||||
251
trustgraph-flow/trustgraph/query/graphql/schema.py
Normal file
251
trustgraph-flow/trustgraph/query/graphql/schema.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
"""
|
||||
Dynamic GraphQL schema generation from RowSchema definitions.
|
||||
|
||||
Provides a builder class that creates Strawberry GraphQL schemas
|
||||
from TrustGraph RowSchema definitions, with pluggable query backends.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Callable, Awaitable
|
||||
|
||||
import strawberry
|
||||
from strawberry import Schema
|
||||
from strawberry.types import Info
|
||||
|
||||
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for query callback function
|
||||
QueryCallback = Callable[
|
||||
[str, str, str, Any, Dict[str, Any], int, Optional[str], Optional[SortDirection]],
|
||||
Awaitable[List[Dict[str, Any]]]
|
||||
]
|
||||
|
||||
|
||||
class GraphQLSchemaBuilder:
|
||||
"""
|
||||
Builds GraphQL schemas from RowSchema definitions.
|
||||
|
||||
This class extracts the GraphQL schema generation logic so it can be
|
||||
reused across different query backends (Cassandra, etc.).
|
||||
|
||||
Usage:
|
||||
builder = GraphQLSchemaBuilder()
|
||||
|
||||
# Add schemas
|
||||
for name, row_schema in schemas.items():
|
||||
builder.add_schema(name, row_schema)
|
||||
|
||||
# Build with a query callback
|
||||
schema = builder.build(query_callback)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.schemas: Dict[str, Any] = {} # name -> RowSchema
|
||||
self.graphql_types: Dict[str, type] = {}
|
||||
self.filter_types: Dict[str, type] = {}
|
||||
|
||||
def add_schema(self, name: str, row_schema) -> None:
|
||||
"""
|
||||
Add a RowSchema to the builder.
|
||||
|
||||
Args:
|
||||
name: The schema name (used as the GraphQL query field name)
|
||||
row_schema: The RowSchema object defining fields
|
||||
"""
|
||||
self.schemas[name] = row_schema
|
||||
self.graphql_types[name] = self._create_graphql_type(name, row_schema)
|
||||
self.filter_types[name] = self._create_filter_type(name, row_schema)
|
||||
logger.debug(f"Added schema {name} with {len(row_schema.fields)} fields")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all schemas from the builder."""
|
||||
self.schemas = {}
|
||||
self.graphql_types = {}
|
||||
self.filter_types = {}
|
||||
|
||||
def build(self, query_callback: QueryCallback) -> Optional[Schema]:
|
||||
"""
|
||||
Build the GraphQL schema with the provided query callback.
|
||||
|
||||
The query callback will be invoked when resolving queries, with:
|
||||
- user: str
|
||||
- collection: str
|
||||
- schema_name: str
|
||||
- row_schema: RowSchema
|
||||
- filters: Dict[str, Any]
|
||||
- limit: int
|
||||
- order_by: Optional[str]
|
||||
- direction: Optional[SortDirection]
|
||||
|
||||
It should return a list of row dictionaries.
|
||||
|
||||
Args:
|
||||
query_callback: Async function to execute queries
|
||||
|
||||
Returns:
|
||||
Strawberry Schema, or None if no schemas are loaded
|
||||
"""
|
||||
if not self.schemas:
|
||||
logger.warning("No schemas loaded, cannot generate GraphQL schema")
|
||||
return None
|
||||
|
||||
# Create the Query class with resolvers
|
||||
query_dict = {'__annotations__': {}}
|
||||
|
||||
for schema_name, row_schema in self.schemas.items():
|
||||
graphql_type = self.graphql_types[schema_name]
|
||||
filter_type = self.filter_types[schema_name]
|
||||
|
||||
# Create resolver function for this schema
|
||||
resolver_func = self._make_resolver(
|
||||
schema_name, row_schema, graphql_type, filter_type, query_callback
|
||||
)
|
||||
|
||||
# Add field to query dictionary
|
||||
query_dict[schema_name] = strawberry.field(resolver=resolver_func)
|
||||
query_dict['__annotations__'][schema_name] = List[graphql_type]
|
||||
|
||||
# Create the Query class
|
||||
Query = type('Query', (), query_dict)
|
||||
Query = strawberry.type(Query)
|
||||
|
||||
# Create the schema with auto_camel_case disabled to keep snake_case field names
|
||||
schema = strawberry.Schema(
|
||||
query=Query,
|
||||
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
|
||||
)
|
||||
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
|
||||
return schema
|
||||
|
||||
def _get_python_type(self, field_type: str):
|
||||
"""Convert schema field type to Python type for GraphQL."""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"timestamp": str, # Use string for timestamps in GraphQL
|
||||
"date": str,
|
||||
"time": str,
|
||||
"uuid": str
|
||||
}
|
||||
return type_mapping.get(field_type, str)
|
||||
|
||||
def _create_graphql_type(self, schema_name: str, row_schema) -> type:
|
||||
"""Create a GraphQL output type from a RowSchema."""
|
||||
# Create annotations for the GraphQL type
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
for field in row_schema.fields:
|
||||
python_type = self._get_python_type(field.type)
|
||||
|
||||
# Make field optional if not required
|
||||
if not field.required and not field.primary:
|
||||
annotations[field.name] = Optional[python_type]
|
||||
defaults[field.name] = None
|
||||
else:
|
||||
annotations[field.name] = python_type
|
||||
|
||||
# Create the class dynamically
|
||||
type_name = f"{schema_name.capitalize()}Type"
|
||||
graphql_class = type(
|
||||
type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry decorator
|
||||
return strawberry.type(graphql_class)
|
||||
|
||||
def _create_filter_type(self, schema_name: str, row_schema) -> type:
|
||||
"""Create a dynamic filter input type for a schema."""
|
||||
filter_type_name = f"{schema_name.capitalize()}Filter"
|
||||
|
||||
# Add __annotations__ and defaults for the fields
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
logger.debug(f"Creating filter type {filter_type_name} for schema {schema_name}")
|
||||
|
||||
for field in row_schema.fields:
|
||||
logger.debug(
|
||||
f"Field {field.name}: type={field.type}, "
|
||||
f"indexed={field.indexed}, primary={field.primary}"
|
||||
)
|
||||
|
||||
# Allow filtering on any field
|
||||
if field.type == "integer":
|
||||
annotations[field.name] = Optional[IntFilter]
|
||||
defaults[field.name] = None
|
||||
elif field.type == "float":
|
||||
annotations[field.name] = Optional[FloatFilter]
|
||||
defaults[field.name] = None
|
||||
elif field.type == "string":
|
||||
annotations[field.name] = Optional[StringFilter]
|
||||
defaults[field.name] = None
|
||||
|
||||
logger.debug(
|
||||
f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}"
|
||||
)
|
||||
|
||||
# Create the class dynamically
|
||||
FilterType = type(
|
||||
filter_type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry input decorator
|
||||
FilterType = strawberry.input(FilterType)
|
||||
|
||||
return FilterType
|
||||
|
||||
def _make_resolver(
|
||||
self,
|
||||
schema_name: str,
|
||||
row_schema,
|
||||
graphql_type: type,
|
||||
filter_type: type,
|
||||
query_callback: QueryCallback
|
||||
):
|
||||
"""Create a resolver function for a schema."""
|
||||
from .filters import parse_where_clause
|
||||
|
||||
async def resolver(
|
||||
info: Info,
|
||||
where: Optional[filter_type] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[SortDirection] = None,
|
||||
limit: Optional[int] = 100
|
||||
) -> List[graphql_type]:
|
||||
# Get context values
|
||||
user = info.context["user"]
|
||||
collection = info.context["collection"]
|
||||
|
||||
# Parse the where clause
|
||||
filters = parse_where_clause(where)
|
||||
|
||||
# Call the query backend
|
||||
results = await query_callback(
|
||||
user, collection, schema_name, row_schema,
|
||||
filters, limit, order_by, direction
|
||||
)
|
||||
|
||||
# Convert to GraphQL types
|
||||
graphql_results = []
|
||||
for row in results:
|
||||
graphql_obj = graphql_type(**row)
|
||||
graphql_results.append(graphql_obj)
|
||||
|
||||
return graphql_results
|
||||
|
||||
return resolver
|
||||
56
trustgraph-flow/trustgraph/query/graphql/types.py
Normal file
56
trustgraph-flow/trustgraph/query/graphql/types.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""
|
||||
GraphQL filter and sort types for row queries.
|
||||
|
||||
These types are used to build dynamic GraphQL schemas for querying
|
||||
structured row data.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from enum import Enum
|
||||
|
||||
import strawberry
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class IntFilter:
|
||||
"""Filter type for integer fields."""
|
||||
eq: Optional[int] = None
|
||||
gt: Optional[int] = None
|
||||
gte: Optional[int] = None
|
||||
lt: Optional[int] = None
|
||||
lte: Optional[int] = None
|
||||
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[int] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[int]] = None
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class StringFilter:
|
||||
"""Filter type for string fields."""
|
||||
eq: Optional[str] = None
|
||||
contains: Optional[str] = None
|
||||
startsWith: Optional[str] = None
|
||||
endsWith: Optional[str] = None
|
||||
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[str] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[str]] = None
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class FloatFilter:
|
||||
"""Filter type for float fields."""
|
||||
eq: Optional[float] = None
|
||||
gt: Optional[float] = None
|
||||
gte: Optional[float] = None
|
||||
lt: Optional[float] = None
|
||||
lte: Optional[float] = None
|
||||
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[float] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[float]] = None
|
||||
|
||||
|
||||
@strawberry.enum
|
||||
class SortDirection(Enum):
|
||||
"""Sort direction for query results."""
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
|
@ -1,738 +0,0 @@
|
|||
"""
|
||||
Objects query service using GraphQL. Input is a GraphQL query with variables.
|
||||
Output is GraphQL response data with any errors.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, List, Set
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
import strawberry
|
||||
from strawberry import Schema
|
||||
from strawberry.types import Info
|
||||
from strawberry.scalars import JSON
|
||||
from strawberry.tools import create_type
|
||||
|
||||
from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from .... schema import Error, RowSchema, Field as SchemaField
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "objects-query"
|
||||
|
||||
# GraphQL filter input types
|
||||
@strawberry.input
|
||||
class IntFilter:
|
||||
eq: Optional[int] = None
|
||||
gt: Optional[int] = None
|
||||
gte: Optional[int] = None
|
||||
lt: Optional[int] = None
|
||||
lte: Optional[int] = None
|
||||
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[int] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[int]] = None
|
||||
|
||||
@strawberry.input
|
||||
class StringFilter:
|
||||
eq: Optional[str] = None
|
||||
contains: Optional[str] = None
|
||||
startsWith: Optional[str] = None
|
||||
endsWith: Optional[str] = None
|
||||
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[str] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[str]] = None
|
||||
|
||||
@strawberry.input
|
||||
class FloatFilter:
|
||||
eq: Optional[float] = None
|
||||
gt: Optional[float] = None
|
||||
gte: Optional[float] = None
|
||||
lt: Optional[float] = None
|
||||
lte: Optional[float] = None
|
||||
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[float] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[float]] = None
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = ObjectsQueryRequest,
|
||||
handler = self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = ObjectsQueryResponse,
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for schema updates
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
# GraphQL schema
|
||||
self.graphql_schema: Optional[Schema] = None
|
||||
|
||||
# GraphQL types cache
|
||||
self.graphql_types: Dict[str, type] = {}
|
||||
|
||||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
# Known keyspaces and tables
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
self.known_tables: Dict[str, Set[str]] = {}
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def sanitize_table(self, name: str) -> str:
|
||||
"""Sanitize table names for Cassandra compatibility"""
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def parse_filter_key(self, filter_key: str) -> tuple[str, str]:
|
||||
"""Parse GraphQL filter key into field name and operator"""
|
||||
if not filter_key:
|
||||
return ("", "eq")
|
||||
|
||||
# Support common GraphQL filter patterns:
|
||||
# field_name -> (field_name, "eq")
|
||||
# field_name_gt -> (field_name, "gt")
|
||||
# field_name_gte -> (field_name, "gte")
|
||||
# field_name_lt -> (field_name, "lt")
|
||||
# field_name_lte -> (field_name, "lte")
|
||||
# field_name_in -> (field_name, "in")
|
||||
|
||||
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
|
||||
|
||||
for op_suffix in operators:
|
||||
if filter_key.endswith(op_suffix):
|
||||
field_name = filter_key[:-len(op_suffix)]
|
||||
operator = op_suffix[1:] # Remove the leading underscore
|
||||
return (field_name, operator)
|
||||
|
||||
# Default to equality if no operator suffix found
|
||||
return (filter_key, "eq")
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
self.graphql_types = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = SchemaField(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
# Regenerate GraphQL schema
|
||||
self.generate_graphql_schema()
|
||||
|
||||
def get_python_type(self, field_type: str):
|
||||
"""Convert schema field type to Python type for GraphQL"""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"timestamp": str, # Use string for timestamps in GraphQL
|
||||
"date": str,
|
||||
"time": str,
|
||||
"uuid": str
|
||||
}
|
||||
return type_mapping.get(field_type, str)
|
||||
|
||||
def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type:
|
||||
"""Create a GraphQL type from a RowSchema"""
|
||||
|
||||
# Create annotations for the GraphQL type
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
for field in row_schema.fields:
|
||||
python_type = self.get_python_type(field.type)
|
||||
|
||||
# Make field optional if not required
|
||||
if not field.required and not field.primary:
|
||||
annotations[field.name] = Optional[python_type]
|
||||
defaults[field.name] = None
|
||||
else:
|
||||
annotations[field.name] = python_type
|
||||
|
||||
# Create the class dynamically
|
||||
type_name = f"{schema_name.capitalize()}Type"
|
||||
graphql_class = type(
|
||||
type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry decorator
|
||||
return strawberry.type(graphql_class)
|
||||
|
||||
def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema):
|
||||
"""Create a dynamic filter input type for a schema"""
|
||||
# Create the filter type dynamically
|
||||
filter_type_name = f"{schema_name.capitalize()}Filter"
|
||||
|
||||
# Add __annotations__ and defaults for the fields
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}")
|
||||
|
||||
for field in row_schema.fields:
|
||||
logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}")
|
||||
|
||||
# Allow filtering on any field for now, not just indexed/primary
|
||||
# if field.indexed or field.primary:
|
||||
if field.type == "integer":
|
||||
annotations[field.name] = Optional[IntFilter]
|
||||
defaults[field.name] = None
|
||||
logger.info(f"Added IntFilter for {field.name}")
|
||||
elif field.type == "float":
|
||||
annotations[field.name] = Optional[FloatFilter]
|
||||
defaults[field.name] = None
|
||||
logger.info(f"Added FloatFilter for {field.name}")
|
||||
elif field.type == "string":
|
||||
annotations[field.name] = Optional[StringFilter]
|
||||
defaults[field.name] = None
|
||||
logger.info(f"Added StringFilter for {field.name}")
|
||||
|
||||
logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}")
|
||||
|
||||
# Create the class dynamically
|
||||
FilterType = type(
|
||||
filter_type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry input decorator
|
||||
FilterType = strawberry.input(FilterType)
|
||||
|
||||
return FilterType
|
||||
|
||||
def create_sort_direction_enum(self):
|
||||
"""Create sort direction enum"""
|
||||
@strawberry.enum
|
||||
class SortDirection(Enum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
return SortDirection
|
||||
|
||||
def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]:
|
||||
"""Parse the idiomatic nested filter structure"""
|
||||
if not where_obj:
|
||||
return {}
|
||||
|
||||
conditions = {}
|
||||
|
||||
logger.info(f"Parsing where clause: {where_obj}")
|
||||
|
||||
for field_name, filter_obj in where_obj.__dict__.items():
|
||||
if filter_obj is None:
|
||||
continue
|
||||
|
||||
logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}")
|
||||
|
||||
if hasattr(filter_obj, '__dict__'):
|
||||
# This is a filter object (StringFilter, IntFilter, etc.)
|
||||
for operator, value in filter_obj.__dict__.items():
|
||||
if value is not None:
|
||||
logger.info(f"Found operator {operator} with value {value}")
|
||||
# Map GraphQL operators to our internal format
|
||||
if operator == "eq":
|
||||
conditions[field_name] = value
|
||||
elif operator in ["gt", "gte", "lt", "lte"]:
|
||||
conditions[f"{field_name}_{operator}"] = value
|
||||
elif operator == "in_":
|
||||
conditions[f"{field_name}_in"] = value
|
||||
elif operator == "contains":
|
||||
conditions[f"{field_name}_contains"] = value
|
||||
|
||||
logger.info(f"Final parsed conditions: {conditions}")
|
||||
return conditions
|
||||
|
||||
def generate_graphql_schema(self):
|
||||
"""Generate GraphQL schema from loaded schemas using dynamic filter types"""
|
||||
if not self.schemas:
|
||||
logger.warning("No schemas loaded, cannot generate GraphQL schema")
|
||||
self.graphql_schema = None
|
||||
return
|
||||
|
||||
# Create GraphQL types and filter types for each schema
|
||||
filter_types = {}
|
||||
sort_direction_enum = self.create_sort_direction_enum()
|
||||
|
||||
for schema_name, row_schema in self.schemas.items():
|
||||
graphql_type = self.create_graphql_type(schema_name, row_schema)
|
||||
filter_type = self.create_filter_type_for_schema(schema_name, row_schema)
|
||||
|
||||
self.graphql_types[schema_name] = graphql_type
|
||||
filter_types[schema_name] = filter_type
|
||||
|
||||
# Create the Query class with resolvers
|
||||
query_dict = {'__annotations__': {}}
|
||||
|
||||
for schema_name, row_schema in self.schemas.items():
|
||||
graphql_type = self.graphql_types[schema_name]
|
||||
filter_type = filter_types[schema_name]
|
||||
|
||||
# Create resolver function for this schema
|
||||
def make_resolver(s_name, r_schema, g_type, f_type, sort_enum):
|
||||
async def resolver(
|
||||
info: Info,
|
||||
where: Optional[f_type] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[sort_enum] = None,
|
||||
limit: Optional[int] = 100
|
||||
) -> List[g_type]:
|
||||
# Get the processor instance from context
|
||||
processor = info.context["processor"]
|
||||
user = info.context["user"]
|
||||
collection = info.context["collection"]
|
||||
|
||||
# Parse the idiomatic where clause
|
||||
filters = processor.parse_idiomatic_where_clause(where)
|
||||
|
||||
# Query Cassandra
|
||||
results = await processor.query_cassandra(
|
||||
user, collection, s_name, r_schema,
|
||||
filters, limit, order_by, direction
|
||||
)
|
||||
|
||||
# Convert to GraphQL types
|
||||
graphql_results = []
|
||||
for row in results:
|
||||
graphql_obj = g_type(**row)
|
||||
graphql_results.append(graphql_obj)
|
||||
|
||||
return graphql_results
|
||||
|
||||
return resolver
|
||||
|
||||
# Add resolver to query
|
||||
resolver_name = schema_name
|
||||
resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum)
|
||||
|
||||
# Add field to query dictionary
|
||||
query_dict[resolver_name] = strawberry.field(resolver=resolver_func)
|
||||
query_dict['__annotations__'][resolver_name] = List[graphql_type]
|
||||
|
||||
# Create the Query class
|
||||
Query = type('Query', (), query_dict)
|
||||
Query = strawberry.type(Query)
|
||||
|
||||
# Create the schema with auto_camel_case disabled to keep snake_case field names
|
||||
self.graphql_schema = strawberry.Schema(
|
||||
query=Query,
|
||||
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
|
||||
)
|
||||
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
|
||||
|
||||
async def query_cassandra(
|
||||
self,
|
||||
user: str,
|
||||
collection: str,
|
||||
schema_name: str,
|
||||
row_schema: RowSchema,
|
||||
filters: Dict[str, Any],
|
||||
limit: int,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[Any] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute a query against Cassandra"""
|
||||
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
|
||||
# Build the query
|
||||
keyspace = self.sanitize_name(user)
|
||||
table = self.sanitize_table(schema_name)
|
||||
|
||||
# Start with basic SELECT
|
||||
query = f"SELECT * FROM {keyspace}.{table}"
|
||||
|
||||
# Add WHERE clauses
|
||||
where_clauses = [f"collection = %s"]
|
||||
params = [collection]
|
||||
|
||||
# Add filters for indexed or primary key fields
|
||||
for filter_key, value in filters.items():
|
||||
if value is not None:
|
||||
# Parse field name and operator from filter key
|
||||
logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})")
|
||||
result = self.parse_filter_key(filter_key)
|
||||
logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})")
|
||||
|
||||
if not result or len(result) != 2:
|
||||
logger.error(f"parse_filter_key returned invalid result: {result}")
|
||||
continue # Skip this filter
|
||||
|
||||
field_name, operator = result
|
||||
|
||||
# Find the field in schema
|
||||
schema_field = None
|
||||
for f in row_schema.fields:
|
||||
if f.name == field_name:
|
||||
schema_field = f
|
||||
break
|
||||
|
||||
if schema_field:
|
||||
safe_field = self.sanitize_name(field_name)
|
||||
|
||||
# Build WHERE clause based on operator
|
||||
if operator == "eq":
|
||||
where_clauses.append(f"{safe_field} = %s")
|
||||
params.append(value)
|
||||
elif operator == "gt":
|
||||
where_clauses.append(f"{safe_field} > %s")
|
||||
params.append(value)
|
||||
elif operator == "gte":
|
||||
where_clauses.append(f"{safe_field} >= %s")
|
||||
params.append(value)
|
||||
elif operator == "lt":
|
||||
where_clauses.append(f"{safe_field} < %s")
|
||||
params.append(value)
|
||||
elif operator == "lte":
|
||||
where_clauses.append(f"{safe_field} <= %s")
|
||||
params.append(value)
|
||||
elif operator == "in":
|
||||
if isinstance(value, list):
|
||||
placeholders = ",".join(["%s"] * len(value))
|
||||
where_clauses.append(f"{safe_field} IN ({placeholders})")
|
||||
params.extend(value)
|
||||
else:
|
||||
# Default to equality for unknown operators
|
||||
where_clauses.append(f"{safe_field} = %s")
|
||||
params.append(value)
|
||||
|
||||
if where_clauses:
|
||||
query += " WHERE " + " AND ".join(where_clauses)
|
||||
|
||||
# Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort)
|
||||
cassandra_order_by_added = False
|
||||
if order_by and direction:
|
||||
# Validate that order_by field exists in schema
|
||||
order_field_exists = any(f.name == order_by for f in row_schema.fields)
|
||||
if order_field_exists:
|
||||
safe_order_field = self.sanitize_name(order_by)
|
||||
direction_str = "ASC" if direction.value == "asc" else "DESC"
|
||||
# Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution
|
||||
query += f" ORDER BY {safe_order_field} {direction_str}"
|
||||
|
||||
# Add limit first (must come before ALLOW FILTERING)
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
# Add ALLOW FILTERING for now (should optimize with proper indexes later)
|
||||
query += " ALLOW FILTERING"
|
||||
|
||||
# Execute query
|
||||
try:
|
||||
result = self.session.execute(query, params)
|
||||
cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY
|
||||
except Exception as e:
|
||||
# If ORDER BY fails, try without it
|
||||
if order_by and direction and "ORDER BY" in query:
|
||||
logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}")
|
||||
# Remove ORDER BY clause and retry
|
||||
query_parts = query.split(" ORDER BY ")
|
||||
if len(query_parts) == 2:
|
||||
query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING"
|
||||
result = self.session.execute(query_without_order, params)
|
||||
cassandra_order_by_added = False
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
raise
|
||||
|
||||
# Convert rows to dicts
|
||||
results = []
|
||||
for row in result:
|
||||
row_dict = {}
|
||||
for field in row_schema.fields:
|
||||
safe_field = self.sanitize_name(field.name)
|
||||
if hasattr(row, safe_field):
|
||||
value = getattr(row, safe_field)
|
||||
# Use original field name in result
|
||||
row_dict[field.name] = value
|
||||
results.append(row_dict)
|
||||
|
||||
# Post-query sorting if Cassandra didn't handle ORDER BY
|
||||
if order_by and direction and not cassandra_order_by_added:
|
||||
reverse_order = (direction.value == "desc")
|
||||
try:
|
||||
results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sort results by {order_by}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
async def execute_graphql_query(
|
||||
self,
|
||||
query: str,
|
||||
variables: Dict[str, Any],
|
||||
operation_name: Optional[str],
|
||||
user: str,
|
||||
collection: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a GraphQL query"""
|
||||
|
||||
if not self.graphql_schema:
|
||||
raise RuntimeError("No GraphQL schema available - no schemas loaded")
|
||||
|
||||
# Create context for the query
|
||||
context = {
|
||||
"processor": self,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
|
||||
# Execute the query
|
||||
result = await self.graphql_schema.execute(
|
||||
query,
|
||||
variable_values=variables,
|
||||
operation_name=operation_name,
|
||||
context_value=context
|
||||
)
|
||||
|
||||
# Build response
|
||||
response = {}
|
||||
|
||||
if result.data:
|
||||
response["data"] = result.data
|
||||
else:
|
||||
response["data"] = None
|
||||
|
||||
if result.errors:
|
||||
response["errors"] = [
|
||||
{
|
||||
"message": str(error),
|
||||
"path": getattr(error, "path", []),
|
||||
"extensions": getattr(error, "extensions", {})
|
||||
}
|
||||
for error in result.errors
|
||||
]
|
||||
else:
|
||||
response["errors"] = []
|
||||
|
||||
# Add extensions if any
|
||||
if hasattr(result, "extensions") and result.extensions:
|
||||
response["extensions"] = result.extensions
|
||||
|
||||
return response
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(f"Handling objects query request {id}...")
|
||||
|
||||
# Execute GraphQL query
|
||||
result = await self.execute_graphql_query(
|
||||
query=request.query,
|
||||
variables=dict(request.variables) if request.variables else {},
|
||||
operation_name=request.operation_name,
|
||||
user=request.user,
|
||||
collection=request.collection
|
||||
)
|
||||
|
||||
# Create response
|
||||
graphql_errors = []
|
||||
if "errors" in result and result["errors"]:
|
||||
for err in result["errors"]:
|
||||
graphql_error = GraphQLError(
|
||||
message=err.get("message", ""),
|
||||
path=err.get("path", []),
|
||||
extensions=err.get("extensions", {})
|
||||
)
|
||||
graphql_errors.append(graphql_error)
|
||||
|
||||
response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data=json.dumps(result.get("data")) if result.get("data") else "null",
|
||||
errors=graphql_errors,
|
||||
extensions=result.get("extensions", {})
|
||||
)
|
||||
|
||||
logger.debug("Sending objects query response...")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
logger.debug("Objects query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in objects query service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
response = ObjectsQueryResponse(
|
||||
error = Error(
|
||||
type = "objects-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
data = None,
|
||||
errors = [],
|
||||
extensions = {}
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
logger.info("Closed Cassandra connection")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
def run():
|
||||
"""Entry point for objects-query-graphql-cassandra command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Row embeddings query modules.
|
||||
"""
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Qdrant row embeddings query service.
|
||||
"""
|
||||
|
||||
from .service import Processor, run, default_ident
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
from .service import run
|
||||
|
||||
run()
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
"""
|
||||
Row embeddings query service for Qdrant.
|
||||
|
||||
Input is query vectors plus user/collection/schema context.
|
||||
Output is matching row index information (index_name, index_value) for
|
||||
use in subsequent Cassandra lookups.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||
|
||||
from .... schema import (
|
||||
RowEmbeddingsRequest, RowEmbeddingsResponse,
|
||||
RowIndexMatch, Error
|
||||
)
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "row-embeddings-query"
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
api_key = params.get("api_key", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"store_uri": store_uri,
|
||||
"api_key": api_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="request",
|
||||
schema=RowEmbeddingsRequest,
|
||||
handler=self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="response",
|
||||
schema=RowEmbeddingsResponse
|
||||
)
|
||||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Qdrant collection naming"""
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'r_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]:
|
||||
"""Find the Qdrant collection for a given user/collection/schema"""
|
||||
prefix = (
|
||||
f"rows_{self.sanitize_name(user)}_"
|
||||
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
|
||||
)
|
||||
|
||||
try:
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
matching = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
]
|
||||
|
||||
if matching:
|
||||
# Return first match (there should typically be only one per dimension)
|
||||
return matching[0]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list Qdrant collections: {e}", exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
|
||||
"""Execute row embeddings query"""
|
||||
|
||||
matches = []
|
||||
|
||||
# Find the collection for this user/collection/schema
|
||||
qdrant_collection = self.find_collection(
|
||||
request.user, request.collection, request.schema_name
|
||||
)
|
||||
|
||||
if not qdrant_collection:
|
||||
logger.info(
|
||||
f"No Qdrant collection found for "
|
||||
f"{request.user}/{request.collection}/{request.schema_name}"
|
||||
)
|
||||
return matches
|
||||
|
||||
for vec in request.vectors:
|
||||
try:
|
||||
# Build optional filter for index_name
|
||||
query_filter = None
|
||||
if request.index_name:
|
||||
query_filter = Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="index_name",
|
||||
match=MatchValue(value=request.index_name)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Query Qdrant
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=qdrant_collection,
|
||||
query=vec,
|
||||
limit=request.limit,
|
||||
with_payload=True,
|
||||
query_filter=query_filter,
|
||||
).points
|
||||
|
||||
# Convert to RowIndexMatch objects
|
||||
for point in search_result:
|
||||
payload = point.payload or {}
|
||||
match = RowIndexMatch(
|
||||
index_name=payload.get("index_name", ""),
|
||||
index_value=payload.get("index_value", []),
|
||||
text=payload.get("text", ""),
|
||||
score=point.score if hasattr(point, 'score') else 0.0
|
||||
)
|
||||
matches.append(match)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
return matches
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(
|
||||
f"Handling row embeddings query for "
|
||||
f"{request.user}/{request.collection}/{request.schema_name}..."
|
||||
)
|
||||
|
||||
# Execute query
|
||||
matches = await self.query_row_embeddings(request)
|
||||
|
||||
response = RowEmbeddingsResponse(
|
||||
error=None,
|
||||
matches=matches
|
||||
)
|
||||
|
||||
logger.debug(f"Returning {len(matches)} matches")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in row embeddings query: {e}", exc_info=True)
|
||||
|
||||
response = RowEmbeddingsResponse(
|
||||
error=Error(
|
||||
type="row-embeddings-query-error",
|
||||
message=str(e)
|
||||
),
|
||||
matches=[]
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
default=None,
|
||||
help='API key for Qdrant (default: None)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for row-embeddings-query-qdrant command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
523
trustgraph-flow/trustgraph/query/rows/cassandra/service.py
Normal file
523
trustgraph-flow/trustgraph/query/rows/cassandra/service.py
Normal file
|
|
@ -0,0 +1,523 @@
|
|||
"""
|
||||
Row query service using GraphQL. Input is a GraphQL query with variables.
|
||||
Output is GraphQL response data with any errors.
|
||||
|
||||
Queries against the unified 'rows' table with schema:
|
||||
- collection: text
|
||||
- schema_name: text
|
||||
- index_name: text
|
||||
- index_value: frozen<list<text>>
|
||||
- data: map<text, text>
|
||||
- source: text
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Any, Optional, List, Set
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||
from .... schema import Error, RowSchema, Field as SchemaField
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "rows-query"
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="request",
|
||||
schema=RowsQueryRequest,
|
||||
handler=self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="response",
|
||||
schema=RowsQueryResponse,
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for schema updates
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
# GraphQL schema builder and generated schema
|
||||
self.schema_builder = GraphQLSchemaBuilder()
|
||||
self.graphql_schema = None
|
||||
|
||||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
# Known keyspaces
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'r_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
self.schema_builder.clear()
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = SchemaField(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
self.schema_builder.add_schema(schema_name, row_schema)
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
# Regenerate GraphQL schema
|
||||
self.graphql_schema = self.schema_builder.build(self.query_cassandra)
|
||||
|
||||
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||
"""Get all index names for a schema."""
|
||||
index_names = []
|
||||
for field in schema.fields:
|
||||
if field.primary or field.indexed:
|
||||
index_names.append(field.name)
|
||||
return index_names
|
||||
|
||||
def find_matching_index(
|
||||
self,
|
||||
schema: RowSchema,
|
||||
filters: Dict[str, Any]
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
Find an index that can satisfy the query filters.
|
||||
Returns (index_name, index_value) if found, None otherwise.
|
||||
|
||||
For exact match queries, we need a filter on an indexed field.
|
||||
"""
|
||||
index_names = self.get_index_names(schema)
|
||||
|
||||
# Look for an exact match filter on an indexed field
|
||||
for index_name in index_names:
|
||||
if index_name in filters:
|
||||
value = filters[index_name]
|
||||
# Single field index -> single element list
|
||||
index_value = [str(value)]
|
||||
return (index_name, index_value)
|
||||
|
||||
return None
|
||||
|
||||
async def query_cassandra(
|
||||
self,
|
||||
user: str,
|
||||
collection: str,
|
||||
schema_name: str,
|
||||
row_schema: RowSchema,
|
||||
filters: Dict[str, Any],
|
||||
limit: int,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[SortDirection] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute a query against the unified Cassandra rows table.
|
||||
|
||||
For exact match queries on indexed fields, we can query directly.
|
||||
For other queries, we need to scan and post-filter.
|
||||
"""
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Try to find an index that matches the filters
|
||||
index_match = self.find_matching_index(row_schema, filters)
|
||||
|
||||
results = []
|
||||
|
||||
if index_match:
|
||||
# Direct query using index
|
||||
index_name, index_value = index_match
|
||||
|
||||
query = f"""
|
||||
SELECT data, source FROM {safe_keyspace}.rows
|
||||
WHERE collection = %s
|
||||
AND schema_name = %s
|
||||
AND index_name = %s
|
||||
AND index_value = %s
|
||||
"""
|
||||
params = [collection, schema_name, index_name, index_value]
|
||||
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
try:
|
||||
rows = self.session.execute(query, params)
|
||||
for row in rows:
|
||||
# Convert data map to dict with proper field names
|
||||
row_dict = dict(row.data) if row.data else {}
|
||||
results.append(row_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query rows: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
else:
|
||||
# No direct index match - scan all rows for this schema
|
||||
# This is less efficient but necessary for non-indexed queries
|
||||
logger.warning(
|
||||
f"No index match for filters {filters} - scanning all indexes"
|
||||
)
|
||||
|
||||
# Get all index names for this schema
|
||||
index_names = self.get_index_names(row_schema)
|
||||
|
||||
if not index_names:
|
||||
logger.warning(f"Schema {schema_name} has no indexes")
|
||||
return []
|
||||
|
||||
# Query using the first index (arbitrary choice for scan)
|
||||
primary_index = index_names[0]
|
||||
|
||||
# We need to scan all values for this index
|
||||
# This requires ALLOW FILTERING or a different approach
|
||||
query = f"""
|
||||
SELECT data, source FROM {safe_keyspace}.rows
|
||||
WHERE collection = %s
|
||||
AND schema_name = %s
|
||||
AND index_name = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
params = [collection, schema_name, primary_index]
|
||||
|
||||
try:
|
||||
rows = self.session.execute(query, params)
|
||||
|
||||
for row in rows:
|
||||
row_dict = dict(row.data) if row.data else {}
|
||||
|
||||
# Apply post-filters
|
||||
if self._matches_filters(row_dict, filters, row_schema):
|
||||
results.append(row_dict)
|
||||
|
||||
if limit and len(results) >= limit:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# Post-query sorting if requested
|
||||
if order_by and results:
|
||||
reverse_order = direction and direction.value == "desc"
|
||||
try:
|
||||
results.sort(
|
||||
key=lambda x: x.get(order_by, ""),
|
||||
reverse=reverse_order
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sort results by {order_by}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def _matches_filters(
|
||||
self,
|
||||
row_dict: Dict[str, Any],
|
||||
filters: Dict[str, Any],
|
||||
row_schema: RowSchema
|
||||
) -> bool:
|
||||
"""Check if a row matches the given filters."""
|
||||
for filter_key, filter_value in filters.items():
|
||||
if filter_value is None:
|
||||
continue
|
||||
|
||||
# Parse filter key for operator
|
||||
if '_' in filter_key:
|
||||
parts = filter_key.rsplit('_', 1)
|
||||
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
|
||||
field_name = parts[0]
|
||||
operator = parts[1]
|
||||
else:
|
||||
field_name = filter_key
|
||||
operator = 'eq'
|
||||
else:
|
||||
field_name = filter_key
|
||||
operator = 'eq'
|
||||
|
||||
row_value = row_dict.get(field_name)
|
||||
if row_value is None:
|
||||
return False
|
||||
|
||||
# Convert types for comparison
|
||||
try:
|
||||
if operator == 'eq':
|
||||
if str(row_value) != str(filter_value):
|
||||
return False
|
||||
elif operator == 'gt':
|
||||
if float(row_value) <= float(filter_value):
|
||||
return False
|
||||
elif operator == 'gte':
|
||||
if float(row_value) < float(filter_value):
|
||||
return False
|
||||
elif operator == 'lt':
|
||||
if float(row_value) >= float(filter_value):
|
||||
return False
|
||||
elif operator == 'lte':
|
||||
if float(row_value) > float(filter_value):
|
||||
return False
|
||||
elif operator == 'contains':
|
||||
if str(filter_value) not in str(row_value):
|
||||
return False
|
||||
elif operator == 'in':
|
||||
if str(row_value) not in [str(v) for v in filter_value]:
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def execute_graphql_query(
|
||||
self,
|
||||
query: str,
|
||||
variables: Dict[str, Any],
|
||||
operation_name: Optional[str],
|
||||
user: str,
|
||||
collection: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a GraphQL query"""
|
||||
|
||||
if not self.graphql_schema:
|
||||
raise RuntimeError("No GraphQL schema available - no schemas loaded")
|
||||
|
||||
# Create context for the query
|
||||
context = {
|
||||
"processor": self,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
|
||||
# Execute the query
|
||||
result = await self.graphql_schema.execute(
|
||||
query,
|
||||
variable_values=variables,
|
||||
operation_name=operation_name,
|
||||
context_value=context
|
||||
)
|
||||
|
||||
# Build response
|
||||
response = {}
|
||||
|
||||
if result.data:
|
||||
response["data"] = result.data
|
||||
else:
|
||||
response["data"] = None
|
||||
|
||||
if result.errors:
|
||||
response["errors"] = [
|
||||
{
|
||||
"message": str(error),
|
||||
"path": getattr(error, "path", []),
|
||||
"extensions": getattr(error, "extensions", {})
|
||||
}
|
||||
for error in result.errors
|
||||
]
|
||||
else:
|
||||
response["errors"] = []
|
||||
|
||||
# Add extensions if any
|
||||
if hasattr(result, "extensions") and result.extensions:
|
||||
response["extensions"] = result.extensions
|
||||
|
||||
return response
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(f"Handling objects query request {id}...")
|
||||
|
||||
# Execute GraphQL query
|
||||
result = await self.execute_graphql_query(
|
||||
query=request.query,
|
||||
variables=dict(request.variables) if request.variables else {},
|
||||
operation_name=request.operation_name,
|
||||
user=request.user,
|
||||
collection=request.collection
|
||||
)
|
||||
|
||||
# Create response
|
||||
graphql_errors = []
|
||||
if "errors" in result and result["errors"]:
|
||||
for err in result["errors"]:
|
||||
graphql_error = GraphQLError(
|
||||
message=err.get("message", ""),
|
||||
path=err.get("path", []),
|
||||
extensions=err.get("extensions", {})
|
||||
)
|
||||
graphql_errors.append(graphql_error)
|
||||
|
||||
response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=json.dumps(result.get("data")) if result.get("data") else "null",
|
||||
errors=graphql_errors,
|
||||
extensions=result.get("extensions", {})
|
||||
)
|
||||
|
||||
logger.debug("Sending objects query response...")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
logger.debug("Objects query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in rows query service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
response = RowsQueryResponse(
|
||||
error=Error(
|
||||
type="rows-query-error",
|
||||
message=str(e),
|
||||
),
|
||||
data=None,
|
||||
errors=[],
|
||||
extensions={}
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
logger.info("Closed Cassandra connection")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for rows-query-cassandra command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Structured Query Service - orchestrates natural language question processing.
|
||||
Takes a question, converts it to GraphQL via nlp-query, executes via objects-query,
|
||||
Takes a question, converts it to GraphQL via nlp-query, executes via rows-query,
|
||||
and returns the results.
|
||||
"""
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ from typing import Dict, Any, Optional
|
|||
|
||||
from ...schema import StructuredQueryRequest, StructuredQueryResponse
|
||||
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
||||
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
|
||||
from ...schema import RowsQueryRequest, RowsQueryResponse
|
||||
from ...schema import Error
|
||||
|
||||
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec
|
||||
|
|
@ -57,13 +57,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
# Client spec for calling objects query service
|
||||
# Client spec for calling rows query service
|
||||
self.register_specification(
|
||||
RequestResponseSpec(
|
||||
request_name = "objects-query-request",
|
||||
response_name = "objects-query-response",
|
||||
request_schema = ObjectsQueryRequest,
|
||||
response_schema = ObjectsQueryResponse
|
||||
request_name = "rows-query-request",
|
||||
response_name = "rows-query-response",
|
||||
request_schema = RowsQueryRequest,
|
||||
response_schema = RowsQueryResponse
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -112,7 +112,7 @@ class Processor(FlowProcessor):
|
|||
variables_as_strings[key] = str(value)
|
||||
|
||||
# Use user/collection values from request
|
||||
objects_request = ObjectsQueryRequest(
|
||||
objects_request = RowsQueryRequest(
|
||||
user=request.user,
|
||||
collection=request.collection,
|
||||
query=nlp_response.graphql_query,
|
||||
|
|
@ -120,12 +120,12 @@ class Processor(FlowProcessor):
|
|||
operation_name=None
|
||||
)
|
||||
|
||||
objects_response = await flow("objects-query-request").request(objects_request)
|
||||
|
||||
objects_response = await flow("rows-query-request").request(objects_request)
|
||||
|
||||
if objects_response.error is not None:
|
||||
raise Exception(f"Objects query service error: {objects_response.error.message}")
|
||||
|
||||
# Handle GraphQL errors from the objects query service
|
||||
raise Exception(f"Rows query service error: {objects_response.error.message}")
|
||||
|
||||
# Handle GraphQL errors from the rows query service
|
||||
graphql_errors = []
|
||||
if objects_response.errors:
|
||||
for gql_error in objects_response.errors:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-write"
|
||||
default_ident = "doc-embeddings-write"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-write"
|
||||
default_ident = "doc-embeddings-write"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
default_cloud = "aws"
|
||||
default_region = "us-east-1"
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-write"
|
||||
default_ident = "doc-embeddings-write"
|
||||
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def get_term_value(term):
|
|||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
default_ident = "ge-write"
|
||||
default_ident = "graph-embeddings-write"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def get_term_value(term):
|
|||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
default_ident = "ge-write"
|
||||
default_ident = "graph-embeddings-write"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
default_cloud = "aws"
|
||||
default_region = "us-east-1"
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def get_term_value(term):
|
|||
return term.id or term.value
|
||||
|
||||
|
||||
default_ident = "ge-write"
|
||||
default_ident = "graph-embeddings-write"
|
||||
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
# Objects storage module
|
||||
|
|
@ -1 +0,0 @@
|
|||
from . write import *
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from . write import run
|
||||
|
||||
run()
|
||||
|
|
@ -1,538 +0,0 @@
|
|||
"""
|
||||
Object writer for Cassandra. Input is ExtractedObject.
|
||||
Writes structured objects to Cassandra tables based on schema definitions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Set, Optional, Any
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from cassandra.cqlengine import connection
|
||||
from cassandra import ConsistencyLevel
|
||||
|
||||
from .... schema import ExtractedObject
|
||||
from .... schema import RowSchema, Field
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import CollectionConfigHandler
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "objects-write"
|
||||
|
||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = ExtractedObject,
|
||||
handler = self.on_object
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handlers
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
self.register_config_handler(self.on_collection_config)
|
||||
|
||||
# Cache of known keyspaces/tables
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
self.known_tables: Dict[str, Set[str]] = {} # keyspace -> set of tables
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = Field(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
def ensure_keyspace(self, keyspace: str):
|
||||
"""Ensure keyspace exists in Cassandra"""
|
||||
if keyspace in self.known_keyspaces:
|
||||
return
|
||||
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize keyspace name
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
|
||||
# Create keyspace if not exists
|
||||
create_keyspace_cql = f"""
|
||||
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
||||
WITH REPLICATION = {{
|
||||
'class': 'SimpleStrategy',
|
||||
'replication_factor': 1
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(create_keyspace_cql)
|
||||
self.known_keyspaces.add(keyspace)
|
||||
self.known_tables[keyspace] = set()
|
||||
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_cassandra_type(self, field_type: str, size: int = 0) -> str:
|
||||
"""Convert schema field type to Cassandra type"""
|
||||
# Handle None size
|
||||
if size is None:
|
||||
size = 0
|
||||
|
||||
type_mapping = {
|
||||
"string": "text",
|
||||
"integer": "bigint" if size > 4 else "int",
|
||||
"float": "double" if size > 4 else "float",
|
||||
"boolean": "boolean",
|
||||
"timestamp": "timestamp",
|
||||
"date": "date",
|
||||
"time": "time",
|
||||
"uuid": "uuid"
|
||||
}
|
||||
|
||||
return type_mapping.get(field_type, "text")
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
# Replace non-alphanumeric characters with underscore
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
# Ensure it starts with a letter
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def sanitize_table(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
# Replace non-alphanumeric characters with underscore
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
# Ensure it starts with a letter
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def ensure_table(self, keyspace: str, table_name: str, schema: RowSchema):
|
||||
"""Ensure table exists with proper structure"""
|
||||
table_key = f"{keyspace}.{table_name}"
|
||||
if table_key in self.known_tables.get(keyspace, set()):
|
||||
return
|
||||
|
||||
# Ensure keyspace exists first
|
||||
self.ensure_keyspace(keyspace)
|
||||
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
safe_table = self.sanitize_table(table_name)
|
||||
|
||||
# Build column definitions
|
||||
columns = ["collection text"] # Collection is always part of table
|
||||
primary_key_fields = []
|
||||
clustering_fields = []
|
||||
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
cassandra_type = self.get_cassandra_type(field.type, field.size)
|
||||
columns.append(f"{safe_field_name} {cassandra_type}")
|
||||
|
||||
if field.primary:
|
||||
primary_key_fields.append(safe_field_name)
|
||||
|
||||
# Build primary key - collection is always first in partition key
|
||||
if primary_key_fields:
|
||||
primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))"
|
||||
else:
|
||||
# If no primary key defined, use collection and a synthetic id
|
||||
columns.append("synthetic_id uuid")
|
||||
primary_key = "PRIMARY KEY ((collection, synthetic_id))"
|
||||
|
||||
# Create table
|
||||
create_table_cql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} (
|
||||
{', '.join(columns)},
|
||||
{primary_key}
|
||||
)
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(create_table_cql)
|
||||
if keyspace not in self.known_tables:
|
||||
self.known_tables[keyspace] = set()
|
||||
self.known_tables[keyspace].add(table_key)
|
||||
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
|
||||
|
||||
# Create secondary indexes for indexed fields
|
||||
for field in schema.fields:
|
||||
if field.indexed and not field.primary:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
index_name = f"{safe_table}_{safe_field_name}_idx"
|
||||
create_index_cql = f"""
|
||||
CREATE INDEX IF NOT EXISTS {index_name}
|
||||
ON {safe_keyspace}.{safe_table} ({safe_field_name})
|
||||
"""
|
||||
try:
|
||||
self.session.execute(create_index_cql)
|
||||
logger.info(f"Created index: {index_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create index {index_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create table {safe_keyspace}.{safe_table}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def convert_value(self, value: Any, field_type: str) -> Any:
|
||||
"""Convert value to appropriate type for Cassandra"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if field_type == "integer":
|
||||
return int(value)
|
||||
elif field_type == "float":
|
||||
return float(value)
|
||||
elif field_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ('true', '1', 'yes')
|
||||
return bool(value)
|
||||
elif field_type == "timestamp":
|
||||
# Handle timestamp conversion if needed
|
||||
return value
|
||||
else:
|
||||
return str(value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
|
||||
return str(value)
|
||||
async def on_object(self, msg, consumer, flow):
|
||||
"""Process incoming ExtractedObject and store in Cassandra"""
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {obj.metadata.collection} does not exist. "
|
||||
f"Create it first via collection management API."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||
return
|
||||
|
||||
# Ensure table exists
|
||||
keyspace = obj.metadata.user
|
||||
table_name = obj.schema_name
|
||||
self.ensure_table(keyspace, table_name, schema)
|
||||
|
||||
# Prepare data for insertion
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
safe_table = self.sanitize_table(table_name)
|
||||
|
||||
# Process each object in the batch
|
||||
for obj_index, value_map in enumerate(obj.values):
|
||||
# Build column names and values for this object
|
||||
columns = ["collection"]
|
||||
values = [obj.metadata.collection]
|
||||
placeholders = ["%s"]
|
||||
|
||||
# Check if we need a synthetic ID
|
||||
has_primary_key = any(field.primary for field in schema.fields)
|
||||
if not has_primary_key:
|
||||
import uuid
|
||||
columns.append("synthetic_id")
|
||||
values.append(uuid.uuid4())
|
||||
placeholders.append("%s")
|
||||
|
||||
# Process fields for this object
|
||||
skip_object = False
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
raw_value = value_map.get(field.name)
|
||||
|
||||
# Handle required fields
|
||||
if field.required and raw_value is None:
|
||||
logger.warning(f"Required field {field.name} is missing in object {obj_index}")
|
||||
# Continue anyway - Cassandra doesn't enforce NOT NULL
|
||||
|
||||
# Check if primary key field is NULL
|
||||
if field.primary and raw_value is None:
|
||||
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
|
||||
skip_object = True
|
||||
break
|
||||
|
||||
# Convert value to appropriate type
|
||||
converted_value = self.convert_value(raw_value, field.type)
|
||||
|
||||
columns.append(safe_field_name)
|
||||
values.append(converted_value)
|
||||
placeholders.append("%s")
|
||||
|
||||
# Skip this object if primary key validation failed
|
||||
if skip_object:
|
||||
continue
|
||||
|
||||
# Build and execute insert query for this object
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
|
||||
VALUES ({', '.join(placeholders)})
|
||||
"""
|
||||
|
||||
# Debug: Show data being inserted
|
||||
logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
|
||||
|
||||
if len(columns) != len(values) or len(columns) != len(placeholders):
|
||||
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
|
||||
|
||||
try:
|
||||
# Convert to tuple - Cassandra driver requires tuple for parameters
|
||||
self.session.execute(insert_cql, tuple(values))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
"""Create/verify collection exists in Cassandra object store"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize names for safety
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Ensure keyspace exists
|
||||
if safe_keyspace not in self.known_keyspaces:
|
||||
self.ensure_keyspace(safe_keyspace)
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
|
||||
# For Cassandra objects, collection is just a property in rows
|
||||
# No need to create separate tables per collection
|
||||
# Just mark that we've seen this collection
|
||||
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Delete all data for a specific collection using schema information"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize names for safety
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Check if keyspace exists
|
||||
if safe_keyspace not in self.known_keyspaces:
|
||||
# Query to verify keyspace exists
|
||||
check_keyspace_cql = """
|
||||
SELECT keyspace_name FROM system_schema.keyspaces
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
|
||||
if not result.one():
|
||||
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
|
||||
return
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
|
||||
# Iterate over schemas we manage to delete from relevant tables
|
||||
tables_deleted = 0
|
||||
|
||||
for schema_name, schema in self.schemas.items():
|
||||
safe_table = self.sanitize_table(schema_name)
|
||||
|
||||
# Check if table exists
|
||||
table_key = f"{user}.{schema_name}"
|
||||
if table_key not in self.known_tables.get(user, set()):
|
||||
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get primary key fields from schema
|
||||
primary_key_fields = [field for field in schema.fields if field.primary]
|
||||
|
||||
if primary_key_fields:
|
||||
# Schema has primary keys: need to query for partition keys first
|
||||
# Build SELECT query for primary key fields
|
||||
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
|
||||
select_cql = f"""
|
||||
SELECT {', '.join(pk_field_names)}
|
||||
FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
|
||||
rows = self.session.execute(select_cql, (collection,))
|
||||
|
||||
# Delete each row using full partition key
|
||||
for row in rows:
|
||||
where_clauses = ["collection = %s"]
|
||||
values = [collection]
|
||||
|
||||
for field_name in pk_field_names:
|
||||
where_clauses.append(f"{field_name} = %s")
|
||||
values.append(getattr(row, field_name))
|
||||
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{safe_table}
|
||||
WHERE {' AND '.join(where_clauses)}
|
||||
"""
|
||||
|
||||
self.session.execute(delete_cql, tuple(values))
|
||||
else:
|
||||
# No primary keys, uses synthetic_id
|
||||
# Need to query for synthetic_ids first
|
||||
select_cql = f"""
|
||||
SELECT synthetic_id
|
||||
FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
|
||||
rows = self.session.execute(select_cql, (collection,))
|
||||
|
||||
# Delete each row using collection and synthetic_id
|
||||
for row in rows:
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s AND synthetic_id = %s
|
||||
"""
|
||||
self.session.execute(delete_cql, (collection, row.synthetic_id))
|
||||
|
||||
tables_deleted += 1
|
||||
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
logger.info("Closed Cassandra connection")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
def run():
|
||||
"""Entry point for objects-write-cassandra command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Row embeddings storage modules.
|
||||
"""
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Qdrant storage for row embeddings.
|
||||
"""
|
||||
|
||||
from .write import Processor, run, default_ident
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
from .write import run
|
||||
|
||||
run()
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
"""
|
||||
Row embeddings writer for Qdrant (Stage 2).
|
||||
|
||||
Consumes RowEmbeddings messages (which already contain computed vectors)
|
||||
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair.
|
||||
|
||||
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
|
||||
Stage 1 (row-embeddings): Compute embeddings
|
||||
Stage 2 (this processor): Store embeddings
|
||||
|
||||
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension}
|
||||
|
||||
Payload structure:
|
||||
- index_name: The indexed field(s) this embedding represents
|
||||
- index_value: The original list of values (for Cassandra lookup)
|
||||
- text: The text that was embedded (for debugging/display)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Set, Tuple
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct, Distance, VectorParams
|
||||
|
||||
from .... schema import RowEmbeddings
|
||||
from .... base import FlowProcessor, ConsumerSpec
|
||||
from .... base import CollectionConfigHandler
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "row-embeddings-write"
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
api_key = params.get("api_key", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"store_uri": store_uri,
|
||||
"api_key": api_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="input",
|
||||
schema=RowEmbeddings,
|
||||
handler=self.on_embeddings
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for collection management
|
||||
self.register_config_handler(self.on_collection_config)
|
||||
|
||||
# Cache of created Qdrant collections
|
||||
self.created_collections: Set[str] = set()
|
||||
|
||||
# Qdrant client
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Qdrant collection naming"""
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'r_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def get_collection_name(
|
||||
self, user: str, collection: str, schema_name: str, dimension: int
|
||||
) -> str:
|
||||
"""Generate Qdrant collection name"""
|
||||
safe_user = self.sanitize_name(user)
|
||||
safe_collection = self.sanitize_name(collection)
|
||||
safe_schema = self.sanitize_name(schema_name)
|
||||
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
|
||||
|
||||
def ensure_collection(self, collection_name: str, dimension: int):
|
||||
"""Create Qdrant collection if it doesn't exist"""
|
||||
if collection_name in self.created_collections:
|
||||
return
|
||||
|
||||
if not self.qdrant.collection_exists(collection_name):
|
||||
logger.info(
|
||||
f"Creating Qdrant collection {collection_name} "
|
||||
f"with dimension {dimension}"
|
||||
)
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=dimension,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
|
||||
self.created_collections.add(collection_name)
|
||||
|
||||
async def on_embeddings(self, msg, consumer, flow):
|
||||
"""Process incoming RowEmbeddings and write to Qdrant"""
|
||||
|
||||
embeddings = msg.value()
|
||||
logger.info(
|
||||
f"Writing {len(embeddings.embeddings)} embeddings for schema "
|
||||
f"{embeddings.schema_name} from {embeddings.metadata.id}"
|
||||
)
|
||||
|
||||
# Validate collection exists in config before processing
|
||||
if not self.collection_exists(
|
||||
embeddings.metadata.user, embeddings.metadata.collection
|
||||
):
|
||||
logger.warning(
|
||||
f"Collection {embeddings.metadata.collection} for user "
|
||||
f"{embeddings.metadata.user} does not exist in config. "
|
||||
f"Dropping message."
|
||||
)
|
||||
return
|
||||
|
||||
user = embeddings.metadata.user
|
||||
collection = embeddings.metadata.collection
|
||||
schema_name = embeddings.schema_name
|
||||
|
||||
embeddings_written = 0
|
||||
qdrant_collection = None
|
||||
|
||||
for row_emb in embeddings.embeddings:
|
||||
if not row_emb.vectors:
|
||||
logger.warning(
|
||||
f"No vectors for index {row_emb.index_name} - skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Use first vector (there may be multiple from different models)
|
||||
for vector in row_emb.vectors:
|
||||
dimension = len(vector)
|
||||
|
||||
# Create/get collection name (lazily on first vector)
|
||||
if qdrant_collection is None:
|
||||
qdrant_collection = self.get_collection_name(
|
||||
user, collection, schema_name, dimension
|
||||
)
|
||||
self.ensure_collection(qdrant_collection, dimension)
|
||||
|
||||
# Write to Qdrant
|
||||
self.qdrant.upsert(
|
||||
collection_name=qdrant_collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vector,
|
||||
payload={
|
||||
"index_name": row_emb.index_name,
|
||||
"index_value": row_emb.index_value,
|
||||
"text": row_emb.text
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
embeddings_written += 1
|
||||
|
||||
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
"""Collection creation via config push - collections created lazily on first write"""
|
||||
logger.info(
|
||||
f"Row embeddings collection create request for {user}/{collection} - "
|
||||
f"will be created lazily on first write"
|
||||
)
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Delete all Qdrant collections for a given user/collection"""
|
||||
try:
|
||||
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_"
|
||||
|
||||
# Get all collections and filter for matches
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
matching_collections = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
]
|
||||
|
||||
if not matching_collections:
|
||||
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||
else:
|
||||
for collection_name in matching_collections:
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
self.created_collections.discard(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
logger.info(
|
||||
f"Deleted {len(matching_collections)} collection(s) "
|
||||
f"for {user}/{collection}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete collection {user}/{collection}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def delete_collection_schema(
|
||||
self, user: str, collection: str, schema_name: str
|
||||
):
|
||||
"""Delete Qdrant collection for a specific user/collection/schema"""
|
||||
try:
|
||||
prefix = (
|
||||
f"rows_{self.sanitize_name(user)}_"
|
||||
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
|
||||
)
|
||||
|
||||
# Get all collections and filter for matches
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
matching_collections = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
]
|
||||
|
||||
if not matching_collections:
|
||||
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||
else:
|
||||
for collection_name in matching_collections:
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
self.created_collections.discard(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Qdrant URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
default=None,
|
||||
help='Qdrant API key (default: None)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for row-embeddings-write-qdrant command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
@ -1,46 +1,49 @@
|
|||
|
||||
"""
|
||||
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
|
||||
Row writer for Cassandra. Input is ExtractedObject.
|
||||
Writes structured rows to a unified Cassandra table with multi-index support.
|
||||
|
||||
Uses a single 'rows' table with the schema:
|
||||
- collection: text
|
||||
- schema_name: text
|
||||
- index_name: text
|
||||
- index_value: frozen<list<text>>
|
||||
- data: map<text, text>
|
||||
- source: text
|
||||
|
||||
Each row is written multiple times - once per indexed field defined in the schema.
|
||||
"""
|
||||
|
||||
raise RuntimeError("This code is no longer in use")
|
||||
|
||||
import pulsar
|
||||
import base64
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Set, Optional, Any, List, Tuple
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
||||
|
||||
from .... schema import Rows
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
from .... schema import ExtractedObject
|
||||
from .... schema import RowSchema, Field
|
||||
from .... base import FlowProcessor, ConsumerSpec
|
||||
from .... base import CollectionConfigHandler
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
module = "rows-write"
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
default_ident = "rows-write"
|
||||
|
||||
default_input_queue = "rows-store" # Default queue name
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(Consumer):
|
||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
|
|
@ -48,99 +51,549 @@ class Processor(Consumer):
|
|||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Rows,
|
||||
"cassandra_host": ','.join(hosts),
|
||||
"cassandra_username": username,
|
||||
"cassandra_password": password,
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
if username and password:
|
||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
||||
else:
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
self.tables = set()
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="input",
|
||||
schema=ExtractedObject,
|
||||
handler=self.on_object
|
||||
)
|
||||
)
|
||||
|
||||
self.session.execute("""
|
||||
create keyspace if not exists trustgraph
|
||||
with replication = {
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
};
|
||||
""");
|
||||
# Register config handlers
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
self.register_config_handler(self.on_collection_config)
|
||||
|
||||
self.session.execute("use trustgraph");
|
||||
# Cache of known keyspaces and whether tables exist
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
|
||||
|
||||
async def handle(self, msg):
|
||||
# Cache of registered (collection, schema_name) pairs
|
||||
self.registered_partitions: Set[Tuple[str, str]] = set()
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
name = v.row_schema.name
|
||||
|
||||
if name not in self.tables:
|
||||
|
||||
# FIXME: SQL injection?
|
||||
|
||||
pkey = []
|
||||
|
||||
stmt = "create table if not exists " + name + " ( "
|
||||
|
||||
for field in v.row_schema.fields:
|
||||
|
||||
stmt += field.name + " text, "
|
||||
|
||||
if field.primary:
|
||||
pkey.append(field.name)
|
||||
|
||||
stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));"
|
||||
|
||||
self.session.execute(stmt)
|
||||
|
||||
self.tables.add(name);
|
||||
|
||||
for row in v.rows:
|
||||
|
||||
field_names = []
|
||||
values = []
|
||||
|
||||
for field in v.row_schema.fields:
|
||||
field_names.append(field.name)
|
||||
values.append(row[field.name])
|
||||
|
||||
# FIXME: SQL injection?
|
||||
stmt = (
|
||||
"insert into " + name + " (" + ", ".join(field_names) +
|
||||
") values (" + ",".join(["%s"] * len(values)) + ")"
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session.execute(stmt, values)
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# If there's an error make sure to do table creation etc.
|
||||
self.tables.remove(name)
|
||||
# Track which schemas changed so we can clear partition cache
|
||||
old_schema_names = set(self.schemas.keys())
|
||||
|
||||
raise e
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = Field(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
# Clear partition cache for schemas that changed
|
||||
# This ensures next write will re-register partitions
|
||||
new_schema_names = set(self.schemas.keys())
|
||||
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
|
||||
if changed_schemas:
|
||||
self.registered_partitions = {
|
||||
(col, sch) for col, sch in self.registered_partitions
|
||||
if sch not in changed_schemas
|
||||
}
|
||||
logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}")
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
# Ensure it starts with a letter
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'r_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def ensure_keyspace(self, keyspace: str):
|
||||
"""Ensure keyspace exists in Cassandra"""
|
||||
if keyspace in self.known_keyspaces:
|
||||
return
|
||||
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize keyspace name
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
|
||||
# Create keyspace if not exists
|
||||
create_keyspace_cql = f"""
|
||||
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
||||
WITH REPLICATION = {{
|
||||
'class': 'SimpleStrategy',
|
||||
'replication_factor': 1
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(create_keyspace_cql)
|
||||
self.known_keyspaces.add(keyspace)
|
||||
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def ensure_tables(self, keyspace: str):
|
||||
"""Ensure unified rows and row_partitions tables exist"""
|
||||
if keyspace in self.tables_initialized:
|
||||
return
|
||||
|
||||
# Ensure keyspace exists first
|
||||
self.ensure_keyspace(keyspace)
|
||||
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
|
||||
# Create unified rows table
|
||||
create_rows_cql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows (
|
||||
collection text,
|
||||
schema_name text,
|
||||
index_name text,
|
||||
index_value frozen<list<text>>,
|
||||
data map<text, text>,
|
||||
source text,
|
||||
PRIMARY KEY ((collection, schema_name, index_name), index_value)
|
||||
)
|
||||
"""
|
||||
|
||||
# Create row_partitions tracking table
|
||||
create_partitions_cql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions (
|
||||
collection text,
|
||||
schema_name text,
|
||||
index_name text,
|
||||
PRIMARY KEY ((collection), schema_name, index_name)
|
||||
)
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(create_rows_cql)
|
||||
logger.info(f"Ensured rows table exists: {safe_keyspace}.rows")
|
||||
|
||||
self.session.execute(create_partitions_cql)
|
||||
logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions")
|
||||
|
||||
self.tables_initialized.add(keyspace)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||
"""
|
||||
Get all index names for a schema.
|
||||
Returns list of index_name strings (single field names or comma-joined composites).
|
||||
"""
|
||||
index_names = []
|
||||
|
||||
for field in schema.fields:
|
||||
# Primary key fields are treated as indexes
|
||||
if field.primary:
|
||||
index_names.append(field.name)
|
||||
# Indexed fields
|
||||
elif field.indexed:
|
||||
index_names.append(field.name)
|
||||
|
||||
# TODO: Support composite indexes in the future
|
||||
# For now, each indexed field is a single-field index
|
||||
|
||||
return index_names
|
||||
|
||||
def register_partitions(self, keyspace: str, collection: str, schema_name: str):
|
||||
"""
|
||||
Register partition entries for a (collection, schema_name) pair.
|
||||
Called once on first row for each pair.
|
||||
"""
|
||||
cache_key = (collection, schema_name)
|
||||
if cache_key in self.registered_partitions:
|
||||
return
|
||||
|
||||
schema = self.schemas.get(schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"Cannot register partitions - schema {schema_name} not found")
|
||||
return
|
||||
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
index_names = self.get_index_names(schema)
|
||||
|
||||
# Insert partition entries for each index
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
self.session.execute(insert_cql, (collection, schema_name, index_name))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}")
|
||||
|
||||
self.registered_partitions.add(cache_key)
|
||||
logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}")
|
||||
|
||||
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
|
||||
"""
|
||||
Build the index_value list for a given index.
|
||||
For single-field indexes, returns a single-element list.
|
||||
For composite indexes (comma-separated), returns multiple elements.
|
||||
"""
|
||||
field_names = [f.strip() for f in index_name.split(',')]
|
||||
values = []
|
||||
|
||||
for field_name in field_names:
|
||||
value = value_map.get(field_name)
|
||||
# Convert to string for storage
|
||||
values.append(str(value) if value is not None else "")
|
||||
|
||||
return values
|
||||
|
||||
async def on_object(self, msg, consumer, flow):
|
||||
"""Process incoming ExtractedObject and store in Cassandra"""
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(
|
||||
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
|
||||
f"from {obj.metadata.id}"
|
||||
)
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {obj.metadata.collection} does not exist. "
|
||||
f"Create it first via collection management API."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||
return
|
||||
|
||||
keyspace = obj.metadata.user
|
||||
collection = obj.metadata.collection
|
||||
schema_name = obj.schema_name
|
||||
source = getattr(obj.metadata, 'source', '') or ''
|
||||
|
||||
# Ensure tables exist
|
||||
self.ensure_tables(keyspace)
|
||||
|
||||
# Register partitions if first time seeing this (collection, schema_name)
|
||||
self.register_partitions(keyspace, collection, schema_name)
|
||||
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
|
||||
# Get all index names for this schema
|
||||
index_names = self.get_index_names(schema)
|
||||
|
||||
if not index_names:
|
||||
logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable")
|
||||
return
|
||||
|
||||
# Prepare insert statement
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.rows
|
||||
(collection, schema_name, index_name, index_value, data, source)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
|
||||
# Process each row in the batch
|
||||
rows_written = 0
|
||||
for row_index, value_map in enumerate(obj.values):
|
||||
# Convert all values to strings for the data map
|
||||
data_map = {}
|
||||
for field in schema.fields:
|
||||
raw_value = value_map.get(field.name)
|
||||
if raw_value is not None:
|
||||
data_map[field.name] = str(raw_value)
|
||||
|
||||
# Write one copy per index
|
||||
for index_name in index_names:
|
||||
index_value = self.build_index_value(value_map, index_name)
|
||||
|
||||
# Skip if index value is empty/null
|
||||
if not index_value or all(v == "" for v in index_value):
|
||||
logger.debug(
|
||||
f"Skipping index {index_name} for row {row_index} - "
|
||||
f"empty index value"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
self.session.execute(
|
||||
insert_cql,
|
||||
(collection, schema_name, index_name, index_value, data_map, source)
|
||||
)
|
||||
rows_written += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to insert row {row_index} for index {index_name}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Wrote {rows_written} index entries for {len(obj.values)} rows "
|
||||
f"({len(index_names)} indexes per row)"
|
||||
)
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
"""Create/verify collection exists in Cassandra row store"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
# Ensure tables exist
|
||||
self.ensure_tables(user)
|
||||
|
||||
logger.info(f"Collection {collection} ready for user {user}")
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Delete all data for a specific collection using partition tracking"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Check if keyspace exists
|
||||
if user not in self.known_keyspaces:
|
||||
check_keyspace_cql = """
|
||||
SELECT keyspace_name FROM system_schema.keyspaces
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
|
||||
if not result.one():
|
||||
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
|
||||
return
|
||||
self.known_keyspaces.add(user)
|
||||
|
||||
# Discover all partitions for this collection
|
||||
select_partitions_cql = f"""
|
||||
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
|
||||
WHERE collection = %s
|
||||
"""
|
||||
|
||||
try:
|
||||
partitions = self.session.execute(select_partitions_cql, (collection,))
|
||||
partition_list = list(partitions)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query partitions for collection {collection}: {e}")
|
||||
raise
|
||||
|
||||
# Delete each partition from rows table
|
||||
delete_rows_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.rows
|
||||
WHERE collection = %s AND schema_name = %s AND index_name = %s
|
||||
"""
|
||||
|
||||
partitions_deleted = 0
|
||||
for partition in partition_list:
|
||||
try:
|
||||
self.session.execute(
|
||||
delete_rows_cql,
|
||||
(collection, partition.schema_name, partition.index_name)
|
||||
)
|
||||
partitions_deleted += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete partition {collection}/{partition.schema_name}/"
|
||||
f"{partition.index_name}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Clean up row_partitions entries
|
||||
delete_partitions_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.row_partitions
|
||||
WHERE collection = %s
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(delete_partitions_cql, (collection,))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
|
||||
raise
|
||||
|
||||
# Clear from local cache
|
||||
self.registered_partitions = {
|
||||
(col, sch) for col, sch in self.registered_partitions
|
||||
if col != collection
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Deleted collection {collection}: {partitions_deleted} partitions "
|
||||
f"from keyspace {safe_keyspace}"
|
||||
)
|
||||
|
||||
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
|
||||
"""Delete all data for a specific collection + schema combination"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Discover partitions for this collection + schema
|
||||
select_partitions_cql = f"""
|
||||
SELECT index_name FROM {safe_keyspace}.row_partitions
|
||||
WHERE collection = %s AND schema_name = %s
|
||||
"""
|
||||
|
||||
try:
|
||||
partitions = self.session.execute(select_partitions_cql, (collection, schema_name))
|
||||
partition_list = list(partitions)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to query partitions for {collection}/{schema_name}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Delete each partition from rows table
|
||||
delete_rows_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.rows
|
||||
WHERE collection = %s AND schema_name = %s AND index_name = %s
|
||||
"""
|
||||
|
||||
partitions_deleted = 0
|
||||
for partition in partition_list:
|
||||
try:
|
||||
self.session.execute(
|
||||
delete_rows_cql,
|
||||
(collection, schema_name, partition.index_name)
|
||||
)
|
||||
partitions_deleted += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete partition {collection}/{schema_name}/"
|
||||
f"{partition.index_name}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Clean up row_partitions entries for this schema
|
||||
delete_partitions_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.row_partitions
|
||||
WHERE collection = %s AND schema_name = %s
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(delete_partitions_cql, (collection, schema_name))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Clear from local cache
|
||||
self.registered_partitions.discard((collection, schema_name))
|
||||
|
||||
logger.info(
|
||||
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
|
||||
f"from keyspace {safe_keyspace}"
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
logger.info("Closed Cassandra connection")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
"""Entry point for rows-write-cassandra command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue