Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916)

Cassandra triples services were using syncronous EntityCentricKnowledgeGraph
methods from async contexts, and connection state was managed with
threading.local which is wrong for asyncio coroutines sharing a single
thread. Qdrant services had no async wrapping at all, blocking the event
loop on every network call. Rows services had unprotected shared state
mutations across concurrent coroutines.

- Add async methods to EntityCentricKnowledgeGraph (async_insert,
  async_get_s/p/o/sp/po/os/spo/all, async_collection_exists,
  async_create_collection, async_delete_collection) using the existing
  cassandra_async.async_execute bridge
- Rewrite triples write + query services: replace threading.local with
  asyncio.Lock + dict cache for per-workspace connections, use async
  ECKG methods for all data operations, keep asyncio.to_thread only for
  one-time blocking ECKG construction
- Wrap all Qdrant calls in asyncio.to_thread across all 6 services
  (doc/graph/row embeddings write + query), add asyncio.Lock + set cache
  for collection existence checks
- Add asyncio.Lock to rows write + query services to protect shared
  state (schemas, sessions, config caches) from concurrent mutation
- Update all affected tests to match new async patterns
This commit is contained in:
cybermaggedon 2026-05-14 16:00:54 +01:00 committed by GitHub
parent bb1109963c
commit a2dde9cafb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 736 additions and 621 deletions

View file

@ -188,13 +188,14 @@ class TestConfigurationPriorityEndToEnd:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.direct.cassandra_kg.Cluster') @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_no_config_defaults_end_to_end(self, mock_cluster): async def test_no_config_defaults_end_to_end(self, mock_kg_class):
"""Test that defaults are used when no configuration provided end-to-end.""" """Test that defaults are used when no configuration provided end-to-end."""
mock_cluster_instance = MagicMock() from unittest.mock import AsyncMock
mock_session = MagicMock()
mock_cluster_instance.connect.return_value = mock_session mock_tg_instance = MagicMock()
mock_cluster.return_value = mock_cluster_instance mock_tg_instance.async_get_all = AsyncMock(return_value=[])
mock_kg_class.return_value = mock_tg_instance
with patch.dict(os.environ, {}, clear=True): with patch.dict(os.environ, {}, clear=True):
processor = TriplesQuery(taskgroup=MagicMock()) processor = TriplesQuery(taskgroup=MagicMock())
@ -205,20 +206,16 @@ class TestConfigurationPriorityEndToEnd:
mock_query.s = None mock_query.s = None
mock_query.p = None mock_query.p = None
mock_query.o = None mock_query.o = None
mock_query.g = None
mock_query.limit = 100 mock_query.limit = 100
# Mock the get_all method to return empty list
mock_tg_instance = MagicMock()
mock_tg_instance.get_all.return_value = []
processor.tg = mock_tg_instance
await processor.query_triples('default_user', mock_query) await processor.query_triples('default_user', mock_query)
# Should use defaults # Should use defaults
mock_cluster.assert_called_once() mock_kg_class.assert_called_once_with(
call_args = mock_cluster.call_args hosts=['cassandra'],
assert call_args.args[0] == ['cassandra'] # Default host keyspace='default_user'
assert 'auth_provider' not in call_args.kwargs # No auth with default config )
class TestNoBackwardCompatibilityEndToEnd: class TestNoBackwardCompatibilityEndToEnd:

View file

@ -101,6 +101,8 @@ class TestRowsCassandraIntegration:
processor.session = None processor.session = None
# Bind actual methods from the new unified table implementation # Bind actual methods from the new unified table implementation
import asyncio
processor._setup_lock = asyncio.Lock()
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor) processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor) processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
@ -108,6 +110,7 @@ class TestRowsCassandraIntegration:
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
processor.on_object = Processor.on_object.__get__(processor, Processor) processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.collection_exists = MagicMock(return_value=True) processor.collection_exists = MagicMock(return_value=True)

View file

@ -184,7 +184,7 @@ class TestObjectsGraphQLQueryIntegration:
await processor.on_schema_config("default", sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
# Connect to Cassandra # Connect to Cassandra
processor.connect_cassandra() await processor.connect_cassandra()
assert processor.session is not None assert processor.session is not None
# Create test keyspace and table # Create test keyspace and table
@ -219,7 +219,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test inserting data and querying via GraphQL""" """Test inserting data and querying via GraphQL"""
# Load schema and connect # Load schema and connect
await processor.on_schema_config("default", sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() await processor.connect_cassandra()
# Setup test data # Setup test data
keyspace = "test_user" keyspace = "test_user"
@ -293,7 +293,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test GraphQL queries with filtering on indexed fields""" """Test GraphQL queries with filtering on indexed fields"""
# Setup (reuse previous setup) # Setup (reuse previous setup)
await processor.on_schema_config("default", sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() await processor.connect_cassandra()
keyspace = "test_user" keyspace = "test_user"
collection = "filter_test" collection = "filter_test"
@ -387,7 +387,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test full message processing workflow""" """Test full message processing workflow"""
# Setup # Setup
await processor.on_schema_config("default", sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() await processor.connect_cassandra()
# Create mock message # Create mock message
request = RowsQueryRequest( request = RowsQueryRequest(
@ -433,7 +433,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test handling multiple concurrent GraphQL queries""" """Test handling multiple concurrent GraphQL queries"""
# Setup # Setup
await processor.on_schema_config("default", sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() await processor.connect_cassandra()
# Create multiple query tasks # Create multiple query tasks
queries = [ queries = [
@ -519,7 +519,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test handling of large query result sets""" """Test handling of large query result sets"""
# Setup # Setup
await processor.on_schema_config("default", sample_schema_config, version=1) await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra() await processor.connect_cassandra()
keyspace = "large_test_user" keyspace = "large_test_user"
collection = "large_collection" collection = "large_collection"

View file

@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_schema_config_parsing(self): async def test_schema_config_parsing(self):
"""Test parsing of schema configuration""" """Test parsing of schema configuration"""
import asyncio
processor = MagicMock() processor = MagicMock()
processor.schemas = {} processor.schemas = {}
processor.schema_builders = {} processor.schema_builders = {}
processor.graphql_schemas = {} processor.graphql_schemas = {}
processor.config_key = "schema" processor.config_key = "schema"
processor.query_cassandra = MagicMock() processor.query_cassandra = MagicMock()
processor._setup_lock = asyncio.Lock()
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config # Create test config
@ -335,7 +338,7 @@ class TestUnifiedTableQueries:
"""Test query execution with matching index""" """Test query execution with matching index"""
processor = MagicMock() processor = MagicMock()
processor.session = MagicMock() processor.session = MagicMock()
processor.connect_cassandra = MagicMock() processor.connect_cassandra = AsyncMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
@ -396,7 +399,7 @@ class TestUnifiedTableQueries:
"""Test query execution without matching index (scan mode)""" """Test query execution without matching index (scan mode)"""
processor = MagicMock() processor = MagicMock()
processor.session = MagicMock() processor.session = MagicMock()
processor.connect_cassandra = MagicMock() processor.connect_cassandra = AsyncMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)

View file

@ -2,8 +2,10 @@
Tests for Cassandra triples query service Tests for Cassandra triples query service
""" """
import asyncio
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.query.triples.cassandra.service import Processor, create_term from trustgraph.query.triples.cassandra.service import Processor, create_term
from trustgraph.schema import Term, IRI, LITERAL from trustgraph.schema import Term, IRI, LITERAL
@ -18,7 +20,7 @@ class TestCassandraQueryProcessor:
return Processor( return Processor(
taskgroup=MagicMock(), taskgroup=MagicMock(),
id='test-cassandra-query', id='test-cassandra-query',
graph_host='localhost' cassandra_host='localhost'
) )
def test_create_term_with_http_uri(self, processor): def test_create_term_with_http_uri(self, processor):
@ -85,7 +87,7 @@ class TestCassandraQueryProcessor:
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_result.o = 'test_object' mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result] mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
processor = Processor( processor = Processor(
taskgroup=MagicMock(), taskgroup=MagicMock(),
@ -110,8 +112,8 @@ class TestCassandraQueryProcessor:
keyspace='test_user' keyspace='test_user'
) )
# Verify get_spo was called with correct parameters # Verify async_get_spo was called with correct parameters
mock_tg_instance.get_spo.assert_called_once_with( mock_tg_instance.async_get_spo.assert_called_once_with(
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100 'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
) )
@ -130,7 +132,8 @@ class TestCassandraQueryProcessor:
assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_host == ['cassandra'] # Updated default
assert processor.cassandra_username is None assert processor.cassandra_username is None
assert processor.cassandra_password is None assert processor.cassandra_password is None
assert processor.table is None assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
def test_processor_initialization_with_custom_params(self): def test_processor_initialization_with_custom_params(self):
"""Test processor initialization with custom parameters""" """Test processor initialization with custom parameters"""
@ -146,7 +149,8 @@ class TestCassandraQueryProcessor:
assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_host == ['cassandra.example.com']
assert processor.cassandra_username == 'queryuser' assert processor.cassandra_username == 'queryuser'
assert processor.cassandra_password == 'querypass' assert processor.cassandra_password == 'querypass'
assert processor.table is None assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
@ -164,7 +168,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None mock_result.otype = None
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_tg_instance.get_sp.return_value = [mock_result] mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -178,7 +182,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query) result = await processor.query_triples('test_user', query)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1 assert len(result) == 1
assert result[0].s.iri == 'test_subject' assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'test_predicate' assert result[0].p.iri == 'test_predicate'
@ -200,7 +204,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None mock_result.otype = None
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_tg_instance.get_s.return_value = [mock_result] mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -214,7 +218,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query) result = await processor.query_triples('test_user', query)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1 assert len(result) == 1
assert result[0].s.iri == 'test_subject' assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'result_predicate' assert result[0].p.iri == 'result_predicate'
@ -236,7 +240,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None mock_result.otype = None
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_tg_instance.get_p.return_value = [mock_result] mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -250,7 +254,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query) result = await processor.query_triples('test_user', query)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1 assert len(result) == 1
assert result[0].s.iri == 'result_subject' assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'test_predicate' assert result[0].p.iri == 'test_predicate'
@ -272,7 +276,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None mock_result.otype = None
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_tg_instance.get_o.return_value = [mock_result] mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -286,7 +290,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query) result = await processor.query_triples('test_user', query)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1 assert len(result) == 1
assert result[0].s.iri == 'result_subject' assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'result_predicate' assert result[0].p.iri == 'result_predicate'
@ -305,11 +309,11 @@ class TestCassandraQueryProcessor:
mock_result.s = 'all_subject' mock_result.s = 'all_subject'
mock_result.p = 'all_predicate' mock_result.p = 'all_predicate'
mock_result.o = 'all_object' mock_result.o = 'all_object'
mock_result.g = '' mock_result.d = ''
mock_result.otype = None mock_result.otype = None
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_tg_instance.get_all.return_value = [mock_result] mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -323,7 +327,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query) result = await processor.query_triples('test_user', query)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1 assert len(result) == 1
assert result[0].s.iri == 'all_subject' assert result[0].s.iri == 'all_subject'
assert result[0].p.iri == 'all_predicate' assert result[0].p.iri == 'all_predicate'
@ -410,7 +414,7 @@ class TestCassandraQueryProcessor:
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_result.o = 'test_object' mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result] mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
processor = Processor( processor = Processor(
taskgroup=MagicMock(), taskgroup=MagicMock(),
@ -451,7 +455,7 @@ class TestCassandraQueryProcessor:
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_result.o = 'test_object' mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result] mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -489,8 +493,8 @@ class TestCassandraQueryProcessor:
mock_result.lang = None mock_result.lang = None
mock_result.p = 'p' mock_result.p = 'p'
mock_result.o = 'o' mock_result.o = 'o'
mock_tg_instance1.get_s.return_value = [mock_result] mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result])
mock_tg_instance2.get_s.return_value = [mock_result] mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -504,7 +508,6 @@ class TestCassandraQueryProcessor:
) )
await processor.query_triples('user1', query1) await processor.query_triples('user1', query1)
assert processor.table == 'user1'
# Second query with different table # Second query with different table
query2 = TriplesQueryRequest( query2 = TriplesQueryRequest(
@ -516,10 +519,11 @@ class TestCassandraQueryProcessor:
) )
await processor.query_triples('user2', query2) await processor.query_triples('user2', query2)
assert processor.table == 'user2'
# Verify TrustGraph was created twice # Verify TrustGraph was created twice for different workspaces
assert mock_kg_class.call_count == 2 assert mock_kg_class.call_count == 2
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
@ -529,7 +533,7 @@ class TestCassandraQueryProcessor:
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
mock_tg_instance.get_spo.side_effect = Exception("Query failed") mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed"))
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -566,7 +570,7 @@ class TestCassandraQueryProcessor:
mock_result2.otype = None mock_result2.otype = None
mock_result2.dtype = None mock_result2.dtype = None
mock_result2.lang = None mock_result2.lang = None
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2] mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations:
mock_result.otype = None mock_result.otype = None
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_tg_instance.get_po.return_value = [mock_result] mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations:
result = await processor.query_triples('test_user', query) result = await processor.query_triples('test_user', query)
# Verify get_po was called (should use optimized po_table) # Verify async_get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with( mock_tg_instance.async_get_po.assert_called_once_with(
'test_collection', 'test_predicate', 'test_object', g=None, limit=50 'test_collection', 'test_predicate', 'test_object', g=None, limit=50
) )
@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations:
mock_result.otype = None mock_result.otype = None
mock_result.dtype = None mock_result.dtype = None
mock_result.lang = None mock_result.lang = None
mock_tg_instance.get_os.return_value = [mock_result] mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations:
result = await processor.query_triples('test_user', query) result = await processor.query_triples('test_user', query)
# Verify get_os was called (should use optimized subject_table with clustering) # Verify async_get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with( mock_tg_instance.async_get_os.assert_called_once_with(
'test_collection', 'test_object', 'test_subject', g=None, limit=25 'test_collection', 'test_object', 'test_subject', g=None, limit=25
) )
@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations:
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
# Mock empty results for all queries # Mock empty results for all queries
mock_tg_instance.get_all.return_value = [] mock_tg_instance.async_get_all = AsyncMock(return_value=[])
mock_tg_instance.get_s.return_value = [] mock_tg_instance.async_get_s = AsyncMock(return_value=[])
mock_tg_instance.get_p.return_value = [] mock_tg_instance.async_get_p = AsyncMock(return_value=[])
mock_tg_instance.get_o.return_value = [] mock_tg_instance.async_get_o = AsyncMock(return_value=[])
mock_tg_instance.get_sp.return_value = [] mock_tg_instance.async_get_sp = AsyncMock(return_value=[])
mock_tg_instance.get_po.return_value = [] mock_tg_instance.async_get_po = AsyncMock(return_value=[])
mock_tg_instance.get_os.return_value = [] mock_tg_instance.async_get_os = AsyncMock(return_value=[])
mock_tg_instance.get_spo.return_value = [] mock_tg_instance.async_get_spo = AsyncMock(return_value=[])
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
# Test each query pattern # Test each query pattern
test_patterns = [ test_patterns = [
# (s, p, o, expected_method) # (s, p, o, expected_method)
(None, None, None, 'get_all'), # All triples (None, None, None, 'async_get_all'), # All triples
('s1', None, None, 'get_s'), # Subject only ('s1', None, None, 'async_get_s'), # Subject only
(None, 'p1', None, 'get_p'), # Predicate only (None, 'p1', None, 'async_get_p'), # Predicate only
(None, None, 'o1', 'get_o'), # Object only (None, None, 'o1', 'async_get_o'), # Object only
('s1', 'p1', None, 'get_sp'), # Subject + Predicate ('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) (None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
('s1', None, 'o1', 'get_os'), # Object + Subject ('s1', None, 'o1', 'async_get_os'), # Object + Subject
('s1', 'p1', 'o1', 'get_spo'), # All three ('s1', 'p1', 'o1', 'async_get_spo'), # All three
] ]
for s, p, o, expected_method in test_patterns: for s, p, o, expected_method in test_patterns:
@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations:
mock_result.lang = None mock_result.lang = None
mock_results.append(mock_result) mock_results.append(mock_result)
mock_tg_instance.get_po.return_value = mock_results mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results)
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations:
result = await processor.query_triples('large_dataset_user', query) result = await processor.query_triples('large_dataset_user', query)
# Verify optimized get_po was used (no ALLOW FILTERING needed!) # Verify optimized async_get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.get_po.assert_called_once_with( mock_tg_instance.async_get_po.assert_called_once_with(
'massive_collection', 'massive_collection',
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type', 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
'http://example.com/Person', 'http://example.com/Person',

View file

@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_embedding_upserted(self): async def test_valid_embedding_upserted(self):
import asyncio
from trustgraph.storage.doc_embeddings.qdrant.write import Processor from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor) proc = Processor.__new__(Processor)
proc.qdrant = MagicMock() proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True) proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock() msg = MagicMock()
msg.metadata.collection = "col1" msg.metadata.collection = "col1"
@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dimension_in_collection_name(self): async def test_dimension_in_collection_name(self):
"""Collection name should include vector dimension.""" """Collection name should include vector dimension."""
import asyncio
from trustgraph.storage.doc_embeddings.qdrant.write import Processor from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor) proc = Processor.__new__(Processor)
proc.qdrant = MagicMock() proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True) proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock() msg = MagicMock()
msg.metadata.collection = "docs" msg.metadata.collection = "docs"
@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_entity_and_vector_upserted(self): async def test_valid_entity_and_vector_upserted(self):
import asyncio
from trustgraph.storage.graph_embeddings.qdrant.write import Processor from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor) proc = Processor.__new__(Processor)
proc.qdrant = MagicMock() proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True) proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock() msg = MagicMock()
msg.metadata.collection = "col1" msg.metadata.collection = "col1"
@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lazy_collection_creation_on_new_dimension(self): async def test_lazy_collection_creation_on_new_dimension(self):
import asyncio
from trustgraph.storage.graph_embeddings.qdrant.write import Processor from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor) proc = Processor.__new__(Processor)
proc.qdrant = MagicMock() proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = False proc.qdrant.collection_exists.return_value = False
proc.collection_exists = MagicMock(return_value=True) proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock() msg = MagicMock()
msg.metadata.collection = "graphs" msg.metadata.collection = "graphs"

View file

@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Assert # Assert
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
# Verify collection existence is checked on each write # Second write uses cached collection state — no collection_exists check
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) mock_qdrant_instance.collection_exists.assert_not_called()
# But upsert should still be called # But upsert should still be called
mock_qdrant_instance.upsert.assert_called_once() mock_qdrant_instance.upsert.assert_called_once()

View file

@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
processor.ensure_collection("test_collection", 384) await processor.ensure_collection("test_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection") mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_instance.create_collection.assert_called_once() mock_qdrant_instance.create_collection.assert_called_once()
# Verify the collection is cached # Verify the collection is cached
assert "test_collection" in processor.created_collections assert "test_collection" in processor._known_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_skips_existing(self, mock_qdrant_client): async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
processor.ensure_collection("existing_collection", 384) await processor.ensure_collection("existing_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once() mock_qdrant_instance.collection_exists.assert_called_once()
mock_qdrant_instance.create_collection.assert_not_called() mock_qdrant_instance.create_collection.assert_not_called()
@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
processor.created_collections.add("cached_collection") processor._known_collections.add("cached_collection")
processor.ensure_collection("cached_collection", 384) await processor.ensure_collection("cached_collection", 384)
# Should not check or create - just return # Should not check or create - just return
mock_qdrant_instance.collection_exists.assert_not_called() mock_qdrant_instance.collection_exists.assert_not_called()
@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384') processor._known_collections.add('rows_test_workspace_test_collection_schema1_384')
await processor.delete_collection('test_workspace', 'test_collection') await processor.delete_collection('test_workspace', 'test_collection')
@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
assert mock_qdrant_instance.delete_collection.call_count == 2 assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed # Verify the cached collection was removed
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client): async def test_delete_collection_schema(self, mock_qdrant_client):

View file

@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_schema_config_parsing(self): async def test_schema_config_parsing(self):
"""Test parsing of schema configurations""" """Test parsing of schema configurations"""
import asyncio
processor = MagicMock() processor = MagicMock()
processor.schemas = {} processor.schemas = {}
processor.config_key = "schema" processor.config_key = "schema"
processor.registered_partitions = set() processor.registered_partitions = set()
processor._setup_lock = asyncio.Lock()
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration # Create test configuration

View file

@ -2,6 +2,8 @@
Tests for Cassandra triples storage service Tests for Cassandra triples storage service
""" """
import asyncio
import pytest import pytest
from unittest.mock import MagicMock, patch, AsyncMock from unittest.mock import MagicMock, patch, AsyncMock
@ -24,7 +26,8 @@ class TestCassandraStorageProcessor:
assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_host == ['cassandra'] # Updated default
assert processor.cassandra_username is None assert processor.cassandra_username is None
assert processor.cassandra_password is None assert processor.cassandra_password is None
assert processor.table is None assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
def test_processor_initialization_with_custom_params(self): def test_processor_initialization_with_custom_params(self):
"""Test processor initialization with custom parameters (new cassandra_* names)""" """Test processor initialization with custom parameters (new cassandra_* names)"""
@ -41,7 +44,8 @@ class TestCassandraStorageProcessor:
assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_host == ['cassandra.example.com']
assert processor.cassandra_username == 'testuser' assert processor.cassandra_username == 'testuser'
assert processor.cassandra_password == 'testpass' assert processor.cassandra_password == 'testpass'
assert processor.table is None assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
def test_processor_initialization_with_partial_auth(self): def test_processor_initialization_with_partial_auth(self):
"""Test processor initialization with only username (no password)""" """Test processor initialization with only username (no password)"""
@ -92,6 +96,7 @@ class TestCassandraStorageProcessor:
"""Test table switching logic when authentication is provided""" """Test table switching logic when authentication is provided"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
processor = Processor( processor = Processor(
@ -114,7 +119,6 @@ class TestCassandraStorageProcessor:
username='testuser', username='testuser',
password='testpass' password='testpass'
) )
assert processor.table == 'user1'
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@ -122,6 +126,7 @@ class TestCassandraStorageProcessor:
"""Test table switching logic when no authentication is provided""" """Test table switching logic when no authentication is provided"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
@ -138,7 +143,6 @@ class TestCassandraStorageProcessor:
hosts=['cassandra'], # Updated default hosts=['cassandra'], # Updated default
keyspace='user2' keyspace='user2'
) )
assert processor.table == 'user2'
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@ -146,6 +150,7 @@ class TestCassandraStorageProcessor:
"""Test that TrustGraph is not recreated when table hasn't changed""" """Test that TrustGraph is not recreated when table hasn't changed"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
@ -169,6 +174,7 @@ class TestCassandraStorageProcessor:
"""Test that triples are properly inserted into Cassandra""" """Test that triples are properly inserted into Cassandra"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
@ -208,12 +214,12 @@ class TestCassandraStorageProcessor:
await processor.store_triples('user1', mock_message) await processor.store_triples('user1', mock_message)
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters) # Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
assert mock_tg_instance.insert.call_count == 2 assert mock_tg_instance.async_insert.call_count == 2
mock_tg_instance.insert.assert_any_call( mock_tg_instance.async_insert.assert_any_call(
'collection1', 'subject1', 'predicate1', 'object1', 'collection1', 'subject1', 'predicate1', 'object1',
g=DEFAULT_GRAPH, otype='l', dtype='', lang='' g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
) )
mock_tg_instance.insert.assert_any_call( mock_tg_instance.async_insert.assert_any_call(
'collection1', 'subject2', 'predicate2', 'object2', 'collection1', 'subject2', 'predicate2', 'object2',
g=DEFAULT_GRAPH, otype='l', dtype='', lang='' g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
) )
@ -224,6 +230,7 @@ class TestCassandraStorageProcessor:
"""Test behavior when message has no triples""" """Test behavior when message has no triples"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
@ -236,19 +243,17 @@ class TestCassandraStorageProcessor:
await processor.store_triples('user1', mock_message) await processor.store_triples('user1', mock_message)
# Verify no triples were inserted # Verify no triples were inserted
mock_tg_instance.insert.assert_not_called() mock_tg_instance.async_insert.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@patch('trustgraph.storage.triples.cassandra.write.time.sleep') async def test_exception_handling_on_connection_failure(self, mock_kg_class):
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
"""Test exception handling during TrustGraph creation""" """Test exception handling during TrustGraph creation"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_kg_class.side_effect = Exception("Connection failed") mock_kg_class.side_effect = Exception("Connection failed")
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.collection = 'collection1' mock_message.metadata.collection = 'collection1'
mock_message.triples = [] mock_message.triples = []
@ -256,9 +261,6 @@ class TestCassandraStorageProcessor:
with pytest.raises(Exception, match="Connection failed"): with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples('user1', mock_message) await processor.store_triples('user1', mock_message)
# Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1)
def test_add_args_method(self): def test_add_args_method(self):
"""Test that add_args properly configures argument parser""" """Test that add_args properly configures argument parser"""
from argparse import ArgumentParser from argparse import ArgumentParser
@ -359,8 +361,6 @@ class TestCassandraStorageProcessor:
mock_message1.triples = [] mock_message1.triples = []
await processor.store_triples('user1', mock_message1) await processor.store_triples('user1', mock_message1)
assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1
# Second message with different table # Second message with different table
mock_message2 = MagicMock() mock_message2 = MagicMock()
@ -368,11 +368,11 @@ class TestCassandraStorageProcessor:
mock_message2.triples = [] mock_message2.triples = []
await processor.store_triples('user2', mock_message2) await processor.store_triples('user2', mock_message2)
assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2
# Verify TrustGraph was created twice for different tables # Verify TrustGraph was created twice for different workspaces
assert mock_kg_class.call_count == 2 assert mock_kg_class.call_count == 2
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@ -380,6 +380,7 @@ class TestCassandraStorageProcessor:
"""Test storing triples with special characters and unicode""" """Test storing triples with special characters and unicode"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
@ -405,7 +406,7 @@ class TestCassandraStorageProcessor:
await processor.store_triples('test_workspace', mock_message) await processor.store_triples('test_workspace', mock_message)
# Verify the triple was inserted with special characters preserved # Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with( mock_tg_instance.async_insert.assert_called_once_with(
'test_collection', 'test_collection',
'subject with spaces & symbols', 'subject with spaces & symbols',
'predicate:with/colons', 'predicate:with/colons',
@ -418,29 +419,29 @@ class TestCassandraStorageProcessor:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class): async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class):
"""Test that table remains unchanged when TrustGraph creation fails""" """Test that a failed connection doesn't leave stale cached state"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_good_instance = MagicMock()
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
# Set an initial table
processor.table = ('old_user', 'old_collection')
# Mock TrustGraph to raise exception
mock_kg_class.side_effect = Exception("Connection failed")
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.collection = 'new_collection' mock_message.metadata.collection = 'collection1'
mock_message.triples = [] mock_message.triples = []
# First call fails
mock_kg_class.side_effect = Exception("Connection failed")
with pytest.raises(Exception, match="Connection failed"): with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples('new_user', mock_message) await processor.store_triples('user1', mock_message)
# Table should remain unchanged since self.table = table happens after try/except # Second call succeeds — should retry connection, not use stale state
assert processor.table == ('old_user', 'old_collection') mock_kg_class.side_effect = None
# TrustGraph should be set to None though mock_kg_class.return_value = mock_good_instance
assert processor.tg is None await processor.store_triples('user1', mock_message)
# Connection was attempted twice (failed + succeeded)
assert mock_kg_class.call_count == 2
class TestCassandraPerformanceOptimizations: class TestCassandraPerformanceOptimizations:
@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations:
"""Test that legacy mode still works with single table""" """Test that legacy mode still works with single table"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}): with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations:
"""Test that optimized mode uses multi-table schema""" """Test that optimized mode uses multi-table schema"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}): with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations:
"""Test that all tables stay consistent during batch writes""" """Test that all tables stay consistent during batch writes"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock() mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations:
await processor.store_triples('user1', mock_message) await processor.store_triples('user1', mock_message)
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph) # Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with( mock_tg_instance.async_insert.assert_called_once_with(
'collection1', 'test_subject', 'test_predicate', 'test_object', 'collection1', 'test_subject', 'test_predicate', 'test_object',
g=DEFAULT_GRAPH, otype='l', dtype='', lang='' g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
) )

View file

@ -89,7 +89,8 @@ class TestSanitizeName:
class TestFindCollection: class TestFindCollection:
def test_finds_matching_collection(self): @pytest.mark.asyncio
async def test_finds_matching_collection(self):
proc = _make_processor() proc = _make_processor()
mock_coll = MagicMock() mock_coll = MagicMock()
mock_coll.name = "rows_test_workspace_test_col_customers_384" mock_coll.name = "rows_test_workspace_test_col_customers_384"
@ -98,11 +99,12 @@ class TestFindCollection:
mock_collections.collections = [mock_coll] mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-workspace", "test-col", "customers") result = await proc.find_collection("test-workspace", "test-col", "customers")
assert result == "rows_test_workspace_test_col_customers_384" assert result == "rows_test_workspace_test_col_customers_384"
def test_returns_none_when_no_match(self): @pytest.mark.asyncio
async def test_returns_none_when_no_match(self):
proc = _make_processor() proc = _make_processor()
mock_coll = MagicMock() mock_coll = MagicMock()
mock_coll.name = "rows_other_workspace_other_col_schema_768" mock_coll.name = "rows_other_workspace_other_col_schema_768"
@ -111,14 +113,15 @@ class TestFindCollection:
mock_collections.collections = [mock_coll] mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-workspace", "test-col", "customers") result = await proc.find_collection("test-workspace", "test-col", "customers")
assert result is None assert result is None
def test_returns_none_on_error(self): @pytest.mark.asyncio
async def test_returns_none_on_error(self):
proc = _make_processor() proc = _make_processor()
proc.qdrant.get_collections.side_effect = Exception("connection error") proc.qdrant.get_collections.side_effect = Exception("connection error")
result = proc.find_collection("workspace", "col", "schema") result = await proc.find_collection("workspace", "col", "schema")
assert result is None assert result is None
@ -139,7 +142,7 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_collection_returns_empty(self): async def test_no_collection_returns_empty(self):
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value=None) proc.find_collection = AsyncMock(return_value=None)
request = _make_request() request = _make_request()
result = await proc.query_row_embeddings("test-workspace", request) result = await proc.query_row_embeddings("test-workspace", request)
@ -148,7 +151,7 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_successful_query_returns_matches(self): async def test_successful_query_returns_matches(self):
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384") proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
points = [ points = [
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
@ -172,7 +175,7 @@ class TestQueryRowEmbeddings:
async def test_index_name_filter_applied(self): async def test_index_name_filter_applied(self):
"""When index_name is specified, a Qdrant filter should be used.""" """When index_name is specified, a Qdrant filter should be used."""
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384") proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
mock_result = MagicMock() mock_result = MagicMock()
mock_result.points = [] mock_result.points = []
@ -188,7 +191,7 @@ class TestQueryRowEmbeddings:
async def test_no_index_name_no_filter(self): async def test_no_index_name_no_filter(self):
"""When index_name is empty, no filter should be applied.""" """When index_name is empty, no filter should be applied."""
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384") proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
mock_result = MagicMock() mock_result = MagicMock()
mock_result.points = [] mock_result.points = []
@ -204,7 +207,7 @@ class TestQueryRowEmbeddings:
async def test_missing_payload_fields_default(self): async def test_missing_payload_fields_default(self):
"""Points with missing payload fields should use defaults.""" """Points with missing payload fields should use defaults."""
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384") proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
point = MagicMock() point = MagicMock()
point.payload = {} # Empty payload point.payload = {} # Empty payload
@ -225,7 +228,7 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_qdrant_error_propagates(self): async def test_qdrant_error_propagates(self):
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384") proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
proc.qdrant.query_points.side_effect = Exception("qdrant down") proc.qdrant.query_points.side_effect = Exception("qdrant down")
request = _make_request() request = _make_request()

View file

@ -1,10 +1,14 @@
import datetime
import os
import logging
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement from cassandra.query import BatchStatement, SimpleStatement
from ssl import SSLContext, PROTOCOL_TLSv1_2 from ssl import SSLContext, PROTOCOL_TLSv1_2
import os
import logging from ..tables.cassandra_async import async_execute
# Global list to track clusters for cleanup # Global list to track clusters for cleanup
_active_clusters = [] _active_clusters = []
@ -461,7 +465,6 @@ class KnowledgeGraph:
def create_collection(self, collection): def create_collection(self, collection):
"""Create collection by inserting metadata row""" """Create collection by inserting metadata row"""
try: try:
import datetime
self.session.execute( self.session.execute(
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now()) (collection, datetime.datetime.now())
@ -954,7 +957,6 @@ class EntityCentricKnowledgeGraph:
def create_collection(self, collection): def create_collection(self, collection):
"""Create collection by inserting metadata row""" """Create collection by inserting metadata row"""
try: try:
import datetime
self.session.execute( self.session.execute(
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now()) (collection, datetime.datetime.now())
@ -1045,6 +1047,222 @@ class EntityCentricKnowledgeGraph:
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
# ========================================================================
# Async methods — use cassandra driver's native async API via async_execute
# ========================================================================
async def async_insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""):
if g is None:
g = DEFAULT_GRAPH
if otype is None:
if o.startswith("http://") or o.startswith("https://"):
otype = "u"
else:
otype = "l"
batch = BatchStatement()
batch.add(self.insert_entity_stmt, (collection, s, 'S', p, otype, s, o, g, dtype, lang))
batch.add(self.insert_entity_stmt, (collection, p, 'P', p, otype, s, o, g, dtype, lang))
if otype == 'u' or otype == 't':
batch.add(self.insert_entity_stmt, (collection, o, 'O', p, otype, s, o, g, dtype, lang))
if g != DEFAULT_GRAPH:
batch.add(self.insert_entity_stmt, (collection, g, 'G', p, otype, s, o, g, dtype, lang))
batch.add(self.insert_collection_stmt, (collection, g, s, p, o, otype, dtype, lang))
await async_execute(self.session, batch)
async def async_get_all(self, collection, limit=50):
return await async_execute(
self.session, self.get_collection_all_stmt, (collection, limit)
)
async def async_get_s(self, collection, s, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=row.p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_p(self, collection, p, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_p_stmt, (collection, p, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=row.p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_o(self, collection, o, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_o_stmt, (collection, o, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=row.p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_sp(self, collection, s, p, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=s, p=p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_po(self, collection, p, o, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_o_p_stmt, (collection, o, p, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=p, o=o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_os(self, collection, o, s, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
)
results = []
for row in rows:
if row.o != o:
continue
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=s, p=row.p, o=o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_spo(self, collection, s, p, o, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
)
results = []
for row in rows:
if row.o != o:
continue
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=s, p=p, o=o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_g(self, collection, g, limit=50):
if g is None:
g = DEFAULT_GRAPH
return await async_execute(
self.session, self.get_collection_by_graph_stmt, (collection, g, limit)
)
async def async_collection_exists(self, collection):
try:
result = await async_execute(
self.session,
f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1",
(collection,)
)
return bool(result)
except Exception as e:
logger.error(f"Error checking collection existence: {e}")
return False
async def async_create_collection(self, collection):
await async_execute(
self.session,
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now())
)
logger.info(f"Created collection metadata for {collection}")
async def async_delete_collection(self, collection):
rows = await async_execute(
self.session,
f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s",
(collection,)
)
entities = set()
quads = []
for row in rows:
d, s, p, o = row.d, row.s, row.p, row.o
otype = row.otype
dtype = row.dtype if hasattr(row, 'dtype') else ''
lang = row.lang if hasattr(row, 'lang') else ''
quads.append((d, s, p, o, otype, dtype, lang))
entities.add(s)
entities.add(p)
if otype == 'u' or otype == 't':
entities.add(o)
if d != DEFAULT_GRAPH:
entities.add(d)
batch = BatchStatement()
count = 0
for entity in entities:
batch.add(self.delete_entity_partition_stmt, (collection, entity))
count += 1
if count % 50 == 0:
await async_execute(self.session, batch)
batch = BatchStatement()
if count % 50 != 0:
await async_execute(self.session, batch)
batch = BatchStatement()
count = 0
for d, s, p, o, otype, dtype, lang in quads:
batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang))
count += 1
if count % 50 == 0:
await async_execute(self.session, batch)
batch = BatchStatement()
if count % 50 != 0:
await async_execute(self.session, batch)
await async_execute(
self.session,
f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s",
(collection,)
)
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
def close(self): def close(self):
"""Close connections""" """Close connections"""
if hasattr(self, 'session') and self.session: if hasattr(self, 'session') and self.session:

View file

@ -4,11 +4,10 @@ Document embeddings query service. Input is vector, output is an array
of chunk_ids of chunk_ids
""" """
import asyncio
import logging import logging
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error from .... schema import Error
@ -38,32 +37,6 @@ class Processor(DocumentEmbeddingsQueryService):
) )
self.qdrant = QdrantClient(url=store_uri, api_key=api_key) self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.last_collection = None
def ensure_collection_exists(self, collection, dim):
"""Ensure collection exists, create if it doesn't"""
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
logger.info(f"Created collection: {collection}")
except Exception as e:
logger.error(f"Qdrant collection creation failed: {e}")
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
async def query_document_embeddings(self, workspace, msg): async def query_document_embeddings(self, workspace, msg):
@ -73,21 +46,24 @@ class Processor(DocumentEmbeddingsQueryService):
if not vec: if not vec:
return [] return []
# Use dimension suffix in collection name
dim = len(vec) dim = len(vec)
collection = f"d_{workspace}_{msg.collection}_{dim}" collection = f"d_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not exists = await asyncio.to_thread(
if not self.collection_exists(collection): self.qdrant.collection_exists, collection
)
if not exists:
logger.info(f"Collection {collection} does not exist, returning empty results") logger.info(f"Collection {collection} does not exist, returning empty results")
return [] return []
search_result = self.qdrant.query_points( result = await asyncio.to_thread(
self.qdrant.query_points,
collection_name=collection, collection_name=collection,
query=vec, query=vec,
limit=msg.limit, limit=msg.limit,
with_payload=True, with_payload=True,
).points )
search_result = result.points
chunks = [] chunks = []
for r in search_result: for r in search_result:

View file

@ -4,11 +4,10 @@ Graph embeddings query service. Input is vector, output is list of
entities entities
""" """
import asyncio
import logging import logging
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL from .... schema import Error, Term, IRI, LITERAL
@ -38,32 +37,6 @@ class Processor(GraphEmbeddingsQueryService):
) )
self.qdrant = QdrantClient(url=store_uri, api_key=api_key) self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.last_collection = None
def ensure_collection_exists(self, collection, dim):
"""Ensure collection exists, create if it doesn't"""
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
logger.info(f"Created collection: {collection}")
except Exception as e:
logger.error(f"Qdrant collection creation failed: {e}")
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def create_value(self, ent): def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"): if ent.startswith("http://") or ent.startswith("https://"):
@ -79,23 +52,26 @@ class Processor(GraphEmbeddingsQueryService):
if not vec: if not vec:
return [] return []
# Use dimension suffix in collection name
dim = len(vec) dim = len(vec)
collection = f"t_{workspace}_{msg.collection}_{dim}" collection = f"t_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not exists = await asyncio.to_thread(
if not self.collection_exists(collection): self.qdrant.collection_exists, collection
)
if not exists:
logger.info(f"Collection {collection} does not exist") logger.info(f"Collection {collection} does not exist")
return [] return []
# Heuristic hack, get (2*limit), so that we have more chance # Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities # of getting (limit) unique entities
search_result = self.qdrant.query_points( result = await asyncio.to_thread(
self.qdrant.query_points,
collection_name=collection, collection_name=collection,
query=vec, query=vec,
limit=msg.limit * 2, limit=msg.limit * 2,
with_payload=True, with_payload=True,
).points )
search_result = result.points
entity_set = set() entity_set = set()
entities = [] entities = []

View file

@ -6,6 +6,7 @@ Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups. use in subsequent Cassandra lookups.
""" """
import asyncio
import logging import logging
import re import re
from typing import Optional from typing import Optional
@ -70,7 +71,7 @@ class Processor(FlowProcessor):
safe_name = 'r_' + safe_name safe_name = 'r_' + safe_name
return safe_name.lower() return safe_name.lower()
def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: async def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given workspace/collection/schema""" """Find the Qdrant collection for a given workspace/collection/schema"""
prefix = ( prefix = (
f"rows_{self.sanitize_name(workspace)}_" f"rows_{self.sanitize_name(workspace)}_"
@ -78,14 +79,15 @@ class Processor(FlowProcessor):
) )
try: try:
all_collections = self.qdrant.get_collections().collections all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching = [ matching = [
coll.name for coll in all_collections coll.name for coll in all_collections
if coll.name.startswith(prefix) if coll.name.startswith(prefix)
] ]
if matching: if matching:
# Return first match (there should typically be only one per dimension)
return matching[0] return matching[0]
except Exception as e: except Exception as e:
@ -100,8 +102,7 @@ class Processor(FlowProcessor):
if not vec: if not vec:
return [] return []
# Find the collection for this workspace/collection/schema qdrant_collection = await self.find_collection(
qdrant_collection = self.find_collection(
workspace, request.collection, request.schema_name workspace, request.collection, request.schema_name
) )
@ -113,7 +114,6 @@ class Processor(FlowProcessor):
return [] return []
try: try:
# Build optional filter for index_name
query_filter = None query_filter = None
if request.index_name: if request.index_name:
query_filter = Filter( query_filter = Filter(
@ -125,16 +125,16 @@ class Processor(FlowProcessor):
] ]
) )
# Query Qdrant result = await asyncio.to_thread(
search_result = self.qdrant.query_points( self.qdrant.query_points,
collection_name=qdrant_collection, collection_name=qdrant_collection,
query=vec, query=vec,
limit=request.limit, limit=request.limit,
with_payload=True, with_payload=True,
query_filter=query_filter, query_filter=query_filter,
).points )
search_result = result.points
# Convert to RowIndexMatch objects
matches = [] matches = []
for point in search_result: for point in search_result:
payload = point.payload or {} payload = point.payload or {}

View file

@ -11,6 +11,7 @@ Queries against the unified 'rows' table with schema:
- source: text - source: text
""" """
import asyncio
import json import json
import logging import logging
import re import re
@ -97,34 +98,38 @@ class Processor(FlowProcessor):
# Cassandra session # Cassandra session
self.cluster = None self.cluster = None
self.session = None self.session = None
self._setup_lock = asyncio.Lock()
# Known keyspaces # Known keyspaces
self.known_keyspaces: Set[str] = set() self.known_keyspaces: Set[str] = set()
def connect_cassandra(self): async def connect_cassandra(self):
"""Connect to Cassandra cluster""" """Connect to Cassandra cluster"""
if self.session: async with self._setup_lock:
return if self.session:
return
try: try:
if self.cassandra_username and self.cassandra_password: if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider( auth_provider = PlainTextAuthProvider(
username=self.cassandra_username, username=self.cassandra_username,
password=self.cassandra_password password=self.cassandra_password
) )
self.cluster = Cluster( cluster = Cluster(
contact_points=self.cassandra_host, contact_points=self.cassandra_host,
auth_provider=auth_provider auth_provider=auth_provider
) )
else: else:
self.cluster = Cluster(contact_points=self.cassandra_host) cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect() session = await asyncio.to_thread(cluster.connect)
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") self.cluster = cluster
self.session = session
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise raise
def sanitize_name(self, name: str) -> str: def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility""" """Sanitize names for Cassandra compatibility"""
@ -140,14 +145,17 @@ class Processor(FlowProcessor):
f"for workspace {workspace}" f"for workspace {workspace}"
) )
# Replace existing schemas for this workspace async with self._setup_lock:
await self._apply_schema_config(workspace, config)
async def _apply_schema_config(self, workspace, config):
ws_schemas: Dict[str, RowSchema] = {} ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas self.schemas[workspace] = ws_schemas
builder = GraphQLSchemaBuilder() builder = GraphQLSchemaBuilder()
self.schema_builders[workspace] = builder self.schema_builders[workspace] = builder
# Check if our config type exists
if self.config_key not in config: if self.config_key not in config:
logger.warning( logger.warning(
f"No '{self.config_key}' type in configuration " f"No '{self.config_key}' type in configuration "
@ -156,16 +164,12 @@ class Processor(FlowProcessor):
self.graphql_schemas[workspace] = None self.graphql_schemas[workspace] = None
return return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key] schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items(): for schema_name, schema_json in schemas_config.items():
try: try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json) schema_def = json.loads(schema_json)
# Create Field objects
fields = [] fields = []
for field_def in schema_def.get("fields", []): for field_def in schema_def.get("fields", []):
field = SchemaField( field = SchemaField(
@ -180,7 +184,6 @@ class Processor(FlowProcessor):
) )
fields.append(field) fields.append(field)
# Create RowSchema
row_schema = RowSchema( row_schema = RowSchema(
name=schema_def.get("name", schema_name), name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""), description=schema_def.get("description", ""),
@ -202,7 +205,6 @@ class Processor(FlowProcessor):
f"{len(ws_schemas)} schemas" f"{len(ws_schemas)} schemas"
) )
# Regenerate GraphQL schema for this workspace
self.graphql_schemas[workspace] = builder.build(self.query_cassandra) self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
def get_index_names(self, schema: RowSchema) -> List[str]: def get_index_names(self, schema: RowSchema) -> List[str]:
@ -254,7 +256,7 @@ class Processor(FlowProcessor):
For other queries, we need to scan and post-filter. For other queries, we need to scan and post-filter.
""" """
# Connect if needed # Connect if needed
self.connect_cassandra() await self.connect_cassandra()
safe_keyspace = self.sanitize_name(workspace) safe_keyspace = self.sanitize_name(workspace)

View file

@ -6,8 +6,8 @@ null. Output is a list of quads.
import asyncio import asyncio
import logging import logging
import json import json
from cassandra.query import SimpleStatement from cassandra.query import SimpleStatement
from .... direct.cassandra_kg import ( from .... direct.cassandra_kg import (
@ -17,6 +17,7 @@ from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK
from .... base import TriplesQueryService from .... base import TriplesQueryService
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -176,45 +177,42 @@ class Processor(TriplesQueryService):
self.cassandra_host = hosts self.cassandra_host = hosts
self.cassandra_username = username self.cassandra_username = username
self.cassandra_password = password self.cassandra_password = password
self.table = None
def ensure_connection(self, workspace): self._connections = {}
"""Ensure we have a connection to the correct keyspace.""" self._conn_lock = asyncio.Lock()
if workspace != self.table:
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password: async def _get_connection(self, workspace):
self.tg = KGClass( async with self._conn_lock:
hosts=self.cassandra_host, if workspace not in self._connections:
keyspace=workspace, if self.cassandra_username and self.cassandra_password:
username=self.cassandra_username, tg = await asyncio.to_thread(
password=self.cassandra_password EntityCentricKnowledgeGraph,
) hosts=self.cassandra_host,
else: keyspace=workspace,
self.tg = KGClass( username=self.cassandra_username,
hosts=self.cassandra_host, password=self.cassandra_password,
keyspace=workspace, )
) else:
self.table = workspace tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
)
self._connections[workspace] = tg
return self._connections[workspace]
async def query_triples(self, workspace, query): async def query_triples(self, workspace, query):
try: try:
# ensure_connection may construct a fresh
# EntityCentricKnowledgeGraph which does sync schema
# setup against Cassandra. Push it to a worker thread
# so the event loop doesn't block on first-use per workspace.
await asyncio.to_thread(self.ensure_connection, workspace)
# Extract values from query
s_val = get_term_value(query.s) s_val = get_term_value(query.s)
p_val = get_term_value(query.p) p_val = get_term_value(query.p)
o_val = get_term_value(query.o) o_val = get_term_value(query.o)
g_val = query.g # Already a string or None g_val = query.g
tg = await self._get_connection(workspace)
def get_object_metadata(row): def get_object_metadata(row):
"""Extract term type metadata from result row"""
return ( return (
getattr(row, 'otype', None), getattr(row, 'otype', None),
getattr(row, 'dtype', None), getattr(row, 'dtype', None),
@ -223,33 +221,21 @@ class Processor(TriplesQueryService):
quads = [] quads = []
# All self.tg.get_* calls below are sync wrappers around
# cassandra session.execute. Materialise inside a worker
# thread so iteration never triggers sync paging back on
# the event loop.
# Route to appropriate query method based on which fields are specified
if s_val is not None: if s_val is not None:
if p_val is not None: if p_val is not None:
if o_val is not None: if o_val is not None:
# SPO specified - find matching graphs resp = await tg.async_get_spo(
resp = await asyncio.to_thread( query.collection, s_val, p_val, o_val,
lambda: list(self.tg.get_spo( g=g_val, limit=query.limit,
query.collection, s_val, p_val, o_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t) term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, p_val, o_val, g, term_type, datatype, language)) quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
else: else:
# SP specified resp = await tg.async_get_sp(
resp = await asyncio.to_thread( query.collection, s_val, p_val,
lambda: list(self.tg.get_sp( g=g_val, limit=query.limit,
query.collection, s_val, p_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -257,24 +243,18 @@ class Processor(TriplesQueryService):
quads.append((s_val, p_val, t.o, g, term_type, datatype, language)) quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
else: else:
if o_val is not None: if o_val is not None:
# SO specified resp = await tg.async_get_os(
resp = await asyncio.to_thread( query.collection, o_val, s_val,
lambda: list(self.tg.get_os( g=g_val, limit=query.limit,
query.collection, o_val, s_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t) term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, t.p, o_val, g, term_type, datatype, language)) quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
else: else:
# S only resp = await tg.async_get_s(
resp = await asyncio.to_thread( query.collection, s_val,
lambda: list(self.tg.get_s( g=g_val, limit=query.limit,
query.collection, s_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -283,24 +263,18 @@ class Processor(TriplesQueryService):
else: else:
if p_val is not None: if p_val is not None:
if o_val is not None: if o_val is not None:
# PO specified resp = await tg.async_get_po(
resp = await asyncio.to_thread( query.collection, p_val, o_val,
lambda: list(self.tg.get_po( g=g_val, limit=query.limit,
query.collection, p_val, o_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t) term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, p_val, o_val, g, term_type, datatype, language)) quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
else: else:
# P only resp = await tg.async_get_p(
resp = await asyncio.to_thread( query.collection, p_val,
lambda: list(self.tg.get_p( g=g_val, limit=query.limit,
query.collection, p_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -308,40 +282,26 @@ class Processor(TriplesQueryService):
quads.append((t.s, p_val, t.o, g, term_type, datatype, language)) quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
else: else:
if o_val is not None: if o_val is not None:
# O only resp = await tg.async_get_o(
resp = await asyncio.to_thread( query.collection, o_val,
lambda: list(self.tg.get_o( g=g_val, limit=query.limit,
query.collection, o_val,
g=g_val, limit=query.limit,
))
) )
for t in resp: for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t) term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, o_val, g, term_type, datatype, language)) quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
else: else:
# Nothing specified - get all resp = await tg.async_get_all(
resp = await asyncio.to_thread( query.collection, limit=query.limit,
lambda: list(self.tg.get_all(
query.collection, limit=query.limit,
))
) )
for t in resp: for t in resp:
# Note: quads_by_collection uses 'd' for graph field
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
# Filter by graph
# g_val=None means all graphs (no filter)
# g_val="" means default graph only
# otherwise filter to specific named graph
if g_val is not None: if g_val is not None:
if g != g_val: if g != g_val:
continue continue
term_type, datatype, language = get_object_metadata(t) term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, t.o, g, term_type, datatype, language)) quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
# Convert to Triple objects (with g field)
# s and p are always IRIs in RDF
# Object uses term_type/datatype/language metadata from database
triples = [ triples = [
Triple( Triple(
s=create_term(q[0], term_type='u'), s=create_term(q[0], term_type='u'),
@ -365,51 +325,36 @@ class Processor(TriplesQueryService):
Uses Cassandra's paging to fetch results incrementally. Uses Cassandra's paging to fetch results incrementally.
""" """
try: try:
await asyncio.to_thread(self.ensure_connection, workspace)
batch_size = query.batch_size if query.batch_size > 0 else 20 batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000 limit = query.limit if query.limit > 0 else 10000
# Extract query pattern
s_val = get_term_value(query.s) s_val = get_term_value(query.s)
p_val = get_term_value(query.p) p_val = get_term_value(query.p)
o_val = get_term_value(query.o) o_val = get_term_value(query.o)
g_val = query.g g_val = query.g
def get_object_metadata(row): def get_object_metadata(row):
"""Extract term type metadata from result row"""
return ( return (
getattr(row, 'otype', None), getattr(row, 'otype', None),
getattr(row, 'dtype', None), getattr(row, 'dtype', None),
getattr(row, 'lang', None), getattr(row, 'lang', None),
) )
# For streaming, we need to execute with fetch_size
# Use the collection table for get_all queries (most common streaming case)
# Determine which query to use based on pattern
if s_val is None and p_val is None and o_val is None: if s_val is None and p_val is None and o_val is None:
# Get all - use collection table with paging
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s" tg = await self._get_connection(workspace)
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s"
params = [query.collection] params = [query.collection]
statement = SimpleStatement(cql, fetch_size=batch_size)
result_set = await async_execute(tg.session, statement, params)
else: else:
# For specific patterns, fall back to non-streaming
# (these typically return small result sets anyway)
async for batch, is_final in self._fallback_stream(workspace, query, batch_size): async for batch, is_final in self._fallback_stream(workspace, query, batch_size):
yield batch, is_final yield batch, is_final
return return
# Materialise in a worker thread. We lose true streaming
# paging (the driver fetches all pages eagerly inside the
# thread) but the event loop stays responsive, and result
# sets at this layer are typically small enough that this
# is acceptable. If true async paging is needed later,
# revisit using ResponseFuture page callbacks.
statement = SimpleStatement(cql, fetch_size=batch_size)
result_set = await asyncio.to_thread(
lambda: list(self.tg.session.execute(statement, params))
)
batch = [] batch = []
count = 0 count = 0

View file

@ -3,11 +3,13 @@
Accepts entity/vector pairs and writes them to a Qdrant store. Accepts entity/vector pairs and writes them to a Qdrant store.
""" """
import asyncio
import uuid
import logging
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
import uuid
import logging
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer from .... base import AsyncProcessor, Consumer, Producer
@ -35,13 +37,35 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
) )
self.qdrant = QdrantClient(url=store_uri, api_key=api_key) self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
# Register for config push notifications # Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"]) self.register_config_handler(self.on_collection_config, types=["collection"])
async def ensure_collection(self, collection_name, dim):
async with self._cache_lock:
if collection_name in self._known_collections:
return
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection_name
)
if not exists:
logger.info(
f"Lazily creating Qdrant collection {collection_name} "
f"with dimension {dim}"
)
await asyncio.to_thread(
self.qdrant.create_collection,
collection_name=collection_name,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
self._known_collections.add(collection_name)
async def store_document_embeddings(self, workspace, message): async def store_document_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(workspace, message.metadata.collection): if not self.collection_exists(workspace, message.metadata.collection):
logger.warning( logger.warning(
f"Collection {message.metadata.collection} for workspace {workspace} " f"Collection {message.metadata.collection} for workspace {workspace} "
@ -60,24 +84,15 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if not vec: if not vec:
continue continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
collection = ( collection = (
f"d_{workspace}_{message.metadata.collection}_{dim}" f"d_{workspace}_{message.metadata.collection}_{dim}"
) )
# Lazily create collection if it doesn't exist (but only if authorized in config) await self.ensure_collection(collection, dim)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert( await asyncio.to_thread(
self.qdrant.upsert,
collection_name=collection, collection_name=collection,
points=[ points=[
PointStruct( PointStruct(
@ -87,7 +102,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
"chunk_id": chunk_id, "chunk_id": chunk_id,
} }
) )
] ],
) )
@staticmethod @staticmethod
@ -124,8 +139,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
try: try:
prefix = f"d_{workspace}_{collection}_" prefix = f"d_{workspace}_{collection}_"
# Get all collections and filter for matches all_collections = await asyncio.to_thread(
all_collections = self.qdrant.get_collections().collections lambda: self.qdrant.get_collections().collections
)
matching_collections = [ matching_collections = [
coll.name for coll in all_collections coll.name for coll in all_collections
if coll.name.startswith(prefix) if coll.name.startswith(prefix)
@ -135,7 +151,11 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
logger.info(f"No collections found matching prefix {prefix}") logger.info(f"No collections found matching prefix {prefix}")
else: else:
for collection_name in matching_collections: for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name) await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")

View file

@ -3,11 +3,13 @@
Accepts entity/vector pairs and writes them to a Qdrant store. Accepts entity/vector pairs and writes them to a Qdrant store.
""" """
import asyncio
import uuid
import logging
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
import uuid
import logging
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer from .... base import AsyncProcessor, Consumer, Producer
@ -50,13 +52,35 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
) )
self.qdrant = QdrantClient(url=store_uri, api_key=api_key) self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
# Register for config push notifications # Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"]) self.register_config_handler(self.on_collection_config, types=["collection"])
async def ensure_collection(self, collection_name, dim):
async with self._cache_lock:
if collection_name in self._known_collections:
return
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection_name
)
if not exists:
logger.info(
f"Lazily creating Qdrant collection {collection_name} "
f"with dimension {dim}"
)
await asyncio.to_thread(
self.qdrant.create_collection,
collection_name=collection_name,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
self._known_collections.add(collection_name)
async def store_graph_embeddings(self, workspace, message): async def store_graph_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(workspace, message.metadata.collection): if not self.collection_exists(workspace, message.metadata.collection):
logger.warning( logger.warning(
f"Collection {message.metadata.collection} for workspace {workspace} " f"Collection {message.metadata.collection} for workspace {workspace} "
@ -75,22 +99,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if not vec: if not vec:
continue continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
collection = ( collection = (
f"t_{workspace}_{message.metadata.collection}_{dim}" f"t_{workspace}_{message.metadata.collection}_{dim}"
) )
# Lazily create collection if it doesn't exist (but only if authorized in config) await self.ensure_collection(collection, dim)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
payload = { payload = {
"entity": entity_value, "entity": entity_value,
@ -98,7 +112,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity.chunk_id: if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert( await asyncio.to_thread(
self.qdrant.upsert,
collection_name=collection, collection_name=collection,
points=[ points=[
PointStruct( PointStruct(
@ -106,7 +121,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
vector=vec, vector=vec,
payload=payload, payload=payload,
) )
] ],
) )
@staticmethod @staticmethod
@ -143,8 +158,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
try: try:
prefix = f"t_{workspace}_{collection}_" prefix = f"t_{workspace}_{collection}_"
# Get all collections and filter for matches all_collections = await asyncio.to_thread(
all_collections = self.qdrant.get_collections().collections lambda: self.qdrant.get_collections().collections
)
matching_collections = [ matching_collections = [
coll.name for coll in all_collections coll.name for coll in all_collections
if coll.name.startswith(prefix) if coll.name.startswith(prefix)
@ -154,7 +170,11 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
logger.info(f"No collections found matching prefix {prefix}") logger.info(f"No collections found matching prefix {prefix}")
else: else:
for collection_name in matching_collections: for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name) await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")

View file

@ -16,10 +16,10 @@ Payload structure:
- text: The text that was embedded (for debugging/display) - text: The text that was embedded (for debugging/display)
""" """
import asyncio
import logging import logging
import re import re
import uuid import uuid
from typing import Set, Tuple
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams from qdrant_client.models import PointStruct, Distance, VectorParams
@ -63,11 +63,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Register config handler for collection management # Register config handler for collection management
self.register_config_handler(self.on_collection_config, types=["collection"]) self.register_config_handler(self.on_collection_config, types=["collection"])
# Cache of created Qdrant collections
self.created_collections: Set[str] = set()
# Qdrant client
self.qdrant = QdrantClient(url=store_uri, api_key=api_key) self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
def sanitize_name(self, name: str) -> str: def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming""" """Sanitize names for Qdrant collection naming"""
@ -85,25 +83,28 @@ class Processor(CollectionConfigHandler, FlowProcessor):
safe_schema = self.sanitize_name(schema_name) safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
def ensure_collection(self, collection_name: str, dimension: int): async def ensure_collection(self, collection_name: str, dimension: int):
"""Create Qdrant collection if it doesn't exist""" """Create Qdrant collection if it doesn't exist"""
if collection_name in self.created_collections: async with self._cache_lock:
return if collection_name in self._known_collections:
return
if not self.qdrant.collection_exists(collection_name): exists = await asyncio.to_thread(
logger.info( self.qdrant.collection_exists, collection_name
f"Creating Qdrant collection {collection_name} "
f"with dimension {dimension}"
) )
self.qdrant.create_collection( if not exists:
collection_name=collection_name, logger.info(
vectors_config=VectorParams( f"Creating Qdrant collection {collection_name} "
size=dimension, f"with dimension {dimension}"
distance=Distance.COSINE
) )
) await asyncio.to_thread(
self.qdrant.create_collection,
self.created_collections.add(collection_name) collection_name=collection_name,
vectors_config=VectorParams(
size=dimension,
distance=Distance.COSINE
),
)
self._known_collections.add(collection_name)
async def on_embeddings(self, msg, consumer, flow): async def on_embeddings(self, msg, consumer, flow):
"""Process incoming RowEmbeddings and write to Qdrant""" """Process incoming RowEmbeddings and write to Qdrant"""
@ -143,15 +144,14 @@ class Processor(CollectionConfigHandler, FlowProcessor):
dimension = len(vector) dimension = len(vector)
# Create/get collection name (lazily on first vector)
if qdrant_collection is None: if qdrant_collection is None:
qdrant_collection = self.get_collection_name( qdrant_collection = self.get_collection_name(
workspace, collection, schema_name, dimension workspace, collection, schema_name, dimension
) )
self.ensure_collection(qdrant_collection, dimension) await self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant await asyncio.to_thread(
self.qdrant.upsert( self.qdrant.upsert,
collection_name=qdrant_collection, collection_name=qdrant_collection,
points=[ points=[
PointStruct( PointStruct(
@ -163,7 +163,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"text": row_emb.text "text": row_emb.text
} }
) )
] ],
) )
embeddings_written += 1 embeddings_written += 1
@ -181,8 +181,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
try: try:
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_" prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches all_collections = await asyncio.to_thread(
all_collections = self.qdrant.get_collections().collections lambda: self.qdrant.get_collections().collections
)
matching_collections = [ matching_collections = [
coll.name for coll in all_collections coll.name for coll in all_collections
if coll.name.startswith(prefix) if coll.name.startswith(prefix)
@ -192,8 +193,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"No Qdrant collections found matching prefix {prefix}") logger.info(f"No Qdrant collections found matching prefix {prefix}")
else: else:
for collection_name in matching_collections: for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name) await asyncio.to_thread(
self.created_collections.discard(collection_name) self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info( logger.info(
f"Deleted {len(matching_collections)} collection(s) " f"Deleted {len(matching_collections)} collection(s) "
@ -217,8 +221,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
) )
# Get all collections and filter for matches all_collections = await asyncio.to_thread(
all_collections = self.qdrant.get_collections().collections lambda: self.qdrant.get_collections().collections
)
matching_collections = [ matching_collections = [
coll.name for coll in all_collections coll.name for coll in all_collections
if coll.name.startswith(prefix) if coll.name.startswith(prefix)
@ -228,8 +233,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"No Qdrant collections found matching prefix {prefix}") logger.info(f"No Qdrant collections found matching prefix {prefix}")
else: else:
for collection_name in matching_collections: for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name) await asyncio.to_thread(
self.created_collections.discard(collection_name) self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted Qdrant collection: {collection_name}")
except Exception as e: except Exception as e:

View file

@ -82,7 +82,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Cache of known keyspaces and whether tables exist # Cache of known keyspaces and whether tables exist
self.known_keyspaces: Set[str] = set() self.known_keyspaces: Set[str] = set()
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables self.tables_initialized: Set[str] = set()
# Cache of registered (collection, schema_name) pairs # Cache of registered (collection, schema_name) pairs
self.registered_partitions: Set[Tuple[str, str]] = set() self.registered_partitions: Set[Tuple[str, str]] = set()
@ -94,6 +94,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
self.cluster = None self.cluster = None
self.session = None self.session = None
# Protects connection setup and cache mutations
self._setup_lock = asyncio.Lock()
def connect_cassandra(self): def connect_cassandra(self):
"""Connect to Cassandra cluster""" """Connect to Cassandra cluster"""
if self.session: if self.session:
@ -126,6 +129,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"for workspace {workspace}" f"for workspace {workspace}"
) )
async with self._setup_lock:
return await self._apply_schema_config(workspace, config, version)
async def _apply_schema_config(self, workspace, config, version):
# Track which schemas changed in this workspace # Track which schemas changed in this workspace
old_schemas = self.schemas.get(workspace, {}) old_schemas = self.schemas.get(workspace, {})
old_schema_names = set(old_schemas.keys()) old_schema_names = set(old_schemas.keys())
@ -391,16 +399,12 @@ class Processor(CollectionConfigHandler, FlowProcessor):
schema_name = obj.schema_name schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or '' source = getattr(obj.metadata, 'source', '') or ''
# Ensure tables exist (sync DDL — push to a worker thread async with self._setup_lock:
# so the event loop stays responsive when running in a await asyncio.to_thread(self.ensure_tables, keyspace)
# processor group sharing the loop with siblings). await asyncio.to_thread(
await asyncio.to_thread(self.ensure_tables, keyspace) self.register_partitions,
keyspace, collection, schema_name, workspace,
# Register partitions if first time seeing this (collection, schema_name) )
await asyncio.to_thread(
self.register_partitions,
keyspace, collection, schema_name, workspace,
)
safe_keyspace = self.sanitize_name(keyspace) safe_keyspace = self.sanitize_name(keyspace)
@ -461,35 +465,27 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def create_collection(self, workspace: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store""" """Create/verify collection exists in Cassandra row store"""
# Connect if not already connected (sync, push to thread) async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra) await asyncio.to_thread(self.connect_cassandra)
await asyncio.to_thread(self.ensure_tables, workspace)
# Ensure tables exist (sync DDL, push to thread)
await asyncio.to_thread(self.ensure_tables, workspace)
logger.info(f"Collection {collection} ready for workspace {workspace}") logger.info(f"Collection {collection} ready for workspace {workspace}")
async def delete_collection(self, workspace: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection using partition tracking""" """Delete all data for a specific collection using partition tracking"""
# Connect if not already connected async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra) await asyncio.to_thread(self.connect_cassandra)
if workspace not in self.known_keyspaces:
safe_ks = self.sanitize_name(workspace)
check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s"
result = await async_execute(self.session, check_cql, (safe_ks,))
if not result:
logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete")
return
self.known_keyspaces.add(workspace)
safe_keyspace = self.sanitize_name(workspace) safe_keyspace = self.sanitize_name(workspace)
# Check if keyspace exists
if workspace not in self.known_keyspaces:
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = await async_execute(
self.session, check_keyspace_cql, (safe_keyspace,)
)
if not result:
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(workspace)
# Discover all partitions for this collection # Discover all partitions for this collection
select_partitions_cql = f""" select_partitions_cql = f"""
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
@ -540,11 +536,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.error(f"Failed to clean up row_partitions for {collection}: {e}") logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
raise raise
# Clear from local cache async with self._setup_lock:
self.registered_partitions = { self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions (col, sch) for col, sch in self.registered_partitions
if col != collection if col != collection
} }
logger.info( logger.info(
f"Deleted collection {collection}: {partitions_deleted} partitions " f"Deleted collection {collection}: {partitions_deleted} partitions "
@ -553,8 +549,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str): async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination""" """Delete all data for a specific collection + schema combination"""
# Connect if not already connected async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra) await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(workspace) safe_keyspace = self.sanitize_name(workspace)
@ -614,8 +610,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
) )
raise raise
# Clear from local cache async with self._setup_lock:
self.registered_partitions.discard((collection, schema_name)) self.registered_partitions.discard((collection, schema_name))
logger.info( logger.info(
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions " f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "

View file

@ -4,12 +4,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph.
""" """
import asyncio import asyncio
import base64
import os
import argparse
import time
import logging import logging
import json
from .... direct.cassandra_kg import ( from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, DEFAULT_GRAPH EntityCentricKnowledgeGraph, DEFAULT_GRAPH
@ -28,6 +23,8 @@ default_ident = "triples-write"
def serialize_triple(triple): def serialize_triple(triple):
"""Serialize a Triple object to JSON for storage.""" """Serialize a Triple object to JSON for storage."""
import json
if triple is None: if triple is None:
return None return None
@ -141,156 +138,84 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
self.cassandra_host = hosts self.cassandra_host = hosts
self.cassandra_username = username self.cassandra_username = username
self.cassandra_password = password self.cassandra_password = password
self.table = None
self.tg = None self._connections = {}
self._conn_lock = asyncio.Lock()
# Register for config push notifications # Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"]) self.register_config_handler(self.on_collection_config, types=["collection"])
async def _get_connection(self, workspace):
async with self._conn_lock:
if workspace not in self._connections:
if self.cassandra_username and self.cassandra_password:
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
)
self._connections[workspace] = tg
return self._connections[workspace]
async def store_triples(self, workspace, message): async def store_triples(self, workspace, message):
# The cassandra-driver work below — connection, schema tg = await self._get_connection(workspace)
# setup, and per-triple inserts — is all synchronous.
# Wrap the whole batch in a worker thread so the event
# loop stays responsive for sibling processors when
# running in a processor group.
def _do_store(): for t in message.triples:
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
g_val = t.g if t.g is not None else DEFAULT_GRAPH
if self.table is None or self.table != workspace: otype = get_term_otype(t.o)
dtype = get_term_dtype(t.o)
lang = get_term_lang(t.o)
self.tg = None await tg.async_insert(
message.metadata.collection,
# Use factory function to select implementation s_val,
KGClass = EntityCentricKnowledgeGraph p_val,
o_val,
try: g=g_val,
if self.cassandra_username and self.cassandra_password: otype=otype,
self.tg = KGClass( dtype=dtype,
hosts=self.cassandra_host, lang=lang,
keyspace=workspace, )
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Exception: {e}", exc_info=True)
time.sleep(1)
raise e
self.table = workspace
for t in message.triples:
# Extract values from Term objects
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
# t.g is None for default graph, or a graph IRI
g_val = t.g if t.g is not None else DEFAULT_GRAPH
# Extract object type metadata for entity-centric storage
otype = get_term_otype(t.o)
dtype = get_term_dtype(t.o)
lang = get_term_lang(t.o)
self.tg.insert(
message.metadata.collection,
s_val,
p_val,
o_val,
g=g_val,
otype=otype,
dtype=dtype,
lang=lang,
)
await asyncio.to_thread(_do_store)
async def create_collection(self, workspace: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create a collection in Cassandra triple store via config push""" """Create a collection in Cassandra triple store via config push"""
try:
tg = await self._get_connection(workspace)
def _do_create():
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = workspace
# Create collection using the built-in method
logger.info(f"Creating collection {collection} for workspace {workspace}") logger.info(f"Creating collection {collection} for workspace {workspace}")
if self.tg.collection_exists(collection): exists = await tg.async_collection_exists(collection)
if exists:
logger.info(f"Collection {collection} already exists") logger.info(f"Collection {collection} already exists")
else: else:
self.tg.create_collection(collection) await tg.async_create_collection(collection)
logger.info(f"Created collection {collection}") logger.info(f"Created collection {collection}")
try:
await asyncio.to_thread(_do_create)
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, workspace: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection from the unified triples table""" """Delete all data for a specific collection from the unified triples table"""
try:
tg = await self._get_connection(workspace)
def _do_delete(): await tg.async_delete_collection(collection)
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = workspace
# Delete all triples for this collection using the built-in method
self.tg.delete_collection(collection)
logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}") logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
try:
await asyncio.to_thread(_do_delete)
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise