diff --git a/tests/unit/test_base/test_cassandra_config.py b/tests/unit/test_base/test_cassandra_config.py index a291434d..fe8a8379 100644 --- a/tests/unit/test_base/test_cassandra_config.py +++ b/tests/unit/test_base/test_cassandra_config.py @@ -409,4 +409,57 @@ class TestEdgeCases: assert hosts == ['mixed-host'] assert username is None # Stays None - assert password == 'mixed-pass' \ No newline at end of file + assert password == 'mixed-pass' + + +class TestReplicationFactorParamPath: + + def test_explicit_kwarg(self): + with patch.dict(os.environ, {}, clear=True): + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=3, + ) + assert rf == 3 + + def test_kwarg_overrides_env(self): + with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True): + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=3, + ) + assert rf == 3 + + def test_env_fallback_when_kwarg_none(self): + with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True): + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=None, + ) + assert rf == 5 + + def test_default_when_no_kwarg_no_env(self): + with patch.dict(os.environ, {}, clear=True): + _, _, _, _, rf = resolve_cassandra_config() + assert rf == 1 + + def test_params_dict_path(self): + with patch.dict(os.environ, {}, clear=True): + params = {'cassandra_replication_factor': 3} + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=params.get('cassandra_replication_factor'), + ) + assert rf == 3 + + def test_params_dict_overrides_env(self): + with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True): + params = {'cassandra_replication_factor': 3} + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=params.get('cassandra_replication_factor'), + ) + assert rf == 3 + + def test_params_dict_missing_falls_to_env(self): + with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True): + params = {} + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=params.get('cassandra_replication_factor'), + ) + assert rf == 5 \ No newline at end of file diff --git a/tests/unit/test_base/test_qdrant_config.py b/tests/unit/test_base/test_qdrant_config.py new file mode 100644 index 00000000..dbbe4214 --- /dev/null +++ b/tests/unit/test_base/test_qdrant_config.py @@ -0,0 +1,136 @@ + +import os +import pytest +from unittest.mock import patch + +from trustgraph.base.qdrant_config import ( + get_qdrant_defaults, + resolve_qdrant_config, +) + + +class TestGetQdrantDefaults: + + def test_defaults_with_no_env_vars(self): + with patch.dict(os.environ, {}, clear=True): + defaults = get_qdrant_defaults() + assert defaults['url'] == 'http://localhost:6333' + assert defaults['api_key'] is None + assert defaults['replication_factor'] == 1 + assert defaults['shard_number'] == 1 + + def test_defaults_from_env(self): + env = { + 'QDRANT_URL': 'http://qdrant:6333', + 'QDRANT_API_KEY': 'secret', + 'QDRANT_REPLICATION_FACTOR': '3', + 'QDRANT_SHARD_NUMBER': '5', + } + with patch.dict(os.environ, env, clear=True): + defaults = get_qdrant_defaults() + assert defaults['url'] == 'http://qdrant:6333' + assert defaults['api_key'] == 'secret' + assert defaults['replication_factor'] == 3 + assert defaults['shard_number'] == 5 + + +class TestResolveQdrantConfig: + + def test_defaults(self): + with patch.dict(os.environ, {}, clear=True): + url, api_key, rf, sn = resolve_qdrant_config() + assert url == 'http://localhost:6333' + assert api_key is None + assert rf == 1 + assert sn == 1 + + def test_explicit_kwargs(self): + with patch.dict(os.environ, {}, clear=True): + url, api_key, rf, sn = resolve_qdrant_config( + url='http://custom:6333', + api_key='key', + replication_factor=3, + shard_number=5, + ) + assert url == 'http://custom:6333' + assert api_key == 'key' + assert rf == 3 + assert sn == 5 + + def test_kwargs_override_env(self): + env = { + 'QDRANT_URL': 'http://env:6333', + 'QDRANT_REPLICATION_FACTOR': '10', + 'QDRANT_SHARD_NUMBER': '10', + } + with patch.dict(os.environ, env, clear=True): + url, _, rf, sn = resolve_qdrant_config( + url='http://explicit:6333', + replication_factor=3, + shard_number=5, + ) + assert url == 'http://explicit:6333' + assert rf == 3 + assert sn == 5 + + def test_env_fallback_when_kwargs_none(self): + env = { + 'QDRANT_URL': 'http://env:6333', + 'QDRANT_REPLICATION_FACTOR': '3', + 'QDRANT_SHARD_NUMBER': '5', + } + with patch.dict(os.environ, env, clear=True): + url, _, rf, sn = resolve_qdrant_config() + assert url == 'http://env:6333' + assert rf == 3 + assert sn == 5 + + def test_params_dict_path(self): + with patch.dict(os.environ, {}, clear=True): + params = { + 'store_uri': 'http://params:6333', + 'api_key': 'pkey', + 'qdrant_replication_factor': 3, + 'qdrant_shard_number': 5, + } + url, api_key, rf, sn = resolve_qdrant_config( + url=params.get('store_uri'), + api_key=params.get('api_key'), + replication_factor=params.get('qdrant_replication_factor'), + shard_number=params.get('qdrant_shard_number'), + ) + assert url == 'http://params:6333' + assert api_key == 'pkey' + assert rf == 3 + assert sn == 5 + + def test_params_dict_overrides_env(self): + env = { + 'QDRANT_REPLICATION_FACTOR': '10', + 'QDRANT_SHARD_NUMBER': '10', + } + with patch.dict(os.environ, env, clear=True): + params = { + 'qdrant_replication_factor': 3, + 'qdrant_shard_number': 5, + } + _, _, rf, sn = resolve_qdrant_config( + replication_factor=params.get('qdrant_replication_factor'), + shard_number=params.get('qdrant_shard_number'), + ) + assert rf == 3 + assert sn == 5 + + def test_params_dict_missing_falls_to_env(self): + env = { + 'QDRANT_REPLICATION_FACTOR': '3', + 'QDRANT_SHARD_NUMBER': '5', + } + with patch.dict(os.environ, env, clear=True): + params = {} + _, _, rf, sn = resolve_qdrant_config( + replication_factor=params.get('qdrant_replication_factor'), + shard_number=params.get('qdrant_shard_number'), + ) + assert rf == 3 + assert sn == 5 diff --git a/trustgraph-base/trustgraph/base/cassandra_config.py b/trustgraph-base/trustgraph/base/cassandra_config.py index 78505c68..b2e36fbd 100644 --- a/trustgraph-base/trustgraph/base/cassandra_config.py +++ b/trustgraph-base/trustgraph/base/cassandra_config.py @@ -103,35 +103,19 @@ def resolve_cassandra_config( host: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, - default_keyspace: Optional[str] = None + default_keyspace: Optional[str] = None, + replication_factor: Optional[int] = None, ) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]: - """ - Resolve Cassandra configuration from various sources. - - Can accept either argparse args object or explicit parameters. - Converts host string to list format for Cassandra driver. - - Args: - args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor - host: Optional explicit host parameter (overrides args) - username: Optional explicit username parameter (overrides args) - password: Optional explicit password parameter (overrides args) - default_keyspace: Optional default keyspace if not specified elsewhere - - Returns: - tuple: (hosts_list, username, password, keyspace, replication_factor) - """ - # If args provided, extract values keyspace = None - replication_factor = 1 if args is not None: host = host or getattr(args, 'cassandra_host', None) username = username or getattr(args, 'cassandra_username', None) password = password or getattr(args, 'cassandra_password', None) keyspace = getattr(args, 'cassandra_keyspace', None) - replication_factor = getattr(args, 'cassandra_replication_factor', 1) + replication_factor = replication_factor or getattr( + args, 'cassandra_replication_factor', None + ) - # Apply defaults if still None defaults = get_cassandra_defaults() host = host or defaults['host'] username = username or defaults['username'] diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index c5fac198..725f1106 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -83,7 +83,8 @@ class Processor(AsyncProcessor): host=cassandra_host, username=cassandra_username, password=cassandra_password, - default_keyspace="config" + default_keyspace="config", + replication_factor=params.get("cassandra_replication_factor"), ) # Store resolved configuration diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index a8f52efd..5c50c207 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -61,7 +61,8 @@ class Processor(WorkspaceProcessor): host=cassandra_host, username=cassandra_username, password=cassandra_password, - default_keyspace="knowledge" + default_keyspace="knowledge", + replication_factor=params.get("cassandra_replication_factor"), ) self.cassandra_host = hosts diff --git a/trustgraph-flow/trustgraph/iam/service/service.py b/trustgraph-flow/trustgraph/iam/service/service.py index 8ce22757..b2f3976d 100644 --- a/trustgraph-flow/trustgraph/iam/service/service.py +++ b/trustgraph-flow/trustgraph/iam/service/service.py @@ -101,6 +101,7 @@ class Processor(AsyncProcessor): username=cassandra_username, password=cassandra_password, default_keyspace="iam", + replication_factor=params.get("cassandra_replication_factor"), ) self.cassandra_host = hosts diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index ee5e9c1b..4d3efbfb 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -146,7 +146,8 @@ class Processor(WorkspaceProcessor): host=cassandra_host, username=cassandra_username, password=cassandra_password, - default_keyspace="librarian" + default_keyspace="librarian", + replication_factor=params.get("cassandra_replication_factor"), ) # Store resolved configuration diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index b98ab7e5..de25a139 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -27,7 +27,8 @@ class Processor(DocumentEmbeddingsQueryService): api_key = params.get("api_key") url, api_key, _, _ = resolve_qdrant_config( - url=store_uri, api_key=api_key, + url=store_uri, + api_key=api_key, ) super(Processor, self).__init__( diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index c212fa86..08d88849 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -30,6 +30,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): url, api_key, replication_factor, shard_number = resolve_qdrant_config( url=store_uri, api_key=api_key, + replication_factor=params.get("qdrant_replication_factor"), + shard_number=params.get("qdrant_shard_number"), ) super(Processor, self).__init__( diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index ab04e42e..b6072bdc 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -44,6 +44,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): url, api_key, replication_factor, shard_number = resolve_qdrant_config( url=store_uri, api_key=api_key, + replication_factor=params.get("qdrant_replication_factor"), + shard_number=params.get("qdrant_shard_number"), ) super(Processor, self).__init__( diff --git a/trustgraph-flow/trustgraph/storage/knowledge/store.py b/trustgraph-flow/trustgraph/storage/knowledge/store.py index 162a4057..f6e12a85 100644 --- a/trustgraph-flow/trustgraph/storage/knowledge/store.py +++ b/trustgraph-flow/trustgraph/storage/knowledge/store.py @@ -27,7 +27,8 @@ class Processor(FlowProcessor): host=params.get("cassandra_host"), username=params.get("cassandra_username"), password=params.get("cassandra_password"), - default_keyspace='knowledge' + default_keyspace='knowledge', + replication_factor=params.get("cassandra_replication_factor"), ) super(Processor, self).__init__( diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index 9071dbc1..4c65edb1 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -46,6 +46,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): url, api_key, replication_factor, shard_number = resolve_qdrant_config( url=store_uri, api_key=api_key, + replication_factor=params.get("qdrant_replication_factor"), + shard_number=params.get("qdrant_shard_number"), ) super(Processor, self).__init__( diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index 12345e46..e5506723 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -50,7 +50,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): hosts, username, password, keyspace, replication_factor = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, - password=cassandra_password + password=cassandra_password, + replication_factor=params.get("cassandra_replication_factor"), ) # Store resolved configuration with proper names