mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-16 19:05:14 +02:00
Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916)
Cassandra triples services were using syncronous EntityCentricKnowledgeGraph methods from async contexts, and connection state was managed with threading.local which is wrong for asyncio coroutines sharing a single thread. Qdrant services had no async wrapping at all, blocking the event loop on every network call. Rows services had unprotected shared state mutations across concurrent coroutines. - Add async methods to EntityCentricKnowledgeGraph (async_insert, async_get_s/p/o/sp/po/os/spo/all, async_collection_exists, async_create_collection, async_delete_collection) using the existing cassandra_async.async_execute bridge - Rewrite triples write + query services: replace threading.local with asyncio.Lock + dict cache for per-workspace connections, use async ECKG methods for all data operations, keep asyncio.to_thread only for one-time blocking ECKG construction - Wrap all Qdrant calls in asyncio.to_thread across all 6 services (doc/graph/row embeddings write + query), add asyncio.Lock + set cache for collection existence checks - Add asyncio.Lock to rows write + query services to protect shared state (schemas, sessions, config caches) from concurrent mutation - Update all affected tests to match new async patterns
This commit is contained in:
parent
bb1109963c
commit
a2dde9cafb
22 changed files with 736 additions and 621 deletions
|
|
@ -188,13 +188,14 @@ class TestConfigurationPriorityEndToEnd:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.direct.cassandra_kg.Cluster')
|
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||||
async def test_no_config_defaults_end_to_end(self, mock_cluster):
|
async def test_no_config_defaults_end_to_end(self, mock_kg_class):
|
||||||
"""Test that defaults are used when no configuration provided end-to-end."""
|
"""Test that defaults are used when no configuration provided end-to-end."""
|
||||||
mock_cluster_instance = MagicMock()
|
from unittest.mock import AsyncMock
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_cluster_instance.connect.return_value = mock_session
|
mock_tg_instance = MagicMock()
|
||||||
mock_cluster.return_value = mock_cluster_instance
|
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
|
||||||
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
processor = TriplesQuery(taskgroup=MagicMock())
|
processor = TriplesQuery(taskgroup=MagicMock())
|
||||||
|
|
@ -205,20 +206,16 @@ class TestConfigurationPriorityEndToEnd:
|
||||||
mock_query.s = None
|
mock_query.s = None
|
||||||
mock_query.p = None
|
mock_query.p = None
|
||||||
mock_query.o = None
|
mock_query.o = None
|
||||||
|
mock_query.g = None
|
||||||
mock_query.limit = 100
|
mock_query.limit = 100
|
||||||
|
|
||||||
# Mock the get_all method to return empty list
|
|
||||||
mock_tg_instance = MagicMock()
|
|
||||||
mock_tg_instance.get_all.return_value = []
|
|
||||||
processor.tg = mock_tg_instance
|
|
||||||
|
|
||||||
await processor.query_triples('default_user', mock_query)
|
await processor.query_triples('default_user', mock_query)
|
||||||
|
|
||||||
# Should use defaults
|
# Should use defaults
|
||||||
mock_cluster.assert_called_once()
|
mock_kg_class.assert_called_once_with(
|
||||||
call_args = mock_cluster.call_args
|
hosts=['cassandra'],
|
||||||
assert call_args.args[0] == ['cassandra'] # Default host
|
keyspace='default_user'
|
||||||
assert 'auth_provider' not in call_args.kwargs # No auth with default config
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestNoBackwardCompatibilityEndToEnd:
|
class TestNoBackwardCompatibilityEndToEnd:
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,8 @@ class TestRowsCassandraIntegration:
|
||||||
processor.session = None
|
processor.session = None
|
||||||
|
|
||||||
# Bind actual methods from the new unified table implementation
|
# Bind actual methods from the new unified table implementation
|
||||||
|
import asyncio
|
||||||
|
processor._setup_lock = asyncio.Lock()
|
||||||
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
||||||
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
||||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||||
|
|
@ -108,6 +110,7 @@ class TestRowsCassandraIntegration:
|
||||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||||
|
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||||
processor.collection_exists = MagicMock(return_value=True)
|
processor.collection_exists = MagicMock(return_value=True)
|
||||||
|
|
|
||||||
|
|
@ -184,7 +184,7 @@ class TestObjectsGraphQLQueryIntegration:
|
||||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||||
|
|
||||||
# Connect to Cassandra
|
# Connect to Cassandra
|
||||||
processor.connect_cassandra()
|
await processor.connect_cassandra()
|
||||||
assert processor.session is not None
|
assert processor.session is not None
|
||||||
|
|
||||||
# Create test keyspace and table
|
# Create test keyspace and table
|
||||||
|
|
@ -219,7 +219,7 @@ class TestObjectsGraphQLQueryIntegration:
|
||||||
"""Test inserting data and querying via GraphQL"""
|
"""Test inserting data and querying via GraphQL"""
|
||||||
# Load schema and connect
|
# Load schema and connect
|
||||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||||
processor.connect_cassandra()
|
await processor.connect_cassandra()
|
||||||
|
|
||||||
# Setup test data
|
# Setup test data
|
||||||
keyspace = "test_user"
|
keyspace = "test_user"
|
||||||
|
|
@ -293,7 +293,7 @@ class TestObjectsGraphQLQueryIntegration:
|
||||||
"""Test GraphQL queries with filtering on indexed fields"""
|
"""Test GraphQL queries with filtering on indexed fields"""
|
||||||
# Setup (reuse previous setup)
|
# Setup (reuse previous setup)
|
||||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||||
processor.connect_cassandra()
|
await processor.connect_cassandra()
|
||||||
|
|
||||||
keyspace = "test_user"
|
keyspace = "test_user"
|
||||||
collection = "filter_test"
|
collection = "filter_test"
|
||||||
|
|
@ -387,7 +387,7 @@ class TestObjectsGraphQLQueryIntegration:
|
||||||
"""Test full message processing workflow"""
|
"""Test full message processing workflow"""
|
||||||
# Setup
|
# Setup
|
||||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||||
processor.connect_cassandra()
|
await processor.connect_cassandra()
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
request = RowsQueryRequest(
|
request = RowsQueryRequest(
|
||||||
|
|
@ -433,7 +433,7 @@ class TestObjectsGraphQLQueryIntegration:
|
||||||
"""Test handling multiple concurrent GraphQL queries"""
|
"""Test handling multiple concurrent GraphQL queries"""
|
||||||
# Setup
|
# Setup
|
||||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||||
processor.connect_cassandra()
|
await processor.connect_cassandra()
|
||||||
|
|
||||||
# Create multiple query tasks
|
# Create multiple query tasks
|
||||||
queries = [
|
queries = [
|
||||||
|
|
@ -519,7 +519,7 @@ class TestObjectsGraphQLQueryIntegration:
|
||||||
"""Test handling of large query result sets"""
|
"""Test handling of large query result sets"""
|
||||||
# Setup
|
# Setup
|
||||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||||
processor.connect_cassandra()
|
await processor.connect_cassandra()
|
||||||
|
|
||||||
keyspace = "large_test_user"
|
keyspace = "large_test_user"
|
||||||
collection = "large_collection"
|
collection = "large_collection"
|
||||||
|
|
|
||||||
|
|
@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_schema_config_parsing(self):
|
async def test_schema_config_parsing(self):
|
||||||
"""Test parsing of schema configuration"""
|
"""Test parsing of schema configuration"""
|
||||||
|
import asyncio
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.schemas = {}
|
processor.schemas = {}
|
||||||
processor.schema_builders = {}
|
processor.schema_builders = {}
|
||||||
processor.graphql_schemas = {}
|
processor.graphql_schemas = {}
|
||||||
processor.config_key = "schema"
|
processor.config_key = "schema"
|
||||||
processor.query_cassandra = MagicMock()
|
processor.query_cassandra = MagicMock()
|
||||||
|
processor._setup_lock = asyncio.Lock()
|
||||||
|
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||||
|
|
||||||
# Create test config
|
# Create test config
|
||||||
|
|
@ -335,7 +338,7 @@ class TestUnifiedTableQueries:
|
||||||
"""Test query execution with matching index"""
|
"""Test query execution with matching index"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.session = MagicMock()
|
processor.session = MagicMock()
|
||||||
processor.connect_cassandra = MagicMock()
|
processor.connect_cassandra = AsyncMock()
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||||
|
|
@ -396,7 +399,7 @@ class TestUnifiedTableQueries:
|
||||||
"""Test query execution without matching index (scan mode)"""
|
"""Test query execution without matching index (scan mode)"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.session = MagicMock()
|
processor.session = MagicMock()
|
||||||
processor.connect_cassandra = MagicMock()
|
processor.connect_cassandra = AsyncMock()
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,10 @@
|
||||||
Tests for Cassandra triples query service
|
Tests for Cassandra triples query service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
|
||||||
from trustgraph.query.triples.cassandra.service import Processor, create_term
|
from trustgraph.query.triples.cassandra.service import Processor, create_term
|
||||||
from trustgraph.schema import Term, IRI, LITERAL
|
from trustgraph.schema import Term, IRI, LITERAL
|
||||||
|
|
@ -18,7 +20,7 @@ class TestCassandraQueryProcessor:
|
||||||
return Processor(
|
return Processor(
|
||||||
taskgroup=MagicMock(),
|
taskgroup=MagicMock(),
|
||||||
id='test-cassandra-query',
|
id='test-cassandra-query',
|
||||||
graph_host='localhost'
|
cassandra_host='localhost'
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_create_term_with_http_uri(self, processor):
|
def test_create_term_with_http_uri(self, processor):
|
||||||
|
|
@ -85,7 +87,7 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_result.o = 'test_object'
|
mock_result.o = 'test_object'
|
||||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(
|
processor = Processor(
|
||||||
taskgroup=MagicMock(),
|
taskgroup=MagicMock(),
|
||||||
|
|
@ -110,8 +112,8 @@ class TestCassandraQueryProcessor:
|
||||||
keyspace='test_user'
|
keyspace='test_user'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify get_spo was called with correct parameters
|
# Verify async_get_spo was called with correct parameters
|
||||||
mock_tg_instance.get_spo.assert_called_once_with(
|
mock_tg_instance.async_get_spo.assert_called_once_with(
|
||||||
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
|
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -130,7 +132,8 @@ class TestCassandraQueryProcessor:
|
||||||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||||
assert processor.cassandra_username is None
|
assert processor.cassandra_username is None
|
||||||
assert processor.cassandra_password is None
|
assert processor.cassandra_password is None
|
||||||
assert processor.table is None
|
assert processor._connections == {}
|
||||||
|
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||||
|
|
||||||
def test_processor_initialization_with_custom_params(self):
|
def test_processor_initialization_with_custom_params(self):
|
||||||
"""Test processor initialization with custom parameters"""
|
"""Test processor initialization with custom parameters"""
|
||||||
|
|
@ -146,7 +149,8 @@ class TestCassandraQueryProcessor:
|
||||||
assert processor.cassandra_host == ['cassandra.example.com']
|
assert processor.cassandra_host == ['cassandra.example.com']
|
||||||
assert processor.cassandra_username == 'queryuser'
|
assert processor.cassandra_username == 'queryuser'
|
||||||
assert processor.cassandra_password == 'querypass'
|
assert processor.cassandra_password == 'querypass'
|
||||||
assert processor.table is None
|
assert processor._connections == {}
|
||||||
|
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||||
|
|
@ -164,7 +168,7 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result.otype = None
|
mock_result.otype = None
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -178,7 +182,7 @@ class TestCassandraQueryProcessor:
|
||||||
|
|
||||||
result = await processor.query_triples('test_user', query)
|
result = await processor.query_triples('test_user', query)
|
||||||
|
|
||||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].s.iri == 'test_subject'
|
assert result[0].s.iri == 'test_subject'
|
||||||
assert result[0].p.iri == 'test_predicate'
|
assert result[0].p.iri == 'test_predicate'
|
||||||
|
|
@ -200,7 +204,7 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result.otype = None
|
mock_result.otype = None
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_tg_instance.get_s.return_value = [mock_result]
|
mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -214,7 +218,7 @@ class TestCassandraQueryProcessor:
|
||||||
|
|
||||||
result = await processor.query_triples('test_user', query)
|
result = await processor.query_triples('test_user', query)
|
||||||
|
|
||||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].s.iri == 'test_subject'
|
assert result[0].s.iri == 'test_subject'
|
||||||
assert result[0].p.iri == 'result_predicate'
|
assert result[0].p.iri == 'result_predicate'
|
||||||
|
|
@ -236,7 +240,7 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result.otype = None
|
mock_result.otype = None
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_tg_instance.get_p.return_value = [mock_result]
|
mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -250,7 +254,7 @@ class TestCassandraQueryProcessor:
|
||||||
|
|
||||||
result = await processor.query_triples('test_user', query)
|
result = await processor.query_triples('test_user', query)
|
||||||
|
|
||||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].s.iri == 'result_subject'
|
assert result[0].s.iri == 'result_subject'
|
||||||
assert result[0].p.iri == 'test_predicate'
|
assert result[0].p.iri == 'test_predicate'
|
||||||
|
|
@ -272,7 +276,7 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result.otype = None
|
mock_result.otype = None
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_tg_instance.get_o.return_value = [mock_result]
|
mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -286,7 +290,7 @@ class TestCassandraQueryProcessor:
|
||||||
|
|
||||||
result = await processor.query_triples('test_user', query)
|
result = await processor.query_triples('test_user', query)
|
||||||
|
|
||||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].s.iri == 'result_subject'
|
assert result[0].s.iri == 'result_subject'
|
||||||
assert result[0].p.iri == 'result_predicate'
|
assert result[0].p.iri == 'result_predicate'
|
||||||
|
|
@ -305,11 +309,11 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result.s = 'all_subject'
|
mock_result.s = 'all_subject'
|
||||||
mock_result.p = 'all_predicate'
|
mock_result.p = 'all_predicate'
|
||||||
mock_result.o = 'all_object'
|
mock_result.o = 'all_object'
|
||||||
mock_result.g = ''
|
mock_result.d = ''
|
||||||
mock_result.otype = None
|
mock_result.otype = None
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_tg_instance.get_all.return_value = [mock_result]
|
mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -323,7 +327,7 @@ class TestCassandraQueryProcessor:
|
||||||
|
|
||||||
result = await processor.query_triples('test_user', query)
|
result = await processor.query_triples('test_user', query)
|
||||||
|
|
||||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].s.iri == 'all_subject'
|
assert result[0].s.iri == 'all_subject'
|
||||||
assert result[0].p.iri == 'all_predicate'
|
assert result[0].p.iri == 'all_predicate'
|
||||||
|
|
@ -410,7 +414,7 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_result.o = 'test_object'
|
mock_result.o = 'test_object'
|
||||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(
|
processor = Processor(
|
||||||
taskgroup=MagicMock(),
|
taskgroup=MagicMock(),
|
||||||
|
|
@ -451,7 +455,7 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_result.o = 'test_object'
|
mock_result.o = 'test_object'
|
||||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -489,8 +493,8 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_result.p = 'p'
|
mock_result.p = 'p'
|
||||||
mock_result.o = 'o'
|
mock_result.o = 'o'
|
||||||
mock_tg_instance1.get_s.return_value = [mock_result]
|
mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result])
|
||||||
mock_tg_instance2.get_s.return_value = [mock_result]
|
mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -504,7 +508,6 @@ class TestCassandraQueryProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
await processor.query_triples('user1', query1)
|
await processor.query_triples('user1', query1)
|
||||||
assert processor.table == 'user1'
|
|
||||||
|
|
||||||
# Second query with different table
|
# Second query with different table
|
||||||
query2 = TriplesQueryRequest(
|
query2 = TriplesQueryRequest(
|
||||||
|
|
@ -516,10 +519,11 @@ class TestCassandraQueryProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
await processor.query_triples('user2', query2)
|
await processor.query_triples('user2', query2)
|
||||||
assert processor.table == 'user2'
|
|
||||||
|
|
||||||
# Verify TrustGraph was created twice
|
# Verify TrustGraph was created twice for different workspaces
|
||||||
assert mock_kg_class.call_count == 2
|
assert mock_kg_class.call_count == 2
|
||||||
|
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||||
|
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||||
|
|
@ -529,7 +533,7 @@ class TestCassandraQueryProcessor:
|
||||||
|
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed"))
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -566,7 +570,7 @@ class TestCassandraQueryProcessor:
|
||||||
mock_result2.otype = None
|
mock_result2.otype = None
|
||||||
mock_result2.dtype = None
|
mock_result2.dtype = None
|
||||||
mock_result2.lang = None
|
mock_result2.lang = None
|
||||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
||||||
mock_result.otype = None
|
mock_result.otype = None
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_tg_instance.get_po.return_value = [mock_result]
|
mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
||||||
|
|
||||||
result = await processor.query_triples('test_user', query)
|
result = await processor.query_triples('test_user', query)
|
||||||
|
|
||||||
# Verify get_po was called (should use optimized po_table)
|
# Verify async_get_po was called (should use optimized po_table)
|
||||||
mock_tg_instance.get_po.assert_called_once_with(
|
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||||
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
|
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
||||||
mock_result.otype = None
|
mock_result.otype = None
|
||||||
mock_result.dtype = None
|
mock_result.dtype = None
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_tg_instance.get_os.return_value = [mock_result]
|
mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
||||||
|
|
||||||
result = await processor.query_triples('test_user', query)
|
result = await processor.query_triples('test_user', query)
|
||||||
|
|
||||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
# Verify async_get_os was called (should use optimized subject_table with clustering)
|
||||||
mock_tg_instance.get_os.assert_called_once_with(
|
mock_tg_instance.async_get_os.assert_called_once_with(
|
||||||
'test_collection', 'test_object', 'test_subject', g=None, limit=25
|
'test_collection', 'test_object', 'test_subject', g=None, limit=25
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations:
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
# Mock empty results for all queries
|
# Mock empty results for all queries
|
||||||
mock_tg_instance.get_all.return_value = []
|
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
|
||||||
mock_tg_instance.get_s.return_value = []
|
mock_tg_instance.async_get_s = AsyncMock(return_value=[])
|
||||||
mock_tg_instance.get_p.return_value = []
|
mock_tg_instance.async_get_p = AsyncMock(return_value=[])
|
||||||
mock_tg_instance.get_o.return_value = []
|
mock_tg_instance.async_get_o = AsyncMock(return_value=[])
|
||||||
mock_tg_instance.get_sp.return_value = []
|
mock_tg_instance.async_get_sp = AsyncMock(return_value=[])
|
||||||
mock_tg_instance.get_po.return_value = []
|
mock_tg_instance.async_get_po = AsyncMock(return_value=[])
|
||||||
mock_tg_instance.get_os.return_value = []
|
mock_tg_instance.async_get_os = AsyncMock(return_value=[])
|
||||||
mock_tg_instance.get_spo.return_value = []
|
mock_tg_instance.async_get_spo = AsyncMock(return_value=[])
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
# Test each query pattern
|
# Test each query pattern
|
||||||
test_patterns = [
|
test_patterns = [
|
||||||
# (s, p, o, expected_method)
|
# (s, p, o, expected_method)
|
||||||
(None, None, None, 'get_all'), # All triples
|
(None, None, None, 'async_get_all'), # All triples
|
||||||
('s1', None, None, 'get_s'), # Subject only
|
('s1', None, None, 'async_get_s'), # Subject only
|
||||||
(None, 'p1', None, 'get_p'), # Predicate only
|
(None, 'p1', None, 'async_get_p'), # Predicate only
|
||||||
(None, None, 'o1', 'get_o'), # Object only
|
(None, None, 'o1', 'async_get_o'), # Object only
|
||||||
('s1', 'p1', None, 'get_sp'), # Subject + Predicate
|
('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate
|
||||||
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
(None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
||||||
('s1', None, 'o1', 'get_os'), # Object + Subject
|
('s1', None, 'o1', 'async_get_os'), # Object + Subject
|
||||||
('s1', 'p1', 'o1', 'get_spo'), # All three
|
('s1', 'p1', 'o1', 'async_get_spo'), # All three
|
||||||
]
|
]
|
||||||
|
|
||||||
for s, p, o, expected_method in test_patterns:
|
for s, p, o, expected_method in test_patterns:
|
||||||
|
|
@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
||||||
mock_result.lang = None
|
mock_result.lang = None
|
||||||
mock_results.append(mock_result)
|
mock_results.append(mock_result)
|
||||||
|
|
||||||
mock_tg_instance.get_po.return_value = mock_results
|
mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results)
|
||||||
|
|
||||||
processor = Processor(taskgroup=MagicMock())
|
processor = Processor(taskgroup=MagicMock())
|
||||||
|
|
||||||
|
|
@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
||||||
|
|
||||||
result = await processor.query_triples('large_dataset_user', query)
|
result = await processor.query_triples('large_dataset_user', query)
|
||||||
|
|
||||||
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
|
# Verify optimized async_get_po was used (no ALLOW FILTERING needed!)
|
||||||
mock_tg_instance.get_po.assert_called_once_with(
|
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||||
'massive_collection',
|
'massive_collection',
|
||||||
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
|
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
|
||||||
'http://example.com/Person',
|
'http://example.com/Person',
|
||||||
|
|
|
||||||
|
|
@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_embedding_upserted(self):
|
async def test_valid_embedding_upserted(self):
|
||||||
|
import asyncio
|
||||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
proc = Processor.__new__(Processor)
|
proc = Processor.__new__(Processor)
|
||||||
proc.qdrant = MagicMock()
|
proc.qdrant = MagicMock()
|
||||||
proc.qdrant.collection_exists.return_value = True
|
proc.qdrant.collection_exists.return_value = True
|
||||||
proc.collection_exists = MagicMock(return_value=True)
|
proc.collection_exists = MagicMock(return_value=True)
|
||||||
|
proc._cache_lock = asyncio.Lock()
|
||||||
|
proc._known_collections = set()
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.metadata.collection = "col1"
|
msg.metadata.collection = "col1"
|
||||||
|
|
@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_dimension_in_collection_name(self):
|
async def test_dimension_in_collection_name(self):
|
||||||
"""Collection name should include vector dimension."""
|
"""Collection name should include vector dimension."""
|
||||||
|
import asyncio
|
||||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
proc = Processor.__new__(Processor)
|
proc = Processor.__new__(Processor)
|
||||||
proc.qdrant = MagicMock()
|
proc.qdrant = MagicMock()
|
||||||
proc.qdrant.collection_exists.return_value = True
|
proc.qdrant.collection_exists.return_value = True
|
||||||
proc.collection_exists = MagicMock(return_value=True)
|
proc.collection_exists = MagicMock(return_value=True)
|
||||||
|
proc._cache_lock = asyncio.Lock()
|
||||||
|
proc._known_collections = set()
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.metadata.collection = "docs"
|
msg.metadata.collection = "docs"
|
||||||
|
|
@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_entity_and_vector_upserted(self):
|
async def test_valid_entity_and_vector_upserted(self):
|
||||||
|
import asyncio
|
||||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
proc = Processor.__new__(Processor)
|
proc = Processor.__new__(Processor)
|
||||||
proc.qdrant = MagicMock()
|
proc.qdrant = MagicMock()
|
||||||
proc.qdrant.collection_exists.return_value = True
|
proc.qdrant.collection_exists.return_value = True
|
||||||
proc.collection_exists = MagicMock(return_value=True)
|
proc.collection_exists = MagicMock(return_value=True)
|
||||||
|
proc._cache_lock = asyncio.Lock()
|
||||||
|
proc._known_collections = set()
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.metadata.collection = "col1"
|
msg.metadata.collection = "col1"
|
||||||
|
|
@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_lazy_collection_creation_on_new_dimension(self):
|
async def test_lazy_collection_creation_on_new_dimension(self):
|
||||||
|
import asyncio
|
||||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
proc = Processor.__new__(Processor)
|
proc = Processor.__new__(Processor)
|
||||||
proc.qdrant = MagicMock()
|
proc.qdrant = MagicMock()
|
||||||
proc.qdrant.collection_exists.return_value = False
|
proc.qdrant.collection_exists.return_value = False
|
||||||
proc.collection_exists = MagicMock(return_value=True)
|
proc.collection_exists = MagicMock(return_value=True)
|
||||||
|
proc._cache_lock = asyncio.Lock()
|
||||||
|
proc._known_collections = set()
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.metadata.collection = "graphs"
|
msg.metadata.collection = "graphs"
|
||||||
|
|
|
||||||
|
|
@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
# Assert
|
# Assert
|
||||||
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
|
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
|
||||||
|
|
||||||
# Verify collection existence is checked on each write
|
# Second write uses cached collection state — no collection_exists check
|
||||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||||
|
|
||||||
# But upsert should still be called
|
# But upsert should still be called
|
||||||
mock_qdrant_instance.upsert.assert_called_once()
|
mock_qdrant_instance.upsert.assert_called_once()
|
||||||
|
|
|
||||||
|
|
@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
processor.ensure_collection("test_collection", 384)
|
await processor.ensure_collection("test_collection", 384)
|
||||||
|
|
||||||
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
||||||
mock_qdrant_instance.create_collection.assert_called_once()
|
mock_qdrant_instance.create_collection.assert_called_once()
|
||||||
|
|
||||||
# Verify the collection is cached
|
# Verify the collection is cached
|
||||||
assert "test_collection" in processor.created_collections
|
assert "test_collection" in processor._known_collections
|
||||||
|
|
||||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
||||||
|
|
@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
processor.ensure_collection("existing_collection", 384)
|
await processor.ensure_collection("existing_collection", 384)
|
||||||
|
|
||||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||||
mock_qdrant_instance.create_collection.assert_not_called()
|
mock_qdrant_instance.create_collection.assert_not_called()
|
||||||
|
|
@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
processor.created_collections.add("cached_collection")
|
processor._known_collections.add("cached_collection")
|
||||||
|
|
||||||
processor.ensure_collection("cached_collection", 384)
|
await processor.ensure_collection("cached_collection", 384)
|
||||||
|
|
||||||
# Should not check or create - just return
|
# Should not check or create - just return
|
||||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||||
|
|
@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
|
processor._known_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||||
|
|
||||||
await processor.delete_collection('test_workspace', 'test_collection')
|
await processor.delete_collection('test_workspace', 'test_collection')
|
||||||
|
|
||||||
|
|
@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||||
|
|
||||||
# Verify the cached collection was removed
|
# Verify the cached collection was removed
|
||||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
|
assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections
|
||||||
|
|
||||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||||
|
|
|
||||||
|
|
@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_schema_config_parsing(self):
|
async def test_schema_config_parsing(self):
|
||||||
"""Test parsing of schema configurations"""
|
"""Test parsing of schema configurations"""
|
||||||
|
import asyncio
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.schemas = {}
|
processor.schemas = {}
|
||||||
processor.config_key = "schema"
|
processor.config_key = "schema"
|
||||||
processor.registered_partitions = set()
|
processor.registered_partitions = set()
|
||||||
|
processor._setup_lock = asyncio.Lock()
|
||||||
|
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||||
|
|
||||||
# Create test configuration
|
# Create test configuration
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@
|
||||||
Tests for Cassandra triples storage service
|
Tests for Cassandra triples storage service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch, AsyncMock
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
|
||||||
|
|
@ -24,7 +26,8 @@ class TestCassandraStorageProcessor:
|
||||||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||||
assert processor.cassandra_username is None
|
assert processor.cassandra_username is None
|
||||||
assert processor.cassandra_password is None
|
assert processor.cassandra_password is None
|
||||||
assert processor.table is None
|
assert processor._connections == {}
|
||||||
|
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||||
|
|
||||||
def test_processor_initialization_with_custom_params(self):
|
def test_processor_initialization_with_custom_params(self):
|
||||||
"""Test processor initialization with custom parameters (new cassandra_* names)"""
|
"""Test processor initialization with custom parameters (new cassandra_* names)"""
|
||||||
|
|
@ -41,7 +44,8 @@ class TestCassandraStorageProcessor:
|
||||||
assert processor.cassandra_host == ['cassandra.example.com']
|
assert processor.cassandra_host == ['cassandra.example.com']
|
||||||
assert processor.cassandra_username == 'testuser'
|
assert processor.cassandra_username == 'testuser'
|
||||||
assert processor.cassandra_password == 'testpass'
|
assert processor.cassandra_password == 'testpass'
|
||||||
assert processor.table is None
|
assert processor._connections == {}
|
||||||
|
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||||
|
|
||||||
def test_processor_initialization_with_partial_auth(self):
|
def test_processor_initialization_with_partial_auth(self):
|
||||||
"""Test processor initialization with only username (no password)"""
|
"""Test processor initialization with only username (no password)"""
|
||||||
|
|
@ -92,6 +96,7 @@ class TestCassandraStorageProcessor:
|
||||||
"""Test table switching logic when authentication is provided"""
|
"""Test table switching logic when authentication is provided"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
|
mock_tg_instance.async_insert = AsyncMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
processor = Processor(
|
processor = Processor(
|
||||||
|
|
@ -114,7 +119,6 @@ class TestCassandraStorageProcessor:
|
||||||
username='testuser',
|
username='testuser',
|
||||||
password='testpass'
|
password='testpass'
|
||||||
)
|
)
|
||||||
assert processor.table == 'user1'
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||||
|
|
@ -122,6 +126,7 @@ class TestCassandraStorageProcessor:
|
||||||
"""Test table switching logic when no authentication is provided"""
|
"""Test table switching logic when no authentication is provided"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
|
mock_tg_instance.async_insert = AsyncMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
processor = Processor(taskgroup=taskgroup_mock)
|
processor = Processor(taskgroup=taskgroup_mock)
|
||||||
|
|
@ -138,7 +143,6 @@ class TestCassandraStorageProcessor:
|
||||||
hosts=['cassandra'], # Updated default
|
hosts=['cassandra'], # Updated default
|
||||||
keyspace='user2'
|
keyspace='user2'
|
||||||
)
|
)
|
||||||
assert processor.table == 'user2'
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||||
|
|
@ -146,6 +150,7 @@ class TestCassandraStorageProcessor:
|
||||||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
|
mock_tg_instance.async_insert = AsyncMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
processor = Processor(taskgroup=taskgroup_mock)
|
processor = Processor(taskgroup=taskgroup_mock)
|
||||||
|
|
@ -169,6 +174,7 @@ class TestCassandraStorageProcessor:
|
||||||
"""Test that triples are properly inserted into Cassandra"""
|
"""Test that triples are properly inserted into Cassandra"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
|
mock_tg_instance.async_insert = AsyncMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
processor = Processor(taskgroup=taskgroup_mock)
|
processor = Processor(taskgroup=taskgroup_mock)
|
||||||
|
|
@ -208,12 +214,12 @@ class TestCassandraStorageProcessor:
|
||||||
await processor.store_triples('user1', mock_message)
|
await processor.store_triples('user1', mock_message)
|
||||||
|
|
||||||
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
||||||
assert mock_tg_instance.insert.call_count == 2
|
assert mock_tg_instance.async_insert.call_count == 2
|
||||||
mock_tg_instance.insert.assert_any_call(
|
mock_tg_instance.async_insert.assert_any_call(
|
||||||
'collection1', 'subject1', 'predicate1', 'object1',
|
'collection1', 'subject1', 'predicate1', 'object1',
|
||||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||||
)
|
)
|
||||||
mock_tg_instance.insert.assert_any_call(
|
mock_tg_instance.async_insert.assert_any_call(
|
||||||
'collection1', 'subject2', 'predicate2', 'object2',
|
'collection1', 'subject2', 'predicate2', 'object2',
|
||||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||||
)
|
)
|
||||||
|
|
@ -224,6 +230,7 @@ class TestCassandraStorageProcessor:
|
||||||
"""Test behavior when message has no triples"""
|
"""Test behavior when message has no triples"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
|
mock_tg_instance.async_insert = AsyncMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
processor = Processor(taskgroup=taskgroup_mock)
|
processor = Processor(taskgroup=taskgroup_mock)
|
||||||
|
|
@ -236,19 +243,17 @@ class TestCassandraStorageProcessor:
|
||||||
await processor.store_triples('user1', mock_message)
|
await processor.store_triples('user1', mock_message)
|
||||||
|
|
||||||
# Verify no triples were inserted
|
# Verify no triples were inserted
|
||||||
mock_tg_instance.insert.assert_not_called()
|
mock_tg_instance.async_insert.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
async def test_exception_handling_on_connection_failure(self, mock_kg_class):
|
||||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
|
|
||||||
"""Test exception handling during TrustGraph creation"""
|
"""Test exception handling during TrustGraph creation"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_kg_class.side_effect = Exception("Connection failed")
|
mock_kg_class.side_effect = Exception("Connection failed")
|
||||||
|
|
||||||
processor = Processor(taskgroup=taskgroup_mock)
|
processor = Processor(taskgroup=taskgroup_mock)
|
||||||
|
|
||||||
# Create mock message
|
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.collection = 'collection1'
|
mock_message.metadata.collection = 'collection1'
|
||||||
mock_message.triples = []
|
mock_message.triples = []
|
||||||
|
|
@ -256,9 +261,6 @@ class TestCassandraStorageProcessor:
|
||||||
with pytest.raises(Exception, match="Connection failed"):
|
with pytest.raises(Exception, match="Connection failed"):
|
||||||
await processor.store_triples('user1', mock_message)
|
await processor.store_triples('user1', mock_message)
|
||||||
|
|
||||||
# Verify sleep was called before re-raising
|
|
||||||
mock_sleep.assert_called_once_with(1)
|
|
||||||
|
|
||||||
def test_add_args_method(self):
|
def test_add_args_method(self):
|
||||||
"""Test that add_args properly configures argument parser"""
|
"""Test that add_args properly configures argument parser"""
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
@ -359,8 +361,6 @@ class TestCassandraStorageProcessor:
|
||||||
mock_message1.triples = []
|
mock_message1.triples = []
|
||||||
|
|
||||||
await processor.store_triples('user1', mock_message1)
|
await processor.store_triples('user1', mock_message1)
|
||||||
assert processor.table == 'user1'
|
|
||||||
assert processor.tg == mock_tg_instance1
|
|
||||||
|
|
||||||
# Second message with different table
|
# Second message with different table
|
||||||
mock_message2 = MagicMock()
|
mock_message2 = MagicMock()
|
||||||
|
|
@ -368,11 +368,11 @@ class TestCassandraStorageProcessor:
|
||||||
mock_message2.triples = []
|
mock_message2.triples = []
|
||||||
|
|
||||||
await processor.store_triples('user2', mock_message2)
|
await processor.store_triples('user2', mock_message2)
|
||||||
assert processor.table == 'user2'
|
|
||||||
assert processor.tg == mock_tg_instance2
|
|
||||||
|
|
||||||
# Verify TrustGraph was created twice for different tables
|
# Verify TrustGraph was created twice for different workspaces
|
||||||
assert mock_kg_class.call_count == 2
|
assert mock_kg_class.call_count == 2
|
||||||
|
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||||
|
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||||
|
|
@ -380,6 +380,7 @@ class TestCassandraStorageProcessor:
|
||||||
"""Test storing triples with special characters and unicode"""
|
"""Test storing triples with special characters and unicode"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
|
mock_tg_instance.async_insert = AsyncMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
processor = Processor(taskgroup=taskgroup_mock)
|
processor = Processor(taskgroup=taskgroup_mock)
|
||||||
|
|
@ -405,7 +406,7 @@ class TestCassandraStorageProcessor:
|
||||||
await processor.store_triples('test_workspace', mock_message)
|
await processor.store_triples('test_workspace', mock_message)
|
||||||
|
|
||||||
# Verify the triple was inserted with special characters preserved
|
# Verify the triple was inserted with special characters preserved
|
||||||
mock_tg_instance.insert.assert_called_once_with(
|
mock_tg_instance.async_insert.assert_called_once_with(
|
||||||
'test_collection',
|
'test_collection',
|
||||||
'subject with spaces & symbols',
|
'subject with spaces & symbols',
|
||||||
'predicate:with/colons',
|
'predicate:with/colons',
|
||||||
|
|
@ -418,29 +419,29 @@ class TestCassandraStorageProcessor:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
|
async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class):
|
||||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
"""Test that a failed connection doesn't leave stale cached state"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
|
mock_good_instance = MagicMock()
|
||||||
|
|
||||||
processor = Processor(taskgroup=taskgroup_mock)
|
processor = Processor(taskgroup=taskgroup_mock)
|
||||||
|
|
||||||
# Set an initial table
|
|
||||||
processor.table = ('old_user', 'old_collection')
|
|
||||||
|
|
||||||
# Mock TrustGraph to raise exception
|
|
||||||
mock_kg_class.side_effect = Exception("Connection failed")
|
|
||||||
|
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.collection = 'new_collection'
|
mock_message.metadata.collection = 'collection1'
|
||||||
mock_message.triples = []
|
mock_message.triples = []
|
||||||
|
|
||||||
|
# First call fails
|
||||||
|
mock_kg_class.side_effect = Exception("Connection failed")
|
||||||
with pytest.raises(Exception, match="Connection failed"):
|
with pytest.raises(Exception, match="Connection failed"):
|
||||||
await processor.store_triples('new_user', mock_message)
|
await processor.store_triples('user1', mock_message)
|
||||||
|
|
||||||
# Table should remain unchanged since self.table = table happens after try/except
|
# Second call succeeds — should retry connection, not use stale state
|
||||||
assert processor.table == ('old_user', 'old_collection')
|
mock_kg_class.side_effect = None
|
||||||
# TrustGraph should be set to None though
|
mock_kg_class.return_value = mock_good_instance
|
||||||
assert processor.tg is None
|
await processor.store_triples('user1', mock_message)
|
||||||
|
|
||||||
|
# Connection was attempted twice (failed + succeeded)
|
||||||
|
assert mock_kg_class.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
class TestCassandraPerformanceOptimizations:
|
class TestCassandraPerformanceOptimizations:
|
||||||
|
|
@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations:
|
||||||
"""Test that legacy mode still works with single table"""
|
"""Test that legacy mode still works with single table"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
|
mock_tg_instance.async_insert = AsyncMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
|
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
|
||||||
|
|
@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations:
|
||||||
"""Test that optimized mode uses multi-table schema"""
|
"""Test that optimized mode uses multi-table schema"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
|
mock_tg_instance.async_insert = AsyncMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
|
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
|
||||||
|
|
@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations:
|
||||||
"""Test that all tables stay consistent during batch writes"""
|
"""Test that all tables stay consistent during batch writes"""
|
||||||
taskgroup_mock = MagicMock()
|
taskgroup_mock = MagicMock()
|
||||||
mock_tg_instance = MagicMock()
|
mock_tg_instance = MagicMock()
|
||||||
|
mock_tg_instance.async_insert = AsyncMock()
|
||||||
mock_kg_class.return_value = mock_tg_instance
|
mock_kg_class.return_value = mock_tg_instance
|
||||||
|
|
||||||
processor = Processor(taskgroup=taskgroup_mock)
|
processor = Processor(taskgroup=taskgroup_mock)
|
||||||
|
|
@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations:
|
||||||
await processor.store_triples('user1', mock_message)
|
await processor.store_triples('user1', mock_message)
|
||||||
|
|
||||||
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
||||||
mock_tg_instance.insert.assert_called_once_with(
|
mock_tg_instance.async_insert.assert_called_once_with(
|
||||||
'collection1', 'test_subject', 'test_predicate', 'test_object',
|
'collection1', 'test_subject', 'test_predicate', 'test_object',
|
||||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,8 @@ class TestSanitizeName:
|
||||||
|
|
||||||
class TestFindCollection:
|
class TestFindCollection:
|
||||||
|
|
||||||
def test_finds_matching_collection(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_finds_matching_collection(self):
|
||||||
proc = _make_processor()
|
proc = _make_processor()
|
||||||
mock_coll = MagicMock()
|
mock_coll = MagicMock()
|
||||||
mock_coll.name = "rows_test_workspace_test_col_customers_384"
|
mock_coll.name = "rows_test_workspace_test_col_customers_384"
|
||||||
|
|
@ -98,11 +99,12 @@ class TestFindCollection:
|
||||||
mock_collections.collections = [mock_coll]
|
mock_collections.collections = [mock_coll]
|
||||||
proc.qdrant.get_collections.return_value = mock_collections
|
proc.qdrant.get_collections.return_value = mock_collections
|
||||||
|
|
||||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||||
|
|
||||||
assert result == "rows_test_workspace_test_col_customers_384"
|
assert result == "rows_test_workspace_test_col_customers_384"
|
||||||
|
|
||||||
def test_returns_none_when_no_match(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_no_match(self):
|
||||||
proc = _make_processor()
|
proc = _make_processor()
|
||||||
mock_coll = MagicMock()
|
mock_coll = MagicMock()
|
||||||
mock_coll.name = "rows_other_workspace_other_col_schema_768"
|
mock_coll.name = "rows_other_workspace_other_col_schema_768"
|
||||||
|
|
@ -111,14 +113,15 @@ class TestFindCollection:
|
||||||
mock_collections.collections = [mock_coll]
|
mock_collections.collections = [mock_coll]
|
||||||
proc.qdrant.get_collections.return_value = mock_collections
|
proc.qdrant.get_collections.return_value = mock_collections
|
||||||
|
|
||||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
def test_returns_none_on_error(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_on_error(self):
|
||||||
proc = _make_processor()
|
proc = _make_processor()
|
||||||
proc.qdrant.get_collections.side_effect = Exception("connection error")
|
proc.qdrant.get_collections.side_effect = Exception("connection error")
|
||||||
|
|
||||||
result = proc.find_collection("workspace", "col", "schema")
|
result = await proc.find_collection("workspace", "col", "schema")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -139,7 +142,7 @@ class TestQueryRowEmbeddings:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_collection_returns_empty(self):
|
async def test_no_collection_returns_empty(self):
|
||||||
proc = _make_processor()
|
proc = _make_processor()
|
||||||
proc.find_collection = MagicMock(return_value=None)
|
proc.find_collection = AsyncMock(return_value=None)
|
||||||
request = _make_request()
|
request = _make_request()
|
||||||
|
|
||||||
result = await proc.query_row_embeddings("test-workspace", request)
|
result = await proc.query_row_embeddings("test-workspace", request)
|
||||||
|
|
@ -148,7 +151,7 @@ class TestQueryRowEmbeddings:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_successful_query_returns_matches(self):
|
async def test_successful_query_returns_matches(self):
|
||||||
proc = _make_processor()
|
proc = _make_processor()
|
||||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||||
|
|
||||||
points = [
|
points = [
|
||||||
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
|
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
|
||||||
|
|
@ -172,7 +175,7 @@ class TestQueryRowEmbeddings:
|
||||||
async def test_index_name_filter_applied(self):
|
async def test_index_name_filter_applied(self):
|
||||||
"""When index_name is specified, a Qdrant filter should be used."""
|
"""When index_name is specified, a Qdrant filter should be used."""
|
||||||
proc = _make_processor()
|
proc = _make_processor()
|
||||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||||
|
|
||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.points = []
|
mock_result.points = []
|
||||||
|
|
@ -188,7 +191,7 @@ class TestQueryRowEmbeddings:
|
||||||
async def test_no_index_name_no_filter(self):
|
async def test_no_index_name_no_filter(self):
|
||||||
"""When index_name is empty, no filter should be applied."""
|
"""When index_name is empty, no filter should be applied."""
|
||||||
proc = _make_processor()
|
proc = _make_processor()
|
||||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||||
|
|
||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.points = []
|
mock_result.points = []
|
||||||
|
|
@ -204,7 +207,7 @@ class TestQueryRowEmbeddings:
|
||||||
async def test_missing_payload_fields_default(self):
|
async def test_missing_payload_fields_default(self):
|
||||||
"""Points with missing payload fields should use defaults."""
|
"""Points with missing payload fields should use defaults."""
|
||||||
proc = _make_processor()
|
proc = _make_processor()
|
||||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||||
|
|
||||||
point = MagicMock()
|
point = MagicMock()
|
||||||
point.payload = {} # Empty payload
|
point.payload = {} # Empty payload
|
||||||
|
|
@ -225,7 +228,7 @@ class TestQueryRowEmbeddings:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_qdrant_error_propagates(self):
|
async def test_qdrant_error_propagates(self):
|
||||||
proc = _make_processor()
|
proc = _make_processor()
|
||||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||||
proc.qdrant.query_points.side_effect = Exception("qdrant down")
|
proc.qdrant.query_points.side_effect = Exception("qdrant down")
|
||||||
|
|
||||||
request = _make_request()
|
request = _make_request()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,14 @@
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from cassandra.query import BatchStatement, SimpleStatement
|
from cassandra.query import BatchStatement, SimpleStatement
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
||||||
import os
|
|
||||||
import logging
|
from ..tables.cassandra_async import async_execute
|
||||||
|
|
||||||
# Global list to track clusters for cleanup
|
# Global list to track clusters for cleanup
|
||||||
_active_clusters = []
|
_active_clusters = []
|
||||||
|
|
@ -461,7 +465,6 @@ class KnowledgeGraph:
|
||||||
def create_collection(self, collection):
|
def create_collection(self, collection):
|
||||||
"""Create collection by inserting metadata row"""
|
"""Create collection by inserting metadata row"""
|
||||||
try:
|
try:
|
||||||
import datetime
|
|
||||||
self.session.execute(
|
self.session.execute(
|
||||||
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
||||||
(collection, datetime.datetime.now())
|
(collection, datetime.datetime.now())
|
||||||
|
|
@ -954,7 +957,6 @@ class EntityCentricKnowledgeGraph:
|
||||||
def create_collection(self, collection):
|
def create_collection(self, collection):
|
||||||
"""Create collection by inserting metadata row"""
|
"""Create collection by inserting metadata row"""
|
||||||
try:
|
try:
|
||||||
import datetime
|
|
||||||
self.session.execute(
|
self.session.execute(
|
||||||
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
||||||
(collection, datetime.datetime.now())
|
(collection, datetime.datetime.now())
|
||||||
|
|
@ -1045,6 +1047,222 @@ class EntityCentricKnowledgeGraph:
|
||||||
|
|
||||||
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
|
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# Async methods — use cassandra driver's native async API via async_execute
|
||||||
|
# ========================================================================
|
||||||
|
|
||||||
|
async def async_insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""):
|
||||||
|
if g is None:
|
||||||
|
g = DEFAULT_GRAPH
|
||||||
|
if otype is None:
|
||||||
|
if o.startswith("http://") or o.startswith("https://"):
|
||||||
|
otype = "u"
|
||||||
|
else:
|
||||||
|
otype = "l"
|
||||||
|
|
||||||
|
batch = BatchStatement()
|
||||||
|
batch.add(self.insert_entity_stmt, (collection, s, 'S', p, otype, s, o, g, dtype, lang))
|
||||||
|
batch.add(self.insert_entity_stmt, (collection, p, 'P', p, otype, s, o, g, dtype, lang))
|
||||||
|
if otype == 'u' or otype == 't':
|
||||||
|
batch.add(self.insert_entity_stmt, (collection, o, 'O', p, otype, s, o, g, dtype, lang))
|
||||||
|
if g != DEFAULT_GRAPH:
|
||||||
|
batch.add(self.insert_entity_stmt, (collection, g, 'G', p, otype, s, o, g, dtype, lang))
|
||||||
|
batch.add(self.insert_collection_stmt, (collection, g, s, p, o, otype, dtype, lang))
|
||||||
|
|
||||||
|
await async_execute(self.session, batch)
|
||||||
|
|
||||||
|
async def async_get_all(self, collection, limit=50):
|
||||||
|
return await async_execute(
|
||||||
|
self.session, self.get_collection_all_stmt, (collection, limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_get_s(self, collection, s, g=None, limit=10):
|
||||||
|
rows = await async_execute(
|
||||||
|
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||||
|
if g is not None and d != g:
|
||||||
|
continue
|
||||||
|
results.append(QuadResult(
|
||||||
|
s=row.s, p=row.p, o=row.o, g=d,
|
||||||
|
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def async_get_p(self, collection, p, g=None, limit=10):
|
||||||
|
rows = await async_execute(
|
||||||
|
self.session, self.get_entity_as_p_stmt, (collection, p, limit)
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||||
|
if g is not None and d != g:
|
||||||
|
continue
|
||||||
|
results.append(QuadResult(
|
||||||
|
s=row.s, p=row.p, o=row.o, g=d,
|
||||||
|
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def async_get_o(self, collection, o, g=None, limit=10):
|
||||||
|
rows = await async_execute(
|
||||||
|
self.session, self.get_entity_as_o_stmt, (collection, o, limit)
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||||
|
if g is not None and d != g:
|
||||||
|
continue
|
||||||
|
results.append(QuadResult(
|
||||||
|
s=row.s, p=row.p, o=row.o, g=d,
|
||||||
|
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def async_get_sp(self, collection, s, p, g=None, limit=10):
|
||||||
|
rows = await async_execute(
|
||||||
|
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||||
|
if g is not None and d != g:
|
||||||
|
continue
|
||||||
|
results.append(QuadResult(
|
||||||
|
s=s, p=p, o=row.o, g=d,
|
||||||
|
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def async_get_po(self, collection, p, o, g=None, limit=10):
|
||||||
|
rows = await async_execute(
|
||||||
|
self.session, self.get_entity_as_o_p_stmt, (collection, o, p, limit)
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||||
|
if g is not None and d != g:
|
||||||
|
continue
|
||||||
|
results.append(QuadResult(
|
||||||
|
s=row.s, p=p, o=o, g=d,
|
||||||
|
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def async_get_os(self, collection, o, s, g=None, limit=10):
|
||||||
|
rows = await async_execute(
|
||||||
|
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
if row.o != o:
|
||||||
|
continue
|
||||||
|
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||||
|
if g is not None and d != g:
|
||||||
|
continue
|
||||||
|
results.append(QuadResult(
|
||||||
|
s=s, p=row.p, o=o, g=d,
|
||||||
|
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def async_get_spo(self, collection, s, p, o, g=None, limit=10):
|
||||||
|
rows = await async_execute(
|
||||||
|
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
if row.o != o:
|
||||||
|
continue
|
||||||
|
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||||
|
if g is not None and d != g:
|
||||||
|
continue
|
||||||
|
results.append(QuadResult(
|
||||||
|
s=s, p=p, o=o, g=d,
|
||||||
|
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def async_get_g(self, collection, g, limit=50):
|
||||||
|
if g is None:
|
||||||
|
g = DEFAULT_GRAPH
|
||||||
|
return await async_execute(
|
||||||
|
self.session, self.get_collection_by_graph_stmt, (collection, g, limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_collection_exists(self, collection):
|
||||||
|
try:
|
||||||
|
result = await async_execute(
|
||||||
|
self.session,
|
||||||
|
f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1",
|
||||||
|
(collection,)
|
||||||
|
)
|
||||||
|
return bool(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking collection existence: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def async_create_collection(self, collection):
|
||||||
|
await async_execute(
|
||||||
|
self.session,
|
||||||
|
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
||||||
|
(collection, datetime.datetime.now())
|
||||||
|
)
|
||||||
|
logger.info(f"Created collection metadata for {collection}")
|
||||||
|
|
||||||
|
async def async_delete_collection(self, collection):
|
||||||
|
rows = await async_execute(
|
||||||
|
self.session,
|
||||||
|
f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s",
|
||||||
|
(collection,)
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = set()
|
||||||
|
quads = []
|
||||||
|
for row in rows:
|
||||||
|
d, s, p, o = row.d, row.s, row.p, row.o
|
||||||
|
otype = row.otype
|
||||||
|
dtype = row.dtype if hasattr(row, 'dtype') else ''
|
||||||
|
lang = row.lang if hasattr(row, 'lang') else ''
|
||||||
|
quads.append((d, s, p, o, otype, dtype, lang))
|
||||||
|
entities.add(s)
|
||||||
|
entities.add(p)
|
||||||
|
if otype == 'u' or otype == 't':
|
||||||
|
entities.add(o)
|
||||||
|
if d != DEFAULT_GRAPH:
|
||||||
|
entities.add(d)
|
||||||
|
|
||||||
|
batch = BatchStatement()
|
||||||
|
count = 0
|
||||||
|
for entity in entities:
|
||||||
|
batch.add(self.delete_entity_partition_stmt, (collection, entity))
|
||||||
|
count += 1
|
||||||
|
if count % 50 == 0:
|
||||||
|
await async_execute(self.session, batch)
|
||||||
|
batch = BatchStatement()
|
||||||
|
if count % 50 != 0:
|
||||||
|
await async_execute(self.session, batch)
|
||||||
|
|
||||||
|
batch = BatchStatement()
|
||||||
|
count = 0
|
||||||
|
for d, s, p, o, otype, dtype, lang in quads:
|
||||||
|
batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang))
|
||||||
|
count += 1
|
||||||
|
if count % 50 == 0:
|
||||||
|
await async_execute(self.session, batch)
|
||||||
|
batch = BatchStatement()
|
||||||
|
if count % 50 != 0:
|
||||||
|
await async_execute(self.session, batch)
|
||||||
|
|
||||||
|
await async_execute(
|
||||||
|
self.session,
|
||||||
|
f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s",
|
||||||
|
(collection,)
|
||||||
|
)
|
||||||
|
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close connections"""
|
"""Close connections"""
|
||||||
if hasattr(self, 'session') and self.session:
|
if hasattr(self, 'session') and self.session:
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,10 @@ Document embeddings query service. Input is vector, output is an array
|
||||||
of chunk_ids
|
of chunk_ids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import PointStruct
|
|
||||||
from qdrant_client.models import Distance, VectorParams
|
|
||||||
|
|
||||||
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||||
from .... schema import Error
|
from .... schema import Error
|
||||||
|
|
@ -38,32 +37,6 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||||
self.last_collection = None
|
|
||||||
|
|
||||||
def ensure_collection_exists(self, collection, dim):
|
|
||||||
"""Ensure collection exists, create if it doesn't"""
|
|
||||||
if collection != self.last_collection:
|
|
||||||
if not self.qdrant.collection_exists(collection):
|
|
||||||
try:
|
|
||||||
self.qdrant.create_collection(
|
|
||||||
collection_name=collection,
|
|
||||||
vectors_config=VectorParams(
|
|
||||||
size=dim, distance=Distance.COSINE
|
|
||||||
),
|
|
||||||
)
|
|
||||||
logger.info(f"Created collection: {collection}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Qdrant collection creation failed: {e}")
|
|
||||||
raise e
|
|
||||||
self.last_collection = collection
|
|
||||||
|
|
||||||
def collection_exists(self, collection):
|
|
||||||
"""Check if collection exists (no implicit creation)"""
|
|
||||||
return self.qdrant.collection_exists(collection)
|
|
||||||
|
|
||||||
def collection_exists(self, collection):
|
|
||||||
"""Check if collection exists (no implicit creation)"""
|
|
||||||
return self.qdrant.collection_exists(collection)
|
|
||||||
|
|
||||||
async def query_document_embeddings(self, workspace, msg):
|
async def query_document_embeddings(self, workspace, msg):
|
||||||
|
|
||||||
|
|
@ -73,21 +46,24 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
if not vec:
|
if not vec:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Use dimension suffix in collection name
|
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
collection = f"d_{workspace}_{msg.collection}_{dim}"
|
collection = f"d_{workspace}_{msg.collection}_{dim}"
|
||||||
|
|
||||||
# Check if collection exists - return empty if not
|
exists = await asyncio.to_thread(
|
||||||
if not self.collection_exists(collection):
|
self.qdrant.collection_exists, collection
|
||||||
|
)
|
||||||
|
if not exists:
|
||||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
search_result = self.qdrant.query_points(
|
result = await asyncio.to_thread(
|
||||||
|
self.qdrant.query_points,
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
query=vec,
|
query=vec,
|
||||||
limit=msg.limit,
|
limit=msg.limit,
|
||||||
with_payload=True,
|
with_payload=True,
|
||||||
).points
|
)
|
||||||
|
search_result = result.points
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
for r in search_result:
|
for r in search_result:
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,10 @@ Graph embeddings query service. Input is vector, output is list of
|
||||||
entities
|
entities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import PointStruct
|
|
||||||
from qdrant_client.models import Distance, VectorParams
|
|
||||||
|
|
||||||
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||||
from .... schema import Error, Term, IRI, LITERAL
|
from .... schema import Error, Term, IRI, LITERAL
|
||||||
|
|
@ -38,32 +37,6 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||||
self.last_collection = None
|
|
||||||
|
|
||||||
def ensure_collection_exists(self, collection, dim):
|
|
||||||
"""Ensure collection exists, create if it doesn't"""
|
|
||||||
if collection != self.last_collection:
|
|
||||||
if not self.qdrant.collection_exists(collection):
|
|
||||||
try:
|
|
||||||
self.qdrant.create_collection(
|
|
||||||
collection_name=collection,
|
|
||||||
vectors_config=VectorParams(
|
|
||||||
size=dim, distance=Distance.COSINE
|
|
||||||
),
|
|
||||||
)
|
|
||||||
logger.info(f"Created collection: {collection}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Qdrant collection creation failed: {e}")
|
|
||||||
raise e
|
|
||||||
self.last_collection = collection
|
|
||||||
|
|
||||||
def collection_exists(self, collection):
|
|
||||||
"""Check if collection exists (no implicit creation)"""
|
|
||||||
return self.qdrant.collection_exists(collection)
|
|
||||||
|
|
||||||
def collection_exists(self, collection):
|
|
||||||
"""Check if collection exists (no implicit creation)"""
|
|
||||||
return self.qdrant.collection_exists(collection)
|
|
||||||
|
|
||||||
def create_value(self, ent):
|
def create_value(self, ent):
|
||||||
if ent.startswith("http://") or ent.startswith("https://"):
|
if ent.startswith("http://") or ent.startswith("https://"):
|
||||||
|
|
@ -79,23 +52,26 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
if not vec:
|
if not vec:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Use dimension suffix in collection name
|
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
collection = f"t_{workspace}_{msg.collection}_{dim}"
|
collection = f"t_{workspace}_{msg.collection}_{dim}"
|
||||||
|
|
||||||
# Check if collection exists - return empty if not
|
exists = await asyncio.to_thread(
|
||||||
if not self.collection_exists(collection):
|
self.qdrant.collection_exists, collection
|
||||||
|
)
|
||||||
|
if not exists:
|
||||||
logger.info(f"Collection {collection} does not exist")
|
logger.info(f"Collection {collection} does not exist")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Heuristic hack, get (2*limit), so that we have more chance
|
# Heuristic hack, get (2*limit), so that we have more chance
|
||||||
# of getting (limit) unique entities
|
# of getting (limit) unique entities
|
||||||
search_result = self.qdrant.query_points(
|
result = await asyncio.to_thread(
|
||||||
|
self.qdrant.query_points,
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
query=vec,
|
query=vec,
|
||||||
limit=msg.limit * 2,
|
limit=msg.limit * 2,
|
||||||
with_payload=True,
|
with_payload=True,
|
||||||
).points
|
)
|
||||||
|
search_result = result.points
|
||||||
|
|
||||||
entity_set = set()
|
entity_set = set()
|
||||||
entities = []
|
entities = []
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ Output is matching row index information (index_name, index_value) for
|
||||||
use in subsequent Cassandra lookups.
|
use in subsequent Cassandra lookups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
@ -70,7 +71,7 @@ class Processor(FlowProcessor):
|
||||||
safe_name = 'r_' + safe_name
|
safe_name = 'r_' + safe_name
|
||||||
return safe_name.lower()
|
return safe_name.lower()
|
||||||
|
|
||||||
def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
|
async def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
|
||||||
"""Find the Qdrant collection for a given workspace/collection/schema"""
|
"""Find the Qdrant collection for a given workspace/collection/schema"""
|
||||||
prefix = (
|
prefix = (
|
||||||
f"rows_{self.sanitize_name(workspace)}_"
|
f"rows_{self.sanitize_name(workspace)}_"
|
||||||
|
|
@ -78,14 +79,15 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
all_collections = self.qdrant.get_collections().collections
|
all_collections = await asyncio.to_thread(
|
||||||
|
lambda: self.qdrant.get_collections().collections
|
||||||
|
)
|
||||||
matching = [
|
matching = [
|
||||||
coll.name for coll in all_collections
|
coll.name for coll in all_collections
|
||||||
if coll.name.startswith(prefix)
|
if coll.name.startswith(prefix)
|
||||||
]
|
]
|
||||||
|
|
||||||
if matching:
|
if matching:
|
||||||
# Return first match (there should typically be only one per dimension)
|
|
||||||
return matching[0]
|
return matching[0]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -100,8 +102,7 @@ class Processor(FlowProcessor):
|
||||||
if not vec:
|
if not vec:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Find the collection for this workspace/collection/schema
|
qdrant_collection = await self.find_collection(
|
||||||
qdrant_collection = self.find_collection(
|
|
||||||
workspace, request.collection, request.schema_name
|
workspace, request.collection, request.schema_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -113,7 +114,6 @@ class Processor(FlowProcessor):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build optional filter for index_name
|
|
||||||
query_filter = None
|
query_filter = None
|
||||||
if request.index_name:
|
if request.index_name:
|
||||||
query_filter = Filter(
|
query_filter = Filter(
|
||||||
|
|
@ -125,16 +125,16 @@ class Processor(FlowProcessor):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query Qdrant
|
result = await asyncio.to_thread(
|
||||||
search_result = self.qdrant.query_points(
|
self.qdrant.query_points,
|
||||||
collection_name=qdrant_collection,
|
collection_name=qdrant_collection,
|
||||||
query=vec,
|
query=vec,
|
||||||
limit=request.limit,
|
limit=request.limit,
|
||||||
with_payload=True,
|
with_payload=True,
|
||||||
query_filter=query_filter,
|
query_filter=query_filter,
|
||||||
).points
|
)
|
||||||
|
search_result = result.points
|
||||||
|
|
||||||
# Convert to RowIndexMatch objects
|
|
||||||
matches = []
|
matches = []
|
||||||
for point in search_result:
|
for point in search_result:
|
||||||
payload = point.payload or {}
|
payload = point.payload or {}
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ Queries against the unified 'rows' table with schema:
|
||||||
- source: text
|
- source: text
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
@ -97,34 +98,38 @@ class Processor(FlowProcessor):
|
||||||
# Cassandra session
|
# Cassandra session
|
||||||
self.cluster = None
|
self.cluster = None
|
||||||
self.session = None
|
self.session = None
|
||||||
|
self._setup_lock = asyncio.Lock()
|
||||||
|
|
||||||
# Known keyspaces
|
# Known keyspaces
|
||||||
self.known_keyspaces: Set[str] = set()
|
self.known_keyspaces: Set[str] = set()
|
||||||
|
|
||||||
def connect_cassandra(self):
|
async def connect_cassandra(self):
|
||||||
"""Connect to Cassandra cluster"""
|
"""Connect to Cassandra cluster"""
|
||||||
if self.session:
|
async with self._setup_lock:
|
||||||
return
|
if self.session:
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.cassandra_username and self.cassandra_password:
|
if self.cassandra_username and self.cassandra_password:
|
||||||
auth_provider = PlainTextAuthProvider(
|
auth_provider = PlainTextAuthProvider(
|
||||||
username=self.cassandra_username,
|
username=self.cassandra_username,
|
||||||
password=self.cassandra_password
|
password=self.cassandra_password
|
||||||
)
|
)
|
||||||
self.cluster = Cluster(
|
cluster = Cluster(
|
||||||
contact_points=self.cassandra_host,
|
contact_points=self.cassandra_host,
|
||||||
auth_provider=auth_provider
|
auth_provider=auth_provider
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
cluster = Cluster(contact_points=self.cassandra_host)
|
||||||
|
|
||||||
self.session = self.cluster.connect()
|
session = await asyncio.to_thread(cluster.connect)
|
||||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
self.cluster = cluster
|
||||||
|
self.session = session
|
||||||
|
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def sanitize_name(self, name: str) -> str:
|
def sanitize_name(self, name: str) -> str:
|
||||||
"""Sanitize names for Cassandra compatibility"""
|
"""Sanitize names for Cassandra compatibility"""
|
||||||
|
|
@ -140,14 +145,17 @@ class Processor(FlowProcessor):
|
||||||
f"for workspace {workspace}"
|
f"for workspace {workspace}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace existing schemas for this workspace
|
async with self._setup_lock:
|
||||||
|
await self._apply_schema_config(workspace, config)
|
||||||
|
|
||||||
|
async def _apply_schema_config(self, workspace, config):
|
||||||
|
|
||||||
ws_schemas: Dict[str, RowSchema] = {}
|
ws_schemas: Dict[str, RowSchema] = {}
|
||||||
self.schemas[workspace] = ws_schemas
|
self.schemas[workspace] = ws_schemas
|
||||||
|
|
||||||
builder = GraphQLSchemaBuilder()
|
builder = GraphQLSchemaBuilder()
|
||||||
self.schema_builders[workspace] = builder
|
self.schema_builders[workspace] = builder
|
||||||
|
|
||||||
# Check if our config type exists
|
|
||||||
if self.config_key not in config:
|
if self.config_key not in config:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No '{self.config_key}' type in configuration "
|
f"No '{self.config_key}' type in configuration "
|
||||||
|
|
@ -156,16 +164,12 @@ class Processor(FlowProcessor):
|
||||||
self.graphql_schemas[workspace] = None
|
self.graphql_schemas[workspace] = None
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get the schemas dictionary for our type
|
|
||||||
schemas_config = config[self.config_key]
|
schemas_config = config[self.config_key]
|
||||||
|
|
||||||
# Process each schema in the schemas config
|
|
||||||
for schema_name, schema_json in schemas_config.items():
|
for schema_name, schema_json in schemas_config.items():
|
||||||
try:
|
try:
|
||||||
# Parse the JSON schema definition
|
|
||||||
schema_def = json.loads(schema_json)
|
schema_def = json.loads(schema_json)
|
||||||
|
|
||||||
# Create Field objects
|
|
||||||
fields = []
|
fields = []
|
||||||
for field_def in schema_def.get("fields", []):
|
for field_def in schema_def.get("fields", []):
|
||||||
field = SchemaField(
|
field = SchemaField(
|
||||||
|
|
@ -180,7 +184,6 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
fields.append(field)
|
fields.append(field)
|
||||||
|
|
||||||
# Create RowSchema
|
|
||||||
row_schema = RowSchema(
|
row_schema = RowSchema(
|
||||||
name=schema_def.get("name", schema_name),
|
name=schema_def.get("name", schema_name),
|
||||||
description=schema_def.get("description", ""),
|
description=schema_def.get("description", ""),
|
||||||
|
|
@ -202,7 +205,6 @@ class Processor(FlowProcessor):
|
||||||
f"{len(ws_schemas)} schemas"
|
f"{len(ws_schemas)} schemas"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Regenerate GraphQL schema for this workspace
|
|
||||||
self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
|
self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
|
||||||
|
|
||||||
def get_index_names(self, schema: RowSchema) -> List[str]:
|
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||||
|
|
@ -254,7 +256,7 @@ class Processor(FlowProcessor):
|
||||||
For other queries, we need to scan and post-filter.
|
For other queries, we need to scan and post-filter.
|
||||||
"""
|
"""
|
||||||
# Connect if needed
|
# Connect if needed
|
||||||
self.connect_cassandra()
|
await self.connect_cassandra()
|
||||||
|
|
||||||
safe_keyspace = self.sanitize_name(workspace)
|
safe_keyspace = self.sanitize_name(workspace)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ null. Output is a list of quads.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from cassandra.query import SimpleStatement
|
from cassandra.query import SimpleStatement
|
||||||
|
|
||||||
from .... direct.cassandra_kg import (
|
from .... direct.cassandra_kg import (
|
||||||
|
|
@ -17,6 +17,7 @@ from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||||
from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK
|
from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK
|
||||||
from .... base import TriplesQueryService
|
from .... base import TriplesQueryService
|
||||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
|
from .... tables.cassandra_async import async_execute
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -176,45 +177,42 @@ class Processor(TriplesQueryService):
|
||||||
self.cassandra_host = hosts
|
self.cassandra_host = hosts
|
||||||
self.cassandra_username = username
|
self.cassandra_username = username
|
||||||
self.cassandra_password = password
|
self.cassandra_password = password
|
||||||
self.table = None
|
|
||||||
|
|
||||||
def ensure_connection(self, workspace):
|
self._connections = {}
|
||||||
"""Ensure we have a connection to the correct keyspace."""
|
self._conn_lock = asyncio.Lock()
|
||||||
if workspace != self.table:
|
|
||||||
KGClass = EntityCentricKnowledgeGraph
|
|
||||||
|
|
||||||
if self.cassandra_username and self.cassandra_password:
|
async def _get_connection(self, workspace):
|
||||||
self.tg = KGClass(
|
async with self._conn_lock:
|
||||||
hosts=self.cassandra_host,
|
if workspace not in self._connections:
|
||||||
keyspace=workspace,
|
if self.cassandra_username and self.cassandra_password:
|
||||||
username=self.cassandra_username,
|
tg = await asyncio.to_thread(
|
||||||
password=self.cassandra_password
|
EntityCentricKnowledgeGraph,
|
||||||
)
|
hosts=self.cassandra_host,
|
||||||
else:
|
keyspace=workspace,
|
||||||
self.tg = KGClass(
|
username=self.cassandra_username,
|
||||||
hosts=self.cassandra_host,
|
password=self.cassandra_password,
|
||||||
keyspace=workspace,
|
)
|
||||||
)
|
else:
|
||||||
self.table = workspace
|
tg = await asyncio.to_thread(
|
||||||
|
EntityCentricKnowledgeGraph,
|
||||||
|
hosts=self.cassandra_host,
|
||||||
|
keyspace=workspace,
|
||||||
|
)
|
||||||
|
self._connections[workspace] = tg
|
||||||
|
return self._connections[workspace]
|
||||||
|
|
||||||
async def query_triples(self, workspace, query):
|
async def query_triples(self, workspace, query):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# ensure_connection may construct a fresh
|
|
||||||
# EntityCentricKnowledgeGraph which does sync schema
|
|
||||||
# setup against Cassandra. Push it to a worker thread
|
|
||||||
# so the event loop doesn't block on first-use per workspace.
|
|
||||||
await asyncio.to_thread(self.ensure_connection, workspace)
|
|
||||||
|
|
||||||
# Extract values from query
|
|
||||||
s_val = get_term_value(query.s)
|
s_val = get_term_value(query.s)
|
||||||
p_val = get_term_value(query.p)
|
p_val = get_term_value(query.p)
|
||||||
o_val = get_term_value(query.o)
|
o_val = get_term_value(query.o)
|
||||||
g_val = query.g # Already a string or None
|
g_val = query.g
|
||||||
|
|
||||||
|
tg = await self._get_connection(workspace)
|
||||||
|
|
||||||
def get_object_metadata(row):
|
def get_object_metadata(row):
|
||||||
"""Extract term type metadata from result row"""
|
|
||||||
return (
|
return (
|
||||||
getattr(row, 'otype', None),
|
getattr(row, 'otype', None),
|
||||||
getattr(row, 'dtype', None),
|
getattr(row, 'dtype', None),
|
||||||
|
|
@ -223,33 +221,21 @@ class Processor(TriplesQueryService):
|
||||||
|
|
||||||
quads = []
|
quads = []
|
||||||
|
|
||||||
# All self.tg.get_* calls below are sync wrappers around
|
|
||||||
# cassandra session.execute. Materialise inside a worker
|
|
||||||
# thread so iteration never triggers sync paging back on
|
|
||||||
# the event loop.
|
|
||||||
|
|
||||||
# Route to appropriate query method based on which fields are specified
|
|
||||||
if s_val is not None:
|
if s_val is not None:
|
||||||
if p_val is not None:
|
if p_val is not None:
|
||||||
if o_val is not None:
|
if o_val is not None:
|
||||||
# SPO specified - find matching graphs
|
resp = await tg.async_get_spo(
|
||||||
resp = await asyncio.to_thread(
|
query.collection, s_val, p_val, o_val,
|
||||||
lambda: list(self.tg.get_spo(
|
g=g_val, limit=query.limit,
|
||||||
query.collection, s_val, p_val, o_val,
|
|
||||||
g=g_val, limit=query.limit,
|
|
||||||
))
|
|
||||||
)
|
)
|
||||||
for t in resp:
|
for t in resp:
|
||||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||||
term_type, datatype, language = get_object_metadata(t)
|
term_type, datatype, language = get_object_metadata(t)
|
||||||
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
|
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
|
||||||
else:
|
else:
|
||||||
# SP specified
|
resp = await tg.async_get_sp(
|
||||||
resp = await asyncio.to_thread(
|
query.collection, s_val, p_val,
|
||||||
lambda: list(self.tg.get_sp(
|
g=g_val, limit=query.limit,
|
||||||
query.collection, s_val, p_val,
|
|
||||||
g=g_val, limit=query.limit,
|
|
||||||
))
|
|
||||||
)
|
)
|
||||||
for t in resp:
|
for t in resp:
|
||||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||||
|
|
@ -257,24 +243,18 @@ class Processor(TriplesQueryService):
|
||||||
quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
|
quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
|
||||||
else:
|
else:
|
||||||
if o_val is not None:
|
if o_val is not None:
|
||||||
# SO specified
|
resp = await tg.async_get_os(
|
||||||
resp = await asyncio.to_thread(
|
query.collection, o_val, s_val,
|
||||||
lambda: list(self.tg.get_os(
|
g=g_val, limit=query.limit,
|
||||||
query.collection, o_val, s_val,
|
|
||||||
g=g_val, limit=query.limit,
|
|
||||||
))
|
|
||||||
)
|
)
|
||||||
for t in resp:
|
for t in resp:
|
||||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||||
term_type, datatype, language = get_object_metadata(t)
|
term_type, datatype, language = get_object_metadata(t)
|
||||||
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
|
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
|
||||||
else:
|
else:
|
||||||
# S only
|
resp = await tg.async_get_s(
|
||||||
resp = await asyncio.to_thread(
|
query.collection, s_val,
|
||||||
lambda: list(self.tg.get_s(
|
g=g_val, limit=query.limit,
|
||||||
query.collection, s_val,
|
|
||||||
g=g_val, limit=query.limit,
|
|
||||||
))
|
|
||||||
)
|
)
|
||||||
for t in resp:
|
for t in resp:
|
||||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||||
|
|
@ -283,24 +263,18 @@ class Processor(TriplesQueryService):
|
||||||
else:
|
else:
|
||||||
if p_val is not None:
|
if p_val is not None:
|
||||||
if o_val is not None:
|
if o_val is not None:
|
||||||
# PO specified
|
resp = await tg.async_get_po(
|
||||||
resp = await asyncio.to_thread(
|
query.collection, p_val, o_val,
|
||||||
lambda: list(self.tg.get_po(
|
g=g_val, limit=query.limit,
|
||||||
query.collection, p_val, o_val,
|
|
||||||
g=g_val, limit=query.limit,
|
|
||||||
))
|
|
||||||
)
|
)
|
||||||
for t in resp:
|
for t in resp:
|
||||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||||
term_type, datatype, language = get_object_metadata(t)
|
term_type, datatype, language = get_object_metadata(t)
|
||||||
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
|
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
|
||||||
else:
|
else:
|
||||||
# P only
|
resp = await tg.async_get_p(
|
||||||
resp = await asyncio.to_thread(
|
query.collection, p_val,
|
||||||
lambda: list(self.tg.get_p(
|
g=g_val, limit=query.limit,
|
||||||
query.collection, p_val,
|
|
||||||
g=g_val, limit=query.limit,
|
|
||||||
))
|
|
||||||
)
|
)
|
||||||
for t in resp:
|
for t in resp:
|
||||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||||
|
|
@ -308,40 +282,26 @@ class Processor(TriplesQueryService):
|
||||||
quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
|
quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
|
||||||
else:
|
else:
|
||||||
if o_val is not None:
|
if o_val is not None:
|
||||||
# O only
|
resp = await tg.async_get_o(
|
||||||
resp = await asyncio.to_thread(
|
query.collection, o_val,
|
||||||
lambda: list(self.tg.get_o(
|
g=g_val, limit=query.limit,
|
||||||
query.collection, o_val,
|
|
||||||
g=g_val, limit=query.limit,
|
|
||||||
))
|
|
||||||
)
|
)
|
||||||
for t in resp:
|
for t in resp:
|
||||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||||
term_type, datatype, language = get_object_metadata(t)
|
term_type, datatype, language = get_object_metadata(t)
|
||||||
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
|
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
|
||||||
else:
|
else:
|
||||||
# Nothing specified - get all
|
resp = await tg.async_get_all(
|
||||||
resp = await asyncio.to_thread(
|
query.collection, limit=query.limit,
|
||||||
lambda: list(self.tg.get_all(
|
|
||||||
query.collection, limit=query.limit,
|
|
||||||
))
|
|
||||||
)
|
)
|
||||||
for t in resp:
|
for t in resp:
|
||||||
# Note: quads_by_collection uses 'd' for graph field
|
|
||||||
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
|
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
|
||||||
# Filter by graph
|
|
||||||
# g_val=None means all graphs (no filter)
|
|
||||||
# g_val="" means default graph only
|
|
||||||
# otherwise filter to specific named graph
|
|
||||||
if g_val is not None:
|
if g_val is not None:
|
||||||
if g != g_val:
|
if g != g_val:
|
||||||
continue
|
continue
|
||||||
term_type, datatype, language = get_object_metadata(t)
|
term_type, datatype, language = get_object_metadata(t)
|
||||||
quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
|
quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
|
||||||
|
|
||||||
# Convert to Triple objects (with g field)
|
|
||||||
# s and p are always IRIs in RDF
|
|
||||||
# Object uses term_type/datatype/language metadata from database
|
|
||||||
triples = [
|
triples = [
|
||||||
Triple(
|
Triple(
|
||||||
s=create_term(q[0], term_type='u'),
|
s=create_term(q[0], term_type='u'),
|
||||||
|
|
@ -365,51 +325,36 @@ class Processor(TriplesQueryService):
|
||||||
Uses Cassandra's paging to fetch results incrementally.
|
Uses Cassandra's paging to fetch results incrementally.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(self.ensure_connection, workspace)
|
|
||||||
|
|
||||||
batch_size = query.batch_size if query.batch_size > 0 else 20
|
batch_size = query.batch_size if query.batch_size > 0 else 20
|
||||||
limit = query.limit if query.limit > 0 else 10000
|
limit = query.limit if query.limit > 0 else 10000
|
||||||
|
|
||||||
# Extract query pattern
|
|
||||||
s_val = get_term_value(query.s)
|
s_val = get_term_value(query.s)
|
||||||
p_val = get_term_value(query.p)
|
p_val = get_term_value(query.p)
|
||||||
o_val = get_term_value(query.o)
|
o_val = get_term_value(query.o)
|
||||||
g_val = query.g
|
g_val = query.g
|
||||||
|
|
||||||
def get_object_metadata(row):
|
def get_object_metadata(row):
|
||||||
"""Extract term type metadata from result row"""
|
|
||||||
return (
|
return (
|
||||||
getattr(row, 'otype', None),
|
getattr(row, 'otype', None),
|
||||||
getattr(row, 'dtype', None),
|
getattr(row, 'dtype', None),
|
||||||
getattr(row, 'lang', None),
|
getattr(row, 'lang', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# For streaming, we need to execute with fetch_size
|
|
||||||
# Use the collection table for get_all queries (most common streaming case)
|
|
||||||
|
|
||||||
# Determine which query to use based on pattern
|
|
||||||
if s_val is None and p_val is None and o_val is None:
|
if s_val is None and p_val is None and o_val is None:
|
||||||
# Get all - use collection table with paging
|
|
||||||
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
|
tg = await self._get_connection(workspace)
|
||||||
|
|
||||||
|
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s"
|
||||||
params = [query.collection]
|
params = [query.collection]
|
||||||
|
statement = SimpleStatement(cql, fetch_size=batch_size)
|
||||||
|
result_set = await async_execute(tg.session, statement, params)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# For specific patterns, fall back to non-streaming
|
|
||||||
# (these typically return small result sets anyway)
|
|
||||||
async for batch, is_final in self._fallback_stream(workspace, query, batch_size):
|
async for batch, is_final in self._fallback_stream(workspace, query, batch_size):
|
||||||
yield batch, is_final
|
yield batch, is_final
|
||||||
return
|
return
|
||||||
|
|
||||||
# Materialise in a worker thread. We lose true streaming
|
|
||||||
# paging (the driver fetches all pages eagerly inside the
|
|
||||||
# thread) but the event loop stays responsive, and result
|
|
||||||
# sets at this layer are typically small enough that this
|
|
||||||
# is acceptable. If true async paging is needed later,
|
|
||||||
# revisit using ResponseFuture page callbacks.
|
|
||||||
statement = SimpleStatement(cql, fetch_size=batch_size)
|
|
||||||
result_set = await asyncio.to_thread(
|
|
||||||
lambda: list(self.tg.session.execute(statement, params))
|
|
||||||
)
|
|
||||||
|
|
||||||
batch = []
|
batch = []
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,13 @@
|
||||||
Accepts entity/vector pairs and writes them to a Qdrant store.
|
Accepts entity/vector pairs and writes them to a Qdrant store.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
from qdrant_client.models import Distance, VectorParams
|
from qdrant_client.models import Distance, VectorParams
|
||||||
import uuid
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
|
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
|
||||||
from .... base import AsyncProcessor, Consumer, Producer
|
from .... base import AsyncProcessor, Consumer, Producer
|
||||||
|
|
@ -35,13 +37,35 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||||
|
self._cache_lock = asyncio.Lock()
|
||||||
|
self._known_collections: set[str] = set()
|
||||||
|
|
||||||
# Register for config push notifications
|
# Register for config push notifications
|
||||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||||
|
|
||||||
|
async def ensure_collection(self, collection_name, dim):
|
||||||
|
async with self._cache_lock:
|
||||||
|
if collection_name in self._known_collections:
|
||||||
|
return
|
||||||
|
exists = await asyncio.to_thread(
|
||||||
|
self.qdrant.collection_exists, collection_name
|
||||||
|
)
|
||||||
|
if not exists:
|
||||||
|
logger.info(
|
||||||
|
f"Lazily creating Qdrant collection {collection_name} "
|
||||||
|
f"with dimension {dim}"
|
||||||
|
)
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.qdrant.create_collection,
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=dim, distance=Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._known_collections.add(collection_name)
|
||||||
|
|
||||||
async def store_document_embeddings(self, workspace, message):
|
async def store_document_embeddings(self, workspace, message):
|
||||||
|
|
||||||
# Validate collection exists in config before processing
|
|
||||||
if not self.collection_exists(workspace, message.metadata.collection):
|
if not self.collection_exists(workspace, message.metadata.collection):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Collection {message.metadata.collection} for workspace {workspace} "
|
f"Collection {message.metadata.collection} for workspace {workspace} "
|
||||||
|
|
@ -60,24 +84,15 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
if not vec:
|
if not vec:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create collection name with dimension suffix for lazy creation
|
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
collection = (
|
collection = (
|
||||||
f"d_{workspace}_{message.metadata.collection}_{dim}"
|
f"d_{workspace}_{message.metadata.collection}_{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
await self.ensure_collection(collection, dim)
|
||||||
if not self.qdrant.collection_exists(collection):
|
|
||||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
|
||||||
self.qdrant.create_collection(
|
|
||||||
collection_name=collection,
|
|
||||||
vectors_config=VectorParams(
|
|
||||||
size=dim,
|
|
||||||
distance=Distance.COSINE
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.qdrant.upsert(
|
await asyncio.to_thread(
|
||||||
|
self.qdrant.upsert,
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
points=[
|
points=[
|
||||||
PointStruct(
|
PointStruct(
|
||||||
|
|
@ -87,7 +102,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
"chunk_id": chunk_id,
|
"chunk_id": chunk_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -124,8 +139,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
try:
|
try:
|
||||||
prefix = f"d_{workspace}_{collection}_"
|
prefix = f"d_{workspace}_{collection}_"
|
||||||
|
|
||||||
# Get all collections and filter for matches
|
all_collections = await asyncio.to_thread(
|
||||||
all_collections = self.qdrant.get_collections().collections
|
lambda: self.qdrant.get_collections().collections
|
||||||
|
)
|
||||||
matching_collections = [
|
matching_collections = [
|
||||||
coll.name for coll in all_collections
|
coll.name for coll in all_collections
|
||||||
if coll.name.startswith(prefix)
|
if coll.name.startswith(prefix)
|
||||||
|
|
@ -135,7 +151,11 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
logger.info(f"No collections found matching prefix {prefix}")
|
logger.info(f"No collections found matching prefix {prefix}")
|
||||||
else:
|
else:
|
||||||
for collection_name in matching_collections:
|
for collection_name in matching_collections:
|
||||||
self.qdrant.delete_collection(collection_name)
|
await asyncio.to_thread(
|
||||||
|
self.qdrant.delete_collection, collection_name
|
||||||
|
)
|
||||||
|
async with self._cache_lock:
|
||||||
|
self._known_collections.discard(collection_name)
|
||||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||||
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
|
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,13 @@
|
||||||
Accepts entity/vector pairs and writes them to a Qdrant store.
|
Accepts entity/vector pairs and writes them to a Qdrant store.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
from qdrant_client.models import Distance, VectorParams
|
from qdrant_client.models import Distance, VectorParams
|
||||||
import uuid
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
||||||
from .... base import AsyncProcessor, Consumer, Producer
|
from .... base import AsyncProcessor, Consumer, Producer
|
||||||
|
|
@ -50,13 +52,35 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||||
|
self._cache_lock = asyncio.Lock()
|
||||||
|
self._known_collections: set[str] = set()
|
||||||
|
|
||||||
# Register for config push notifications
|
# Register for config push notifications
|
||||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||||
|
|
||||||
|
async def ensure_collection(self, collection_name, dim):
|
||||||
|
async with self._cache_lock:
|
||||||
|
if collection_name in self._known_collections:
|
||||||
|
return
|
||||||
|
exists = await asyncio.to_thread(
|
||||||
|
self.qdrant.collection_exists, collection_name
|
||||||
|
)
|
||||||
|
if not exists:
|
||||||
|
logger.info(
|
||||||
|
f"Lazily creating Qdrant collection {collection_name} "
|
||||||
|
f"with dimension {dim}"
|
||||||
|
)
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.qdrant.create_collection,
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=dim, distance=Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._known_collections.add(collection_name)
|
||||||
|
|
||||||
async def store_graph_embeddings(self, workspace, message):
|
async def store_graph_embeddings(self, workspace, message):
|
||||||
|
|
||||||
# Validate collection exists in config before processing
|
|
||||||
if not self.collection_exists(workspace, message.metadata.collection):
|
if not self.collection_exists(workspace, message.metadata.collection):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Collection {message.metadata.collection} for workspace {workspace} "
|
f"Collection {message.metadata.collection} for workspace {workspace} "
|
||||||
|
|
@ -75,22 +99,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
if not vec:
|
if not vec:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create collection name with dimension suffix for lazy creation
|
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
collection = (
|
collection = (
|
||||||
f"t_{workspace}_{message.metadata.collection}_{dim}"
|
f"t_{workspace}_{message.metadata.collection}_{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
await self.ensure_collection(collection, dim)
|
||||||
if not self.qdrant.collection_exists(collection):
|
|
||||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
|
||||||
self.qdrant.create_collection(
|
|
||||||
collection_name=collection,
|
|
||||||
vectors_config=VectorParams(
|
|
||||||
size=dim,
|
|
||||||
distance=Distance.COSINE
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"entity": entity_value,
|
"entity": entity_value,
|
||||||
|
|
@ -98,7 +112,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
if entity.chunk_id:
|
if entity.chunk_id:
|
||||||
payload["chunk_id"] = entity.chunk_id
|
payload["chunk_id"] = entity.chunk_id
|
||||||
|
|
||||||
self.qdrant.upsert(
|
await asyncio.to_thread(
|
||||||
|
self.qdrant.upsert,
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
points=[
|
points=[
|
||||||
PointStruct(
|
PointStruct(
|
||||||
|
|
@ -106,7 +121,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
vector=vec,
|
vector=vec,
|
||||||
payload=payload,
|
payload=payload,
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -143,8 +158,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
try:
|
try:
|
||||||
prefix = f"t_{workspace}_{collection}_"
|
prefix = f"t_{workspace}_{collection}_"
|
||||||
|
|
||||||
# Get all collections and filter for matches
|
all_collections = await asyncio.to_thread(
|
||||||
all_collections = self.qdrant.get_collections().collections
|
lambda: self.qdrant.get_collections().collections
|
||||||
|
)
|
||||||
matching_collections = [
|
matching_collections = [
|
||||||
coll.name for coll in all_collections
|
coll.name for coll in all_collections
|
||||||
if coll.name.startswith(prefix)
|
if coll.name.startswith(prefix)
|
||||||
|
|
@ -154,7 +170,11 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
logger.info(f"No collections found matching prefix {prefix}")
|
logger.info(f"No collections found matching prefix {prefix}")
|
||||||
else:
|
else:
|
||||||
for collection_name in matching_collections:
|
for collection_name in matching_collections:
|
||||||
self.qdrant.delete_collection(collection_name)
|
await asyncio.to_thread(
|
||||||
|
self.qdrant.delete_collection, collection_name
|
||||||
|
)
|
||||||
|
async with self._cache_lock:
|
||||||
|
self._known_collections.discard(collection_name)
|
||||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||||
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
|
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,10 @@ Payload structure:
|
||||||
- text: The text that was embedded (for debugging/display)
|
- text: The text that was embedded (for debugging/display)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Set, Tuple
|
|
||||||
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import PointStruct, Distance, VectorParams
|
from qdrant_client.models import PointStruct, Distance, VectorParams
|
||||||
|
|
@ -63,11 +63,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
# Register config handler for collection management
|
# Register config handler for collection management
|
||||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||||
|
|
||||||
# Cache of created Qdrant collections
|
|
||||||
self.created_collections: Set[str] = set()
|
|
||||||
|
|
||||||
# Qdrant client
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||||
|
self._cache_lock = asyncio.Lock()
|
||||||
|
self._known_collections: set[str] = set()
|
||||||
|
|
||||||
def sanitize_name(self, name: str) -> str:
|
def sanitize_name(self, name: str) -> str:
|
||||||
"""Sanitize names for Qdrant collection naming"""
|
"""Sanitize names for Qdrant collection naming"""
|
||||||
|
|
@ -85,25 +83,28 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
safe_schema = self.sanitize_name(schema_name)
|
safe_schema = self.sanitize_name(schema_name)
|
||||||
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
|
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
|
||||||
|
|
||||||
def ensure_collection(self, collection_name: str, dimension: int):
|
async def ensure_collection(self, collection_name: str, dimension: int):
|
||||||
"""Create Qdrant collection if it doesn't exist"""
|
"""Create Qdrant collection if it doesn't exist"""
|
||||||
if collection_name in self.created_collections:
|
async with self._cache_lock:
|
||||||
return
|
if collection_name in self._known_collections:
|
||||||
|
return
|
||||||
if not self.qdrant.collection_exists(collection_name):
|
exists = await asyncio.to_thread(
|
||||||
logger.info(
|
self.qdrant.collection_exists, collection_name
|
||||||
f"Creating Qdrant collection {collection_name} "
|
|
||||||
f"with dimension {dimension}"
|
|
||||||
)
|
)
|
||||||
self.qdrant.create_collection(
|
if not exists:
|
||||||
collection_name=collection_name,
|
logger.info(
|
||||||
vectors_config=VectorParams(
|
f"Creating Qdrant collection {collection_name} "
|
||||||
size=dimension,
|
f"with dimension {dimension}"
|
||||||
distance=Distance.COSINE
|
|
||||||
)
|
)
|
||||||
)
|
await asyncio.to_thread(
|
||||||
|
self.qdrant.create_collection,
|
||||||
self.created_collections.add(collection_name)
|
collection_name=collection_name,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=dimension,
|
||||||
|
distance=Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._known_collections.add(collection_name)
|
||||||
|
|
||||||
async def on_embeddings(self, msg, consumer, flow):
|
async def on_embeddings(self, msg, consumer, flow):
|
||||||
"""Process incoming RowEmbeddings and write to Qdrant"""
|
"""Process incoming RowEmbeddings and write to Qdrant"""
|
||||||
|
|
@ -143,15 +144,14 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
dimension = len(vector)
|
dimension = len(vector)
|
||||||
|
|
||||||
# Create/get collection name (lazily on first vector)
|
|
||||||
if qdrant_collection is None:
|
if qdrant_collection is None:
|
||||||
qdrant_collection = self.get_collection_name(
|
qdrant_collection = self.get_collection_name(
|
||||||
workspace, collection, schema_name, dimension
|
workspace, collection, schema_name, dimension
|
||||||
)
|
)
|
||||||
self.ensure_collection(qdrant_collection, dimension)
|
await self.ensure_collection(qdrant_collection, dimension)
|
||||||
|
|
||||||
# Write to Qdrant
|
await asyncio.to_thread(
|
||||||
self.qdrant.upsert(
|
self.qdrant.upsert,
|
||||||
collection_name=qdrant_collection,
|
collection_name=qdrant_collection,
|
||||||
points=[
|
points=[
|
||||||
PointStruct(
|
PointStruct(
|
||||||
|
|
@ -163,7 +163,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
"text": row_emb.text
|
"text": row_emb.text
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
embeddings_written += 1
|
embeddings_written += 1
|
||||||
|
|
||||||
|
|
@ -181,8 +181,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
try:
|
try:
|
||||||
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
|
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
|
||||||
|
|
||||||
# Get all collections and filter for matches
|
all_collections = await asyncio.to_thread(
|
||||||
all_collections = self.qdrant.get_collections().collections
|
lambda: self.qdrant.get_collections().collections
|
||||||
|
)
|
||||||
matching_collections = [
|
matching_collections = [
|
||||||
coll.name for coll in all_collections
|
coll.name for coll in all_collections
|
||||||
if coll.name.startswith(prefix)
|
if coll.name.startswith(prefix)
|
||||||
|
|
@ -192,8 +193,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||||
else:
|
else:
|
||||||
for collection_name in matching_collections:
|
for collection_name in matching_collections:
|
||||||
self.qdrant.delete_collection(collection_name)
|
await asyncio.to_thread(
|
||||||
self.created_collections.discard(collection_name)
|
self.qdrant.delete_collection, collection_name
|
||||||
|
)
|
||||||
|
async with self._cache_lock:
|
||||||
|
self._known_collections.discard(collection_name)
|
||||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleted {len(matching_collections)} collection(s) "
|
f"Deleted {len(matching_collections)} collection(s) "
|
||||||
|
|
@ -217,8 +221,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
|
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all collections and filter for matches
|
all_collections = await asyncio.to_thread(
|
||||||
all_collections = self.qdrant.get_collections().collections
|
lambda: self.qdrant.get_collections().collections
|
||||||
|
)
|
||||||
matching_collections = [
|
matching_collections = [
|
||||||
coll.name for coll in all_collections
|
coll.name for coll in all_collections
|
||||||
if coll.name.startswith(prefix)
|
if coll.name.startswith(prefix)
|
||||||
|
|
@ -228,8 +233,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||||
else:
|
else:
|
||||||
for collection_name in matching_collections:
|
for collection_name in matching_collections:
|
||||||
self.qdrant.delete_collection(collection_name)
|
await asyncio.to_thread(
|
||||||
self.created_collections.discard(collection_name)
|
self.qdrant.delete_collection, collection_name
|
||||||
|
)
|
||||||
|
async with self._cache_lock:
|
||||||
|
self._known_collections.discard(collection_name)
|
||||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
# Cache of known keyspaces and whether tables exist
|
# Cache of known keyspaces and whether tables exist
|
||||||
self.known_keyspaces: Set[str] = set()
|
self.known_keyspaces: Set[str] = set()
|
||||||
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
|
self.tables_initialized: Set[str] = set()
|
||||||
|
|
||||||
# Cache of registered (collection, schema_name) pairs
|
# Cache of registered (collection, schema_name) pairs
|
||||||
self.registered_partitions: Set[Tuple[str, str]] = set()
|
self.registered_partitions: Set[Tuple[str, str]] = set()
|
||||||
|
|
@ -94,6 +94,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
self.cluster = None
|
self.cluster = None
|
||||||
self.session = None
|
self.session = None
|
||||||
|
|
||||||
|
# Protects connection setup and cache mutations
|
||||||
|
self._setup_lock = asyncio.Lock()
|
||||||
|
|
||||||
def connect_cassandra(self):
|
def connect_cassandra(self):
|
||||||
"""Connect to Cassandra cluster"""
|
"""Connect to Cassandra cluster"""
|
||||||
if self.session:
|
if self.session:
|
||||||
|
|
@ -126,6 +129,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
f"for workspace {workspace}"
|
f"for workspace {workspace}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async with self._setup_lock:
|
||||||
|
return await self._apply_schema_config(workspace, config, version)
|
||||||
|
|
||||||
|
async def _apply_schema_config(self, workspace, config, version):
|
||||||
|
|
||||||
# Track which schemas changed in this workspace
|
# Track which schemas changed in this workspace
|
||||||
old_schemas = self.schemas.get(workspace, {})
|
old_schemas = self.schemas.get(workspace, {})
|
||||||
old_schema_names = set(old_schemas.keys())
|
old_schema_names = set(old_schemas.keys())
|
||||||
|
|
@ -391,16 +399,12 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
schema_name = obj.schema_name
|
schema_name = obj.schema_name
|
||||||
source = getattr(obj.metadata, 'source', '') or ''
|
source = getattr(obj.metadata, 'source', '') or ''
|
||||||
|
|
||||||
# Ensure tables exist (sync DDL — push to a worker thread
|
async with self._setup_lock:
|
||||||
# so the event loop stays responsive when running in a
|
await asyncio.to_thread(self.ensure_tables, keyspace)
|
||||||
# processor group sharing the loop with siblings).
|
await asyncio.to_thread(
|
||||||
await asyncio.to_thread(self.ensure_tables, keyspace)
|
self.register_partitions,
|
||||||
|
keyspace, collection, schema_name, workspace,
|
||||||
# Register partitions if first time seeing this (collection, schema_name)
|
)
|
||||||
await asyncio.to_thread(
|
|
||||||
self.register_partitions,
|
|
||||||
keyspace, collection, schema_name, workspace,
|
|
||||||
)
|
|
||||||
|
|
||||||
safe_keyspace = self.sanitize_name(keyspace)
|
safe_keyspace = self.sanitize_name(keyspace)
|
||||||
|
|
||||||
|
|
@ -461,35 +465,27 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
||||||
"""Create/verify collection exists in Cassandra row store"""
|
"""Create/verify collection exists in Cassandra row store"""
|
||||||
# Connect if not already connected (sync, push to thread)
|
async with self._setup_lock:
|
||||||
await asyncio.to_thread(self.connect_cassandra)
|
await asyncio.to_thread(self.connect_cassandra)
|
||||||
|
await asyncio.to_thread(self.ensure_tables, workspace)
|
||||||
# Ensure tables exist (sync DDL, push to thread)
|
|
||||||
await asyncio.to_thread(self.ensure_tables, workspace)
|
|
||||||
|
|
||||||
logger.info(f"Collection {collection} ready for workspace {workspace}")
|
logger.info(f"Collection {collection} ready for workspace {workspace}")
|
||||||
|
|
||||||
async def delete_collection(self, workspace: str, collection: str):
|
async def delete_collection(self, workspace: str, collection: str):
|
||||||
"""Delete all data for a specific collection using partition tracking"""
|
"""Delete all data for a specific collection using partition tracking"""
|
||||||
# Connect if not already connected
|
async with self._setup_lock:
|
||||||
await asyncio.to_thread(self.connect_cassandra)
|
await asyncio.to_thread(self.connect_cassandra)
|
||||||
|
if workspace not in self.known_keyspaces:
|
||||||
|
safe_ks = self.sanitize_name(workspace)
|
||||||
|
check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s"
|
||||||
|
result = await async_execute(self.session, check_cql, (safe_ks,))
|
||||||
|
if not result:
|
||||||
|
logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete")
|
||||||
|
return
|
||||||
|
self.known_keyspaces.add(workspace)
|
||||||
|
|
||||||
safe_keyspace = self.sanitize_name(workspace)
|
safe_keyspace = self.sanitize_name(workspace)
|
||||||
|
|
||||||
# Check if keyspace exists
|
|
||||||
if workspace not in self.known_keyspaces:
|
|
||||||
check_keyspace_cql = """
|
|
||||||
SELECT keyspace_name FROM system_schema.keyspaces
|
|
||||||
WHERE keyspace_name = %s
|
|
||||||
"""
|
|
||||||
result = await async_execute(
|
|
||||||
self.session, check_keyspace_cql, (safe_keyspace,)
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
|
|
||||||
return
|
|
||||||
self.known_keyspaces.add(workspace)
|
|
||||||
|
|
||||||
# Discover all partitions for this collection
|
# Discover all partitions for this collection
|
||||||
select_partitions_cql = f"""
|
select_partitions_cql = f"""
|
||||||
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
|
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
|
||||||
|
|
@ -540,11 +536,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
|
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Clear from local cache
|
async with self._setup_lock:
|
||||||
self.registered_partitions = {
|
self.registered_partitions = {
|
||||||
(col, sch) for col, sch in self.registered_partitions
|
(col, sch) for col, sch in self.registered_partitions
|
||||||
if col != collection
|
if col != collection
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleted collection {collection}: {partitions_deleted} partitions "
|
f"Deleted collection {collection}: {partitions_deleted} partitions "
|
||||||
|
|
@ -553,8 +549,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
|
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
|
||||||
"""Delete all data for a specific collection + schema combination"""
|
"""Delete all data for a specific collection + schema combination"""
|
||||||
# Connect if not already connected
|
async with self._setup_lock:
|
||||||
await asyncio.to_thread(self.connect_cassandra)
|
await asyncio.to_thread(self.connect_cassandra)
|
||||||
|
|
||||||
safe_keyspace = self.sanitize_name(workspace)
|
safe_keyspace = self.sanitize_name(workspace)
|
||||||
|
|
||||||
|
|
@ -614,8 +610,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Clear from local cache
|
async with self._setup_lock:
|
||||||
self.registered_partitions.discard((collection, schema_name))
|
self.registered_partitions.discard((collection, schema_name))
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
|
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
import json
|
|
||||||
|
|
||||||
from .... direct.cassandra_kg import (
|
from .... direct.cassandra_kg import (
|
||||||
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
|
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
|
||||||
|
|
@ -28,6 +23,8 @@ default_ident = "triples-write"
|
||||||
|
|
||||||
def serialize_triple(triple):
|
def serialize_triple(triple):
|
||||||
"""Serialize a Triple object to JSON for storage."""
|
"""Serialize a Triple object to JSON for storage."""
|
||||||
|
import json
|
||||||
|
|
||||||
if triple is None:
|
if triple is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -141,156 +138,84 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
||||||
self.cassandra_host = hosts
|
self.cassandra_host = hosts
|
||||||
self.cassandra_username = username
|
self.cassandra_username = username
|
||||||
self.cassandra_password = password
|
self.cassandra_password = password
|
||||||
self.table = None
|
|
||||||
self.tg = None
|
self._connections = {}
|
||||||
|
self._conn_lock = asyncio.Lock()
|
||||||
|
|
||||||
# Register for config push notifications
|
# Register for config push notifications
|
||||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||||
|
|
||||||
|
async def _get_connection(self, workspace):
|
||||||
|
async with self._conn_lock:
|
||||||
|
if workspace not in self._connections:
|
||||||
|
if self.cassandra_username and self.cassandra_password:
|
||||||
|
tg = await asyncio.to_thread(
|
||||||
|
EntityCentricKnowledgeGraph,
|
||||||
|
hosts=self.cassandra_host,
|
||||||
|
keyspace=workspace,
|
||||||
|
username=self.cassandra_username,
|
||||||
|
password=self.cassandra_password,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tg = await asyncio.to_thread(
|
||||||
|
EntityCentricKnowledgeGraph,
|
||||||
|
hosts=self.cassandra_host,
|
||||||
|
keyspace=workspace,
|
||||||
|
)
|
||||||
|
self._connections[workspace] = tg
|
||||||
|
return self._connections[workspace]
|
||||||
|
|
||||||
async def store_triples(self, workspace, message):
|
async def store_triples(self, workspace, message):
|
||||||
|
|
||||||
# The cassandra-driver work below — connection, schema
|
tg = await self._get_connection(workspace)
|
||||||
# setup, and per-triple inserts — is all synchronous.
|
|
||||||
# Wrap the whole batch in a worker thread so the event
|
|
||||||
# loop stays responsive for sibling processors when
|
|
||||||
# running in a processor group.
|
|
||||||
|
|
||||||
def _do_store():
|
for t in message.triples:
|
||||||
|
s_val = get_term_value(t.s)
|
||||||
|
p_val = get_term_value(t.p)
|
||||||
|
o_val = get_term_value(t.o)
|
||||||
|
g_val = t.g if t.g is not None else DEFAULT_GRAPH
|
||||||
|
|
||||||
if self.table is None or self.table != workspace:
|
otype = get_term_otype(t.o)
|
||||||
|
dtype = get_term_dtype(t.o)
|
||||||
|
lang = get_term_lang(t.o)
|
||||||
|
|
||||||
self.tg = None
|
await tg.async_insert(
|
||||||
|
message.metadata.collection,
|
||||||
# Use factory function to select implementation
|
s_val,
|
||||||
KGClass = EntityCentricKnowledgeGraph
|
p_val,
|
||||||
|
o_val,
|
||||||
try:
|
g=g_val,
|
||||||
if self.cassandra_username and self.cassandra_password:
|
otype=otype,
|
||||||
self.tg = KGClass(
|
dtype=dtype,
|
||||||
hosts=self.cassandra_host,
|
lang=lang,
|
||||||
keyspace=workspace,
|
)
|
||||||
username=self.cassandra_username,
|
|
||||||
password=self.cassandra_password,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.tg = KGClass(
|
|
||||||
hosts=self.cassandra_host,
|
|
||||||
keyspace=workspace,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception: {e}", exc_info=True)
|
|
||||||
time.sleep(1)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
self.table = workspace
|
|
||||||
|
|
||||||
for t in message.triples:
|
|
||||||
# Extract values from Term objects
|
|
||||||
s_val = get_term_value(t.s)
|
|
||||||
p_val = get_term_value(t.p)
|
|
||||||
o_val = get_term_value(t.o)
|
|
||||||
# t.g is None for default graph, or a graph IRI
|
|
||||||
g_val = t.g if t.g is not None else DEFAULT_GRAPH
|
|
||||||
|
|
||||||
# Extract object type metadata for entity-centric storage
|
|
||||||
otype = get_term_otype(t.o)
|
|
||||||
dtype = get_term_dtype(t.o)
|
|
||||||
lang = get_term_lang(t.o)
|
|
||||||
|
|
||||||
self.tg.insert(
|
|
||||||
message.metadata.collection,
|
|
||||||
s_val,
|
|
||||||
p_val,
|
|
||||||
o_val,
|
|
||||||
g=g_val,
|
|
||||||
otype=otype,
|
|
||||||
dtype=dtype,
|
|
||||||
lang=lang,
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.to_thread(_do_store)
|
|
||||||
|
|
||||||
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
||||||
"""Create a collection in Cassandra triple store via config push"""
|
"""Create a collection in Cassandra triple store via config push"""
|
||||||
|
try:
|
||||||
|
tg = await self._get_connection(workspace)
|
||||||
|
|
||||||
def _do_create():
|
|
||||||
# Create or reuse connection for this workspace's keyspace
|
|
||||||
if self.table is None or self.table != workspace:
|
|
||||||
self.tg = None
|
|
||||||
|
|
||||||
# Use factory function to select implementation
|
|
||||||
KGClass = EntityCentricKnowledgeGraph
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self.cassandra_username and self.cassandra_password:
|
|
||||||
self.tg = KGClass(
|
|
||||||
hosts=self.cassandra_host,
|
|
||||||
keyspace=workspace,
|
|
||||||
username=self.cassandra_username,
|
|
||||||
password=self.cassandra_password,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.tg = KGClass(
|
|
||||||
hosts=self.cassandra_host,
|
|
||||||
keyspace=workspace,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
self.table = workspace
|
|
||||||
|
|
||||||
# Create collection using the built-in method
|
|
||||||
logger.info(f"Creating collection {collection} for workspace {workspace}")
|
logger.info(f"Creating collection {collection} for workspace {workspace}")
|
||||||
|
|
||||||
if self.tg.collection_exists(collection):
|
exists = await tg.async_collection_exists(collection)
|
||||||
|
if exists:
|
||||||
logger.info(f"Collection {collection} already exists")
|
logger.info(f"Collection {collection} already exists")
|
||||||
else:
|
else:
|
||||||
self.tg.create_collection(collection)
|
await tg.async_create_collection(collection)
|
||||||
logger.info(f"Created collection {collection}")
|
logger.info(f"Created collection {collection}")
|
||||||
|
|
||||||
try:
|
|
||||||
await asyncio.to_thread(_do_create)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
|
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def delete_collection(self, workspace: str, collection: str):
|
async def delete_collection(self, workspace: str, collection: str):
|
||||||
"""Delete all data for a specific collection from the unified triples table"""
|
"""Delete all data for a specific collection from the unified triples table"""
|
||||||
|
try:
|
||||||
|
tg = await self._get_connection(workspace)
|
||||||
|
|
||||||
def _do_delete():
|
await tg.async_delete_collection(collection)
|
||||||
# Create or reuse connection for this workspace's keyspace
|
|
||||||
if self.table is None or self.table != workspace:
|
|
||||||
self.tg = None
|
|
||||||
|
|
||||||
# Use factory function to select implementation
|
|
||||||
KGClass = EntityCentricKnowledgeGraph
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self.cassandra_username and self.cassandra_password:
|
|
||||||
self.tg = KGClass(
|
|
||||||
hosts=self.cassandra_host,
|
|
||||||
keyspace=workspace,
|
|
||||||
username=self.cassandra_username,
|
|
||||||
password=self.cassandra_password,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.tg = KGClass(
|
|
||||||
hosts=self.cassandra_host,
|
|
||||||
keyspace=workspace,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
self.table = workspace
|
|
||||||
|
|
||||||
# Delete all triples for this collection using the built-in method
|
|
||||||
self.tg.delete_collection(collection)
|
|
||||||
logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
|
logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
|
||||||
|
|
||||||
try:
|
|
||||||
await asyncio.to_thread(_do_delete)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
|
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue