trustgraph/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
cybermaggedon a2dde9cafb
Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916)
Cassandra triples services were using syncronous EntityCentricKnowledgeGraph
methods from async contexts, and connection state was managed with
threading.local which is wrong for asyncio coroutines sharing a single
thread. Qdrant services had no async wrapping at all, blocking the event
loop on every network call. Rows services had unprotected shared state
mutations across concurrent coroutines.

- Add async methods to EntityCentricKnowledgeGraph (async_insert,
  async_get_s/p/o/sp/po/os/spo/all, async_collection_exists,
  async_create_collection, async_delete_collection) using the existing
  cassandra_async.async_execute bridge
- Rewrite triples write + query services: replace threading.local with
  asyncio.Lock + dict cache for per-workspace connections, use async
  ECKG methods for all data operations, keep asyncio.to_thread only for
  one-time blocking ECKG construction
- Wrap all Qdrant calls in asyncio.to_thread across all 6 services
  (doc/graph/row embeddings write + query), add asyncio.Lock + set cache
  for collection existence checks
- Add asyncio.Lock to rows write + query services to protect shared
  state (schemas, sessions, config caches) from concurrent mutation
- Update all affected tests to match new async patterns
2026-05-14 16:00:54 +01:00

272 lines
9.5 KiB
Python

"""
Row embeddings writer for Qdrant (Stage 2).
Consumes RowEmbeddings messages (which already contain computed vectors)
and writes them to Qdrant. One Qdrant collection per (workspace, collection, schema_name) pair.
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (row-embeddings): Compute embeddings
Stage 2 (this processor): Store embeddings
Collection naming: rows_{workspace}_{collection}_{schema_name}_{dimension}
Payload structure:
- index_name: The indexed field(s) this embedding represents
- index_value: The original list of values (for Cassandra lookup)
- text: The text that was embedded (for debugging/display)
"""
import asyncio
import logging
import re
import uuid
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams
from .... schema import RowEmbeddings
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.register_specification(
ConsumerSpec(
name="input",
schema=RowEmbeddings,
handler=self.on_embeddings
)
)
# Register config handler for collection management
self.register_config_handler(self.on_collection_config, types=["collection"])
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def get_collection_name(
self, workspace: str, collection: str, schema_name: str, dimension: int
) -> str:
"""Generate Qdrant collection name"""
safe_user = self.sanitize_name(workspace)
safe_collection = self.sanitize_name(collection)
safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
async def ensure_collection(self, collection_name: str, dimension: int):
"""Create Qdrant collection if it doesn't exist"""
async with self._cache_lock:
if collection_name in self._known_collections:
return
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection_name
)
if not exists:
logger.info(
f"Creating Qdrant collection {collection_name} "
f"with dimension {dimension}"
)
await asyncio.to_thread(
self.qdrant.create_collection,
collection_name=collection_name,
vectors_config=VectorParams(
size=dimension,
distance=Distance.COSINE
),
)
self._known_collections.add(collection_name)
async def on_embeddings(self, msg, consumer, flow):
"""Process incoming RowEmbeddings and write to Qdrant"""
embeddings = msg.value()
logger.info(
f"Writing {len(embeddings.embeddings)} embeddings for schema "
f"{embeddings.schema_name} from {embeddings.metadata.id}"
)
workspace = flow.workspace
# Validate collection exists in config before processing
if not self.collection_exists(
workspace, embeddings.metadata.collection
):
logger.warning(
f"Collection {embeddings.metadata.collection} for workspace "
f"{workspace} does not exist in config. "
f"Dropping message."
)
return
collection = embeddings.metadata.collection
schema_name = embeddings.schema_name
embeddings_written = 0
qdrant_collection = None
for row_emb in embeddings.embeddings:
vector = row_emb.vector
if not vector:
logger.warning(
f"No vector for index {row_emb.index_name} - skipping"
)
continue
dimension = len(vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
workspace, collection, schema_name, dimension
)
await self.ensure_collection(qdrant_collection, dimension)
await asyncio.to_thread(
self.qdrant.upsert,
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
],
)
embeddings_written += 1
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Collection creation via config push - collections created lazily on first write"""
logger.info(
f"Row embeddings collection create request for {workspace}/{collection} - "
f"will be created lazily on first write"
)
async def delete_collection(self, workspace: str, collection: str):
"""Delete all Qdrant collections for a given workspace/collection"""
try:
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(
f"Deleted {len(matching_collections)} collection(s) "
f"for {workspace}/{collection}"
)
except Exception as e:
logger.error(
f"Failed to delete collection {workspace}/{collection}: {e}",
exc_info=True
)
raise
async def delete_collection_schema(
self, workspace: str, collection: str, schema_name: str
):
"""Delete Qdrant collection for a specific workspace/collection/schema"""
try:
prefix = (
f"rows_{self.sanitize_name(workspace)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
except Exception as e:
logger.error(
f"Failed to delete collection {workspace}/{collection}/{schema_name}: {e}",
exc_info=True
)
raise
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='Qdrant API key (default: None)'
)
def run():
"""Entry point for row-embeddings-write-qdrant command"""
Processor.launch(default_ident, __doc__)