release/v2.4 -> master (#924)

* CLI auth migration, document embeddings core lifecycle (#913)

Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient
with first-frame auth (fixes broken raw websocket path). Fix wire
format field names (root/vector). Remove ~600 lines of dead raw
websocket code from invoke_graph_rag.py.

Add document embeddings core lifecycle to the knowledge service:
list/get/put/delete/load operations across schema, translator,
Cassandra table store, knowledge manager, gateway registry, REST API,
socket client, and CLI (tg-get-de-core, tg-put-de-core).

Fix delete_kg_core to also clean up document embeddings rows.

* Remove spurious workspace parameter from SPARQL algebra evaluator (#915)

Fix threading of workspace paramater:
- The SPARQL algebra evaluator was threading a workspace parameter
  through every function and passing it to TriplesClient.query(),
  which doesn't accept it. Workspace isolation is handled by pub/sub
  topic routing — the TriplesClient is already scoped to a
  workspace-specific flow, same as GraphRAG. Passing workspace
  explicitly was both incorrect and unnecessary.

Update tests:
- tests/unit/test_query/test_sparql_algebra.py (new) — Tests
  _query_pattern, _eval_bgp, and evaluate() with various algebra
  nodes. Key tests assert workspace is never in tc.query() kwargs,
  plus correctness tests for BGP, JOIN, UNION, SLICE, DISTINCT, and
  edge cases.
- tests/unit/test_retrieval/test_graph_rag.py — Added
  test_triples_query_never_passes_workspace (checks query()) and
  test_follow_edges_never_passes_workspace (checks query_stream()).

* Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916)

Cassandra triples services were using syncronous EntityCentricKnowledgeGraph
methods from async contexts, and connection state was managed with
threading.local which is wrong for asyncio coroutines sharing a single
thread. Qdrant services had no async wrapping at all, blocking the event
loop on every network call. Rows services had unprotected shared state
mutations across concurrent coroutines.

- Add async methods to EntityCentricKnowledgeGraph (async_insert,
  async_get_s/p/o/sp/po/os/spo/all, async_collection_exists,
  async_create_collection, async_delete_collection) using the existing
  cassandra_async.async_execute bridge
- Rewrite triples write + query services: replace threading.local with
  asyncio.Lock + dict cache for per-workspace connections, use async
  ECKG methods for all data operations, keep asyncio.to_thread only for
  one-time blocking ECKG construction
- Wrap all Qdrant calls in asyncio.to_thread across all 6 services
  (doc/graph/row embeddings write + query), add asyncio.Lock + set cache
  for collection existence checks
- Add asyncio.Lock to rows write + query services to protect shared
  state (schemas, sessions, config caches) from concurrent mutation
- Update all affected tests to match new async patterns

* Fixed error only returning a page of results (#921)

The root cause: async_execute only materialises the first result
page (by design — it says so in its docstring). The streaming query
set fetch_size=20 and expected to iterate all results, but only got
the first 20 rows back.

The fix uses
  asyncio.to_thread(lambda: list(tg.session.execute(...)))
which lets the sync driver iterate
all pages in a worker thread — exactly what the pre-async code did.

* Optional test warning suppression (#923)

* Fix test collection module errors & silence upstream Pytest warnings (#823)

* chore: add virtual environment and .env directories to gitignore

* test: filter upstream DeprecationWarning and UserWarning messages

* fix(namespace): remove empty __init__.py files to fix PEP 420 implicit namespace routing for trustgraph sub-packages

* Revert __init__.py deletions

* Add .ini changes but commented out, will be useful at times

---------

Co-authored-by: Salil M <d2kyt@protonmail.com>
This commit is contained in:
cybermaggedon 2026-05-15 13:02:51 +01:00 committed by GitHub
parent 159b1e2824
commit 142dd0231c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 1910 additions and 1492 deletions

View file

@ -1,5 +1,6 @@
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
from .. schema import DocumentEmbeddings
from .. knowledge import hash
from .. exceptions import RequestError
from .. tables.knowledge import KnowledgeTableStore
@ -157,6 +158,98 @@ class KnowledgeManager:
)
)
async def list_de_cores(self, request, respond, workspace):
ids = await self.table_store.list_de_cores(workspace)
await respond(
KnowledgeResponse(
error = None,
ids = ids,
eos = False,
triples = None,
graph_embeddings = None,
)
)
async def get_de_core(self, request, respond, workspace):
logger.info("Getting document embeddings core...")
async def publish_de(de):
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None,
document_embeddings = de,
)
)
await self.table_store.get_document_embeddings(
workspace,
request.id,
publish_de,
)
logger.debug("Document embeddings core retrieval complete")
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = True,
triples = None,
graph_embeddings = None,
)
)
async def put_de_core(self, request, respond, workspace):
if request.document_embeddings:
await self.table_store.add_document_embeddings(
workspace, request.document_embeddings
)
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None,
)
)
async def delete_de_core(self, request, respond, workspace):
logger.info("Deleting document embeddings core...")
await self.table_store.delete_document_embeddings(
workspace, request.id
)
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None,
)
)
async def load_de_core(self, request, respond, workspace):
if self.background_task is None:
self.background_task = asyncio.create_task(
self.core_loader()
)
await self.loader_queue.put((request, respond, workspace))
async def core_loader(self):
logger.info("Knowledge background processor running...")
@ -165,7 +258,7 @@ class KnowledgeManager:
logger.debug("Waiting for next load...")
request, respond, workspace = await self.loader_queue.get()
logger.info(f"Loading knowledge: {request.id}")
logger.info(f"Loading: {request.operation} {request.id}")
try:
@ -187,25 +280,14 @@ class KnowledgeManager:
if "interfaces" not in flow:
raise RuntimeError("No defined interfaces")
if "triples-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no triples-store")
if "graph-embeddings-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no graph-embeddings-store")
t_q = flow["interfaces"]["triples-store"]["flow"]
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
# Got this far, it should all work
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None
if request.operation == "load-de-core":
await self._load_de_core(
request, respond, workspace, flow,
)
else:
await self._load_kg_core(
request, respond, workspace, flow,
)
)
except Exception as e:
@ -223,72 +305,145 @@ class KnowledgeManager:
)
)
logger.debug("Starting knowledge loading process...")
try:
t_pub = None
ge_pub = None
logger.debug(f"Triples queue: {t_q}")
logger.debug(f"Graph embeddings queue: {ge_q}")
t_pub = Publisher(
self.flow_config.pubsub, t_q,
schema=Triples,
)
ge_pub = Publisher(
self.flow_config.pubsub, ge_q,
schema=GraphEmbeddings
)
logger.debug("Starting publishers...")
await t_pub.start()
await ge_pub.start()
async def publish_triples(t):
# Override collection with request collection
if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'):
t.metadata.collection = request.collection or "default"
await t_pub.send(None, t)
logger.debug("Publishing triples...")
await self.table_store.get_triples(
workspace,
request.id,
publish_triples,
)
async def publish_ge(g):
# Override collection with request collection
if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'):
g.metadata.collection = request.collection or "default"
await ge_pub.send(None, g)
logger.debug("Publishing graph embeddings...")
await self.table_store.get_graph_embeddings(
workspace,
request.id,
publish_ge,
)
logger.debug("Knowledge loading completed")
except Exception as e:
logger.error(f"Knowledge exception: {e}", exc_info=True)
finally:
logger.debug("Stopping publishers...")
if t_pub: await t_pub.stop()
if ge_pub: await ge_pub.stop()
logger.debug("Knowledge processing done")
continue
async def _load_kg_core(self, request, respond, workspace, flow):
if "triples-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no triples-store")
if "graph-embeddings-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no graph-embeddings-store")
t_q = flow["interfaces"]["triples-store"]["flow"]
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None
)
)
t_pub = None
ge_pub = None
try:
logger.debug(f"Triples queue: {t_q}")
logger.debug(f"Graph embeddings queue: {ge_q}")
t_pub = Publisher(
self.flow_config.pubsub, t_q,
schema=Triples,
)
ge_pub = Publisher(
self.flow_config.pubsub, ge_q,
schema=GraphEmbeddings
)
logger.debug("Starting publishers...")
await t_pub.start()
await ge_pub.start()
async def publish_triples(t):
if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'):
t.metadata.collection = request.collection or "default"
await t_pub.send(None, t)
logger.debug("Publishing triples...")
await self.table_store.get_triples(
workspace,
request.id,
publish_triples,
)
async def publish_ge(g):
if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'):
g.metadata.collection = request.collection or "default"
await ge_pub.send(None, g)
logger.debug("Publishing graph embeddings...")
await self.table_store.get_graph_embeddings(
workspace,
request.id,
publish_ge,
)
logger.debug("Knowledge core loading completed")
except Exception as e:
logger.error(f"Knowledge exception: {e}", exc_info=True)
finally:
logger.debug("Stopping publishers...")
if t_pub: await t_pub.stop()
if ge_pub: await ge_pub.stop()
async def _load_de_core(self, request, respond, workspace, flow):
if "document-embeddings-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no document-embeddings-store")
de_q = flow["interfaces"]["document-embeddings-store"]["flow"]
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None
)
)
de_pub = None
try:
logger.debug(f"Document embeddings queue: {de_q}")
de_pub = Publisher(
self.flow_config.pubsub, de_q,
schema=DocumentEmbeddings,
)
logger.debug("Starting publisher...")
await de_pub.start()
async def publish_de(de):
if hasattr(de, 'metadata') and hasattr(de.metadata, 'collection'):
de.metadata.collection = request.collection or "default"
await de_pub.send(None, de)
logger.debug("Publishing document embeddings...")
await self.table_store.get_document_embeddings(
workspace,
request.id,
publish_de,
)
logger.debug("Document embeddings core loading completed")
except Exception as e:
logger.error(f"Knowledge exception: {e}", exc_info=True)
finally:
logger.debug("Stopping publisher...")
if de_pub: await de_pub.stop()

View file

@ -187,6 +187,11 @@ class Processor(WorkspaceProcessor):
"put-kg-core": self.knowledge.put_kg_core,
"load-kg-core": self.knowledge.load_kg_core,
"unload-kg-core": self.knowledge.unload_kg_core,
"list-de-cores": self.knowledge.list_de_cores,
"get-de-core": self.knowledge.get_de_core,
"delete-de-core": self.knowledge.delete_de_core,
"put-de-core": self.knowledge.put_de_core,
"load-de-core": self.knowledge.load_de_core,
}
if v.operation not in impls:

View file

@ -1,10 +1,14 @@
import datetime
import os
import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
from ssl import SSLContext, PROTOCOL_TLSv1_2
import os
import logging
from ..tables.cassandra_async import async_execute
# Global list to track clusters for cleanup
_active_clusters = []
@ -461,7 +465,6 @@ class KnowledgeGraph:
def create_collection(self, collection):
"""Create collection by inserting metadata row"""
try:
import datetime
self.session.execute(
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now())
@ -954,7 +957,6 @@ class EntityCentricKnowledgeGraph:
def create_collection(self, collection):
"""Create collection by inserting metadata row"""
try:
import datetime
self.session.execute(
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now())
@ -1045,6 +1047,222 @@ class EntityCentricKnowledgeGraph:
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
# ========================================================================
# Async methods — use cassandra driver's native async API via async_execute
# ========================================================================
async def async_insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""):
if g is None:
g = DEFAULT_GRAPH
if otype is None:
if o.startswith("http://") or o.startswith("https://"):
otype = "u"
else:
otype = "l"
batch = BatchStatement()
batch.add(self.insert_entity_stmt, (collection, s, 'S', p, otype, s, o, g, dtype, lang))
batch.add(self.insert_entity_stmt, (collection, p, 'P', p, otype, s, o, g, dtype, lang))
if otype == 'u' or otype == 't':
batch.add(self.insert_entity_stmt, (collection, o, 'O', p, otype, s, o, g, dtype, lang))
if g != DEFAULT_GRAPH:
batch.add(self.insert_entity_stmt, (collection, g, 'G', p, otype, s, o, g, dtype, lang))
batch.add(self.insert_collection_stmt, (collection, g, s, p, o, otype, dtype, lang))
await async_execute(self.session, batch)
async def async_get_all(self, collection, limit=50):
return await async_execute(
self.session, self.get_collection_all_stmt, (collection, limit)
)
async def async_get_s(self, collection, s, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=row.p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_p(self, collection, p, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_p_stmt, (collection, p, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=row.p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_o(self, collection, o, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_o_stmt, (collection, o, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=row.p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_sp(self, collection, s, p, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=s, p=p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_po(self, collection, p, o, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_o_p_stmt, (collection, o, p, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=p, o=o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_os(self, collection, o, s, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
)
results = []
for row in rows:
if row.o != o:
continue
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=s, p=row.p, o=o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_spo(self, collection, s, p, o, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
)
results = []
for row in rows:
if row.o != o:
continue
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=s, p=p, o=o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_g(self, collection, g, limit=50):
if g is None:
g = DEFAULT_GRAPH
return await async_execute(
self.session, self.get_collection_by_graph_stmt, (collection, g, limit)
)
async def async_collection_exists(self, collection):
try:
result = await async_execute(
self.session,
f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1",
(collection,)
)
return bool(result)
except Exception as e:
logger.error(f"Error checking collection existence: {e}")
return False
async def async_create_collection(self, collection):
await async_execute(
self.session,
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now())
)
logger.info(f"Created collection metadata for {collection}")
async def async_delete_collection(self, collection):
rows = await async_execute(
self.session,
f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s",
(collection,)
)
entities = set()
quads = []
for row in rows:
d, s, p, o = row.d, row.s, row.p, row.o
otype = row.otype
dtype = row.dtype if hasattr(row, 'dtype') else ''
lang = row.lang if hasattr(row, 'lang') else ''
quads.append((d, s, p, o, otype, dtype, lang))
entities.add(s)
entities.add(p)
if otype == 'u' or otype == 't':
entities.add(o)
if d != DEFAULT_GRAPH:
entities.add(d)
batch = BatchStatement()
count = 0
for entity in entities:
batch.add(self.delete_entity_partition_stmt, (collection, entity))
count += 1
if count % 50 == 0:
await async_execute(self.session, batch)
batch = BatchStatement()
if count % 50 != 0:
await async_execute(self.session, batch)
batch = BatchStatement()
count = 0
for d, s, p, o, otype, dtype, lang in quads:
batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang))
count += 1
if count % 50 == 0:
await async_execute(self.session, batch)
batch = BatchStatement()
if count % 50 != 0:
await async_execute(self.session, batch)
await async_execute(
self.session,
f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s",
(collection,)
)
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
def close(self):
"""Close connections"""
if hasattr(self, 'session') and self.session:

View file

@ -457,6 +457,12 @@ for _op in ("put-kg-core", "delete-kg-core",
"load-kg-core", "unload-kg-core"):
_register_kind_op("knowledge", _op, "knowledge:write")
# knowledge: document-embeddings core service.
for _op in ("get-de-core", "list-de-cores"):
_register_kind_op("knowledge", _op, "knowledge:read")
for _op in ("put-de-core", "delete-de-core", "load-de-core"):
_register_kind_op("knowledge", _op, "knowledge:write")
# collection-management: workspace collection lifecycle.
_register_kind_op("collection-management", "list-collections", "collections:read")

View file

@ -4,11 +4,10 @@ Document embeddings query service. Input is vector, output is an array
of chunk_ids
"""
import asyncio
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
@ -38,32 +37,6 @@ class Processor(DocumentEmbeddingsQueryService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.last_collection = None
def ensure_collection_exists(self, collection, dim):
"""Ensure collection exists, create if it doesn't"""
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
logger.info(f"Created collection: {collection}")
except Exception as e:
logger.error(f"Qdrant collection creation failed: {e}")
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
async def query_document_embeddings(self, workspace, msg):
@ -73,21 +46,24 @@ class Processor(DocumentEmbeddingsQueryService):
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection
)
if not exists:
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
search_result = self.qdrant.query_points(
result = await asyncio.to_thread(
self.qdrant.query_points,
collection_name=collection,
query=vec,
limit=msg.limit,
with_payload=True,
).points
)
search_result = result.points
chunks = []
for r in search_result:

View file

@ -4,11 +4,10 @@ Graph embeddings query service. Input is vector, output is list of
entities
"""
import asyncio
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
@ -38,32 +37,6 @@ class Processor(GraphEmbeddingsQueryService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.last_collection = None
def ensure_collection_exists(self, collection, dim):
"""Ensure collection exists, create if it doesn't"""
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
logger.info(f"Created collection: {collection}")
except Exception as e:
logger.error(f"Qdrant collection creation failed: {e}")
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@ -79,23 +52,26 @@ class Processor(GraphEmbeddingsQueryService):
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection
)
if not exists:
logger.info(f"Collection {collection} does not exist")
return []
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
search_result = self.qdrant.query_points(
result = await asyncio.to_thread(
self.qdrant.query_points,
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
)
search_result = result.points
entity_set = set()
entities = []

View file

@ -6,6 +6,7 @@ Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups.
"""
import asyncio
import logging
import re
from typing import Optional
@ -70,7 +71,7 @@ class Processor(FlowProcessor):
safe_name = 'r_' + safe_name
return safe_name.lower()
def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
async def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given workspace/collection/schema"""
prefix = (
f"rows_{self.sanitize_name(workspace)}_"
@ -78,14 +79,15 @@ class Processor(FlowProcessor):
)
try:
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if matching:
# Return first match (there should typically be only one per dimension)
return matching[0]
except Exception as e:
@ -100,8 +102,7 @@ class Processor(FlowProcessor):
if not vec:
return []
# Find the collection for this workspace/collection/schema
qdrant_collection = self.find_collection(
qdrant_collection = await self.find_collection(
workspace, request.collection, request.schema_name
)
@ -113,7 +114,6 @@ class Processor(FlowProcessor):
return []
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
@ -125,16 +125,16 @@ class Processor(FlowProcessor):
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
result = await asyncio.to_thread(
self.qdrant.query_points,
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
)
search_result = result.points
# Convert to RowIndexMatch objects
matches = []
for point in search_result:
payload = point.payload or {}

View file

@ -11,6 +11,7 @@ Queries against the unified 'rows' table with schema:
- source: text
"""
import asyncio
import json
import logging
import re
@ -97,34 +98,38 @@ class Processor(FlowProcessor):
# Cassandra session
self.cluster = None
self.session = None
self._setup_lock = asyncio.Lock()
# Known keyspaces
self.known_keyspaces: Set[str] = set()
def connect_cassandra(self):
async def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
async with self._setup_lock:
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
session = await asyncio.to_thread(cluster.connect)
self.cluster = cluster
self.session = session
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
@ -140,14 +145,17 @@ class Processor(FlowProcessor):
f"for workspace {workspace}"
)
# Replace existing schemas for this workspace
async with self._setup_lock:
await self._apply_schema_config(workspace, config)
async def _apply_schema_config(self, workspace, config):
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
builder = GraphQLSchemaBuilder()
self.schema_builders[workspace] = builder
# Check if our config type exists
if self.config_key not in config:
logger.warning(
f"No '{self.config_key}' type in configuration "
@ -156,16 +164,12 @@ class Processor(FlowProcessor):
self.graphql_schemas[workspace] = None
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
@ -180,7 +184,6 @@ class Processor(FlowProcessor):
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
@ -202,7 +205,6 @@ class Processor(FlowProcessor):
f"{len(ws_schemas)} schemas"
)
# Regenerate GraphQL schema for this workspace
self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
def get_index_names(self, schema: RowSchema) -> List[str]:
@ -254,7 +256,7 @@ class Processor(FlowProcessor):
For other queries, we need to scan and post-filter.
"""
# Connect if needed
self.connect_cassandra()
await self.connect_cassandra()
safe_keyspace = self.sanitize_name(workspace)

View file

@ -30,14 +30,13 @@ class EvaluationError(Exception):
pass
async def evaluate(node, triples_client, workspace, collection, limit=10000):
async def evaluate(node, triples_client, collection, limit=10000):
"""
Evaluate a SPARQL algebra node.
Args:
node: rdflib CompValue algebra node
triples_client: TriplesClient instance for triple pattern queries
workspace: workspace/keyspace identifier
collection: collection identifier
limit: safety limit on results
@ -55,24 +54,24 @@ async def evaluate(node, triples_client, workspace, collection, limit=10000):
logger.warning(f"Unsupported algebra node: {name}")
return [{}]
return await handler(node, triples_client, workspace, collection, limit)
return await handler(node, triples_client, collection, limit)
# --- Node handlers ---
async def _eval_select_query(node, tc, workspace, collection, limit):
async def _eval_select_query(node, tc, collection, limit):
"""Evaluate a SelectQuery node."""
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
async def _eval_project(node, tc, workspace, collection, limit):
async def _eval_project(node, tc, collection, limit):
"""Evaluate a Project node (SELECT variable projection)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
variables = [str(v) for v in node.PV]
return project(solutions, variables)
async def _eval_bgp(node, tc, workspace, collection, limit):
async def _eval_bgp(node, tc, collection, limit):
"""
Evaluate a Basic Graph Pattern.
@ -107,7 +106,7 @@ async def _eval_bgp(node, tc, workspace, collection, limit):
# Query the triples store
results = await _query_pattern(
tc, s_val, p_val, o_val, workspace, collection, limit
tc, s_val, p_val, o_val, collection, limit
)
# Map results back to variable bindings,
@ -130,17 +129,17 @@ async def _eval_bgp(node, tc, workspace, collection, limit):
return solutions[:limit]
async def _eval_join(node, tc, workspace, collection, limit):
async def _eval_join(node, tc, collection, limit):
"""Evaluate a Join node."""
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
left = await evaluate(node.p1, tc, collection, limit)
right = await evaluate(node.p2, tc, collection, limit)
return hash_join(left, right)[:limit]
async def _eval_left_join(node, tc, workspace, collection, limit):
async def _eval_left_join(node, tc, collection, limit):
"""Evaluate a LeftJoin node (OPTIONAL)."""
left_sols = await evaluate(node.p1, tc, workspace, collection, limit)
right_sols = await evaluate(node.p2, tc, workspace, collection, limit)
left_sols = await evaluate(node.p1, tc, collection, limit)
right_sols = await evaluate(node.p2, tc, collection, limit)
filter_fn = None
if hasattr(node, "expr") and node.expr is not None:
@ -153,16 +152,16 @@ async def _eval_left_join(node, tc, workspace, collection, limit):
return left_join(left_sols, right_sols, filter_fn)[:limit]
async def _eval_union(node, tc, workspace, collection, limit):
async def _eval_union(node, tc, collection, limit):
"""Evaluate a Union node."""
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
left = await evaluate(node.p1, tc, collection, limit)
right = await evaluate(node.p2, tc, collection, limit)
return union(left, right)[:limit]
async def _eval_filter(node, tc, workspace, collection, limit):
async def _eval_filter(node, tc, collection, limit):
"""Evaluate a Filter node."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
expr = node.expr
return [
sol for sol in solutions
@ -170,22 +169,22 @@ async def _eval_filter(node, tc, workspace, collection, limit):
]
async def _eval_distinct(node, tc, workspace, collection, limit):
async def _eval_distinct(node, tc, collection, limit):
"""Evaluate a Distinct node."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
return distinct(solutions)
async def _eval_reduced(node, tc, workspace, collection, limit):
async def _eval_reduced(node, tc, collection, limit):
"""Evaluate a Reduced node (like Distinct but implementation-defined)."""
# Treat same as Distinct
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
return distinct(solutions)
async def _eval_order_by(node, tc, workspace, collection, limit):
async def _eval_order_by(node, tc, collection, limit):
"""Evaluate an OrderBy node."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
key_fns = []
for cond in node.expr:
@ -206,7 +205,7 @@ async def _eval_order_by(node, tc, workspace, collection, limit):
return order_by(solutions, key_fns)
async def _eval_slice(node, tc, workspace, collection, limit):
async def _eval_slice(node, tc, collection, limit):
"""Evaluate a Slice node (LIMIT/OFFSET)."""
# Pass tighter limit downstream if possible
inner_limit = limit
@ -214,13 +213,13 @@ async def _eval_slice(node, tc, workspace, collection, limit):
offset = node.start or 0
inner_limit = min(limit, offset + node.length)
solutions = await evaluate(node.p, tc, workspace, collection, inner_limit)
solutions = await evaluate(node.p, tc, collection, inner_limit)
return slice_solutions(solutions, node.start or 0, node.length)
async def _eval_extend(node, tc, workspace, collection, limit):
async def _eval_extend(node, tc, collection, limit):
"""Evaluate an Extend node (BIND)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
var_name = str(node.var)
expr = node.expr
@ -246,9 +245,9 @@ async def _eval_extend(node, tc, workspace, collection, limit):
return result
async def _eval_group(node, tc, workspace, collection, limit):
async def _eval_group(node, tc, collection, limit):
"""Evaluate a Group node (GROUP BY with aggregation)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
# Extract grouping expressions
group_exprs = []
@ -289,9 +288,9 @@ async def _eval_group(node, tc, workspace, collection, limit):
return result
async def _eval_aggregate_join(node, tc, workspace, collection, limit):
async def _eval_aggregate_join(node, tc, collection, limit):
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
result = []
for sol in solutions:
@ -310,7 +309,7 @@ async def _eval_aggregate_join(node, tc, workspace, collection, limit):
return result
async def _eval_graph(node, tc, workspace, collection, limit):
async def _eval_graph(node, tc, collection, limit):
"""Evaluate a Graph node (GRAPH clause)."""
term = node.term
@ -319,16 +318,16 @@ async def _eval_graph(node, tc, workspace, collection, limit):
# We'd need to pass graph to triples queries
# For now, evaluate inner pattern normally
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
elif isinstance(term, Variable):
# GRAPH ?g { ... } — variable graph
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
else:
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
async def _eval_values(node, tc, workspace, collection, limit):
async def _eval_values(node, tc, collection, limit):
"""Evaluate a VALUES clause (inline data)."""
variables = [str(v) for v in node.var]
solutions = []
@ -343,9 +342,9 @@ async def _eval_values(node, tc, workspace, collection, limit):
return solutions
async def _eval_to_multiset(node, tc, workspace, collection, limit):
async def _eval_to_multiset(node, tc, collection, limit):
"""Evaluate a ToMultiSet node (subquery)."""
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
# --- Aggregate computation ---
@ -487,7 +486,7 @@ def _resolve_term(tmpl, solution):
return rdflib_term_to_term(tmpl)
async def _query_pattern(tc, s, p, o, workspace, collection, limit):
async def _query_pattern(tc, s, p, o, collection, limit):
"""
Issue a streaming triple pattern query via TriplesClient.
@ -496,7 +495,6 @@ async def _query_pattern(tc, s, p, o, workspace, collection, limit):
results = await tc.query(
s=s, p=p, o=o,
limit=limit,
workspace=workspace,
collection=collection,
)
return results

View file

@ -141,7 +141,6 @@ class Processor(FlowProcessor):
solutions = await evaluate(
parsed.algebra,
triples_client,
workspace=flow.workspace,
collection=request.collection or "default",
limit=request.limit or 10000,
)

View file

@ -6,8 +6,8 @@ null. Output is a list of quads.
import asyncio
import logging
import json
from cassandra.query import SimpleStatement
from .... direct.cassandra_kg import (
@ -176,45 +176,42 @@ class Processor(TriplesQueryService):
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
self.table = None
def ensure_connection(self, workspace):
"""Ensure we have a connection to the correct keyspace."""
if workspace != self.table:
KGClass = EntityCentricKnowledgeGraph
self._connections = {}
self._conn_lock = asyncio.Lock()
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
)
self.table = workspace
async def _get_connection(self, workspace):
async with self._conn_lock:
if workspace not in self._connections:
if self.cassandra_username and self.cassandra_password:
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
)
self._connections[workspace] = tg
return self._connections[workspace]
async def query_triples(self, workspace, query):
try:
# ensure_connection may construct a fresh
# EntityCentricKnowledgeGraph which does sync schema
# setup against Cassandra. Push it to a worker thread
# so the event loop doesn't block on first-use per workspace.
await asyncio.to_thread(self.ensure_connection, workspace)
# Extract values from query
s_val = get_term_value(query.s)
p_val = get_term_value(query.p)
o_val = get_term_value(query.o)
g_val = query.g # Already a string or None
g_val = query.g
tg = await self._get_connection(workspace)
def get_object_metadata(row):
"""Extract term type metadata from result row"""
return (
getattr(row, 'otype', None),
getattr(row, 'dtype', None),
@ -223,33 +220,21 @@ class Processor(TriplesQueryService):
quads = []
# All self.tg.get_* calls below are sync wrappers around
# cassandra session.execute. Materialise inside a worker
# thread so iteration never triggers sync paging back on
# the event loop.
# Route to appropriate query method based on which fields are specified
if s_val is not None:
if p_val is not None:
if o_val is not None:
# SPO specified - find matching graphs
resp = await asyncio.to_thread(
lambda: list(self.tg.get_spo(
query.collection, s_val, p_val, o_val,
g=g_val, limit=query.limit,
))
resp = await tg.async_get_spo(
query.collection, s_val, p_val, o_val,
g=g_val, limit=query.limit,
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
else:
# SP specified
resp = await asyncio.to_thread(
lambda: list(self.tg.get_sp(
query.collection, s_val, p_val,
g=g_val, limit=query.limit,
))
resp = await tg.async_get_sp(
query.collection, s_val, p_val,
g=g_val, limit=query.limit,
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -257,24 +242,18 @@ class Processor(TriplesQueryService):
quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
else:
if o_val is not None:
# SO specified
resp = await asyncio.to_thread(
lambda: list(self.tg.get_os(
query.collection, o_val, s_val,
g=g_val, limit=query.limit,
))
resp = await tg.async_get_os(
query.collection, o_val, s_val,
g=g_val, limit=query.limit,
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
else:
# S only
resp = await asyncio.to_thread(
lambda: list(self.tg.get_s(
query.collection, s_val,
g=g_val, limit=query.limit,
))
resp = await tg.async_get_s(
query.collection, s_val,
g=g_val, limit=query.limit,
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -283,24 +262,18 @@ class Processor(TriplesQueryService):
else:
if p_val is not None:
if o_val is not None:
# PO specified
resp = await asyncio.to_thread(
lambda: list(self.tg.get_po(
query.collection, p_val, o_val,
g=g_val, limit=query.limit,
))
resp = await tg.async_get_po(
query.collection, p_val, o_val,
g=g_val, limit=query.limit,
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
else:
# P only
resp = await asyncio.to_thread(
lambda: list(self.tg.get_p(
query.collection, p_val,
g=g_val, limit=query.limit,
))
resp = await tg.async_get_p(
query.collection, p_val,
g=g_val, limit=query.limit,
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -308,40 +281,26 @@ class Processor(TriplesQueryService):
quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
else:
if o_val is not None:
# O only
resp = await asyncio.to_thread(
lambda: list(self.tg.get_o(
query.collection, o_val,
g=g_val, limit=query.limit,
))
resp = await tg.async_get_o(
query.collection, o_val,
g=g_val, limit=query.limit,
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
else:
# Nothing specified - get all
resp = await asyncio.to_thread(
lambda: list(self.tg.get_all(
query.collection, limit=query.limit,
))
resp = await tg.async_get_all(
query.collection, limit=query.limit,
)
for t in resp:
# Note: quads_by_collection uses 'd' for graph field
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
# Filter by graph
# g_val=None means all graphs (no filter)
# g_val="" means default graph only
# otherwise filter to specific named graph
if g_val is not None:
if g != g_val:
continue
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
# Convert to Triple objects (with g field)
# s and p are always IRIs in RDF
# Object uses term_type/datatype/language metadata from database
triples = [
Triple(
s=create_term(q[0], term_type='u'),
@ -365,51 +324,41 @@ class Processor(TriplesQueryService):
Uses Cassandra's paging to fetch results incrementally.
"""
try:
await asyncio.to_thread(self.ensure_connection, workspace)
batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000
# Extract query pattern
s_val = get_term_value(query.s)
p_val = get_term_value(query.p)
o_val = get_term_value(query.o)
g_val = query.g
def get_object_metadata(row):
"""Extract term type metadata from result row"""
return (
getattr(row, 'otype', None),
getattr(row, 'dtype', None),
getattr(row, 'lang', None),
)
# For streaming, we need to execute with fetch_size
# Use the collection table for get_all queries (most common streaming case)
# Determine which query to use based on pattern
if s_val is None and p_val is None and o_val is None:
# Get all - use collection table with paging
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
tg = await self._get_connection(workspace)
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s"
params = [query.collection]
statement = SimpleStatement(cql, fetch_size=batch_size)
# async_execute only materialises the first page;
# this query needs all pages, so use sync execute
# in a worker thread where page iteration can block.
result_set = await asyncio.to_thread(
lambda: list(tg.session.execute(statement, params))
)
else:
# For specific patterns, fall back to non-streaming
# (these typically return small result sets anyway)
async for batch, is_final in self._fallback_stream(workspace, query, batch_size):
yield batch, is_final
return
# Materialise in a worker thread. We lose true streaming
# paging (the driver fetches all pages eagerly inside the
# thread) but the event loop stays responsive, and result
# sets at this layer are typically small enough that this
# is acceptable. If true async paging is needed later,
# revisit using ResponseFuture page callbacks.
statement = SimpleStatement(cql, fetch_size=batch_size)
result_set = await asyncio.to_thread(
lambda: list(self.tg.session.execute(statement, params))
)
batch = []
count = 0

View file

@ -3,11 +3,13 @@
Accepts entity/vector pairs and writes them to a Qdrant store.
"""
import asyncio
import uuid
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
import logging
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
@ -35,13 +37,35 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def ensure_collection(self, collection_name, dim):
async with self._cache_lock:
if collection_name in self._known_collections:
return
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection_name
)
if not exists:
logger.info(
f"Lazily creating Qdrant collection {collection_name} "
f"with dimension {dim}"
)
await asyncio.to_thread(
self.qdrant.create_collection,
collection_name=collection_name,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
self._known_collections.add(collection_name)
async def store_document_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for workspace {workspace} "
@ -60,24 +84,15 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{workspace}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
await self.ensure_collection(collection, dim)
self.qdrant.upsert(
await asyncio.to_thread(
self.qdrant.upsert,
collection_name=collection,
points=[
PointStruct(
@ -87,7 +102,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
"chunk_id": chunk_id,
}
)
]
],
)
@staticmethod
@ -124,8 +139,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
try:
prefix = f"d_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
@ -135,7 +151,11 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
logger.info(f"No collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")

View file

@ -3,11 +3,13 @@
Accepts entity/vector pairs and writes them to a Qdrant store.
"""
import asyncio
import uuid
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
import logging
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
@ -50,13 +52,35 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def ensure_collection(self, collection_name, dim):
async with self._cache_lock:
if collection_name in self._known_collections:
return
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection_name
)
if not exists:
logger.info(
f"Lazily creating Qdrant collection {collection_name} "
f"with dimension {dim}"
)
await asyncio.to_thread(
self.qdrant.create_collection,
collection_name=collection_name,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
self._known_collections.add(collection_name)
async def store_graph_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for workspace {workspace} "
@ -75,22 +99,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{workspace}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
await self.ensure_collection(collection, dim)
payload = {
"entity": entity_value,
@ -98,7 +112,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
await asyncio.to_thread(
self.qdrant.upsert,
collection_name=collection,
points=[
PointStruct(
@ -106,7 +121,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
vector=vec,
payload=payload,
)
]
],
)
@staticmethod
@ -143,8 +158,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
try:
prefix = f"t_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
@ -154,7 +170,11 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
logger.info(f"No collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")

View file

@ -16,10 +16,10 @@ Payload structure:
- text: The text that was embedded (for debugging/display)
"""
import asyncio
import logging
import re
import uuid
from typing import Set, Tuple
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams
@ -63,11 +63,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Register config handler for collection management
self.register_config_handler(self.on_collection_config, types=["collection"])
# Cache of created Qdrant collections
self.created_collections: Set[str] = set()
# Qdrant client
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
@ -85,25 +83,28 @@ class Processor(CollectionConfigHandler, FlowProcessor):
safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
def ensure_collection(self, collection_name: str, dimension: int):
async def ensure_collection(self, collection_name: str, dimension: int):
"""Create Qdrant collection if it doesn't exist"""
if collection_name in self.created_collections:
return
if not self.qdrant.collection_exists(collection_name):
logger.info(
f"Creating Qdrant collection {collection_name} "
f"with dimension {dimension}"
async with self._cache_lock:
if collection_name in self._known_collections:
return
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection_name
)
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=dimension,
distance=Distance.COSINE
if not exists:
logger.info(
f"Creating Qdrant collection {collection_name} "
f"with dimension {dimension}"
)
)
self.created_collections.add(collection_name)
await asyncio.to_thread(
self.qdrant.create_collection,
collection_name=collection_name,
vectors_config=VectorParams(
size=dimension,
distance=Distance.COSINE
),
)
self._known_collections.add(collection_name)
async def on_embeddings(self, msg, consumer, flow):
"""Process incoming RowEmbeddings and write to Qdrant"""
@ -143,15 +144,14 @@ class Processor(CollectionConfigHandler, FlowProcessor):
dimension = len(vector)
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
workspace, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
await self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
await asyncio.to_thread(
self.qdrant.upsert,
collection_name=qdrant_collection,
points=[
PointStruct(
@ -163,7 +163,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"text": row_emb.text
}
)
]
],
)
embeddings_written += 1
@ -181,8 +181,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
try:
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
@ -192,8 +193,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(
f"Deleted {len(matching_collections)} collection(s) "
@ -217,8 +221,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
@ -228,8 +233,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
except Exception as e:

View file

@ -82,7 +82,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Cache of known keyspaces and whether tables exist
self.known_keyspaces: Set[str] = set()
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
self.tables_initialized: Set[str] = set()
# Cache of registered (collection, schema_name) pairs
self.registered_partitions: Set[Tuple[str, str]] = set()
@ -94,6 +94,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
self.cluster = None
self.session = None
# Protects connection setup and cache mutations
self._setup_lock = asyncio.Lock()
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
@ -126,6 +129,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"for workspace {workspace}"
)
async with self._setup_lock:
return await self._apply_schema_config(workspace, config, version)
async def _apply_schema_config(self, workspace, config, version):
# Track which schemas changed in this workspace
old_schemas = self.schemas.get(workspace, {})
old_schema_names = set(old_schemas.keys())
@ -391,16 +399,12 @@ class Processor(CollectionConfigHandler, FlowProcessor):
schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or ''
# Ensure tables exist (sync DDL — push to a worker thread
# so the event loop stays responsive when running in a
# processor group sharing the loop with siblings).
await asyncio.to_thread(self.ensure_tables, keyspace)
# Register partitions if first time seeing this (collection, schema_name)
await asyncio.to_thread(
self.register_partitions,
keyspace, collection, schema_name, workspace,
)
async with self._setup_lock:
await asyncio.to_thread(self.ensure_tables, keyspace)
await asyncio.to_thread(
self.register_partitions,
keyspace, collection, schema_name, workspace,
)
safe_keyspace = self.sanitize_name(keyspace)
@ -461,35 +465,27 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
# Connect if not already connected (sync, push to thread)
await asyncio.to_thread(self.connect_cassandra)
# Ensure tables exist (sync DDL, push to thread)
await asyncio.to_thread(self.ensure_tables, workspace)
async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra)
await asyncio.to_thread(self.ensure_tables, workspace)
logger.info(f"Collection {collection} ready for workspace {workspace}")
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
# Connect if not already connected
await asyncio.to_thread(self.connect_cassandra)
async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra)
if workspace not in self.known_keyspaces:
safe_ks = self.sanitize_name(workspace)
check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s"
result = await async_execute(self.session, check_cql, (safe_ks,))
if not result:
logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete")
return
self.known_keyspaces.add(workspace)
safe_keyspace = self.sanitize_name(workspace)
# Check if keyspace exists
if workspace not in self.known_keyspaces:
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = await async_execute(
self.session, check_keyspace_cql, (safe_keyspace,)
)
if not result:
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(workspace)
# Discover all partitions for this collection
select_partitions_cql = f"""
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
@ -540,11 +536,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
raise
# Clear from local cache
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if col != collection
}
async with self._setup_lock:
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if col != collection
}
logger.info(
f"Deleted collection {collection}: {partitions_deleted} partitions "
@ -553,8 +549,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
# Connect if not already connected
await asyncio.to_thread(self.connect_cassandra)
async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(workspace)
@ -614,8 +610,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
)
raise
# Clear from local cache
self.registered_partitions.discard((collection, schema_name))
async with self._setup_lock:
self.registered_partitions.discard((collection, schema_name))
logger.info(
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "

View file

@ -4,12 +4,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph.
"""
import asyncio
import base64
import os
import argparse
import time
import logging
import json
from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
@ -28,6 +23,8 @@ default_ident = "triples-write"
def serialize_triple(triple):
"""Serialize a Triple object to JSON for storage."""
import json
if triple is None:
return None
@ -141,156 +138,84 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
self.table = None
self.tg = None
self._connections = {}
self._conn_lock = asyncio.Lock()
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def _get_connection(self, workspace):
async with self._conn_lock:
if workspace not in self._connections:
if self.cassandra_username and self.cassandra_password:
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
)
self._connections[workspace] = tg
return self._connections[workspace]
async def store_triples(self, workspace, message):
# The cassandra-driver work below — connection, schema
# setup, and per-triple inserts — is all synchronous.
# Wrap the whole batch in a worker thread so the event
# loop stays responsive for sibling processors when
# running in a processor group.
tg = await self._get_connection(workspace)
def _do_store():
for t in message.triples:
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
g_val = t.g if t.g is not None else DEFAULT_GRAPH
if self.table is None or self.table != workspace:
otype = get_term_otype(t.o)
dtype = get_term_dtype(t.o)
lang = get_term_lang(t.o)
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Exception: {e}", exc_info=True)
time.sleep(1)
raise e
self.table = workspace
for t in message.triples:
# Extract values from Term objects
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
# t.g is None for default graph, or a graph IRI
g_val = t.g if t.g is not None else DEFAULT_GRAPH
# Extract object type metadata for entity-centric storage
otype = get_term_otype(t.o)
dtype = get_term_dtype(t.o)
lang = get_term_lang(t.o)
self.tg.insert(
message.metadata.collection,
s_val,
p_val,
o_val,
g=g_val,
otype=otype,
dtype=dtype,
lang=lang,
)
await asyncio.to_thread(_do_store)
await tg.async_insert(
message.metadata.collection,
s_val,
p_val,
o_val,
g=g_val,
otype=otype,
dtype=dtype,
lang=lang,
)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create a collection in Cassandra triple store via config push"""
try:
tg = await self._get_connection(workspace)
def _do_create():
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = workspace
# Create collection using the built-in method
logger.info(f"Creating collection {collection} for workspace {workspace}")
if self.tg.collection_exists(collection):
exists = await tg.async_collection_exists(collection)
if exists:
logger.info(f"Collection {collection} already exists")
else:
self.tg.create_collection(collection)
await tg.async_create_collection(collection)
logger.info(f"Created collection {collection}")
try:
await asyncio.to_thread(_do_create)
except Exception as e:
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection from the unified triples table"""
try:
tg = await self._get_connection(workspace)
def _do_delete():
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = workspace
# Delete all triples for this collection using the built-in method
self.tg.delete_collection(collection)
await tg.async_delete_collection(collection)
logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
try:
await asyncio.to_thread(_do_delete)
except Exception as e:
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise

View file

@ -1,6 +1,7 @@
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
from .. schema import DocumentEmbeddings, ChunkEmbeddings
from cassandra.cluster import Cluster
@ -217,6 +218,16 @@ class KnowledgeTableStore:
WHERE workspace = ? AND document_id = ?
""")
self.delete_document_embeddings_stmt = self.cassandra.prepare("""
DELETE FROM document_embeddings
WHERE workspace = ? AND document_id = ?
""")
self.list_de_cores_stmt = self.cassandra.prepare("""
SELECT DISTINCT workspace, document_id FROM document_embeddings
WHERE workspace = ?
""")
async def add_triples(self, workspace, m):
when = int(time.time() * 1000)
@ -338,6 +349,50 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
try:
await async_execute(
self.cassandra,
self.delete_document_embeddings_stmt,
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def delete_document_embeddings(self, workspace, document_id):
logger.debug("Delete document embeddings...")
try:
await async_execute(
self.cassandra,
self.delete_document_embeddings_stmt,
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def list_de_cores(self, workspace):
logger.debug("List DE cores...")
try:
rows = await async_execute(
self.cassandra,
self.list_de_cores_stmt,
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
lst = [row[1] for row in rows]
logger.debug("Done")
return lst
async def get_triples(self, workspace, document_id, receiver):
logger.debug("Get triples...")
@ -417,3 +472,42 @@ class KnowledgeTableStore:
logger.debug("Done")
async def get_document_embeddings(self, workspace, document_id, receiver):
logger.debug("Get DE...")
try:
rows = await async_execute(
self.cassandra,
self.get_document_embeddings_stmt,
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
for row in rows:
if row[3]:
chunks = [
ChunkEmbeddings(
chunk_id=ch[0],
vector=ch[1],
)
for ch in row[3]
]
else:
chunks = []
await receiver(
DocumentEmbeddings(
metadata = Metadata(
id = document_id,
collection = "default",
),
chunks = chunks
)
)
logger.debug("Done")