feat: data store replication configuration and TLS upgrade

- Add centralised qdrant_config.py helper with env-var fallback for
  QDRANT_URL, QDRANT_API_KEY, QDRANT_REPLICATION_FACTOR, QDRANT_SHARD_NUMBER
- Update all 6 Qdrant processors to use the helper; writers pass
  replication_factor and shard_number to create_collection
- Fix hardcoded Cassandra replication_factor=1 in cassandra_kg.py,
  write.py, and sparql_cassandra.py to respect CASSANDRA_REPLICATION_FACTOR
- Upgrade Cassandra TLS from deprecated PROTOCOL_TLSv1_2 to
  ssl.create_default_context() across all connectors
This commit is contained in:
Cyber MacGeddon 2026-06-04 11:47:23 +01:00
parent acf182c265
commit 80cffd71dc
15 changed files with 182 additions and 129 deletions

View file

@ -259,6 +259,8 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
proc.replication_factor = 1
proc.shard_number = 1
msg = MagicMock()
msg.metadata.collection = "graphs"

View file

@ -0,0 +1,87 @@
import os
import argparse
from typing import Optional, Any, Tuple
def get_qdrant_defaults() -> dict:
return {
'url': os.getenv('QDRANT_URL', 'http://localhost:6333'),
'api_key': os.getenv('QDRANT_API_KEY'),
'replication_factor': int(os.getenv('QDRANT_REPLICATION_FACTOR', '1')),
'shard_number': int(os.getenv('QDRANT_SHARD_NUMBER', '1')),
}
def add_qdrant_args(parser: argparse.ArgumentParser) -> None:
defaults = get_qdrant_defaults()
url_help = f"Qdrant URL (default: {defaults['url']})"
if 'QDRANT_URL' in os.environ:
url_help += " [from QDRANT_URL]"
api_key_help = "Qdrant API key"
if defaults['api_key']:
api_key_help += " (default: <set>)"
if 'QDRANT_API_KEY' in os.environ:
api_key_help += " [from QDRANT_API_KEY]"
replication_help = f"Qdrant collection replication factor (default: {defaults['replication_factor']})"
if 'QDRANT_REPLICATION_FACTOR' in os.environ:
replication_help += " [from QDRANT_REPLICATION_FACTOR]"
shard_help = f"Qdrant collection shard number (default: {defaults['shard_number']})"
if 'QDRANT_SHARD_NUMBER' in os.environ:
shard_help += " [from QDRANT_SHARD_NUMBER]"
parser.add_argument(
'--store-uri',
default=defaults['url'],
help=url_help,
)
parser.add_argument(
'--api-key',
default=defaults['api_key'],
help=api_key_help,
)
parser.add_argument(
'--qdrant-replication-factor',
type=int,
default=defaults['replication_factor'],
help=replication_help,
)
parser.add_argument(
'--qdrant-shard-number',
type=int,
default=defaults['shard_number'],
help=shard_help,
)
def resolve_qdrant_config(
args: Optional[Any] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
replication_factor: Optional[int] = None,
shard_number: Optional[int] = None,
) -> Tuple[str, Optional[str], int, int]:
if args is not None:
url = url or getattr(args, 'store_uri', None)
api_key = api_key or getattr(args, 'api_key', None)
replication_factor = replication_factor or getattr(
args, 'qdrant_replication_factor', None
)
shard_number = shard_number or getattr(
args, 'qdrant_shard_number', None
)
defaults = get_qdrant_defaults()
url = url or defaults['url']
api_key = api_key or defaults['api_key']
replication_factor = replication_factor or defaults['replication_factor']
shard_number = shard_number or defaults['shard_number']
return url, api_key, replication_factor, shard_number

View file

@ -6,7 +6,7 @@ import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
from ssl import SSLContext, PROTOCOL_TLSv1_2
import ssl
from ..tables.cassandra_async import async_execute
@ -41,13 +41,15 @@ class KnowledgeGraph:
def __init__(
self, hosts=None,
keyspace="trustgraph", username=None, password=None
keyspace="trustgraph", username=None, password=None,
replication_factor=1,
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.replication_factor = replication_factor
self.username = username
# 7-table schema for quads with full query pattern support
@ -68,7 +70,7 @@ class KnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@ -92,7 +94,7 @@ class KnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
'replication_factor' : {self.replication_factor}
}};
""")
@ -539,13 +541,15 @@ class EntityCentricKnowledgeGraph:
def __init__(
self, hosts=None,
keyspace="trustgraph", username=None, password=None
keyspace="trustgraph", username=None, password=None,
replication_factor=1,
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.replication_factor = replication_factor
self.username = username
# 2-table entity-centric schema
@ -556,7 +560,7 @@ class EntityCentricKnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@ -580,7 +584,7 @@ class EntityCentricKnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
'replication_factor' : {self.replication_factor}
}};
""")

View file

@ -12,31 +12,32 @@ from qdrant_client import QdrantClient
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-query"
default_store_uri = 'http://localhost:6333'
class Processor(DocumentEmbeddingsQueryService):
def __init__(self, **params):
store_uri = params.get("store_uri", default_store_uri)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
#optional api key
api_key = params.get("api_key", None)
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
super(Processor, self).__init__(
**params | {
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)
async def query_document_embeddings(self, workspace, msg):
@ -85,18 +86,7 @@ class Processor(DocumentEmbeddingsQueryService):
def add_args(parser):
DocumentEmbeddingsQueryService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'API key for qdrant (default: None)'
)
add_qdrant_args(parser)
def run():

View file

@ -12,31 +12,32 @@ from qdrant_client import QdrantClient
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "graph-embeddings-query"
default_store_uri = 'http://localhost:6333'
class Processor(GraphEmbeddingsQueryService):
def __init__(self, **params):
store_uri = params.get("store_uri", default_store_uri)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
#optional api key
api_key = params.get("api_key", None)
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
super(Processor, self).__init__(
**params | {
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@ -104,18 +105,7 @@ class Processor(GraphEmbeddingsQueryService):
def add_args(parser):
GraphEmbeddingsQueryService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'API key for qdrant (default: None)'
)
add_qdrant_args(parser)
def run():

View file

@ -116,7 +116,7 @@ class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
# Create keyspace
self.session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': {self.cassandra_config.get('replication_factor', 1)}}}
""")
# Create triples table optimized for SPARQL queries

View file

@ -19,12 +19,12 @@ from .... schema import (
RowIndexMatch, Error
)
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
default_store_uri = 'http://localhost:6333'
default_concurrency = 10
@ -35,13 +35,17 @@ class Processor(FlowProcessor):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
@ -62,7 +66,7 @@ class Processor(FlowProcessor):
)
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
@ -192,21 +196,9 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='API key for Qdrant (default: None)'
)
add_qdrant_args(parser)
parser.add_argument(
'-c', '--concurrency',

View file

@ -14,29 +14,34 @@ from qdrant_client.models import Distance, VectorParams
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
def __init__(self, **params):
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
super(Processor, self).__init__(
**params | {
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)
self.replication_factor = replication_factor
self.shard_number = shard_number
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@ -61,6 +66,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
replication_factor=self.replication_factor,
shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@ -109,18 +116,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
def add_args(parser):
DocumentEmbeddingsStoreService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'Qdrant API key (default: None)'
)
add_qdrant_args(parser)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""

View file

@ -14,6 +14,7 @@ from qdrant_client.models import Distance, VectorParams
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
from .... schema import IRI, LITERAL
# Module logger
@ -29,29 +30,32 @@ def get_term_value(term):
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "graph-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
def __init__(self, **params):
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
super(Processor, self).__init__(
**params | {
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)
self.replication_factor = replication_factor
self.shard_number = shard_number
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@ -76,6 +80,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
replication_factor=self.replication_factor,
shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@ -128,18 +134,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
def add_args(parser):
GraphEmbeddingsStoreService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'Qdrant API key'
)
add_qdrant_args(parser)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""

View file

@ -27,12 +27,12 @@ from qdrant_client.models import PointStruct, Distance, VectorParams
from .... schema import RowEmbeddings
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, FlowProcessor):
@ -41,13 +41,17 @@ class Processor(CollectionConfigHandler, FlowProcessor):
id = params.get("id", default_ident)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
@ -63,7 +67,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Register config handler for collection management
self.register_config_handler(self.on_collection_config, types=["collection"])
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)
self.replication_factor = replication_factor
self.shard_number = shard_number
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@ -103,6 +109,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
size=dimension,
distance=Distance.COSINE
),
replication_factor=self.replication_factor,
shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@ -249,21 +257,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='Qdrant API key (default: None)'
)
add_qdrant_args(parser)
def run():

View file

@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace, _ = resolve_cassandra_config(
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
@ -57,6 +57,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
self.replication_factor = replication_factor
# Config key for schemas
self.config_key = params.get("config_type", "schema")
@ -232,7 +233,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
'replication_factor': 1
'replication_factor': {self.replication_factor}
}}
"""

View file

@ -4,7 +4,7 @@ from .. schema import Metadata, GraphEmbeddings
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
import ssl
import uuid
import time
@ -33,7 +33,7 @@ class ConfigTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)

View file

@ -15,7 +15,7 @@ import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
import ssl
from . cassandra_async import async_execute
@ -39,7 +39,7 @@ class IamTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(",")]
if cassandra_username and cassandra_password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password,
)

View file

@ -23,7 +23,7 @@ def tuple_to_term(value, is_uri):
else:
return Term(type=LITERAL, value=value)
from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
import ssl
import uuid
import time
@ -50,7 +50,7 @@ class KnowledgeTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)

View file

@ -24,7 +24,7 @@ from .. exceptions import RequestError
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement
from ssl import SSLContext, PROTOCOL_TLSv1_2
import ssl
import uuid
import time
@ -53,7 +53,7 @@ class LibraryTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)