mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-23 21:58:06 +02:00
feat: data store replication configuration and TLS upgrade
- Add centralised qdrant_config.py helper with env-var fallback for QDRANT_URL, QDRANT_API_KEY, QDRANT_REPLICATION_FACTOR, QDRANT_SHARD_NUMBER - Update all 6 Qdrant processors to use the helper; writers pass replication_factor and shard_number to create_collection - Fix hardcoded Cassandra replication_factor=1 in cassandra_kg.py, write.py, and sparql_cassandra.py to respect CASSANDRA_REPLICATION_FACTOR - Upgrade Cassandra TLS from deprecated PROTOCOL_TLSv1_2 to ssl.create_default_context() across all connectors
This commit is contained in:
parent
acf182c265
commit
80cffd71dc
15 changed files with 182 additions and 129 deletions
|
|
@ -259,6 +259,8 @@ class TestGraphEmbeddingsNullProtection:
|
||||||
proc.collection_exists = MagicMock(return_value=True)
|
proc.collection_exists = MagicMock(return_value=True)
|
||||||
proc._cache_lock = asyncio.Lock()
|
proc._cache_lock = asyncio.Lock()
|
||||||
proc._known_collections = set()
|
proc._known_collections = set()
|
||||||
|
proc.replication_factor = 1
|
||||||
|
proc.shard_number = 1
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.metadata.collection = "graphs"
|
msg.metadata.collection = "graphs"
|
||||||
|
|
|
||||||
87
trustgraph-base/trustgraph/base/qdrant_config.py
Normal file
87
trustgraph-base/trustgraph/base/qdrant_config.py
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from typing import Optional, Any, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def get_qdrant_defaults() -> dict:
|
||||||
|
return {
|
||||||
|
'url': os.getenv('QDRANT_URL', 'http://localhost:6333'),
|
||||||
|
'api_key': os.getenv('QDRANT_API_KEY'),
|
||||||
|
'replication_factor': int(os.getenv('QDRANT_REPLICATION_FACTOR', '1')),
|
||||||
|
'shard_number': int(os.getenv('QDRANT_SHARD_NUMBER', '1')),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_qdrant_args(parser: argparse.ArgumentParser) -> None:
|
||||||
|
defaults = get_qdrant_defaults()
|
||||||
|
|
||||||
|
url_help = f"Qdrant URL (default: {defaults['url']})"
|
||||||
|
if 'QDRANT_URL' in os.environ:
|
||||||
|
url_help += " [from QDRANT_URL]"
|
||||||
|
|
||||||
|
api_key_help = "Qdrant API key"
|
||||||
|
if defaults['api_key']:
|
||||||
|
api_key_help += " (default: <set>)"
|
||||||
|
if 'QDRANT_API_KEY' in os.environ:
|
||||||
|
api_key_help += " [from QDRANT_API_KEY]"
|
||||||
|
|
||||||
|
replication_help = f"Qdrant collection replication factor (default: {defaults['replication_factor']})"
|
||||||
|
if 'QDRANT_REPLICATION_FACTOR' in os.environ:
|
||||||
|
replication_help += " [from QDRANT_REPLICATION_FACTOR]"
|
||||||
|
|
||||||
|
shard_help = f"Qdrant collection shard number (default: {defaults['shard_number']})"
|
||||||
|
if 'QDRANT_SHARD_NUMBER' in os.environ:
|
||||||
|
shard_help += " [from QDRANT_SHARD_NUMBER]"
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--store-uri',
|
||||||
|
default=defaults['url'],
|
||||||
|
help=url_help,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--api-key',
|
||||||
|
default=defaults['api_key'],
|
||||||
|
help=api_key_help,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--qdrant-replication-factor',
|
||||||
|
type=int,
|
||||||
|
default=defaults['replication_factor'],
|
||||||
|
help=replication_help,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--qdrant-shard-number',
|
||||||
|
type=int,
|
||||||
|
default=defaults['shard_number'],
|
||||||
|
help=shard_help,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_qdrant_config(
|
||||||
|
args: Optional[Any] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
replication_factor: Optional[int] = None,
|
||||||
|
shard_number: Optional[int] = None,
|
||||||
|
) -> Tuple[str, Optional[str], int, int]:
|
||||||
|
if args is not None:
|
||||||
|
url = url or getattr(args, 'store_uri', None)
|
||||||
|
api_key = api_key or getattr(args, 'api_key', None)
|
||||||
|
replication_factor = replication_factor or getattr(
|
||||||
|
args, 'qdrant_replication_factor', None
|
||||||
|
)
|
||||||
|
shard_number = shard_number or getattr(
|
||||||
|
args, 'qdrant_shard_number', None
|
||||||
|
)
|
||||||
|
|
||||||
|
defaults = get_qdrant_defaults()
|
||||||
|
url = url or defaults['url']
|
||||||
|
api_key = api_key or defaults['api_key']
|
||||||
|
replication_factor = replication_factor or defaults['replication_factor']
|
||||||
|
shard_number = shard_number or defaults['shard_number']
|
||||||
|
|
||||||
|
return url, api_key, replication_factor, shard_number
|
||||||
|
|
@ -6,7 +6,7 @@ import logging
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from cassandra.query import BatchStatement, SimpleStatement
|
from cassandra.query import BatchStatement, SimpleStatement
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
from ..tables.cassandra_async import async_execute
|
from ..tables.cassandra_async import async_execute
|
||||||
|
|
||||||
|
|
@ -41,13 +41,15 @@ class KnowledgeGraph:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hosts=None,
|
self, hosts=None,
|
||||||
keyspace="trustgraph", username=None, password=None
|
keyspace="trustgraph", username=None, password=None,
|
||||||
|
replication_factor=1,
|
||||||
):
|
):
|
||||||
|
|
||||||
if hosts is None:
|
if hosts is None:
|
||||||
hosts = ["localhost"]
|
hosts = ["localhost"]
|
||||||
|
|
||||||
self.keyspace = keyspace
|
self.keyspace = keyspace
|
||||||
|
self.replication_factor = replication_factor
|
||||||
self.username = username
|
self.username = username
|
||||||
|
|
||||||
# 7-table schema for quads with full query pattern support
|
# 7-table schema for quads with full query pattern support
|
||||||
|
|
@ -68,7 +70,7 @@ class KnowledgeGraph:
|
||||||
self.collection_metadata_table = "collection_metadata"
|
self.collection_metadata_table = "collection_metadata"
|
||||||
|
|
||||||
if username and password:
|
if username and password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
||||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
||||||
else:
|
else:
|
||||||
|
|
@ -92,7 +94,7 @@ class KnowledgeGraph:
|
||||||
create keyspace if not exists {self.keyspace}
|
create keyspace if not exists {self.keyspace}
|
||||||
with replication = {{
|
with replication = {{
|
||||||
'class' : 'SimpleStrategy',
|
'class' : 'SimpleStrategy',
|
||||||
'replication_factor' : 1
|
'replication_factor' : {self.replication_factor}
|
||||||
}};
|
}};
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
@ -539,13 +541,15 @@ class EntityCentricKnowledgeGraph:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hosts=None,
|
self, hosts=None,
|
||||||
keyspace="trustgraph", username=None, password=None
|
keyspace="trustgraph", username=None, password=None,
|
||||||
|
replication_factor=1,
|
||||||
):
|
):
|
||||||
|
|
||||||
if hosts is None:
|
if hosts is None:
|
||||||
hosts = ["localhost"]
|
hosts = ["localhost"]
|
||||||
|
|
||||||
self.keyspace = keyspace
|
self.keyspace = keyspace
|
||||||
|
self.replication_factor = replication_factor
|
||||||
self.username = username
|
self.username = username
|
||||||
|
|
||||||
# 2-table entity-centric schema
|
# 2-table entity-centric schema
|
||||||
|
|
@ -556,7 +560,7 @@ class EntityCentricKnowledgeGraph:
|
||||||
self.collection_metadata_table = "collection_metadata"
|
self.collection_metadata_table = "collection_metadata"
|
||||||
|
|
||||||
if username and password:
|
if username and password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
||||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
||||||
else:
|
else:
|
||||||
|
|
@ -580,7 +584,7 @@ class EntityCentricKnowledgeGraph:
|
||||||
create keyspace if not exists {self.keyspace}
|
create keyspace if not exists {self.keyspace}
|
||||||
with replication = {{
|
with replication = {{
|
||||||
'class' : 'SimpleStrategy',
|
'class' : 'SimpleStrategy',
|
||||||
'replication_factor' : 1
|
'replication_factor' : {self.replication_factor}
|
||||||
}};
|
}};
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,31 +12,32 @@ from qdrant_client import QdrantClient
|
||||||
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||||
from .... schema import Error
|
from .... schema import Error
|
||||||
from .... base import DocumentEmbeddingsQueryService
|
from .... base import DocumentEmbeddingsQueryService
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "doc-embeddings-query"
|
default_ident = "doc-embeddings-query"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
class Processor(DocumentEmbeddingsQueryService):
|
class Processor(DocumentEmbeddingsQueryService):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
#optional api key
|
url, api_key, _, _ = resolve_qdrant_config(
|
||||||
api_key = params.get("api_key", None)
|
url=store_uri, api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
|
||||||
async def query_document_embeddings(self, workspace, msg):
|
async def query_document_embeddings(self, workspace, msg):
|
||||||
|
|
||||||
|
|
@ -85,18 +86,7 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
DocumentEmbeddingsQueryService.add_args(parser)
|
DocumentEmbeddingsQueryService.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help=f'API key for qdrant (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,31 +12,32 @@ from qdrant_client import QdrantClient
|
||||||
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||||
from .... schema import Error, Term, IRI, LITERAL
|
from .... schema import Error, Term, IRI, LITERAL
|
||||||
from .... base import GraphEmbeddingsQueryService
|
from .... base import GraphEmbeddingsQueryService
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "graph-embeddings-query"
|
default_ident = "graph-embeddings-query"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
class Processor(GraphEmbeddingsQueryService):
|
class Processor(GraphEmbeddingsQueryService):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
#optional api key
|
url, api_key, _, _ = resolve_qdrant_config(
|
||||||
api_key = params.get("api_key", None)
|
url=store_uri, api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
|
||||||
def create_value(self, ent):
|
def create_value(self, ent):
|
||||||
if ent.startswith("http://") or ent.startswith("https://"):
|
if ent.startswith("http://") or ent.startswith("https://"):
|
||||||
|
|
@ -104,18 +105,7 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
GraphEmbeddingsQueryService.add_args(parser)
|
GraphEmbeddingsQueryService.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help=f'API key for qdrant (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
|
||||||
# Create keyspace
|
# Create keyspace
|
||||||
self.session.execute(f"""
|
self.session.execute(f"""
|
||||||
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
|
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
|
||||||
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
|
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': {self.cassandra_config.get('replication_factor', 1)}}}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Create triples table optimized for SPARQL queries
|
# Create triples table optimized for SPARQL queries
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,12 @@ from .... schema import (
|
||||||
RowIndexMatch, Error
|
RowIndexMatch, Error
|
||||||
)
|
)
|
||||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "row-embeddings-query"
|
default_ident = "row-embeddings-query"
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
default_concurrency = 10
|
default_concurrency = 10
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -35,13 +35,17 @@ class Processor(FlowProcessor):
|
||||||
id = params.get("id", default_ident)
|
id = params.get("id", default_ident)
|
||||||
concurrency = params.get("concurrency", default_concurrency)
|
concurrency = params.get("concurrency", default_concurrency)
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
api_key = params.get("api_key", None)
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
|
url, api_key, _, _ = resolve_qdrant_config(
|
||||||
|
url=store_uri, api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -62,7 +66,7 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
|
||||||
def sanitize_name(self, name: str) -> str:
|
def sanitize_name(self, name: str) -> str:
|
||||||
"""Sanitize names for Qdrant collection naming"""
|
"""Sanitize names for Qdrant collection naming"""
|
||||||
|
|
@ -192,21 +196,9 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add command-line arguments"""
|
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help='API key for Qdrant (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c', '--concurrency',
|
'-c', '--concurrency',
|
||||||
|
|
|
||||||
|
|
@ -14,29 +14,34 @@ from qdrant_client.models import Distance, VectorParams
|
||||||
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
|
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
|
||||||
from .... base import AsyncProcessor, Consumer, Producer
|
from .... base import AsyncProcessor, Consumer, Producer
|
||||||
from .... base import ConsumerMetrics, ProducerMetrics
|
from .... base import ConsumerMetrics, ProducerMetrics
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "doc-embeddings-write"
|
default_ident = "doc-embeddings-write"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
api_key = params.get("api_key", None)
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
|
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
||||||
|
url=store_uri, api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
self.replication_factor = replication_factor
|
||||||
|
self.shard_number = shard_number
|
||||||
self._cache_lock = asyncio.Lock()
|
self._cache_lock = asyncio.Lock()
|
||||||
self._known_collections: set[str] = set()
|
self._known_collections: set[str] = set()
|
||||||
|
|
||||||
|
|
@ -61,6 +66,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
vectors_config=VectorParams(
|
vectors_config=VectorParams(
|
||||||
size=dim, distance=Distance.COSINE
|
size=dim, distance=Distance.COSINE
|
||||||
),
|
),
|
||||||
|
replication_factor=self.replication_factor,
|
||||||
|
shard_number=self.shard_number,
|
||||||
)
|
)
|
||||||
self._known_collections.add(collection_name)
|
self._known_collections.add(collection_name)
|
||||||
|
|
||||||
|
|
@ -109,18 +116,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
DocumentEmbeddingsStoreService.add_args(parser)
|
DocumentEmbeddingsStoreService.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help=f'Qdrant API key (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from qdrant_client.models import Distance, VectorParams
|
||||||
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
||||||
from .... base import AsyncProcessor, Consumer, Producer
|
from .... base import AsyncProcessor, Consumer, Producer
|
||||||
from .... base import ConsumerMetrics, ProducerMetrics
|
from .... base import ConsumerMetrics, ProducerMetrics
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
from .... schema import IRI, LITERAL
|
from .... schema import IRI, LITERAL
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
|
|
@ -29,29 +30,32 @@ def get_term_value(term):
|
||||||
elif term.type == LITERAL:
|
elif term.type == LITERAL:
|
||||||
return term.value
|
return term.value
|
||||||
else:
|
else:
|
||||||
# For blank nodes or other types, use id or value
|
|
||||||
return term.id or term.value
|
return term.id or term.value
|
||||||
|
|
||||||
|
|
||||||
default_ident = "graph-embeddings-write"
|
default_ident = "graph-embeddings-write"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
api_key = params.get("api_key", None)
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
|
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
||||||
|
url=store_uri, api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
self.replication_factor = replication_factor
|
||||||
|
self.shard_number = shard_number
|
||||||
self._cache_lock = asyncio.Lock()
|
self._cache_lock = asyncio.Lock()
|
||||||
self._known_collections: set[str] = set()
|
self._known_collections: set[str] = set()
|
||||||
|
|
||||||
|
|
@ -76,6 +80,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
vectors_config=VectorParams(
|
vectors_config=VectorParams(
|
||||||
size=dim, distance=Distance.COSINE
|
size=dim, distance=Distance.COSINE
|
||||||
),
|
),
|
||||||
|
replication_factor=self.replication_factor,
|
||||||
|
shard_number=self.shard_number,
|
||||||
)
|
)
|
||||||
self._known_collections.add(collection_name)
|
self._known_collections.add(collection_name)
|
||||||
|
|
||||||
|
|
@ -128,18 +134,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
GraphEmbeddingsStoreService.add_args(parser)
|
GraphEmbeddingsStoreService.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help=f'Qdrant API key'
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,12 @@ from qdrant_client.models import PointStruct, Distance, VectorParams
|
||||||
from .... schema import RowEmbeddings
|
from .... schema import RowEmbeddings
|
||||||
from .... base import FlowProcessor, ConsumerSpec
|
from .... base import FlowProcessor, ConsumerSpec
|
||||||
from .... base import CollectionConfigHandler
|
from .... base import CollectionConfigHandler
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "row-embeddings-write"
|
default_ident = "row-embeddings-write"
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
|
|
||||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
@ -41,13 +41,17 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
id = params.get("id", default_ident)
|
id = params.get("id", default_ident)
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
api_key = params.get("api_key", None)
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
|
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
||||||
|
url=store_uri, api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -63,7 +67,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
# Register config handler for collection management
|
# Register config handler for collection management
|
||||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
self.replication_factor = replication_factor
|
||||||
|
self.shard_number = shard_number
|
||||||
self._cache_lock = asyncio.Lock()
|
self._cache_lock = asyncio.Lock()
|
||||||
self._known_collections: set[str] = set()
|
self._known_collections: set[str] = set()
|
||||||
|
|
||||||
|
|
@ -103,6 +109,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
size=dimension,
|
size=dimension,
|
||||||
distance=Distance.COSINE
|
distance=Distance.COSINE
|
||||||
),
|
),
|
||||||
|
replication_factor=self.replication_factor,
|
||||||
|
shard_number=self.shard_number,
|
||||||
)
|
)
|
||||||
self._known_collections.add(collection_name)
|
self._known_collections.add(collection_name)
|
||||||
|
|
||||||
|
|
@ -249,21 +257,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add command-line arguments"""
|
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help='Qdrant API key (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
# Resolve configuration with environment variable fallback
|
||||||
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password
|
password=cassandra_password
|
||||||
|
|
@ -57,6 +57,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
self.cassandra_host = hosts # Store as list
|
self.cassandra_host = hosts # Store as list
|
||||||
self.cassandra_username = username
|
self.cassandra_username = username
|
||||||
self.cassandra_password = password
|
self.cassandra_password = password
|
||||||
|
self.replication_factor = replication_factor
|
||||||
|
|
||||||
# Config key for schemas
|
# Config key for schemas
|
||||||
self.config_key = params.get("config_type", "schema")
|
self.config_key = params.get("config_type", "schema")
|
||||||
|
|
@ -232,7 +233,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
||||||
WITH REPLICATION = {{
|
WITH REPLICATION = {{
|
||||||
'class': 'SimpleStrategy',
|
'class': 'SimpleStrategy',
|
||||||
'replication_factor': 1
|
'replication_factor': {self.replication_factor}
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from .. schema import Metadata, GraphEmbeddings
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
|
@ -33,7 +33,7 @@ class ConfigTableStore:
|
||||||
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
||||||
|
|
||||||
if cassandra_username and cassandra_password:
|
if cassandra_username and cassandra_password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(
|
auth_provider = PlainTextAuthProvider(
|
||||||
username=cassandra_username, password=cassandra_password
|
username=cassandra_username, password=cassandra_password
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import logging
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
from . cassandra_async import async_execute
|
from . cassandra_async import async_execute
|
||||||
|
|
||||||
|
|
@ -39,7 +39,7 @@ class IamTableStore:
|
||||||
cassandra_host = [h.strip() for h in cassandra_host.split(",")]
|
cassandra_host = [h.strip() for h in cassandra_host.split(",")]
|
||||||
|
|
||||||
if cassandra_username and cassandra_password:
|
if cassandra_username and cassandra_password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(
|
auth_provider = PlainTextAuthProvider(
|
||||||
username=cassandra_username, password=cassandra_password,
|
username=cassandra_username, password=cassandra_password,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ def tuple_to_term(value, is_uri):
|
||||||
else:
|
else:
|
||||||
return Term(type=LITERAL, value=value)
|
return Term(type=LITERAL, value=value)
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
|
@ -50,7 +50,7 @@ class KnowledgeTableStore:
|
||||||
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
||||||
|
|
||||||
if cassandra_username and cassandra_password:
|
if cassandra_username and cassandra_password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(
|
auth_provider = PlainTextAuthProvider(
|
||||||
username=cassandra_username, password=cassandra_password
|
username=cassandra_username, password=cassandra_password
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from .. exceptions import RequestError
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from cassandra.query import BatchStatement
|
from cassandra.query import BatchStatement
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
|
@ -53,7 +53,7 @@ class LibraryTableStore:
|
||||||
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
||||||
|
|
||||||
if cassandra_username and cassandra_password:
|
if cassandra_username and cassandra_password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(
|
auth_provider = PlainTextAuthProvider(
|
||||||
username=cassandra_username, password=cassandra_password
|
username=cassandra_username, password=cassandra_password
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue