From a2dde9cafbdea5e6c0f48c5b6ef52c0b6b30c2b1 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 14 May 2026 16:00:54 +0100 Subject: [PATCH] Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916) Cassandra triples services were using syncronous EntityCentricKnowledgeGraph methods from async contexts, and connection state was managed with threading.local which is wrong for asyncio coroutines sharing a single thread. Qdrant services had no async wrapping at all, blocking the event loop on every network call. Rows services had unprotected shared state mutations across concurrent coroutines. - Add async methods to EntityCentricKnowledgeGraph (async_insert, async_get_s/p/o/sp/po/os/spo/all, async_collection_exists, async_create_collection, async_delete_collection) using the existing cassandra_async.async_execute bridge - Rewrite triples write + query services: replace threading.local with asyncio.Lock + dict cache for per-workspace connections, use async ECKG methods for all data operations, keep asyncio.to_thread only for one-time blocking ECKG construction - Wrap all Qdrant calls in asyncio.to_thread across all 6 services (doc/graph/row embeddings write + query), add asyncio.Lock + set cache for collection existence checks - Add asyncio.Lock to rows write + query services to protect shared state (schemas, sessions, config caches) from concurrent mutation - Update all affected tests to match new async patterns --- .../test_cassandra_config_end_to_end.py | 79 +++--- .../test_rows_cassandra_integration.py | 3 + .../test_rows_graphql_query_integration.py | 12 +- .../test_query/test_rows_cassandra_query.py | 7 +- .../test_triples_cassandra_query.py | 112 ++++----- .../test_null_embedding_protection.py | 12 + .../test_doc_embeddings_qdrant_storage.py | 4 +- .../test_row_embeddings_qdrant_storage.py | 14 +- .../test_rows_cassandra_storage.py | 3 + .../test_triples_cassandra_storage.py | 78 +++--- .../test_row_embeddings_query.py | 27 ++- .../trustgraph/direct/cassandra_kg.py | 226 +++++++++++++++++- .../query/doc_embeddings/qdrant/service.py | 42 +--- .../query/graph_embeddings/qdrant/service.py | 42 +--- .../query/row_embeddings/qdrant/service.py | 20 +- .../query/rows/cassandra/service.py | 60 ++--- .../query/triples/cassandra/service.py | 167 +++++-------- .../storage/doc_embeddings/qdrant/write.py | 58 +++-- .../storage/graph_embeddings/qdrant/write.py | 58 +++-- .../storage/row_embeddings/qdrant/write.py | 76 +++--- .../storage/rows/cassandra/write.py | 78 +++--- .../storage/triples/cassandra/write.py | 179 ++++---------- 22 files changed, 736 insertions(+), 621 deletions(-) diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index 514a5dbf..290e1348 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -63,26 +63,26 @@ class TestEndToEndConfigurationFlow: 'CASSANDRA_USERNAME': 'obj-user', 'CASSANDRA_PASSWORD': 'obj-pass' } - + mock_auth_instance = MagicMock() mock_auth_provider.return_value = mock_auth_instance mock_cluster_instance = MagicMock() mock_session = MagicMock() mock_cluster_instance.connect.return_value = mock_session mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) - + # Trigger Cassandra connection processor.connect_cassandra() - + # Verify auth provider was created with env vars mock_auth_provider.assert_called_once_with( username='obj-user', password='obj-pass' ) - + # Verify cluster was created with hosts from env and auth mock_cluster.assert_called_once() call_args = mock_cluster.call_args @@ -188,37 +188,34 @@ class TestConfigurationPriorityEndToEnd: ) @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra_kg.Cluster') - async def test_no_config_defaults_end_to_end(self, mock_cluster): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_no_config_defaults_end_to_end(self, mock_kg_class): """Test that defaults are used when no configuration provided end-to-end.""" - mock_cluster_instance = MagicMock() - mock_session = MagicMock() - mock_cluster_instance.connect.return_value = mock_session - mock_cluster.return_value = mock_cluster_instance - + from unittest.mock import AsyncMock + + mock_tg_instance = MagicMock() + mock_tg_instance.async_get_all = AsyncMock(return_value=[]) + mock_kg_class.return_value = mock_tg_instance + with patch.dict(os.environ, {}, clear=True): processor = TriplesQuery(taskgroup=MagicMock()) - + # Mock query to trigger TrustGraph creation mock_query = MagicMock() mock_query.collection = 'default_collection' mock_query.s = None mock_query.p = None mock_query.o = None + mock_query.g = None mock_query.limit = 100 - - # Mock the get_all method to return empty list - mock_tg_instance = MagicMock() - mock_tg_instance.get_all.return_value = [] - processor.tg = mock_tg_instance - + await processor.query_triples('default_user', mock_query) - + # Should use defaults - mock_cluster.assert_called_once() - call_args = mock_cluster.call_args - assert call_args.args[0] == ['cassandra'] # Default host - assert 'auth_provider' not in call_args.kwargs # No auth with default config + mock_kg_class.assert_called_once_with( + hosts=['cassandra'], + keyspace='default_user' + ) class TestNoBackwardCompatibilityEndToEnd: @@ -324,16 +321,16 @@ class TestMultipleHostsHandling: env_vars = { 'CASSANDRA_HOST': 'host1,host2,host3,host4,host5' } - + mock_cluster_instance = MagicMock() mock_session = MagicMock() mock_cluster_instance.connect.return_value = mock_session mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Verify all hosts were passed to Cluster mock_cluster.assert_called_once() call_args = mock_cluster.call_args @@ -392,27 +389,27 @@ class TestAuthenticationFlow: 'CASSANDRA_USERNAME': 'auth-user', 'CASSANDRA_PASSWORD': 'auth-secret' } - + mock_auth_instance = MagicMock() mock_auth_provider.return_value = mock_auth_instance mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Auth provider should be created mock_auth_provider.assert_called_once_with( username='auth-user', password='auth-secret' ) - + # Cluster should be created with auth provider call_args = mock_cluster.call_args assert 'auth_provider' in call_args.kwargs assert call_args.kwargs['auth_provider'] == mock_auth_instance - + @patch('trustgraph.storage.rows.cassandra.write.Cluster') @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster): @@ -421,21 +418,21 @@ class TestAuthenticationFlow: 'CASSANDRA_HOST': 'no-auth-host' # No username/password } - + mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Auth provider should not be created mock_auth_provider.assert_not_called() - + # Cluster should be created without auth provider call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs - + @patch('trustgraph.storage.rows.cassandra.write.Cluster') @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster): @@ -446,15 +443,15 @@ class TestAuthenticationFlow: cassandra_username='partial-user' # No password ) - + mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + processor.connect_cassandra() - + # Auth provider should not be created (needs both username AND password) mock_auth_provider.assert_not_called() - + # Cluster should be created without auth provider call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs \ No newline at end of file diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py index 1358d420..d668600c 100644 --- a/tests/integration/test_rows_cassandra_integration.py +++ b/tests/integration/test_rows_cassandra_integration.py @@ -101,6 +101,8 @@ class TestRowsCassandraIntegration: processor.session = None # Bind actual methods from the new unified table implementation + import asyncio + processor._setup_lock = asyncio.Lock() processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor) processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor) processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) @@ -108,6 +110,7 @@ class TestRowsCassandraIntegration: processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) processor.on_object = Processor.on_object.__get__(processor, Processor) processor.collection_exists = MagicMock(return_value=True) diff --git a/tests/integration/test_rows_graphql_query_integration.py b/tests/integration/test_rows_graphql_query_integration.py index 29b4464d..a455accd 100644 --- a/tests/integration/test_rows_graphql_query_integration.py +++ b/tests/integration/test_rows_graphql_query_integration.py @@ -184,7 +184,7 @@ class TestObjectsGraphQLQueryIntegration: await processor.on_schema_config("default", sample_schema_config, version=1) # Connect to Cassandra - processor.connect_cassandra() + await processor.connect_cassandra() assert processor.session is not None # Create test keyspace and table @@ -219,7 +219,7 @@ class TestObjectsGraphQLQueryIntegration: """Test inserting data and querying via GraphQL""" # Load schema and connect await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Setup test data keyspace = "test_user" @@ -293,7 +293,7 @@ class TestObjectsGraphQLQueryIntegration: """Test GraphQL queries with filtering on indexed fields""" # Setup (reuse previous setup) await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() keyspace = "test_user" collection = "filter_test" @@ -387,7 +387,7 @@ class TestObjectsGraphQLQueryIntegration: """Test full message processing workflow""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Create mock message request = RowsQueryRequest( @@ -433,7 +433,7 @@ class TestObjectsGraphQLQueryIntegration: """Test handling multiple concurrent GraphQL queries""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Create multiple query tasks queries = [ @@ -519,7 +519,7 @@ class TestObjectsGraphQLQueryIntegration: """Test handling of large query result sets""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() keyspace = "large_test_user" collection = "large_collection" diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py index bb6bbe84..b61500a4 100644 --- a/tests/unit/test_query/test_rows_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic: @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configuration""" + import asyncio processor = MagicMock() processor.schemas = {} processor.schema_builders = {} processor.graphql_schemas = {} processor.config_key = "schema" processor.query_cassandra = MagicMock() + processor._setup_lock = asyncio.Lock() + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test config @@ -335,7 +338,7 @@ class TestUnifiedTableQueries: """Test query execution with matching index""" processor = MagicMock() processor.session = MagicMock() - processor.connect_cassandra = MagicMock() + processor.connect_cassandra = AsyncMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) @@ -396,7 +399,7 @@ class TestUnifiedTableQueries: """Test query execution without matching index (scan mode)""" processor = MagicMock() processor.session = MagicMock() - processor.connect_cassandra = MagicMock() + processor.connect_cassandra = AsyncMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index 09681214..980fa904 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -2,8 +2,10 @@ Tests for Cassandra triples query service """ +import asyncio + import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, AsyncMock from trustgraph.query.triples.cassandra.service import Processor, create_term from trustgraph.schema import Term, IRI, LITERAL @@ -18,7 +20,7 @@ class TestCassandraQueryProcessor: return Processor( taskgroup=MagicMock(), id='test-cassandra-query', - graph_host='localhost' + cassandra_host='localhost' ) def test_create_term_with_http_uri(self, processor): @@ -85,7 +87,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor( taskgroup=MagicMock(), @@ -110,8 +112,8 @@ class TestCassandraQueryProcessor: keyspace='test_user' ) - # Verify get_spo was called with correct parameters - mock_tg_instance.get_spo.assert_called_once_with( + # Verify async_get_spo was called with correct parameters + mock_tg_instance.async_get_spo.assert_called_once_with( 'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100 ) @@ -130,23 +132,25 @@ class TestCassandraQueryProcessor: assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_username is None assert processor.cassandra_password is None - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_custom_params(self): """Test processor initialization with custom parameters""" taskgroup_mock = MagicMock() - + processor = Processor( taskgroup=taskgroup_mock, cassandra_host='cassandra.example.com', cassandra_username='queryuser', cassandra_password='querypass' ) - + assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_username == 'queryuser' assert processor.cassandra_password == 'querypass' - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -164,7 +168,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_sp.return_value = [mock_result] + mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -178,7 +182,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) + mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) assert len(result) == 1 assert result[0].s.iri == 'test_subject' assert result[0].p.iri == 'test_predicate' @@ -200,7 +204,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_s.return_value = [mock_result] + mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -214,7 +218,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) + mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) assert len(result) == 1 assert result[0].s.iri == 'test_subject' assert result[0].p.iri == 'result_predicate' @@ -236,7 +240,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_p.return_value = [mock_result] + mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -250,7 +254,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) + mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) assert len(result) == 1 assert result[0].s.iri == 'result_subject' assert result[0].p.iri == 'test_predicate' @@ -272,7 +276,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_o.return_value = [mock_result] + mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -286,7 +290,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) + mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) assert len(result) == 1 assert result[0].s.iri == 'result_subject' assert result[0].p.iri == 'result_predicate' @@ -305,11 +309,11 @@ class TestCassandraQueryProcessor: mock_result.s = 'all_subject' mock_result.p = 'all_predicate' mock_result.o = 'all_object' - mock_result.g = '' + mock_result.d = '' mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_all.return_value = [mock_result] + mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -323,7 +327,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) + mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 assert result[0].s.iri == 'all_subject' assert result[0].p.iri == 'all_predicate' @@ -410,7 +414,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor( taskgroup=MagicMock(), @@ -451,7 +455,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -489,8 +493,8 @@ class TestCassandraQueryProcessor: mock_result.lang = None mock_result.p = 'p' mock_result.o = 'o' - mock_tg_instance1.get_s.return_value = [mock_result] - mock_tg_instance2.get_s.return_value = [mock_result] + mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result]) + mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -504,7 +508,6 @@ class TestCassandraQueryProcessor: ) await processor.query_triples('user1', query1) - assert processor.table == 'user1' # Second query with different table query2 = TriplesQueryRequest( @@ -516,10 +519,11 @@ class TestCassandraQueryProcessor: ) await processor.query_triples('user2', query2) - assert processor.table == 'user2' - # Verify TrustGraph was created twice + # Verify TrustGraph was created twice for different workspaces assert mock_kg_class.call_count == 2 + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1') + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2') @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -529,7 +533,7 @@ class TestCassandraQueryProcessor: mock_tg_instance = MagicMock() mock_kg_class.return_value = mock_tg_instance - mock_tg_instance.get_spo.side_effect = Exception("Query failed") + mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed")) processor = Processor(taskgroup=MagicMock()) @@ -566,7 +570,7 @@ class TestCassandraQueryProcessor: mock_result2.otype = None mock_result2.dtype = None mock_result2.lang = None - mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2] + mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2]) processor = Processor(taskgroup=MagicMock()) @@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_po.return_value = [mock_result] + mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('test_user', query) - # Verify get_po was called (should use optimized po_table) - mock_tg_instance.get_po.assert_called_once_with( + # Verify async_get_po was called (should use optimized po_table) + mock_tg_instance.async_get_po.assert_called_once_with( 'test_collection', 'test_predicate', 'test_object', g=None, limit=50 ) @@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_os.return_value = [mock_result] + mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('test_user', query) - # Verify get_os was called (should use optimized subject_table with clustering) - mock_tg_instance.get_os.assert_called_once_with( + # Verify async_get_os was called (should use optimized subject_table with clustering) + mock_tg_instance.async_get_os.assert_called_once_with( 'test_collection', 'test_object', 'test_subject', g=None, limit=25 ) @@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations: mock_kg_class.return_value = mock_tg_instance # Mock empty results for all queries - mock_tg_instance.get_all.return_value = [] - mock_tg_instance.get_s.return_value = [] - mock_tg_instance.get_p.return_value = [] - mock_tg_instance.get_o.return_value = [] - mock_tg_instance.get_sp.return_value = [] - mock_tg_instance.get_po.return_value = [] - mock_tg_instance.get_os.return_value = [] - mock_tg_instance.get_spo.return_value = [] + mock_tg_instance.async_get_all = AsyncMock(return_value=[]) + mock_tg_instance.async_get_s = AsyncMock(return_value=[]) + mock_tg_instance.async_get_p = AsyncMock(return_value=[]) + mock_tg_instance.async_get_o = AsyncMock(return_value=[]) + mock_tg_instance.async_get_sp = AsyncMock(return_value=[]) + mock_tg_instance.async_get_po = AsyncMock(return_value=[]) + mock_tg_instance.async_get_os = AsyncMock(return_value=[]) + mock_tg_instance.async_get_spo = AsyncMock(return_value=[]) processor = Processor(taskgroup=MagicMock()) # Test each query pattern test_patterns = [ # (s, p, o, expected_method) - (None, None, None, 'get_all'), # All triples - ('s1', None, None, 'get_s'), # Subject only - (None, 'p1', None, 'get_p'), # Predicate only - (None, None, 'o1', 'get_o'), # Object only - ('s1', 'p1', None, 'get_sp'), # Subject + Predicate - (None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) - ('s1', None, 'o1', 'get_os'), # Object + Subject - ('s1', 'p1', 'o1', 'get_spo'), # All three + (None, None, None, 'async_get_all'), # All triples + ('s1', None, None, 'async_get_s'), # Subject only + (None, 'p1', None, 'async_get_p'), # Predicate only + (None, None, 'o1', 'async_get_o'), # Object only + ('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate + (None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) + ('s1', None, 'o1', 'async_get_os'), # Object + Subject + ('s1', 'p1', 'o1', 'async_get_spo'), # All three ] for s, p, o, expected_method in test_patterns: @@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.lang = None mock_results.append(mock_result) - mock_tg_instance.get_po.return_value = mock_results + mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results) processor = Processor(taskgroup=MagicMock()) @@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('large_dataset_user', query) - # Verify optimized get_po was used (no ALLOW FILTERING needed!) - mock_tg_instance.get_po.assert_called_once_with( + # Verify optimized async_get_po was used (no ALLOW FILTERING needed!) + mock_tg_instance.async_get_po.assert_called_once_with( 'massive_collection', 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type', 'http://example.com/Person', diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py index 2296e961..dbe06b40 100644 --- a/tests/unit/test_reliability/test_null_embedding_protection.py +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection: @pytest.mark.asyncio async def test_valid_embedding_upserted(self): + import asyncio from trustgraph.storage.doc_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "col1" @@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection: @pytest.mark.asyncio async def test_dimension_in_collection_name(self): """Collection name should include vector dimension.""" + import asyncio from trustgraph.storage.doc_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "docs" @@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection: @pytest.mark.asyncio async def test_valid_entity_and_vector_upserted(self): + import asyncio from trustgraph.storage.graph_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "col1" @@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection: @pytest.mark.asyncio async def test_lazy_collection_creation_on_new_dimension(self): + import asyncio from trustgraph.storage.graph_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = False proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "graphs" diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index ce6e6b3d..360ac3dc 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Assert expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions - # Verify collection existence is checked on each write - mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) + # Second write uses cached collection state — no collection_exists check + mock_qdrant_instance.collection_exists.assert_not_called() # But upsert should still be called mock_qdrant_instance.upsert.assert_called_once() diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py index 8754f47c..44fdf516 100644 --- a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - processor.ensure_collection("test_collection", 384) + await processor.ensure_collection("test_collection", 384) mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection") mock_qdrant_instance.create_collection.assert_called_once() # Verify the collection is cached - assert "test_collection" in processor.created_collections + assert "test_collection" in processor._known_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_ensure_collection_skips_existing(self, mock_qdrant_client): @@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - processor.ensure_collection("existing_collection", 384) + await processor.ensure_collection("existing_collection", 384) mock_qdrant_instance.collection_exists.assert_called_once() mock_qdrant_instance.create_collection.assert_not_called() @@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add("cached_collection") + processor._known_collections.add("cached_collection") - processor.ensure_collection("cached_collection", 384) + await processor.ensure_collection("cached_collection", 384) # Should not check or create - just return mock_qdrant_instance.collection_exists.assert_not_called() @@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add('rows_test_workspace_test_collection_schema1_384') + processor._known_collections.add('rows_test_workspace_test_collection_schema1_384') await processor.delete_collection('test_workspace', 'test_collection') @@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): assert mock_qdrant_instance.delete_collection.call_count == 2 # Verify the cached collection was removed - assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections + assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_delete_collection_schema(self, mock_qdrant_client): diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py index 852f01a1..3e5664ea 100644 --- a/tests/unit/test_storage/test_rows_cassandra_storage.py +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic: @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configurations""" + import asyncio processor = MagicMock() processor.schemas = {} processor.config_key = "schema" processor.registered_partitions = set() + processor._setup_lock = asyncio.Lock() + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test configuration diff --git a/tests/unit/test_storage/test_triples_cassandra_storage.py b/tests/unit/test_storage/test_triples_cassandra_storage.py index 04acbb16..394f0e54 100644 --- a/tests/unit/test_storage/test_triples_cassandra_storage.py +++ b/tests/unit/test_storage/test_triples_cassandra_storage.py @@ -2,6 +2,8 @@ Tests for Cassandra triples storage service """ +import asyncio + import pytest from unittest.mock import MagicMock, patch, AsyncMock @@ -24,12 +26,13 @@ class TestCassandraStorageProcessor: assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_username is None assert processor.cassandra_password is None - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_custom_params(self): """Test processor initialization with custom parameters (new cassandra_* names)""" taskgroup_mock = MagicMock() - + processor = Processor( taskgroup=taskgroup_mock, id='custom-storage', @@ -37,11 +40,12 @@ class TestCassandraStorageProcessor: cassandra_username='testuser', cassandra_password='testpass' ) - + assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_username == 'testuser' assert processor.cassandra_password == 'testpass' - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_partial_auth(self): """Test processor initialization with only username (no password)""" @@ -92,6 +96,7 @@ class TestCassandraStorageProcessor: """Test table switching logic when authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor( @@ -114,7 +119,6 @@ class TestCassandraStorageProcessor: username='testuser', password='testpass' ) - assert processor.table == 'user1' @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -122,6 +126,7 @@ class TestCassandraStorageProcessor: """Test table switching logic when no authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -138,7 +143,6 @@ class TestCassandraStorageProcessor: hosts=['cassandra'], # Updated default keyspace='user2' ) - assert processor.table == 'user2' @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -146,6 +150,7 @@ class TestCassandraStorageProcessor: """Test that TrustGraph is not recreated when table hasn't changed""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -169,6 +174,7 @@ class TestCassandraStorageProcessor: """Test that triples are properly inserted into Cassandra""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -208,12 +214,12 @@ class TestCassandraStorageProcessor: await processor.store_triples('user1', mock_message) # Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters) - assert mock_tg_instance.insert.call_count == 2 - mock_tg_instance.insert.assert_any_call( + assert mock_tg_instance.async_insert.call_count == 2 + mock_tg_instance.async_insert.assert_any_call( 'collection1', 'subject1', 'predicate1', 'object1', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) - mock_tg_instance.insert.assert_any_call( + mock_tg_instance.async_insert.assert_any_call( 'collection1', 'subject2', 'predicate2', 'object2', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) @@ -224,6 +230,7 @@ class TestCassandraStorageProcessor: """Test behavior when message has no triples""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -236,19 +243,17 @@ class TestCassandraStorageProcessor: await processor.store_triples('user1', mock_message) # Verify no triples were inserted - mock_tg_instance.insert.assert_not_called() + mock_tg_instance.async_insert.assert_not_called() @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') - @patch('trustgraph.storage.triples.cassandra.write.time.sleep') - async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class): + async def test_exception_handling_on_connection_failure(self, mock_kg_class): """Test exception handling during TrustGraph creation""" taskgroup_mock = MagicMock() mock_kg_class.side_effect = Exception("Connection failed") processor = Processor(taskgroup=taskgroup_mock) - # Create mock message mock_message = MagicMock() mock_message.metadata.collection = 'collection1' mock_message.triples = [] @@ -256,9 +261,6 @@ class TestCassandraStorageProcessor: with pytest.raises(Exception, match="Connection failed"): await processor.store_triples('user1', mock_message) - # Verify sleep was called before re-raising - mock_sleep.assert_called_once_with(1) - def test_add_args_method(self): """Test that add_args properly configures argument parser""" from argparse import ArgumentParser @@ -359,8 +361,6 @@ class TestCassandraStorageProcessor: mock_message1.triples = [] await processor.store_triples('user1', mock_message1) - assert processor.table == 'user1' - assert processor.tg == mock_tg_instance1 # Second message with different table mock_message2 = MagicMock() @@ -368,11 +368,11 @@ class TestCassandraStorageProcessor: mock_message2.triples = [] await processor.store_triples('user2', mock_message2) - assert processor.table == 'user2' - assert processor.tg == mock_tg_instance2 - # Verify TrustGraph was created twice for different tables + # Verify TrustGraph was created twice for different workspaces assert mock_kg_class.call_count == 2 + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1') + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2') @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -380,6 +380,7 @@ class TestCassandraStorageProcessor: """Test storing triples with special characters and unicode""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -405,7 +406,7 @@ class TestCassandraStorageProcessor: await processor.store_triples('test_workspace', mock_message) # Verify the triple was inserted with special characters preserved - mock_tg_instance.insert.assert_called_once_with( + mock_tg_instance.async_insert.assert_called_once_with( 'test_collection', 'subject with spaces & symbols', 'predicate:with/colons', @@ -418,29 +419,29 @@ class TestCassandraStorageProcessor: @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') - async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class): - """Test that table remains unchanged when TrustGraph creation fails""" + async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class): + """Test that a failed connection doesn't leave stale cached state""" taskgroup_mock = MagicMock() + mock_good_instance = MagicMock() processor = Processor(taskgroup=taskgroup_mock) - # Set an initial table - processor.table = ('old_user', 'old_collection') - - # Mock TrustGraph to raise exception - mock_kg_class.side_effect = Exception("Connection failed") - mock_message = MagicMock() - mock_message.metadata.collection = 'new_collection' + mock_message.metadata.collection = 'collection1' mock_message.triples = [] + # First call fails + mock_kg_class.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): - await processor.store_triples('new_user', mock_message) + await processor.store_triples('user1', mock_message) - # Table should remain unchanged since self.table = table happens after try/except - assert processor.table == ('old_user', 'old_collection') - # TrustGraph should be set to None though - assert processor.tg is None + # Second call succeeds — should retry connection, not use stale state + mock_kg_class.side_effect = None + mock_kg_class.return_value = mock_good_instance + await processor.store_triples('user1', mock_message) + + # Connection was attempted twice (failed + succeeded) + assert mock_kg_class.call_count == 2 class TestCassandraPerformanceOptimizations: @@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations: """Test that legacy mode still works with single table""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}): @@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations: """Test that optimized mode uses multi-table schema""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}): @@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations: """Test that all tables stay consistent during batch writes""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations: await processor.store_triples('user1', mock_message) # Verify insert was called for the triple (implementation details tested in KnowledgeGraph) - mock_tg_instance.insert.assert_called_once_with( + mock_tg_instance.async_insert.assert_called_once_with( 'collection1', 'test_subject', 'test_predicate', 'test_object', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) diff --git a/tests/unit/test_structured_data/test_row_embeddings_query.py b/tests/unit/test_structured_data/test_row_embeddings_query.py index 51cf834f..f1297e1c 100644 --- a/tests/unit/test_structured_data/test_row_embeddings_query.py +++ b/tests/unit/test_structured_data/test_row_embeddings_query.py @@ -89,7 +89,8 @@ class TestSanitizeName: class TestFindCollection: - def test_finds_matching_collection(self): + @pytest.mark.asyncio + async def test_finds_matching_collection(self): proc = _make_processor() mock_coll = MagicMock() mock_coll.name = "rows_test_workspace_test_col_customers_384" @@ -98,11 +99,12 @@ class TestFindCollection: mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-workspace", "test-col", "customers") + result = await proc.find_collection("test-workspace", "test-col", "customers") assert result == "rows_test_workspace_test_col_customers_384" - def test_returns_none_when_no_match(self): + @pytest.mark.asyncio + async def test_returns_none_when_no_match(self): proc = _make_processor() mock_coll = MagicMock() mock_coll.name = "rows_other_workspace_other_col_schema_768" @@ -111,14 +113,15 @@ class TestFindCollection: mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-workspace", "test-col", "customers") + result = await proc.find_collection("test-workspace", "test-col", "customers") assert result is None - def test_returns_none_on_error(self): + @pytest.mark.asyncio + async def test_returns_none_on_error(self): proc = _make_processor() proc.qdrant.get_collections.side_effect = Exception("connection error") - result = proc.find_collection("workspace", "col", "schema") + result = await proc.find_collection("workspace", "col", "schema") assert result is None @@ -139,7 +142,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_no_collection_returns_empty(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value=None) + proc.find_collection = AsyncMock(return_value=None) request = _make_request() result = await proc.query_row_embeddings("test-workspace", request) @@ -148,7 +151,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_successful_query_returns_matches(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") points = [ _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), @@ -172,7 +175,7 @@ class TestQueryRowEmbeddings: async def test_index_name_filter_applied(self): """When index_name is specified, a Qdrant filter should be used.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] @@ -188,7 +191,7 @@ class TestQueryRowEmbeddings: async def test_no_index_name_no_filter(self): """When index_name is empty, no filter should be applied.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] @@ -204,7 +207,7 @@ class TestQueryRowEmbeddings: async def test_missing_payload_fields_default(self): """Points with missing payload fields should use defaults.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") point = MagicMock() point.payload = {} # Empty payload @@ -225,7 +228,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_qdrant_error_propagates(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") proc.qdrant.query_points.side_effect = Exception("qdrant down") request = _make_request() diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index 59d2a2a1..d7abd1a9 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -1,10 +1,14 @@ +import datetime +import os +import logging + from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.query import BatchStatement, SimpleStatement from ssl import SSLContext, PROTOCOL_TLSv1_2 -import os -import logging + +from ..tables.cassandra_async import async_execute # Global list to track clusters for cleanup _active_clusters = [] @@ -461,7 +465,6 @@ class KnowledgeGraph: def create_collection(self, collection): """Create collection by inserting metadata row""" try: - import datetime self.session.execute( f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", (collection, datetime.datetime.now()) @@ -954,7 +957,6 @@ class EntityCentricKnowledgeGraph: def create_collection(self, collection): """Create collection by inserting metadata row""" try: - import datetime self.session.execute( f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", (collection, datetime.datetime.now()) @@ -1045,6 +1047,222 @@ class EntityCentricKnowledgeGraph: logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") + # ======================================================================== + # Async methods — use cassandra driver's native async API via async_execute + # ======================================================================== + + async def async_insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""): + if g is None: + g = DEFAULT_GRAPH + if otype is None: + if o.startswith("http://") or o.startswith("https://"): + otype = "u" + else: + otype = "l" + + batch = BatchStatement() + batch.add(self.insert_entity_stmt, (collection, s, 'S', p, otype, s, o, g, dtype, lang)) + batch.add(self.insert_entity_stmt, (collection, p, 'P', p, otype, s, o, g, dtype, lang)) + if otype == 'u' or otype == 't': + batch.add(self.insert_entity_stmt, (collection, o, 'O', p, otype, s, o, g, dtype, lang)) + if g != DEFAULT_GRAPH: + batch.add(self.insert_entity_stmt, (collection, g, 'G', p, otype, s, o, g, dtype, lang)) + batch.add(self.insert_collection_stmt, (collection, g, s, p, o, otype, dtype, lang)) + + await async_execute(self.session, batch) + + async def async_get_all(self, collection, limit=50): + return await async_execute( + self.session, self.get_collection_all_stmt, (collection, limit) + ) + + async def async_get_s(self, collection, s, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_stmt, (collection, s, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_p(self, collection, p, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_p_stmt, (collection, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_o(self, collection, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_o_stmt, (collection, o, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_sp(self, collection, s, p, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_po(self, collection, p, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_o_p_stmt, (collection, o, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_os(self, collection, o, s, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_stmt, (collection, s, limit) + ) + results = [] + for row in rows: + if row.o != o: + continue + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=row.p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_spo(self, collection, s, p, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit) + ) + results = [] + for row in rows: + if row.o != o: + continue + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_g(self, collection, g, limit=50): + if g is None: + g = DEFAULT_GRAPH + return await async_execute( + self.session, self.get_collection_by_graph_stmt, (collection, g, limit) + ) + + async def async_collection_exists(self, collection): + try: + result = await async_execute( + self.session, + f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1", + (collection,) + ) + return bool(result) + except Exception as e: + logger.error(f"Error checking collection existence: {e}") + return False + + async def async_create_collection(self, collection): + await async_execute( + self.session, + f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", + (collection, datetime.datetime.now()) + ) + logger.info(f"Created collection metadata for {collection}") + + async def async_delete_collection(self, collection): + rows = await async_execute( + self.session, + f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s", + (collection,) + ) + + entities = set() + quads = [] + for row in rows: + d, s, p, o = row.d, row.s, row.p, row.o + otype = row.otype + dtype = row.dtype if hasattr(row, 'dtype') else '' + lang = row.lang if hasattr(row, 'lang') else '' + quads.append((d, s, p, o, otype, dtype, lang)) + entities.add(s) + entities.add(p) + if otype == 'u' or otype == 't': + entities.add(o) + if d != DEFAULT_GRAPH: + entities.add(d) + + batch = BatchStatement() + count = 0 + for entity in entities: + batch.add(self.delete_entity_partition_stmt, (collection, entity)) + count += 1 + if count % 50 == 0: + await async_execute(self.session, batch) + batch = BatchStatement() + if count % 50 != 0: + await async_execute(self.session, batch) + + batch = BatchStatement() + count = 0 + for d, s, p, o, otype, dtype, lang in quads: + batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang)) + count += 1 + if count % 50 == 0: + await async_execute(self.session, batch) + batch = BatchStatement() + if count % 50 != 0: + await async_execute(self.session, batch) + + await async_execute( + self.session, + f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s", + (collection,) + ) + logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") + def close(self): """Close connections""" if hasattr(self, 'session') and self.session: diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 1d59c835..f6770744 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -4,11 +4,10 @@ Document embeddings query service. Input is vector, output is an array of chunk_ids """ +import asyncio import logging from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct -from qdrant_client.models import Distance, VectorParams from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error @@ -38,32 +37,6 @@ class Processor(DocumentEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - self.last_collection = None - - def ensure_collection_exists(self, collection, dim): - """Ensure collection exists, create if it doesn't""" - if collection != self.last_collection: - if not self.qdrant.collection_exists(collection): - try: - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), - ) - logger.info(f"Created collection: {collection}") - except Exception as e: - logger.error(f"Qdrant collection creation failed: {e}") - raise e - self.last_collection = collection - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) async def query_document_embeddings(self, workspace, msg): @@ -73,21 +46,24 @@ class Processor(DocumentEmbeddingsQueryService): if not vec: return [] - # Use dimension suffix in collection name dim = len(vec) collection = f"d_{workspace}_{msg.collection}_{dim}" - # Check if collection exists - return empty if not - if not self.collection_exists(collection): + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection + ) + if not exists: logger.info(f"Collection {collection} does not exist, returning empty results") return [] - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=collection, query=vec, limit=msg.limit, with_payload=True, - ).points + ) + search_result = result.points chunks = [] for r in search_result: diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index b8fb1361..167130c9 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -4,11 +4,10 @@ Graph embeddings query service. Input is vector, output is list of entities """ +import asyncio import logging from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct -from qdrant_client.models import Distance, VectorParams from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL @@ -38,32 +37,6 @@ class Processor(GraphEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - self.last_collection = None - - def ensure_collection_exists(self, collection, dim): - """Ensure collection exists, create if it doesn't""" - if collection != self.last_collection: - if not self.qdrant.collection_exists(collection): - try: - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), - ) - logger.info(f"Created collection: {collection}") - except Exception as e: - logger.error(f"Qdrant collection creation failed: {e}") - raise e - self.last_collection = collection - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -79,23 +52,26 @@ class Processor(GraphEmbeddingsQueryService): if not vec: return [] - # Use dimension suffix in collection name dim = len(vec) collection = f"t_{workspace}_{msg.collection}_{dim}" - # Check if collection exists - return empty if not - if not self.collection_exists(collection): + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection + ) + if not exists: logger.info(f"Collection {collection} does not exist") return [] # Heuristic hack, get (2*limit), so that we have more chance # of getting (limit) unique entities - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=collection, query=vec, limit=msg.limit * 2, with_payload=True, - ).points + ) + search_result = result.points entity_set = set() entities = [] diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index dd89a8d8..1534c044 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -6,6 +6,7 @@ Output is matching row index information (index_name, index_value) for use in subsequent Cassandra lookups. """ +import asyncio import logging import re from typing import Optional @@ -70,7 +71,7 @@ class Processor(FlowProcessor): safe_name = 'r_' + safe_name return safe_name.lower() - def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: + async def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: """Find the Qdrant collection for a given workspace/collection/schema""" prefix = ( f"rows_{self.sanitize_name(workspace)}_" @@ -78,14 +79,15 @@ class Processor(FlowProcessor): ) try: - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching = [ coll.name for coll in all_collections if coll.name.startswith(prefix) ] if matching: - # Return first match (there should typically be only one per dimension) return matching[0] except Exception as e: @@ -100,8 +102,7 @@ class Processor(FlowProcessor): if not vec: return [] - # Find the collection for this workspace/collection/schema - qdrant_collection = self.find_collection( + qdrant_collection = await self.find_collection( workspace, request.collection, request.schema_name ) @@ -113,7 +114,6 @@ class Processor(FlowProcessor): return [] try: - # Build optional filter for index_name query_filter = None if request.index_name: query_filter = Filter( @@ -125,16 +125,16 @@ class Processor(FlowProcessor): ] ) - # Query Qdrant - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=qdrant_collection, query=vec, limit=request.limit, with_payload=True, query_filter=query_filter, - ).points + ) + search_result = result.points - # Convert to RowIndexMatch objects matches = [] for point in search_result: payload = point.payload or {} diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 73cfcd83..7157daae 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -11,6 +11,7 @@ Queries against the unified 'rows' table with schema: - source: text """ +import asyncio import json import logging import re @@ -97,34 +98,38 @@ class Processor(FlowProcessor): # Cassandra session self.cluster = None self.session = None + self._setup_lock = asyncio.Lock() # Known keyspaces self.known_keyspaces: Set[str] = set() - def connect_cassandra(self): + async def connect_cassandra(self): """Connect to Cassandra cluster""" - if self.session: - return + async with self._setup_lock: + if self.session: + return - try: - if self.cassandra_username and self.cassandra_password: - auth_provider = PlainTextAuthProvider( - username=self.cassandra_username, - password=self.cassandra_password - ) - self.cluster = Cluster( - contact_points=self.cassandra_host, - auth_provider=auth_provider - ) - else: - self.cluster = Cluster(contact_points=self.cassandra_host) + try: + if self.cassandra_username and self.cassandra_password: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_username, + password=self.cassandra_password + ) + cluster = Cluster( + contact_points=self.cassandra_host, + auth_provider=auth_provider + ) + else: + cluster = Cluster(contact_points=self.cassandra_host) - self.session = self.cluster.connect() - logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") + session = await asyncio.to_thread(cluster.connect) + self.cluster = cluster + self.session = session + logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") - except Exception as e: - logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) - raise + except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) + raise def sanitize_name(self, name: str) -> str: """Sanitize names for Cassandra compatibility""" @@ -140,14 +145,17 @@ class Processor(FlowProcessor): f"for workspace {workspace}" ) - # Replace existing schemas for this workspace + async with self._setup_lock: + await self._apply_schema_config(workspace, config) + + async def _apply_schema_config(self, workspace, config): + ws_schemas: Dict[str, RowSchema] = {} self.schemas[workspace] = ws_schemas builder = GraphQLSchemaBuilder() self.schema_builders[workspace] = builder - # Check if our config type exists if self.config_key not in config: logger.warning( f"No '{self.config_key}' type in configuration " @@ -156,16 +164,12 @@ class Processor(FlowProcessor): self.graphql_schemas[workspace] = None return - # Get the schemas dictionary for our type schemas_config = config[self.config_key] - # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): try: - # Parse the JSON schema definition schema_def = json.loads(schema_json) - # Create Field objects fields = [] for field_def in schema_def.get("fields", []): field = SchemaField( @@ -180,7 +184,6 @@ class Processor(FlowProcessor): ) fields.append(field) - # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), @@ -202,7 +205,6 @@ class Processor(FlowProcessor): f"{len(ws_schemas)} schemas" ) - # Regenerate GraphQL schema for this workspace self.graphql_schemas[workspace] = builder.build(self.query_cassandra) def get_index_names(self, schema: RowSchema) -> List[str]: @@ -254,7 +256,7 @@ class Processor(FlowProcessor): For other queries, we need to scan and post-filter. """ # Connect if needed - self.connect_cassandra() + await self.connect_cassandra() safe_keyspace = self.sanitize_name(workspace) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index a9bdbbac..1fadaab3 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -6,8 +6,8 @@ null. Output is a list of quads. import asyncio import logging - import json + from cassandra.query import SimpleStatement from .... direct.cassandra_kg import ( @@ -17,6 +17,7 @@ from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK from .... base import TriplesQueryService from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config +from .... tables.cassandra_async import async_execute # Module logger logger = logging.getLogger(__name__) @@ -176,45 +177,42 @@ class Processor(TriplesQueryService): self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password - self.table = None - def ensure_connection(self, workspace): - """Ensure we have a connection to the correct keyspace.""" - if workspace != self.table: - KGClass = EntityCentricKnowledgeGraph + self._connections = {} + self._conn_lock = asyncio.Lock() - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - self.table = workspace + async def _get_connection(self, workspace): + async with self._conn_lock: + if workspace not in self._connections: + if self.cassandra_username and self.cassandra_password: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + username=self.cassandra_username, + password=self.cassandra_password, + ) + else: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + ) + self._connections[workspace] = tg + return self._connections[workspace] async def query_triples(self, workspace, query): try: - # ensure_connection may construct a fresh - # EntityCentricKnowledgeGraph which does sync schema - # setup against Cassandra. Push it to a worker thread - # so the event loop doesn't block on first-use per workspace. - await asyncio.to_thread(self.ensure_connection, workspace) - - # Extract values from query s_val = get_term_value(query.s) p_val = get_term_value(query.p) o_val = get_term_value(query.o) - g_val = query.g # Already a string or None + g_val = query.g + + tg = await self._get_connection(workspace) def get_object_metadata(row): - """Extract term type metadata from result row""" return ( getattr(row, 'otype', None), getattr(row, 'dtype', None), @@ -223,33 +221,21 @@ class Processor(TriplesQueryService): quads = [] - # All self.tg.get_* calls below are sync wrappers around - # cassandra session.execute. Materialise inside a worker - # thread so iteration never triggers sync paging back on - # the event loop. - - # Route to appropriate query method based on which fields are specified if s_val is not None: if p_val is not None: if o_val is not None: - # SPO specified - find matching graphs - resp = await asyncio.to_thread( - lambda: list(self.tg.get_spo( - query.collection, s_val, p_val, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_spo( + query.collection, s_val, p_val, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((s_val, p_val, o_val, g, term_type, datatype, language)) else: - # SP specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_sp( - query.collection, s_val, p_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_sp( + query.collection, s_val, p_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -257,24 +243,18 @@ class Processor(TriplesQueryService): quads.append((s_val, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: - # SO specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_os( - query.collection, o_val, s_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_os( + query.collection, o_val, s_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((s_val, t.p, o_val, g, term_type, datatype, language)) else: - # S only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_s( - query.collection, s_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_s( + query.collection, s_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -283,24 +263,18 @@ class Processor(TriplesQueryService): else: if p_val is not None: if o_val is not None: - # PO specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_po( - query.collection, p_val, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_po( + query.collection, p_val, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((t.s, p_val, o_val, g, term_type, datatype, language)) else: - # P only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_p( - query.collection, p_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_p( + query.collection, p_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -308,40 +282,26 @@ class Processor(TriplesQueryService): quads.append((t.s, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: - # O only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_o( - query.collection, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_o( + query.collection, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((t.s, t.p, o_val, g, term_type, datatype, language)) else: - # Nothing specified - get all - resp = await asyncio.to_thread( - lambda: list(self.tg.get_all( - query.collection, limit=query.limit, - )) + resp = await tg.async_get_all( + query.collection, limit=query.limit, ) for t in resp: - # Note: quads_by_collection uses 'd' for graph field g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH - # Filter by graph - # g_val=None means all graphs (no filter) - # g_val="" means default graph only - # otherwise filter to specific named graph if g_val is not None: if g != g_val: continue term_type, datatype, language = get_object_metadata(t) quads.append((t.s, t.p, t.o, g, term_type, datatype, language)) - # Convert to Triple objects (with g field) - # s and p are always IRIs in RDF - # Object uses term_type/datatype/language metadata from database triples = [ Triple( s=create_term(q[0], term_type='u'), @@ -365,51 +325,36 @@ class Processor(TriplesQueryService): Uses Cassandra's paging to fetch results incrementally. """ try: - await asyncio.to_thread(self.ensure_connection, workspace) batch_size = query.batch_size if query.batch_size > 0 else 20 limit = query.limit if query.limit > 0 else 10000 - # Extract query pattern s_val = get_term_value(query.s) p_val = get_term_value(query.p) o_val = get_term_value(query.o) g_val = query.g def get_object_metadata(row): - """Extract term type metadata from result row""" return ( getattr(row, 'otype', None), getattr(row, 'dtype', None), getattr(row, 'lang', None), ) - # For streaming, we need to execute with fetch_size - # Use the collection table for get_all queries (most common streaming case) - - # Determine which query to use based on pattern if s_val is None and p_val is None and o_val is None: - # Get all - use collection table with paging - cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s" + + tg = await self._get_connection(workspace) + + cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s" params = [query.collection] + statement = SimpleStatement(cql, fetch_size=batch_size) + result_set = await async_execute(tg.session, statement, params) + else: - # For specific patterns, fall back to non-streaming - # (these typically return small result sets anyway) async for batch, is_final in self._fallback_stream(workspace, query, batch_size): yield batch, is_final return - # Materialise in a worker thread. We lose true streaming - # paging (the driver fetches all pages eagerly inside the - # thread) but the event loop stays responsive, and result - # sets at this layer are typically small enough that this - # is acceptable. If true async paging is needed later, - # revisit using ResponseFuture page callbacks. - statement = SimpleStatement(cql, fetch_size=batch_size) - result_set = await asyncio.to_thread( - lambda: list(self.tg.session.execute(statement, params)) - ) - batch = [] count = 0 diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index fb7166b5..2bfef99c 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -3,11 +3,13 @@ Accepts entity/vector pairs and writes them to a Qdrant store. """ +import asyncio +import uuid +import logging + from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -import logging from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer @@ -35,13 +37,35 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def ensure_collection(self, collection_name, dim): + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name + ) + if not exists: + logger.info( + f"Lazily creating Qdrant collection {collection_name} " + f"with dimension {dim}" + ) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) + async def store_document_embeddings(self, workspace, message): - # Validate collection exists in config before processing if not self.collection_exists(workspace, message.metadata.collection): logger.warning( f"Collection {message.metadata.collection} for workspace {workspace} " @@ -60,24 +84,15 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if not vec: continue - # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( f"d_{workspace}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) + await self.ensure_collection(collection, dim) - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=collection, points=[ PointStruct( @@ -87,7 +102,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): "chunk_id": chunk_id, } ) - ] + ], ) @staticmethod @@ -124,8 +139,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): try: prefix = f"d_{workspace}_{collection}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -135,7 +151,11 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): logger.info(f"No collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 391c2a04..13dcdba8 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -3,11 +3,13 @@ Accepts entity/vector pairs and writes them to a Qdrant store. """ +import asyncio +import uuid +import logging + from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -import logging from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer @@ -50,13 +52,35 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def ensure_collection(self, collection_name, dim): + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name + ) + if not exists: + logger.info( + f"Lazily creating Qdrant collection {collection_name} " + f"with dimension {dim}" + ) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) + async def store_graph_embeddings(self, workspace, message): - # Validate collection exists in config before processing if not self.collection_exists(workspace, message.metadata.collection): logger.warning( f"Collection {message.metadata.collection} for workspace {workspace} " @@ -75,22 +99,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if not vec: continue - # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( f"t_{workspace}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) + await self.ensure_collection(collection, dim) payload = { "entity": entity_value, @@ -98,7 +112,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity.chunk_id: payload["chunk_id"] = entity.chunk_id - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=collection, points=[ PointStruct( @@ -106,7 +121,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): vector=vec, payload=payload, ) - ] + ], ) @staticmethod @@ -143,8 +158,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): try: prefix = f"t_{workspace}_{collection}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -154,7 +170,11 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): logger.info(f"No collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index 32d87871..a01629c5 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -16,10 +16,10 @@ Payload structure: - text: The text that was embedded (for debugging/display) """ +import asyncio import logging import re import uuid -from typing import Set, Tuple from qdrant_client import QdrantClient from qdrant_client.models import PointStruct, Distance, VectorParams @@ -63,11 +63,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Register config handler for collection management self.register_config_handler(self.on_collection_config, types=["collection"]) - # Cache of created Qdrant collections - self.created_collections: Set[str] = set() - - # Qdrant client self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() def sanitize_name(self, name: str) -> str: """Sanitize names for Qdrant collection naming""" @@ -85,25 +83,28 @@ class Processor(CollectionConfigHandler, FlowProcessor): safe_schema = self.sanitize_name(schema_name) return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" - def ensure_collection(self, collection_name: str, dimension: int): + async def ensure_collection(self, collection_name: str, dimension: int): """Create Qdrant collection if it doesn't exist""" - if collection_name in self.created_collections: - return - - if not self.qdrant.collection_exists(collection_name): - logger.info( - f"Creating Qdrant collection {collection_name} " - f"with dimension {dimension}" + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name ) - self.qdrant.create_collection( - collection_name=collection_name, - vectors_config=VectorParams( - size=dimension, - distance=Distance.COSINE + if not exists: + logger.info( + f"Creating Qdrant collection {collection_name} " + f"with dimension {dimension}" ) - ) - - self.created_collections.add(collection_name) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dimension, + distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) async def on_embeddings(self, msg, consumer, flow): """Process incoming RowEmbeddings and write to Qdrant""" @@ -143,15 +144,14 @@ class Processor(CollectionConfigHandler, FlowProcessor): dimension = len(vector) - # Create/get collection name (lazily on first vector) if qdrant_collection is None: qdrant_collection = self.get_collection_name( workspace, collection, schema_name, dimension ) - self.ensure_collection(qdrant_collection, dimension) + await self.ensure_collection(qdrant_collection, dimension) - # Write to Qdrant - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=qdrant_collection, points=[ PointStruct( @@ -163,7 +163,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): "text": row_emb.text } ) - ] + ], ) embeddings_written += 1 @@ -181,8 +181,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): try: prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -192,8 +193,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"No Qdrant collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) - self.created_collections.discard(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info( f"Deleted {len(matching_collections)} collection(s) " @@ -217,8 +221,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" ) - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -228,8 +233,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"No Qdrant collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) - self.created_collections.discard(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") except Exception as e: diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index a5dad748..65eeee06 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -82,7 +82,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Cache of known keyspaces and whether tables exist self.known_keyspaces: Set[str] = set() - self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables + self.tables_initialized: Set[str] = set() # Cache of registered (collection, schema_name) pairs self.registered_partitions: Set[Tuple[str, str]] = set() @@ -94,6 +94,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): self.cluster = None self.session = None + # Protects connection setup and cache mutations + self._setup_lock = asyncio.Lock() + def connect_cassandra(self): """Connect to Cassandra cluster""" if self.session: @@ -126,6 +129,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"for workspace {workspace}" ) + async with self._setup_lock: + return await self._apply_schema_config(workspace, config, version) + + async def _apply_schema_config(self, workspace, config, version): + # Track which schemas changed in this workspace old_schemas = self.schemas.get(workspace, {}) old_schema_names = set(old_schemas.keys()) @@ -391,16 +399,12 @@ class Processor(CollectionConfigHandler, FlowProcessor): schema_name = obj.schema_name source = getattr(obj.metadata, 'source', '') or '' - # Ensure tables exist (sync DDL — push to a worker thread - # so the event loop stays responsive when running in a - # processor group sharing the loop with siblings). - await asyncio.to_thread(self.ensure_tables, keyspace) - - # Register partitions if first time seeing this (collection, schema_name) - await asyncio.to_thread( - self.register_partitions, - keyspace, collection, schema_name, workspace, - ) + async with self._setup_lock: + await asyncio.to_thread(self.ensure_tables, keyspace) + await asyncio.to_thread( + self.register_partitions, + keyspace, collection, schema_name, workspace, + ) safe_keyspace = self.sanitize_name(keyspace) @@ -461,35 +465,27 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create/verify collection exists in Cassandra row store""" - # Connect if not already connected (sync, push to thread) - await asyncio.to_thread(self.connect_cassandra) - - # Ensure tables exist (sync DDL, push to thread) - await asyncio.to_thread(self.ensure_tables, workspace) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) + await asyncio.to_thread(self.ensure_tables, workspace) logger.info(f"Collection {collection} ready for workspace {workspace}") async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection using partition tracking""" - # Connect if not already connected - await asyncio.to_thread(self.connect_cassandra) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) + if workspace not in self.known_keyspaces: + safe_ks = self.sanitize_name(workspace) + check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s" + result = await async_execute(self.session, check_cql, (safe_ks,)) + if not result: + logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete") + return + self.known_keyspaces.add(workspace) safe_keyspace = self.sanitize_name(workspace) - # Check if keyspace exists - if workspace not in self.known_keyspaces: - check_keyspace_cql = """ - SELECT keyspace_name FROM system_schema.keyspaces - WHERE keyspace_name = %s - """ - result = await async_execute( - self.session, check_keyspace_cql, (safe_keyspace,) - ) - if not result: - logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") - return - self.known_keyspaces.add(workspace) - # Discover all partitions for this collection select_partitions_cql = f""" SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions @@ -540,11 +536,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(f"Failed to clean up row_partitions for {collection}: {e}") raise - # Clear from local cache - self.registered_partitions = { - (col, sch) for col, sch in self.registered_partitions - if col != collection - } + async with self._setup_lock: + self.registered_partitions = { + (col, sch) for col, sch in self.registered_partitions + if col != collection + } logger.info( f"Deleted collection {collection}: {partitions_deleted} partitions " @@ -553,8 +549,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str): """Delete all data for a specific collection + schema combination""" - # Connect if not already connected - await asyncio.to_thread(self.connect_cassandra) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) safe_keyspace = self.sanitize_name(workspace) @@ -614,8 +610,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): ) raise - # Clear from local cache - self.registered_partitions.discard((collection, schema_name)) + async with self._setup_lock: + self.registered_partitions.discard((collection, schema_name)) logger.info( f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions " diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 0774153b..79d6c549 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -4,12 +4,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph. """ import asyncio -import base64 -import os -import argparse -import time import logging -import json from .... direct.cassandra_kg import ( EntityCentricKnowledgeGraph, DEFAULT_GRAPH @@ -28,6 +23,8 @@ default_ident = "triples-write" def serialize_triple(triple): """Serialize a Triple object to JSON for storage.""" + import json + if triple is None: return None @@ -141,156 +138,84 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password - self.table = None - self.tg = None + + self._connections = {} + self._conn_lock = asyncio.Lock() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def _get_connection(self, workspace): + async with self._conn_lock: + if workspace not in self._connections: + if self.cassandra_username and self.cassandra_password: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + username=self.cassandra_username, + password=self.cassandra_password, + ) + else: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + ) + self._connections[workspace] = tg + return self._connections[workspace] + async def store_triples(self, workspace, message): - # The cassandra-driver work below — connection, schema - # setup, and per-triple inserts — is all synchronous. - # Wrap the whole batch in a worker thread so the event - # loop stays responsive for sibling processors when - # running in a processor group. + tg = await self._get_connection(workspace) - def _do_store(): + for t in message.triples: + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) + g_val = t.g if t.g is not None else DEFAULT_GRAPH - if self.table is None or self.table != workspace: + otype = get_term_otype(t.o) + dtype = get_term_dtype(t.o) + lang = get_term_lang(t.o) - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Exception: {e}", exc_info=True) - time.sleep(1) - raise e - - self.table = workspace - - for t in message.triples: - # Extract values from Term objects - s_val = get_term_value(t.s) - p_val = get_term_value(t.p) - o_val = get_term_value(t.o) - # t.g is None for default graph, or a graph IRI - g_val = t.g if t.g is not None else DEFAULT_GRAPH - - # Extract object type metadata for entity-centric storage - otype = get_term_otype(t.o) - dtype = get_term_dtype(t.o) - lang = get_term_lang(t.o) - - self.tg.insert( - message.metadata.collection, - s_val, - p_val, - o_val, - g=g_val, - otype=otype, - dtype=dtype, - lang=lang, - ) - - await asyncio.to_thread(_do_store) + await tg.async_insert( + message.metadata.collection, + s_val, + p_val, + o_val, + g=g_val, + otype=otype, + dtype=dtype, + lang=lang, + ) async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create a collection in Cassandra triple store via config push""" + try: + tg = await self._get_connection(workspace) - def _do_create(): - # Create or reuse connection for this workspace's keyspace - if self.table is None or self.table != workspace: - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") - raise - - self.table = workspace - - # Create collection using the built-in method logger.info(f"Creating collection {collection} for workspace {workspace}") - if self.tg.collection_exists(collection): + exists = await tg.async_collection_exists(collection) + if exists: logger.info(f"Collection {collection} already exists") else: - self.tg.create_collection(collection) + await tg.async_create_collection(collection) logger.info(f"Created collection {collection}") - try: - await asyncio.to_thread(_do_create) except Exception as e: logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection from the unified triples table""" + try: + tg = await self._get_connection(workspace) - def _do_delete(): - # Create or reuse connection for this workspace's keyspace - if self.table is None or self.table != workspace: - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") - raise - - self.table = workspace - - # Delete all triples for this collection using the built-in method - self.tg.delete_collection(collection) + await tg.async_delete_collection(collection) logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}") - try: - await asyncio.to_thread(_do_delete) except Exception as e: logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise