Better proc group logging and concurrency (#810)

- Silence pika, cassandra etc. logging at INFO (too much chatter) 
- Add per processor log tags so that logs can be understood in
  processor group.
- Deal with RabbitMQ lag weirdness
- Added more processor group examples
This commit is contained in:
cybermaggedon 2026-04-15 14:52:01 +01:00 committed by GitHub
parent ce3c8b421b
commit 2bf4af294e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1021 additions and 647 deletions

View file

@ -0,0 +1,47 @@
# Control plane. Stateful "always on" services that every flow depends on.
# Cassandra-heavy, low traffic.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.config.service.Processor
params:
<<: *defaults
id: config-svc
cassandra_host: localhost
- class: trustgraph.librarian.Processor
params:
<<: *defaults
id: librarian
cassandra_host: localhost
object_store_endpoint: localhost:3900
object_store_access_key: GK000000000000000000000001
object_store_secret_key: b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
object_store_region: garage
- class: trustgraph.cores.service.Processor
params:
<<: *defaults
id: knowledge
cassandra_host: localhost
- class: trustgraph.storage.knowledge.store.Processor
params:
<<: *defaults
id: kg-store
cassandra_host: localhost
- class: trustgraph.metering.Processor
params:
<<: *defaults
id: metering
- class: trustgraph.metering.Processor
params:
<<: *defaults
id: metering-rag

View file

@ -0,0 +1,45 @@
# Embeddings store. All Qdrant-backed vector query/write processors.
# One process owns the Qdrant driver pool for the whole group.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.query.doc_embeddings.qdrant.Processor
params:
<<: *defaults
id: doc-embeddings-query
store_uri: http://localhost:6333
- class: trustgraph.storage.doc_embeddings.qdrant.Processor
params:
<<: *defaults
id: doc-embeddings-write
store_uri: http://localhost:6333
- class: trustgraph.query.graph_embeddings.qdrant.Processor
params:
<<: *defaults
id: graph-embeddings-query
store_uri: http://localhost:6333
- class: trustgraph.storage.graph_embeddings.qdrant.Processor
params:
<<: *defaults
id: graph-embeddings-write
store_uri: http://localhost:6333
- class: trustgraph.query.row_embeddings.qdrant.Processor
params:
<<: *defaults
id: row-embeddings-query
store_uri: http://localhost:6333
- class: trustgraph.storage.row_embeddings.qdrant.Processor
params:
<<: *defaults
id: row-embeddings-write
store_uri: http://localhost:6333

View file

@ -0,0 +1,31 @@
# Embeddings. Memory-hungry — fastembed loads an ML model at startup.
# Keep isolated from other groups so its memory footprint and restart
# latency don't affect siblings.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.embeddings.fastembed.Processor
params:
<<: *defaults
id: embeddings
concurrency: 1
- class: trustgraph.embeddings.document_embeddings.Processor
params:
<<: *defaults
id: document-embeddings
- class: trustgraph.embeddings.graph_embeddings.Processor
params:
<<: *defaults
id: graph-embeddings
- class: trustgraph.embeddings.row_embeddings.Processor
params:
<<: *defaults
id: row-embeddings

View file

@ -0,0 +1,52 @@
# Ingest pipeline. Document-processing hot path. Bursty, correlated
# failures — if chunker dies the extractors have nothing to do anyway.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.chunking.recursive.Processor
params:
<<: *defaults
id: chunker
chunk_size: 2000
chunk_overlap: 50
- class: trustgraph.extract.kg.agent.Processor
params:
<<: *defaults
id: kg-extract-agent
concurrency: 1
- class: trustgraph.extract.kg.definitions.Processor
params:
<<: *defaults
id: kg-extract-definitions
concurrency: 1
- class: trustgraph.extract.kg.ontology.Processor
params:
<<: *defaults
id: kg-extract-ontology
concurrency: 1
- class: trustgraph.extract.kg.relationships.Processor
params:
<<: *defaults
id: kg-extract-relationships
concurrency: 1
- class: trustgraph.extract.kg.rows.Processor
params:
<<: *defaults
id: kg-extract-rows
concurrency: 1
- class: trustgraph.prompt.template.Processor
params:
<<: *defaults
id: prompt
concurrency: 1

View file

@ -0,0 +1,24 @@
# LLM. Outbound text-completion calls. Isolated because the upstream
# LLM API is often the bottleneck and the most likely thing to need
# restart (provider changes, model changes, API flakiness).
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.model.text_completion.openai.Processor
params:
<<: *defaults
id: text-completion
max_output: 8192
temperature: 0.0
- class: trustgraph.model.text_completion.openai.Processor
params:
<<: *defaults
id: text-completion-rag
max_output: 8192
temperature: 0.0

View file

@ -0,0 +1,64 @@
# RAG / retrieval / agent. Query-time serving path. Drives outbound
# LLM calls via prompt-rag. sparql-query lives here because it's a
# read-side serving endpoint, not a backend writer.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.agent.orchestrator.Processor
params:
<<: *defaults
id: agent-manager
- class: trustgraph.retrieval.graph_rag.Processor
params:
<<: *defaults
id: graph-rag
concurrency: 1
entity_limit: 50
triple_limit: 30
edge_limit: 30
edge_score_limit: 10
max_subgraph_size: 100
max_path_length: 2
- class: trustgraph.retrieval.document_rag.Processor
params:
<<: *defaults
id: document-rag
doc_limit: 20
- class: trustgraph.retrieval.nlp_query.Processor
params:
<<: *defaults
id: nlp-query
- class: trustgraph.retrieval.structured_query.Processor
params:
<<: *defaults
id: structured-query
- class: trustgraph.retrieval.structured_diag.Processor
params:
<<: *defaults
id: structured-diag
- class: trustgraph.query.sparql.Processor
params:
<<: *defaults
id: sparql-query
- class: trustgraph.prompt.template.Processor
params:
<<: *defaults
id: prompt-rag
concurrency: 1
- class: trustgraph.agent.mcp_tool.Service
params:
<<: *defaults
id: mcp-tool

View file

@ -0,0 +1,20 @@
# Rows store. Cassandra-backed structured row query/write.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.query.rows.cassandra.Processor
params:
<<: *defaults
id: rows-query
cassandra_host: localhost
- class: trustgraph.storage.rows.cassandra.Processor
params:
<<: *defaults
id: rows-write
cassandra_host: localhost

View file

@ -0,0 +1,20 @@
# Triples store. Cassandra-backed RDF triple query/write.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.query.triples.cassandra.Processor
params:
<<: *defaults
id: triples-query
cassandra_host: localhost
- class: trustgraph.storage.triples.cassandra.Processor
params:
<<: *defaults
id: triples-write
cassandra_host: localhost

View file

@ -8,12 +8,51 @@ ensuring consistent log formats, levels, and command-line arguments.
Supports dual output to console and Loki for centralized log aggregation. Supports dual output to console and Loki for centralized log aggregation.
""" """
import contextvars
import logging import logging
import logging.handlers import logging.handlers
from queue import Queue from queue import Queue
import os import os
# The current processor id for this task context. Read by
# _ProcessorIdFilter to stamp every LogRecord with its owning
# processor, and read by logging_loki's emitter via record.tags
# to label log lines in Loki. ContextVar so asyncio subtasks
# inherit their parent supervisor's processor id automatically.
current_processor_id = contextvars.ContextVar(
"current_processor_id", default="unknown"
)
def set_processor_id(pid):
"""Set the processor id for the current task context.
All subsequent log records emitted from this task and any
asyncio tasks spawned from it will be tagged with this id
in the console format and in Loki labels.
"""
current_processor_id.set(pid)
class _ProcessorIdFilter(logging.Filter):
"""Stamps every LogRecord with processor_id from the contextvar.
Attaches two fields to each record:
record.processor_id used by the console format string
record.tags merged into Loki labels by logging_loki's
emitter (it reads record.tags and combines
with the handler's static tags)
"""
def filter(self, record):
pid = current_processor_id.get()
record.processor_id = pid
existing = getattr(record, "tags", None) or {}
record.tags = {**existing, "processor": pid}
return True
def add_logging_args(parser): def add_logging_args(parser):
""" """
Add standard logging arguments to an argument parser. Add standard logging arguments to an argument parser.
@ -87,12 +126,15 @@ def setup_logging(args):
loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push') loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push')
loki_username = args.get('loki_username') loki_username = args.get('loki_username')
loki_password = args.get('loki_password') loki_password = args.get('loki_password')
processor_id = args.get('id') # Processor identity (e.g., "config-svc", "text-completion")
try: try:
from logging_loki import LokiHandler from logging_loki import LokiHandler
# Create Loki handler with optional authentication and processor label # Create Loki handler with optional authentication. The
# processor label is NOT baked in here — it's stamped onto
# each record by _ProcessorIdFilter reading the task-local
# contextvar, and logging_loki's emitter reads record.tags
# to build per-record Loki labels.
loki_handler_kwargs = { loki_handler_kwargs = {
'url': loki_url, 'url': loki_url,
'version': "1", 'version': "1",
@ -101,10 +143,6 @@ def setup_logging(args):
if loki_username and loki_password: if loki_username and loki_password:
loki_handler_kwargs['auth'] = (loki_username, loki_password) loki_handler_kwargs['auth'] = (loki_username, loki_password)
# Add processor label if available (for consistency with Prometheus metrics)
if processor_id:
loki_handler_kwargs['tags'] = {'processor': processor_id}
loki_handler = LokiHandler(**loki_handler_kwargs) loki_handler = LokiHandler(**loki_handler_kwargs)
# Wrap in QueueHandler for non-blocking operation # Wrap in QueueHandler for non-blocking operation
@ -133,23 +171,44 @@ def setup_logging(args):
print(f"WARNING: Failed to setup Loki logging: {e}") print(f"WARNING: Failed to setup Loki logging: {e}")
print("Continuing with console-only logging") print("Continuing with console-only logging")
# Get processor ID for log formatting (use 'unknown' if not available) # Configure logging with all handlers. The processor id comes
processor_id = args.get('id', 'unknown') # from _ProcessorIdFilter (via contextvar) and is injected into
# each record as record.processor_id. The format string reads
# Configure logging with all handlers # that attribute on every emit.
# Use processor ID as the primary identifier in logs
logging.basicConfig( logging.basicConfig(
level=getattr(logging, log_level.upper()), level=getattr(logging, log_level.upper()),
format=f'%(asctime)s - {processor_id} - %(levelname)s - %(message)s', format='%(asctime)s - %(processor_id)s - %(levelname)s - %(message)s',
handlers=handlers, handlers=handlers,
force=True # Force reconfiguration if already configured force=True # Force reconfiguration if already configured
) )
# Prevent recursive logging from Loki's HTTP client # Attach the processor-id filter to every handler so all records
if loki_enabled and queue_listener: # passing through any sink get stamped (console, queue→loki,
# Disable urllib3 logging to prevent infinite loop # future handlers). Filters on handlers run regardless of which
logging.getLogger('urllib3').setLevel(logging.WARNING) # logger originated the record, so logs from pika, cassandra,
logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING) # processor code, etc. all pass through it.
processor_filter = _ProcessorIdFilter()
for h in handlers:
h.addFilter(processor_filter)
# Seed the contextvar from --id if one was supplied. In group
# mode --id isn't present; the processor_group supervisor sets
# it per task. In standalone mode AsyncProcessor.launch provides
# it via argparse default.
if args.get('id'):
set_processor_id(args['id'])
# Silence noisy third-party library loggers. These emit INFO-level
# chatter (connection churn, channel open/close, driver warnings) that
# drowns the useful signal and can't be attributed to a specific
# processor anyway. WARNING and above still propagate.
for noisy in (
'pika',
'cassandra',
'urllib3',
'urllib3.connectionpool',
):
logging.getLogger(noisy).setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info(f"Logging configured with level: {log_level}") logger.info(f"Logging configured with level: {log_level}")

View file

@ -29,7 +29,7 @@ import time
from prometheus_client import start_http_server from prometheus_client import start_http_server
from . logging import add_logging_args, setup_logging from . logging import add_logging_args, setup_logging, set_processor_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,6 +64,13 @@ async def _supervise(entry):
pid = entry["params"]["id"] pid = entry["params"]["id"]
class_path = entry["class"] class_path = entry["class"]
# Stamp the contextvar for this supervisor task. Every log
# record emitted from this task — and from any inner TaskGroup
# child created by the processor — inherits this id via
# contextvar propagation. Siblings in the outer group set
# their own id in their own task context and do not interfere.
set_processor_id(pid)
while True: while True:
try: try:

View file

@ -227,15 +227,30 @@ class RabbitMQBackendConsumer:
self._connect() self._connect()
def receive(self, timeout_millis: int = 2000) -> Message: def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message. Raises TimeoutError if none available.""" """Receive a message. Raises TimeoutError if none available.
Loop ordering matters: check _incoming at the TOP of each
iteration, not as the loop condition. process_data_events
may dispatch a message via the _on_message callback during
the pump; we must re-check _incoming on the next iteration
before giving up on the deadline. The previous control
flow (`while deadline: check; pump`) could lose a wakeup if
the pump consumed the remainder of the window the
`while` check would fail before `_incoming` was re-read,
leaving a just-dispatched message stranded until the next
receive() call one full poll cycle later.
"""
if not self._is_alive(): if not self._is_alive():
self._connect() self._connect()
timeout_seconds = timeout_millis / 1000.0 timeout_seconds = timeout_millis / 1000.0
deadline = time.monotonic() + timeout_seconds deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline: while True:
# Check if a message was already delivered # Check if a message has been dispatched to our queue.
# This catches both (a) messages dispatched before this
# receive() was called and (b) messages dispatched
# during the previous iteration's process_data_events.
try: try:
method, properties, body = self._incoming.get_nowait() method, properties, body = self._incoming.get_nowait()
return RabbitMQMessage( return RabbitMQMessage(
@ -244,14 +259,16 @@ class RabbitMQBackendConsumer:
except queue.Empty: except queue.Empty:
pass pass
# Drive pika's I/O — delivers messages and processes heartbeats
remaining = deadline - time.monotonic() remaining = deadline - time.monotonic()
if remaining > 0: if remaining <= 0:
self._connection.process_data_events( raise TimeoutError("No message received within timeout")
time_limit=min(0.1, remaining),
)
raise TimeoutError("No message received within timeout") # Drive pika's I/O. Any messages delivered during this
# call land in _incoming via _on_message; the next
# iteration of this loop catches them at the top.
self._connection.process_data_events(
time_limit=min(0.1, remaining),
)
def acknowledge(self, message: Message) -> None: def acknowledge(self, message: Message) -> None:
if isinstance(message, RabbitMQMessage) and message._method: if isinstance(message, RabbitMQMessage) and message._method:

View file

@ -4,6 +4,7 @@ Embeddings service, applies an embeddings model using fastembed
Input is text, output is embeddings vector. Input is text, output is embeddings vector.
""" """
import asyncio
import logging import logging
from ... base import EmbeddingsService from ... base import EmbeddingsService
@ -37,7 +38,13 @@ class Processor(EmbeddingsService):
self._load_model(model) self._load_model(model)
def _load_model(self, model_name): def _load_model(self, model_name):
"""Load a model, caching it for reuse""" """Load a model, caching it for reuse.
Synchronous CPU and I/O heavy. Callers that run on the
event loop must dispatch via asyncio.to_thread to avoid
freezing the loop (which, in processor-group deployments,
freezes every sibling processor in the same process).
"""
if self.cached_model_name != model_name: if self.cached_model_name != model_name:
logger.info(f"Loading FastEmbed model: {model_name}") logger.info(f"Loading FastEmbed model: {model_name}")
self.embeddings = TextEmbedding(model_name=model_name) self.embeddings = TextEmbedding(model_name=model_name)
@ -46,6 +53,11 @@ class Processor(EmbeddingsService):
else: else:
logger.debug(f"Using cached model: {model_name}") logger.debug(f"Using cached model: {model_name}")
def _run_embed(self, texts):
"""Synchronous embed call. Runs in a worker thread via
asyncio.to_thread from on_embeddings."""
return list(self.embeddings.embed(texts))
async def on_embeddings(self, texts, model=None): async def on_embeddings(self, texts, model=None):
if not texts: if not texts:
@ -53,11 +65,18 @@ class Processor(EmbeddingsService):
use_model = model or self.default_model use_model = model or self.default_model
# Reload model if it has changed # Reload model if it has changed. Model loading is sync
self._load_model(use_model) # and can take seconds; push it to a worker thread so the
# event loop (and any sibling processors in group mode)
# stay responsive.
if self.cached_model_name != use_model:
await asyncio.to_thread(self._load_model, use_model)
# FastEmbed processes the full batch efficiently # FastEmbed inference is synchronous ONNX runtime work.
vecs = list(self.embeddings.embed(texts)) # Dispatch to a worker thread so the event loop stays
# responsive for other tasks (important in group mode
# where the loop is shared across many processors).
vecs = await asyncio.to_thread(self._run_embed, texts)
# Return list of vectors, one per input text # Return list of vectors, one per input text
return [v.tolist() for v in vecs] return [v.tolist() for v in vecs]

View file

@ -23,6 +23,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute
from ... graphql import GraphQLSchemaBuilder, SortDirection from ... graphql import GraphQLSchemaBuilder, SortDirection
@ -263,7 +264,7 @@ class Processor(FlowProcessor):
query += f" LIMIT {limit}" query += f" LIMIT {limit}"
try: try:
rows = self.session.execute(query, params) rows = await async_execute(self.session, query, params)
for row in rows: for row in rows:
# Convert data map to dict with proper field names # Convert data map to dict with proper field names
row_dict = dict(row.data) if row.data else {} row_dict = dict(row.data) if row.data else {}
@ -301,7 +302,7 @@ class Processor(FlowProcessor):
params = [collection, schema_name, primary_index] params = [collection, schema_name, primary_index]
try: try:
rows = self.session.execute(query, params) rows = await async_execute(self.session, query, params)
for row in rows: for row in rows:
row_dict = dict(row.data) if row.data else {} row_dict = dict(row.data) if row.data else {}

View file

@ -4,6 +4,7 @@ Triples query service. Input is a (s, p, o, g) quad pattern, some values may be
null. Output is a list of quads. null. Output is a list of quads.
""" """
import asyncio
import logging import logging
import json import json
@ -200,7 +201,11 @@ class Processor(TriplesQueryService):
try: try:
self.ensure_connection(query.user) # ensure_connection may construct a fresh
# EntityCentricKnowledgeGraph which does sync schema
# setup against Cassandra. Push it to a worker thread
# so the event loop doesn't block on first-use per user.
await asyncio.to_thread(self.ensure_connection, query.user)
# Extract values from query # Extract values from query
s_val = get_term_value(query.s) s_val = get_term_value(query.s)
@ -218,14 +223,21 @@ class Processor(TriplesQueryService):
quads = [] quads = []
# All self.tg.get_* calls below are sync wrappers around
# cassandra session.execute. Materialise inside a worker
# thread so iteration never triggers sync paging back on
# the event loop.
# Route to appropriate query method based on which fields are specified # Route to appropriate query method based on which fields are specified
if s_val is not None: if s_val is not None:
if p_val is not None: if p_val is not None:
if o_val is not None: if o_val is not None:
# SPO specified - find matching graphs # SPO specified - find matching graphs
resp = self.tg.get_spo( resp = await asyncio.to_thread(
query.collection, s_val, p_val, o_val, g=g_val, lambda: list(self.tg.get_spo(
limit=query.limit query.collection, s_val, p_val, o_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -233,9 +245,11 @@ class Processor(TriplesQueryService):
quads.append((s_val, p_val, o_val, g, term_type, datatype, language)) quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
else: else:
# SP specified # SP specified
resp = self.tg.get_sp( resp = await asyncio.to_thread(
query.collection, s_val, p_val, g=g_val, lambda: list(self.tg.get_sp(
limit=query.limit query.collection, s_val, p_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -244,9 +258,11 @@ class Processor(TriplesQueryService):
else: else:
if o_val is not None: if o_val is not None:
# SO specified # SO specified
resp = self.tg.get_os( resp = await asyncio.to_thread(
query.collection, o_val, s_val, g=g_val, lambda: list(self.tg.get_os(
limit=query.limit query.collection, o_val, s_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -254,9 +270,11 @@ class Processor(TriplesQueryService):
quads.append((s_val, t.p, o_val, g, term_type, datatype, language)) quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
else: else:
# S only # S only
resp = self.tg.get_s( resp = await asyncio.to_thread(
query.collection, s_val, g=g_val, lambda: list(self.tg.get_s(
limit=query.limit query.collection, s_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -266,9 +284,11 @@ class Processor(TriplesQueryService):
if p_val is not None: if p_val is not None:
if o_val is not None: if o_val is not None:
# PO specified # PO specified
resp = self.tg.get_po( resp = await asyncio.to_thread(
query.collection, p_val, o_val, g=g_val, lambda: list(self.tg.get_po(
limit=query.limit query.collection, p_val, o_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -276,9 +296,11 @@ class Processor(TriplesQueryService):
quads.append((t.s, p_val, o_val, g, term_type, datatype, language)) quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
else: else:
# P only # P only
resp = self.tg.get_p( resp = await asyncio.to_thread(
query.collection, p_val, g=g_val, lambda: list(self.tg.get_p(
limit=query.limit query.collection, p_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -287,9 +309,11 @@ class Processor(TriplesQueryService):
else: else:
if o_val is not None: if o_val is not None:
# O only # O only
resp = self.tg.get_o( resp = await asyncio.to_thread(
query.collection, o_val, g=g_val, lambda: list(self.tg.get_o(
limit=query.limit query.collection, o_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -297,9 +321,10 @@ class Processor(TriplesQueryService):
quads.append((t.s, t.p, o_val, g, term_type, datatype, language)) quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
else: else:
# Nothing specified - get all # Nothing specified - get all
resp = self.tg.get_all( resp = await asyncio.to_thread(
query.collection, lambda: list(self.tg.get_all(
limit=query.limit query.collection, limit=query.limit,
))
) )
for t in resp: for t in resp:
# Note: quads_by_collection uses 'd' for graph field # Note: quads_by_collection uses 'd' for graph field
@ -340,7 +365,7 @@ class Processor(TriplesQueryService):
Uses Cassandra's paging to fetch results incrementally. Uses Cassandra's paging to fetch results incrementally.
""" """
try: try:
self.ensure_connection(query.user) await asyncio.to_thread(self.ensure_connection, query.user)
batch_size = query.batch_size if query.batch_size > 0 else 20 batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000 limit = query.limit if query.limit > 0 else 10000
@ -374,9 +399,16 @@ class Processor(TriplesQueryService):
yield batch, is_final yield batch, is_final
return return
# Create statement with fetch_size for true streaming # Materialise in a worker thread. We lose true streaming
# paging (the driver fetches all pages eagerly inside the
# thread) but the event loop stays responsive, and result
# sets at this layer are typically small enough that this
# is acceptable. If true async paging is needed later,
# revisit using ResponseFuture page callbacks.
statement = SimpleStatement(cql, fetch_size=batch_size) statement = SimpleStatement(cql, fetch_size=batch_size)
result_set = self.tg.session.execute(statement, params) result_set = await asyncio.to_thread(
lambda: list(self.tg.session.execute(statement, params))
)
batch = [] batch = []
count = 0 count = 0

View file

@ -13,6 +13,7 @@ Uses a single 'rows' table with the schema:
Each row is written multiple times - once per indexed field defined in the schema. Each row is written multiple times - once per indexed field defined in the schema.
""" """
import asyncio
import json import json
import logging import logging
import re import re
@ -26,6 +27,7 @@ from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler from .... base import CollectionConfigHandler
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -361,11 +363,15 @@ class Processor(CollectionConfigHandler, FlowProcessor):
schema_name = obj.schema_name schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or '' source = getattr(obj.metadata, 'source', '') or ''
# Ensure tables exist # Ensure tables exist (sync DDL — push to a worker thread
self.ensure_tables(keyspace) # so the event loop stays responsive when running in a
# processor group sharing the loop with siblings).
await asyncio.to_thread(self.ensure_tables, keyspace)
# Register partitions if first time seeing this (collection, schema_name) # Register partitions if first time seeing this (collection, schema_name)
self.register_partitions(keyspace, collection, schema_name) await asyncio.to_thread(
self.register_partitions, keyspace, collection, schema_name
)
safe_keyspace = self.sanitize_name(keyspace) safe_keyspace = self.sanitize_name(keyspace)
@ -406,9 +412,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
continue continue
try: try:
self.session.execute( await async_execute(
self.session,
insert_cql, insert_cql,
(collection, schema_name, index_name, index_value, data_map, source) (collection, schema_name, index_name, index_value, data_map, source),
) )
rows_written += 1 rows_written += 1
except Exception as e: except Exception as e:
@ -425,18 +432,18 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store""" """Create/verify collection exists in Cassandra row store"""
# Connect if not already connected # Connect if not already connected (sync, push to thread)
self.connect_cassandra() await asyncio.to_thread(self.connect_cassandra)
# Ensure tables exist # Ensure tables exist (sync DDL, push to thread)
self.ensure_tables(user) await asyncio.to_thread(self.ensure_tables, user)
logger.info(f"Collection {collection} ready for user {user}") logger.info(f"Collection {collection} ready for user {user}")
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection using partition tracking""" """Delete all data for a specific collection using partition tracking"""
# Connect if not already connected # Connect if not already connected
self.connect_cassandra() await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user) safe_keyspace = self.sanitize_name(user)
@ -446,8 +453,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
SELECT keyspace_name FROM system_schema.keyspaces SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s WHERE keyspace_name = %s
""" """
result = self.session.execute(check_keyspace_cql, (safe_keyspace,)) result = await async_execute(
if not result.one(): self.session, check_keyspace_cql, (safe_keyspace,)
)
if not result:
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return return
self.known_keyspaces.add(user) self.known_keyspaces.add(user)
@ -459,8 +468,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
""" """
try: try:
partitions = self.session.execute(select_partitions_cql, (collection,)) partition_list = await async_execute(
partition_list = list(partitions) self.session, select_partitions_cql, (collection,)
)
except Exception as e: except Exception as e:
logger.error(f"Failed to query partitions for collection {collection}: {e}") logger.error(f"Failed to query partitions for collection {collection}: {e}")
raise raise
@ -474,9 +484,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
partitions_deleted = 0 partitions_deleted = 0
for partition in partition_list: for partition in partition_list:
try: try:
self.session.execute( await async_execute(
self.session,
delete_rows_cql, delete_rows_cql,
(collection, partition.schema_name, partition.index_name) (collection, partition.schema_name, partition.index_name),
) )
partitions_deleted += 1 partitions_deleted += 1
except Exception as e: except Exception as e:
@ -493,7 +504,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
""" """
try: try:
self.session.execute(delete_partitions_cql, (collection,)) await async_execute(
self.session, delete_partitions_cql, (collection,)
)
except Exception as e: except Exception as e:
logger.error(f"Failed to clean up row_partitions for {collection}: {e}") logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
raise raise
@ -512,7 +525,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def delete_collection_schema(self, user: str, collection: str, schema_name: str): async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination""" """Delete all data for a specific collection + schema combination"""
# Connect if not already connected # Connect if not already connected
self.connect_cassandra() await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user) safe_keyspace = self.sanitize_name(user)
@ -523,8 +536,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
""" """
try: try:
partitions = self.session.execute(select_partitions_cql, (collection, schema_name)) partition_list = await async_execute(
partition_list = list(partitions) self.session, select_partitions_cql, (collection, schema_name)
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to query partitions for {collection}/{schema_name}: {e}" f"Failed to query partitions for {collection}/{schema_name}: {e}"
@ -540,9 +554,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
partitions_deleted = 0 partitions_deleted = 0
for partition in partition_list: for partition in partition_list:
try: try:
self.session.execute( await async_execute(
self.session,
delete_rows_cql, delete_rows_cql,
(collection, schema_name, partition.index_name) (collection, schema_name, partition.index_name),
) )
partitions_deleted += 1 partitions_deleted += 1
except Exception as e: except Exception as e:
@ -559,7 +574,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
""" """
try: try:
self.session.execute(delete_partitions_cql, (collection, schema_name)) await async_execute(
self.session,
delete_partitions_cql,
(collection, schema_name),
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}" f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"

View file

@ -3,6 +3,7 @@
Graph writer. Input is graph edge. Writes edges to Cassandra graph. Graph writer. Input is graph edge. Writes edges to Cassandra graph.
""" """
import asyncio
import base64 import base64
import os import os
import argparse import argparse
@ -150,59 +151,71 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
user = message.metadata.user user = message.metadata.user
if self.table is None or self.table != user: # The cassandra-driver work below — connection, schema
# setup, and per-triple inserts — is all synchronous.
# Wrap the whole batch in a worker thread so the event
# loop stays responsive for sibling processors when
# running in a processor group.
self.tg = None def _do_store():
# Use factory function to select implementation if self.table is None or self.table != user:
KGClass = EntityCentricKnowledgeGraph
try: self.tg = None
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
)
except Exception as e:
logger.error(f"Exception: {e}", exc_info=True)
time.sleep(1)
raise e
self.table = user # Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
for t in message.triples: try:
# Extract values from Term objects if self.cassandra_username and self.cassandra_password:
s_val = get_term_value(t.s) self.tg = KGClass(
p_val = get_term_value(t.p) hosts=self.cassandra_host,
o_val = get_term_value(t.o) keyspace=message.metadata.user,
# t.g is None for default graph, or a graph IRI username=self.cassandra_username,
g_val = t.g if t.g is not None else DEFAULT_GRAPH password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
)
except Exception as e:
logger.error(f"Exception: {e}", exc_info=True)
time.sleep(1)
raise e
# Extract object type metadata for entity-centric storage self.table = user
otype = get_term_otype(t.o)
dtype = get_term_dtype(t.o)
lang = get_term_lang(t.o)
self.tg.insert( for t in message.triples:
message.metadata.collection, # Extract values from Term objects
s_val, s_val = get_term_value(t.s)
p_val, p_val = get_term_value(t.p)
o_val, o_val = get_term_value(t.o)
g=g_val, # t.g is None for default graph, or a graph IRI
otype=otype, g_val = t.g if t.g is not None else DEFAULT_GRAPH
dtype=dtype,
lang=lang # Extract object type metadata for entity-centric storage
) otype = get_term_otype(t.o)
dtype = get_term_dtype(t.o)
lang = get_term_lang(t.o)
self.tg.insert(
message.metadata.collection,
s_val,
p_val,
o_val,
g=g_val,
otype=otype,
dtype=dtype,
lang=lang,
)
await asyncio.to_thread(_do_store)
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create a collection in Cassandra triple store via config push""" """Create a collection in Cassandra triple store via config push"""
try:
def _do_create():
# Create or reuse connection for this user's keyspace # Create or reuse connection for this user's keyspace
if self.table is None or self.table != user: if self.table is None or self.table != user:
self.tg = None self.tg = None
@ -216,7 +229,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
hosts=self.cassandra_host, hosts=self.cassandra_host,
keyspace=user, keyspace=user,
username=self.cassandra_username, username=self.cassandra_username,
password=self.cassandra_password password=self.cassandra_password,
) )
else: else:
self.tg = KGClass( self.tg = KGClass(
@ -238,13 +251,16 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
self.tg.create_collection(collection) self.tg.create_collection(collection)
logger.info(f"Created collection {collection}") logger.info(f"Created collection {collection}")
try:
await asyncio.to_thread(_do_create)
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection from the unified triples table""" """Delete all data for a specific collection from the unified triples table"""
try:
def _do_delete():
# Create or reuse connection for this user's keyspace # Create or reuse connection for this user's keyspace
if self.table is None or self.table != user: if self.table is None or self.table != user:
self.tg = None self.tg = None
@ -258,7 +274,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
hosts=self.cassandra_host, hosts=self.cassandra_host,
keyspace=user, keyspace=user,
username=self.cassandra_username, username=self.cassandra_username,
password=self.cassandra_password password=self.cassandra_password,
) )
else: else:
self.tg = KGClass( self.tg = KGClass(
@ -275,6 +291,8 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
self.tg.delete_collection(collection) self.tg.delete_collection(collection)
logger.info(f"Deleted all triples for collection {collection} from keyspace {user}") logger.info(f"Deleted all triples for collection {collection} from keyspace {user}")
try:
await asyncio.to_thread(_do_delete)
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
raise raise

View file

@ -0,0 +1,78 @@
"""
Async wrapper for cassandra-driver sessions.
The cassandra driver exposes a callback-based async API via
session.execute_async, returning a ResponseFuture that fires
on_result / on_error from the driver's own worker thread.
This module bridges that into an awaitable interface.
Usage:
from ..tables.cassandra_async import async_execute
rows = await async_execute(self.cassandra, stmt, (param1, param2))
for row in rows:
...
Notes:
- Rows are materialised into a list inside the driver callback
thread before the future is resolved, so subsequent iteration
in the caller never triggers a sync page-fetch on the asyncio
loop. This is safe for single-page results (the common case
in this codebase); if a query needs pagination, handle it
explicitly.
- Callbacks fire on a driver worker thread; call_soon_threadsafe
is used to hand the result back to the asyncio loop.
- Errors from the driver are re-raised in the awaiting coroutine.
"""
import asyncio
async def async_execute(session, query, parameters=None):
"""Execute a CQL statement asynchronously.
Args:
session: cassandra.cluster.Session (self.cassandra)
query: statement string or PreparedStatement
parameters: tuple/list of bind params, or None
Returns:
A list of rows (materialised from the first result page).
"""
loop = asyncio.get_running_loop()
fut = loop.create_future()
def on_result(rows):
# Materialise on the driver thread so the loop thread
# never touches a lazy iterator that might trigger
# further sync I/O.
try:
materialised = list(rows) if rows is not None else []
except Exception as e:
loop.call_soon_threadsafe(
_set_exception_if_pending, fut, e
)
return
loop.call_soon_threadsafe(
_set_result_if_pending, fut, materialised
)
def on_error(exc):
loop.call_soon_threadsafe(
_set_exception_if_pending, fut, exc
)
rf = session.execute_async(query, parameters)
rf.add_callbacks(on_result, on_error)
return await fut
def _set_result_if_pending(fut, result):
if not fut.done():
fut.set_result(result)
def _set_exception_if_pending(fut, exc):
if not fut.done():
fut.set_exception(exc)

View file

@ -11,6 +11,8 @@ import time
import asyncio import asyncio
import logging import logging
from . cassandra_async import async_execute
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ConfigTableStore: class ConfigTableStore:
@ -102,21 +104,20 @@ class ConfigTableStore:
async def inc_version(self): async def inc_version(self):
self.cassandra.execute(""" await async_execute(self.cassandra, """
UPDATE version set version = version + 1 UPDATE version set version = version + 1
WHERE id = 'version' WHERE id = 'version'
""") """)
async def get_version(self): async def get_version(self):
resp = self.cassandra.execute(""" rows = await async_execute(self.cassandra, """
SELECT version FROM version SELECT version FROM version
WHERE id = 'version' WHERE id = 'version'
""") """)
row = resp.one() if rows:
return rows[0][0]
if row: return row[0]
return None return None
@ -153,150 +154,91 @@ class ConfigTableStore:
""") """)
async def put_config(self, cls, key, value): async def put_config(self, cls, key, value):
try:
while True: await async_execute(
self.cassandra,
try: self.put_config_stmt,
(cls, key, value),
resp = self.cassandra.execute( )
self.put_config_stmt, except Exception:
( cls, key, value ) logger.error("Exception occurred", exc_info=True)
) raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
async def get_value(self, cls, key): async def get_value(self, cls, key):
try:
rows = await async_execute(
self.cassandra,
self.get_value_stmt,
(cls, key),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
while True: for row in rows:
try:
resp = self.cassandra.execute(
self.get_value_stmt,
( cls, key )
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
for row in resp:
return row[0] return row[0]
return None return None
async def get_values(self, cls): async def get_values(self, cls):
try:
rows = await async_execute(
self.cassandra,
self.get_values_stmt,
(cls,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
while True: return [[row[0], row[1]] for row in rows]
try:
resp = self.cassandra.execute(
self.get_values_stmt,
( cls, )
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
return [
[row[0], row[1]]
for row in resp
]
async def get_classes(self): async def get_classes(self):
try:
rows = await async_execute(
self.cassandra,
self.get_classes_stmt,
(),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
while True: return [row[0] for row in rows]
try:
resp = self.cassandra.execute(
self.get_classes_stmt,
()
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
return [
row[0] for row in resp
]
async def get_all(self): async def get_all(self):
try:
rows = await async_execute(
self.cassandra,
self.get_all_stmt,
(),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
while True: return [(row[0], row[1], row[2]) for row in rows]
try:
resp = self.cassandra.execute(
self.get_all_stmt,
()
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
return [
(row[0], row[1], row[2])
for row in resp
]
async def get_keys(self, cls): async def get_keys(self, cls):
try:
rows = await async_execute(
self.cassandra,
self.get_keys_stmt,
(cls,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
while True: return [row[0] for row in rows]
try:
resp = self.cassandra.execute(
self.get_keys_stmt,
( cls, )
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
return [
row[0] for row in resp
]
async def delete_key(self, cls, key): async def delete_key(self, cls, key):
try:
while True: await async_execute(
self.cassandra,
try: self.delete_key_stmt,
(cls, key),
resp = self.cassandra.execute( )
self.delete_key_stmt, except Exception:
(cls, key) logger.error("Exception occurred", exc_info=True)
) raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e

View file

@ -4,6 +4,8 @@ from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from . cassandra_async import async_execute
def term_to_tuple(term): def term_to_tuple(term):
"""Convert Term to (value, is_uri) tuple for database storage.""" """Convert Term to (value, is_uri) tuple for database storage."""
@ -225,25 +227,19 @@ class KnowledgeTableStore:
for v in m.triples for v in m.triples
] ]
while True: try:
await async_execute(
try: self.cassandra,
self.insert_triples_stmt,
resp = self.cassandra.execute( (
self.insert_triples_stmt, uuid.uuid4(), m.metadata.user,
( m.metadata.root or m.metadata.id, when,
uuid.uuid4(), m.metadata.user, [], triples,
m.metadata.root or m.metadata.id, when, ),
[], triples, )
) except Exception:
) logger.error("Exception occurred", exc_info=True)
raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
async def add_graph_embeddings(self, m): async def add_graph_embeddings(self, m):
@ -257,25 +253,19 @@ class KnowledgeTableStore:
for v in m.entities for v in m.entities
] ]
while True: try:
await async_execute(
try: self.cassandra,
self.insert_graph_embeddings_stmt,
resp = self.cassandra.execute( (
self.insert_graph_embeddings_stmt, uuid.uuid4(), m.metadata.user,
( m.metadata.root or m.metadata.id, when,
uuid.uuid4(), m.metadata.user, [], entities,
m.metadata.root or m.metadata.id, when, ),
[], entities, )
) except Exception:
) logger.error("Exception occurred", exc_info=True)
raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
async def add_document_embeddings(self, m): async def add_document_embeddings(self, m):
@ -289,50 +279,35 @@ class KnowledgeTableStore:
for v in m.chunks for v in m.chunks
] ]
while True: try:
await async_execute(
try: self.cassandra,
self.insert_document_embeddings_stmt,
resp = self.cassandra.execute( (
self.insert_document_embeddings_stmt, uuid.uuid4(), m.metadata.user,
( m.metadata.root or m.metadata.id, when,
uuid.uuid4(), m.metadata.user, [], chunks,
m.metadata.root or m.metadata.id, when, ),
[], chunks, )
) except Exception:
) logger.error("Exception occurred", exc_info=True)
raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
async def list_kg_cores(self, user): async def list_kg_cores(self, user):
logger.debug("List kg cores...") logger.debug("List kg cores...")
while True: try:
rows = await async_execute(
self.cassandra,
self.list_cores_stmt,
(user,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
try: lst = [row[1] for row in rows]
resp = self.cassandra.execute(
self.list_cores_stmt,
(user,)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
lst = [
row[1]
for row in resp
]
logger.debug("Done") logger.debug("Done")
@ -342,56 +317,41 @@ class KnowledgeTableStore:
logger.debug("Delete kg cores...") logger.debug("Delete kg cores...")
while True: try:
await async_execute(
self.cassandra,
self.delete_triples_stmt,
(user, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
try: try:
await async_execute(
resp = self.cassandra.execute( self.cassandra,
self.delete_triples_stmt, self.delete_graph_embeddings_stmt,
(user, document_id) (user, document_id),
) )
except Exception:
break logger.error("Exception occurred", exc_info=True)
raise
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
while True:
try:
resp = self.cassandra.execute(
self.delete_graph_embeddings_stmt,
(user, document_id)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
async def get_triples(self, user, document_id, receiver): async def get_triples(self, user, document_id, receiver):
logger.debug("Get triples...") logger.debug("Get triples...")
while True: try:
rows = await async_execute(
self.cassandra,
self.get_triples_stmt,
(user, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
try: for row in rows:
resp = self.cassandra.execute(
self.get_triples_stmt,
(user, document_id)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
for row in resp:
if row[3]: if row[3]:
triples = [ triples = [
@ -422,22 +382,17 @@ class KnowledgeTableStore:
logger.debug("Get GE...") logger.debug("Get GE...")
while True: try:
rows = await async_execute(
self.cassandra,
self.get_graph_embeddings_stmt,
(user, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
try: for row in rows:
resp = self.cassandra.execute(
self.get_graph_embeddings_stmt,
(user, document_id)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
for row in resp:
if row[3]: if row[3]:
entities = [ entities = [

View file

@ -31,6 +31,8 @@ import time
import asyncio import asyncio
import logging import logging
from . cassandra_async import async_execute
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LibraryTableStore: class LibraryTableStore:
@ -321,18 +323,13 @@ class LibraryTableStore:
async def document_exists(self, user, id): async def document_exists(self, user, id):
resp = self.cassandra.execute( rows = await async_execute(
self.cassandra,
self.test_document_exists_stmt, self.test_document_exists_stmt,
( user, id ) (user, id),
) )
# If a row exists, document exists. It's a cursor, can't just return bool(rows)
# count the length
for row in resp:
return True
return False
async def add_document(self, document, object_id): async def add_document(self, document, object_id):
@ -349,26 +346,20 @@ class LibraryTableStore:
parent_id = getattr(document, 'parent_id', '') or '' parent_id = getattr(document, 'parent_id', '') or ''
document_type = getattr(document, 'document_type', 'source') or 'source' document_type = getattr(document, 'document_type', 'source') or 'source'
while True: try:
await async_execute(
try: self.cassandra,
self.insert_document_stmt,
resp = self.cassandra.execute( (
self.insert_document_stmt, document.id, document.user, int(document.time * 1000),
( document.kind, document.title, document.comments,
document.id, document.user, int(document.time * 1000), metadata, document.tags, object_id,
document.kind, document.title, document.comments, parent_id, document_type
metadata, document.tags, object_id, ),
parent_id, document_type )
) except Exception:
) logger.error("Exception occurred", exc_info=True)
raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Add complete") logger.debug("Add complete")
@ -383,25 +374,19 @@ class LibraryTableStore:
for v in document.metadata for v in document.metadata
] ]
while True: try:
await async_execute(
try: self.cassandra,
self.update_document_stmt,
resp = self.cassandra.execute( (
self.update_document_stmt, int(document.time * 1000), document.title,
( document.comments, metadata, document.tags,
int(document.time * 1000), document.title, document.user, document.id
document.comments, metadata, document.tags, ),
document.user, document.id )
) except Exception:
) logger.error("Exception occurred", exc_info=True)
raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Update complete") logger.debug("Update complete")
@ -409,23 +394,15 @@ class LibraryTableStore:
logger.info(f"Removing document {document_id}") logger.info(f"Removing document {document_id}")
while True: try:
await async_execute(
try: self.cassandra,
self.delete_document_stmt,
resp = self.cassandra.execute( (user, document_id),
self.delete_document_stmt, )
( except Exception:
user, document_id logger.error("Exception occurred", exc_info=True)
) raise
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Delete complete") logger.debug("Delete complete")
@ -433,21 +410,15 @@ class LibraryTableStore:
logger.debug("List documents...") logger.debug("List documents...")
while True: try:
rows = await async_execute(
try: self.cassandra,
self.list_document_stmt,
resp = self.cassandra.execute( (user,),
self.list_document_stmt, )
(user,) except Exception:
) logger.error("Exception occurred", exc_info=True)
raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
lst = [ lst = [
DocumentMetadata( DocumentMetadata(
@ -469,7 +440,7 @@ class LibraryTableStore:
parent_id = row[8] if row[8] else "", parent_id = row[8] if row[8] else "",
document_type = row[9] if row[9] else "source", document_type = row[9] if row[9] else "source",
) )
for row in resp for row in rows
] ]
logger.debug("Done") logger.debug("Done")
@ -481,20 +452,15 @@ class LibraryTableStore:
logger.debug(f"List children for parent {parent_id}") logger.debug(f"List children for parent {parent_id}")
while True: try:
rows = await async_execute(
try: self.cassandra,
self.list_children_stmt,
resp = self.cassandra.execute( (parent_id,),
self.list_children_stmt, )
(parent_id,) except Exception:
) logger.error("Exception occurred", exc_info=True)
raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
lst = [ lst = [
DocumentMetadata( DocumentMetadata(
@ -516,7 +482,7 @@ class LibraryTableStore:
parent_id = row[9] if row[9] else "", parent_id = row[9] if row[9] else "",
document_type = row[10] if row[10] else "source", document_type = row[10] if row[10] else "source",
) )
for row in resp for row in rows
] ]
logger.debug("Done") logger.debug("Done")
@ -527,23 +493,17 @@ class LibraryTableStore:
logger.debug("Get document") logger.debug("Get document")
while True: try:
rows = await async_execute(
self.cassandra,
self.get_document_stmt,
(user, id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
try: for row in rows:
resp = self.cassandra.execute(
self.get_document_stmt,
(user, id)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
for row in resp:
doc = DocumentMetadata( doc = DocumentMetadata(
id = id, id = id,
user = user, user = user,
@ -573,23 +533,17 @@ class LibraryTableStore:
logger.debug("Get document obj ID") logger.debug("Get document obj ID")
while True: try:
rows = await async_execute(
self.cassandra,
self.get_document_stmt,
(user, id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
try: for row in rows:
resp = self.cassandra.execute(
self.get_document_stmt,
(user, id)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
for row in resp:
logger.debug("Done") logger.debug("Done")
return row[6] return row[6]
@ -597,43 +551,32 @@ class LibraryTableStore:
async def processing_exists(self, user, id): async def processing_exists(self, user, id):
resp = self.cassandra.execute( rows = await async_execute(
self.cassandra,
self.test_processing_exists_stmt, self.test_processing_exists_stmt,
( user, id ) (user, id),
) )
# If a row exists, document exists. It's a cursor, can't just return bool(rows)
# count the length
for row in resp:
return True
return False
async def add_processing(self, processing): async def add_processing(self, processing):
logger.info(f"Adding processing {processing.id}") logger.info(f"Adding processing {processing.id}")
while True: try:
await async_execute(
try: self.cassandra,
self.insert_processing_stmt,
resp = self.cassandra.execute( (
self.insert_processing_stmt, processing.id, processing.document_id,
( int(processing.time * 1000), processing.flow,
processing.id, processing.document_id, processing.user, processing.collection,
int(processing.time * 1000), processing.flow, processing.tags
processing.user, processing.collection, ),
processing.tags )
) except Exception:
) logger.error("Exception occurred", exc_info=True)
raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Add complete") logger.debug("Add complete")
@ -641,23 +584,15 @@ class LibraryTableStore:
logger.info(f"Removing processing {processing_id}") logger.info(f"Removing processing {processing_id}")
while True: try:
await async_execute(
try: self.cassandra,
self.delete_processing_stmt,
resp = self.cassandra.execute( (user, processing_id),
self.delete_processing_stmt, )
( except Exception:
user, processing_id logger.error("Exception occurred", exc_info=True)
) raise
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Delete complete") logger.debug("Delete complete")
@ -665,21 +600,15 @@ class LibraryTableStore:
logger.debug("List processing objects") logger.debug("List processing objects")
while True: try:
rows = await async_execute(
try: self.cassandra,
self.list_processing_stmt,
resp = self.cassandra.execute( (user,),
self.list_processing_stmt, )
(user,) except Exception:
) logger.error("Exception occurred", exc_info=True)
raise
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
lst = [ lst = [
ProcessingMetadata( ProcessingMetadata(
@ -691,7 +620,7 @@ class LibraryTableStore:
collection = row[4], collection = row[4],
tags = row[5] if row[5] else [], tags = row[5] if row[5] else [],
) )
for row in resp for row in rows
] ]
logger.debug("Done") logger.debug("Done")
@ -718,20 +647,19 @@ class LibraryTableStore:
now = int(time.time() * 1000) now = int(time.time() * 1000)
while True: try:
try: await async_execute(
self.cassandra.execute( self.cassandra,
self.insert_upload_session_stmt, self.insert_upload_session_stmt,
( (
upload_id, user, document_id, document_metadata, upload_id, user, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size, s3_upload_id, object_id, total_size, chunk_size,
total_chunks, {}, now, now total_chunks, {}, now, now
) ),
) )
break except Exception:
except Exception as e: logger.error("Exception occurred", exc_info=True)
logger.error("Exception occurred", exc_info=True) raise
raise e
logger.debug("Upload session created") logger.debug("Upload session created")
@ -740,18 +668,17 @@ class LibraryTableStore:
logger.debug(f"Get upload session {upload_id}") logger.debug(f"Get upload session {upload_id}")
while True: try:
try: rows = await async_execute(
resp = self.cassandra.execute( self.cassandra,
self.get_upload_session_stmt, self.get_upload_session_stmt,
(upload_id,) (upload_id,),
) )
break except Exception:
except Exception as e: logger.error("Exception occurred", exc_info=True)
logger.error("Exception occurred", exc_info=True) raise
raise e
for row in resp: for row in rows:
session = { session = {
"upload_id": row[0], "upload_id": row[0],
"user": row[1], "user": row[1],
@ -778,20 +705,19 @@ class LibraryTableStore:
now = int(time.time() * 1000) now = int(time.time() * 1000)
while True: try:
try: await async_execute(
self.cassandra.execute( self.cassandra,
self.update_upload_session_chunk_stmt, self.update_upload_session_chunk_stmt,
( (
{chunk_index: etag}, {chunk_index: etag},
now, now,
upload_id upload_id
) ),
) )
break except Exception:
except Exception as e: logger.error("Exception occurred", exc_info=True)
logger.error("Exception occurred", exc_info=True) raise
raise e
logger.debug("Chunk recorded") logger.debug("Chunk recorded")
@ -800,16 +726,15 @@ class LibraryTableStore:
logger.info(f"Deleting upload session {upload_id}") logger.info(f"Deleting upload session {upload_id}")
while True: try:
try: await async_execute(
self.cassandra.execute( self.cassandra,
self.delete_upload_session_stmt, self.delete_upload_session_stmt,
(upload_id,) (upload_id,),
) )
break except Exception:
except Exception as e: logger.error("Exception occurred", exc_info=True)
logger.error("Exception occurred", exc_info=True) raise
raise e
logger.debug("Upload session deleted") logger.debug("Upload session deleted")
@ -818,19 +743,18 @@ class LibraryTableStore:
logger.debug(f"List upload sessions for {user}") logger.debug(f"List upload sessions for {user}")
while True: try:
try: rows = await async_execute(
resp = self.cassandra.execute( self.cassandra,
self.list_upload_sessions_stmt, self.list_upload_sessions_stmt,
(user,) (user,),
) )
break except Exception:
except Exception as e: logger.error("Exception occurred", exc_info=True)
logger.error("Exception occurred", exc_info=True) raise
raise e
sessions = [] sessions = []
for row in resp: for row in rows:
chunks_received = row[6] if row[6] else {} chunks_received = row[6] if row[6] else {}
sessions.append({ sessions.append({
"upload_id": row[0], "upload_id": row[0],