mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-23 21:58:06 +02:00
fix: wire replication params through YAML/params path for Cassandra and Qdrant
resolve_cassandra_config did not accept replication_factor as a kwarg, so cassandra_replication_factor from YAML params was silently ignored by all 6 callers. Add the kwarg and pass it from every caller. Same fix for Qdrant: 3 writers now pass qdrant_replication_factor and qdrant_shard_number from params. Add tests covering the params path for both helpers.
This commit is contained in:
parent
4913f8c2eb
commit
0f7cfa2170
13 changed files with 214 additions and 28 deletions
|
|
@ -409,4 +409,57 @@ class TestEdgeCases:
|
||||||
|
|
||||||
assert hosts == ['mixed-host']
|
assert hosts == ['mixed-host']
|
||||||
assert username is None # Stays None
|
assert username is None # Stays None
|
||||||
assert password == 'mixed-pass'
|
assert password == 'mixed-pass'
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplicationFactorParamPath:
|
||||||
|
|
||||||
|
def test_explicit_kwarg(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=3,
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
|
||||||
|
def test_kwarg_overrides_env(self):
|
||||||
|
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=3,
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
|
||||||
|
def test_env_fallback_when_kwarg_none(self):
|
||||||
|
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=None,
|
||||||
|
)
|
||||||
|
assert rf == 5
|
||||||
|
|
||||||
|
def test_default_when_no_kwarg_no_env(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config()
|
||||||
|
assert rf == 1
|
||||||
|
|
||||||
|
def test_params_dict_path(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
params = {'cassandra_replication_factor': 3}
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=params.get('cassandra_replication_factor'),
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
|
||||||
|
def test_params_dict_overrides_env(self):
|
||||||
|
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
|
||||||
|
params = {'cassandra_replication_factor': 3}
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=params.get('cassandra_replication_factor'),
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
|
||||||
|
def test_params_dict_missing_falls_to_env(self):
|
||||||
|
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
|
||||||
|
params = {}
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=params.get('cassandra_replication_factor'),
|
||||||
|
)
|
||||||
|
assert rf == 5
|
||||||
136
tests/unit/test_base/test_qdrant_config.py
Normal file
136
tests/unit/test_base/test_qdrant_config.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from trustgraph.base.qdrant_config import (
|
||||||
|
get_qdrant_defaults,
|
||||||
|
resolve_qdrant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetQdrantDefaults:
|
||||||
|
|
||||||
|
def test_defaults_with_no_env_vars(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
defaults = get_qdrant_defaults()
|
||||||
|
assert defaults['url'] == 'http://localhost:6333'
|
||||||
|
assert defaults['api_key'] is None
|
||||||
|
assert defaults['replication_factor'] == 1
|
||||||
|
assert defaults['shard_number'] == 1
|
||||||
|
|
||||||
|
def test_defaults_from_env(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_URL': 'http://qdrant:6333',
|
||||||
|
'QDRANT_API_KEY': 'secret',
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '3',
|
||||||
|
'QDRANT_SHARD_NUMBER': '5',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
defaults = get_qdrant_defaults()
|
||||||
|
assert defaults['url'] == 'http://qdrant:6333'
|
||||||
|
assert defaults['api_key'] == 'secret'
|
||||||
|
assert defaults['replication_factor'] == 3
|
||||||
|
assert defaults['shard_number'] == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveQdrantConfig:
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
url, api_key, rf, sn = resolve_qdrant_config()
|
||||||
|
assert url == 'http://localhost:6333'
|
||||||
|
assert api_key is None
|
||||||
|
assert rf == 1
|
||||||
|
assert sn == 1
|
||||||
|
|
||||||
|
def test_explicit_kwargs(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
url, api_key, rf, sn = resolve_qdrant_config(
|
||||||
|
url='http://custom:6333',
|
||||||
|
api_key='key',
|
||||||
|
replication_factor=3,
|
||||||
|
shard_number=5,
|
||||||
|
)
|
||||||
|
assert url == 'http://custom:6333'
|
||||||
|
assert api_key == 'key'
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_kwargs_override_env(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_URL': 'http://env:6333',
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '10',
|
||||||
|
'QDRANT_SHARD_NUMBER': '10',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
url, _, rf, sn = resolve_qdrant_config(
|
||||||
|
url='http://explicit:6333',
|
||||||
|
replication_factor=3,
|
||||||
|
shard_number=5,
|
||||||
|
)
|
||||||
|
assert url == 'http://explicit:6333'
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_env_fallback_when_kwargs_none(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_URL': 'http://env:6333',
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '3',
|
||||||
|
'QDRANT_SHARD_NUMBER': '5',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
url, _, rf, sn = resolve_qdrant_config()
|
||||||
|
assert url == 'http://env:6333'
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_params_dict_path(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
params = {
|
||||||
|
'store_uri': 'http://params:6333',
|
||||||
|
'api_key': 'pkey',
|
||||||
|
'qdrant_replication_factor': 3,
|
||||||
|
'qdrant_shard_number': 5,
|
||||||
|
}
|
||||||
|
url, api_key, rf, sn = resolve_qdrant_config(
|
||||||
|
url=params.get('store_uri'),
|
||||||
|
api_key=params.get('api_key'),
|
||||||
|
replication_factor=params.get('qdrant_replication_factor'),
|
||||||
|
shard_number=params.get('qdrant_shard_number'),
|
||||||
|
)
|
||||||
|
assert url == 'http://params:6333'
|
||||||
|
assert api_key == 'pkey'
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_params_dict_overrides_env(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '10',
|
||||||
|
'QDRANT_SHARD_NUMBER': '10',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
params = {
|
||||||
|
'qdrant_replication_factor': 3,
|
||||||
|
'qdrant_shard_number': 5,
|
||||||
|
}
|
||||||
|
_, _, rf, sn = resolve_qdrant_config(
|
||||||
|
replication_factor=params.get('qdrant_replication_factor'),
|
||||||
|
shard_number=params.get('qdrant_shard_number'),
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_params_dict_missing_falls_to_env(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '3',
|
||||||
|
'QDRANT_SHARD_NUMBER': '5',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
params = {}
|
||||||
|
_, _, rf, sn = resolve_qdrant_config(
|
||||||
|
replication_factor=params.get('qdrant_replication_factor'),
|
||||||
|
shard_number=params.get('qdrant_shard_number'),
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
@ -103,35 +103,19 @@ def resolve_cassandra_config(
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
password: Optional[str] = None,
|
password: Optional[str] = None,
|
||||||
default_keyspace: Optional[str] = None
|
default_keyspace: Optional[str] = None,
|
||||||
|
replication_factor: Optional[int] = None,
|
||||||
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
|
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
|
||||||
"""
|
|
||||||
Resolve Cassandra configuration from various sources.
|
|
||||||
|
|
||||||
Can accept either argparse args object or explicit parameters.
|
|
||||||
Converts host string to list format for Cassandra driver.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor
|
|
||||||
host: Optional explicit host parameter (overrides args)
|
|
||||||
username: Optional explicit username parameter (overrides args)
|
|
||||||
password: Optional explicit password parameter (overrides args)
|
|
||||||
default_keyspace: Optional default keyspace if not specified elsewhere
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (hosts_list, username, password, keyspace, replication_factor)
|
|
||||||
"""
|
|
||||||
# If args provided, extract values
|
|
||||||
keyspace = None
|
keyspace = None
|
||||||
replication_factor = 1
|
|
||||||
if args is not None:
|
if args is not None:
|
||||||
host = host or getattr(args, 'cassandra_host', None)
|
host = host or getattr(args, 'cassandra_host', None)
|
||||||
username = username or getattr(args, 'cassandra_username', None)
|
username = username or getattr(args, 'cassandra_username', None)
|
||||||
password = password or getattr(args, 'cassandra_password', None)
|
password = password or getattr(args, 'cassandra_password', None)
|
||||||
keyspace = getattr(args, 'cassandra_keyspace', None)
|
keyspace = getattr(args, 'cassandra_keyspace', None)
|
||||||
replication_factor = getattr(args, 'cassandra_replication_factor', 1)
|
replication_factor = replication_factor or getattr(
|
||||||
|
args, 'cassandra_replication_factor', None
|
||||||
|
)
|
||||||
|
|
||||||
# Apply defaults if still None
|
|
||||||
defaults = get_cassandra_defaults()
|
defaults = get_cassandra_defaults()
|
||||||
host = host or defaults['host']
|
host = host or defaults['host']
|
||||||
username = username or defaults['username']
|
username = username or defaults['username']
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,8 @@ class Processor(AsyncProcessor):
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
default_keyspace="config"
|
default_keyspace="config",
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store resolved configuration
|
# Store resolved configuration
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,8 @@ class Processor(WorkspaceProcessor):
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
default_keyspace="knowledge"
|
default_keyspace="knowledge",
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cassandra_host = hosts
|
self.cassandra_host = hosts
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,7 @@ class Processor(AsyncProcessor):
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
default_keyspace="iam",
|
default_keyspace="iam",
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cassandra_host = hosts
|
self.cassandra_host = hosts
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,8 @@ class Processor(WorkspaceProcessor):
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
default_keyspace="librarian"
|
default_keyspace="librarian",
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store resolved configuration
|
# Store resolved configuration
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,8 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
api_key = params.get("api_key")
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
url, api_key, _, _ = resolve_qdrant_config(
|
url, api_key, _, _ = resolve_qdrant_config(
|
||||||
url=store_uri, api_key=api_key,
|
url=store_uri,
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
|
|
||||||
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
||||||
url=store_uri, api_key=api_key,
|
url=store_uri, api_key=api_key,
|
||||||
|
replication_factor=params.get("qdrant_replication_factor"),
|
||||||
|
shard_number=params.get("qdrant_shard_number"),
|
||||||
)
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
|
|
||||||
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
||||||
url=store_uri, api_key=api_key,
|
url=store_uri, api_key=api_key,
|
||||||
|
replication_factor=params.get("qdrant_replication_factor"),
|
||||||
|
shard_number=params.get("qdrant_shard_number"),
|
||||||
)
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,8 @@ class Processor(FlowProcessor):
|
||||||
host=params.get("cassandra_host"),
|
host=params.get("cassandra_host"),
|
||||||
username=params.get("cassandra_username"),
|
username=params.get("cassandra_username"),
|
||||||
password=params.get("cassandra_password"),
|
password=params.get("cassandra_password"),
|
||||||
default_keyspace='knowledge'
|
default_keyspace='knowledge',
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
||||||
url=store_uri, api_key=api_key,
|
url=store_uri, api_key=api_key,
|
||||||
|
replication_factor=params.get("qdrant_replication_factor"),
|
||||||
|
shard_number=params.get("qdrant_shard_number"),
|
||||||
)
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
|
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password
|
password=cassandra_password,
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store resolved configuration with proper names
|
# Store resolved configuration with proper names
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue