mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-17 11:25:12 +02:00
Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916)
Cassandra triples services were using syncronous EntityCentricKnowledgeGraph methods from async contexts, and connection state was managed with threading.local which is wrong for asyncio coroutines sharing a single thread. Qdrant services had no async wrapping at all, blocking the event loop on every network call. Rows services had unprotected shared state mutations across concurrent coroutines. - Add async methods to EntityCentricKnowledgeGraph (async_insert, async_get_s/p/o/sp/po/os/spo/all, async_collection_exists, async_create_collection, async_delete_collection) using the existing cassandra_async.async_execute bridge - Rewrite triples write + query services: replace threading.local with asyncio.Lock + dict cache for per-workspace connections, use async ECKG methods for all data operations, keep asyncio.to_thread only for one-time blocking ECKG construction - Wrap all Qdrant calls in asyncio.to_thread across all 6 services (doc/graph/row embeddings write + query), add asyncio.Lock + set cache for collection existence checks - Add asyncio.Lock to rows write + query services to protect shared state (schemas, sessions, config caches) from concurrent mutation - Update all affected tests to match new async patterns
This commit is contained in:
parent
bb1109963c
commit
a2dde9cafb
22 changed files with 736 additions and 621 deletions
|
|
@ -63,26 +63,26 @@ class TestEndToEndConfigurationFlow:
|
|||
'CASSANDRA_USERNAME': 'obj-user',
|
||||
'CASSANDRA_PASSWORD': 'obj-pass'
|
||||
}
|
||||
|
||||
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_provider.return_value = mock_auth_instance
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Trigger Cassandra connection
|
||||
processor.connect_cassandra()
|
||||
|
||||
|
||||
# Verify auth provider was created with env vars
|
||||
mock_auth_provider.assert_called_once_with(
|
||||
username='obj-user',
|
||||
password='obj-pass'
|
||||
)
|
||||
|
||||
|
||||
# Verify cluster was created with hosts from env and auth
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
|
|
@ -188,37 +188,34 @@ class TestConfigurationPriorityEndToEnd:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.direct.cassandra_kg.Cluster')
|
||||
async def test_no_config_defaults_end_to_end(self, mock_cluster):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_no_config_defaults_end_to_end(self, mock_kg_class):
|
||||
"""Test that defaults are used when no configuration provided end-to-end."""
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
processor = TriplesQuery(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Mock query to trigger TrustGraph creation
|
||||
mock_query = MagicMock()
|
||||
mock_query.collection = 'default_collection'
|
||||
mock_query.s = None
|
||||
mock_query.p = None
|
||||
mock_query.o = None
|
||||
mock_query.g = None
|
||||
mock_query.limit = 100
|
||||
|
||||
# Mock the get_all method to return empty list
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.get_all.return_value = []
|
||||
processor.tg = mock_tg_instance
|
||||
|
||||
|
||||
await processor.query_triples('default_user', mock_query)
|
||||
|
||||
|
||||
# Should use defaults
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.args[0] == ['cassandra'] # Default host
|
||||
assert 'auth_provider' not in call_args.kwargs # No auth with default config
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['cassandra'],
|
||||
keyspace='default_user'
|
||||
)
|
||||
|
||||
|
||||
class TestNoBackwardCompatibilityEndToEnd:
|
||||
|
|
@ -324,16 +321,16 @@ class TestMultipleHostsHandling:
|
|||
env_vars = {
|
||||
'CASSANDRA_HOST': 'host1,host2,host3,host4,host5'
|
||||
}
|
||||
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
|
||||
# Verify all hosts were passed to Cluster
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
|
|
@ -392,27 +389,27 @@ class TestAuthenticationFlow:
|
|||
'CASSANDRA_USERNAME': 'auth-user',
|
||||
'CASSANDRA_PASSWORD': 'auth-secret'
|
||||
}
|
||||
|
||||
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_provider.return_value = mock_auth_instance
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
|
||||
# Auth provider should be created
|
||||
mock_auth_provider.assert_called_once_with(
|
||||
username='auth-user',
|
||||
password='auth-secret'
|
||||
)
|
||||
|
||||
|
||||
# Cluster should be created with auth provider
|
||||
call_args = mock_cluster.call_args
|
||||
assert 'auth_provider' in call_args.kwargs
|
||||
assert call_args.kwargs['auth_provider'] == mock_auth_instance
|
||||
|
||||
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster):
|
||||
|
|
@ -421,21 +418,21 @@ class TestAuthenticationFlow:
|
|||
'CASSANDRA_HOST': 'no-auth-host'
|
||||
# No username/password
|
||||
}
|
||||
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
|
||||
# Auth provider should not be created
|
||||
mock_auth_provider.assert_not_called()
|
||||
|
||||
|
||||
# Cluster should be created without auth provider
|
||||
call_args = mock_cluster.call_args
|
||||
assert 'auth_provider' not in call_args.kwargs
|
||||
|
||||
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster):
|
||||
|
|
@ -446,15 +443,15 @@ class TestAuthenticationFlow:
|
|||
cassandra_username='partial-user'
|
||||
# No password
|
||||
)
|
||||
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
|
||||
processor.connect_cassandra()
|
||||
|
||||
|
||||
# Auth provider should not be created (needs both username AND password)
|
||||
mock_auth_provider.assert_not_called()
|
||||
|
||||
|
||||
# Cluster should be created without auth provider
|
||||
call_args = mock_cluster.call_args
|
||||
assert 'auth_provider' not in call_args.kwargs
|
||||
|
|
@ -101,6 +101,8 @@ class TestRowsCassandraIntegration:
|
|||
processor.session = None
|
||||
|
||||
# Bind actual methods from the new unified table implementation
|
||||
import asyncio
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
||||
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
|
|
@ -108,6 +110,7 @@ class TestRowsCassandraIntegration:
|
|||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Connect to Cassandra
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
assert processor.session is not None
|
||||
|
||||
# Create test keyspace and table
|
||||
|
|
@ -219,7 +219,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test inserting data and querying via GraphQL"""
|
||||
# Load schema and connect
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
# Setup test data
|
||||
keyspace = "test_user"
|
||||
|
|
@ -293,7 +293,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test GraphQL queries with filtering on indexed fields"""
|
||||
# Setup (reuse previous setup)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
keyspace = "test_user"
|
||||
collection = "filter_test"
|
||||
|
|
@ -387,7 +387,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test full message processing workflow"""
|
||||
# Setup
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
# Create mock message
|
||||
request = RowsQueryRequest(
|
||||
|
|
@ -433,7 +433,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test handling multiple concurrent GraphQL queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
# Create multiple query tasks
|
||||
queries = [
|
||||
|
|
@ -519,7 +519,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test handling of large query result sets"""
|
||||
# Setup
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
keyspace = "large_test_user"
|
||||
collection = "large_collection"
|
||||
|
|
|
|||
|
|
@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic:
|
|||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configuration"""
|
||||
import asyncio
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.schema_builders = {}
|
||||
processor.graphql_schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.query_cassandra = MagicMock()
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test config
|
||||
|
|
@ -335,7 +338,7 @@ class TestUnifiedTableQueries:
|
|||
"""Test query execution with matching index"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.connect_cassandra = AsyncMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
|
@ -396,7 +399,7 @@ class TestUnifiedTableQueries:
|
|||
"""Test query execution without matching index (scan mode)"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.connect_cassandra = AsyncMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@
|
|||
Tests for Cassandra triples query service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor, create_term
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
|
@ -18,7 +20,7 @@ class TestCassandraQueryProcessor:
|
|||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
cassandra_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_term_with_http_uri(self, processor):
|
||||
|
|
@ -85,7 +87,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
|
|
@ -110,8 +112,8 @@ class TestCassandraQueryProcessor:
|
|||
keyspace='test_user'
|
||||
)
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
# Verify async_get_spo was called with correct parameters
|
||||
mock_tg_instance.async_get_spo.assert_called_once_with(
|
||||
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
|
||||
)
|
||||
|
||||
|
|
@ -130,23 +132,25 @@ class TestCassandraQueryProcessor:
|
|||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||
assert processor.cassandra_username is None
|
||||
assert processor.cassandra_password is None
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
cassandra_host='cassandra.example.com',
|
||||
cassandra_username='queryuser',
|
||||
cassandra_password='querypass'
|
||||
)
|
||||
|
||||
|
||||
assert processor.cassandra_host == ['cassandra.example.com']
|
||||
assert processor.cassandra_username == 'queryuser'
|
||||
assert processor.cassandra_password == 'querypass'
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -164,7 +168,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -178,7 +182,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
|
|
@ -200,7 +204,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -214,7 +218,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
|
|
@ -236,7 +240,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -250,7 +254,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
|
|
@ -272,7 +276,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -286,7 +290,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
|
|
@ -305,11 +309,11 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_result.g = ''
|
||||
mock_result.d = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -323,7 +327,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'all_subject'
|
||||
assert result[0].p.iri == 'all_predicate'
|
||||
|
|
@ -410,7 +414,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
|
|
@ -451,7 +455,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -489,8 +493,8 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.lang = None
|
||||
mock_result.p = 'p'
|
||||
mock_result.o = 'o'
|
||||
mock_tg_instance1.get_s.return_value = [mock_result]
|
||||
mock_tg_instance2.get_s.return_value = [mock_result]
|
||||
mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -504,7 +508,6 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
await processor.query_triples('user1', query1)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
|
|
@ -516,10 +519,11 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
await processor.query_triples('user2', query2)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
# Verify TrustGraph was created twice for different workspaces
|
||||
assert mock_kg_class.call_count == 2
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -529,7 +533,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed"))
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -566,7 +570,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result2.otype = None
|
||||
mock_result2.dtype = None
|
||||
mock_result2.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_po.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
# Verify async_get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
|
||||
)
|
||||
|
||||
|
|
@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_os.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.get_os.assert_called_once_with(
|
||||
# Verify async_get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.async_get_os.assert_called_once_with(
|
||||
'test_collection', 'test_object', 'test_subject', g=None, limit=25
|
||||
)
|
||||
|
||||
|
|
@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock empty results for all queries
|
||||
mock_tg_instance.get_all.return_value = []
|
||||
mock_tg_instance.get_s.return_value = []
|
||||
mock_tg_instance.get_p.return_value = []
|
||||
mock_tg_instance.get_o.return_value = []
|
||||
mock_tg_instance.get_sp.return_value = []
|
||||
mock_tg_instance.get_po.return_value = []
|
||||
mock_tg_instance.get_os.return_value = []
|
||||
mock_tg_instance.get_spo.return_value = []
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_s = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_p = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_o = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_os = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Test each query pattern
|
||||
test_patterns = [
|
||||
# (s, p, o, expected_method)
|
||||
(None, None, None, 'get_all'), # All triples
|
||||
('s1', None, None, 'get_s'), # Subject only
|
||||
(None, 'p1', None, 'get_p'), # Predicate only
|
||||
(None, None, 'o1', 'get_o'), # Object only
|
||||
('s1', 'p1', None, 'get_sp'), # Subject + Predicate
|
||||
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
||||
('s1', None, 'o1', 'get_os'), # Object + Subject
|
||||
('s1', 'p1', 'o1', 'get_spo'), # All three
|
||||
(None, None, None, 'async_get_all'), # All triples
|
||||
('s1', None, None, 'async_get_s'), # Subject only
|
||||
(None, 'p1', None, 'async_get_p'), # Predicate only
|
||||
(None, None, 'o1', 'async_get_o'), # Object only
|
||||
('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate
|
||||
(None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
||||
('s1', None, 'o1', 'async_get_os'), # Object + Subject
|
||||
('s1', 'p1', 'o1', 'async_get_spo'), # All three
|
||||
]
|
||||
|
||||
for s, p, o, expected_method in test_patterns:
|
||||
|
|
@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.lang = None
|
||||
mock_results.append(mock_result)
|
||||
|
||||
mock_tg_instance.get_po.return_value = mock_results
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results)
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('large_dataset_user', query)
|
||||
|
||||
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
# Verify optimized async_get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||
'massive_collection',
|
||||
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
|
||||
'http://example.com/Person',
|
||||
|
|
|
|||
|
|
@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_embedding_upserted(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "col1"
|
||||
|
|
@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection:
|
|||
@pytest.mark.asyncio
|
||||
async def test_dimension_in_collection_name(self):
|
||||
"""Collection name should include vector dimension."""
|
||||
import asyncio
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "docs"
|
||||
|
|
@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_entity_and_vector_upserted(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "col1"
|
||||
|
|
@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lazy_collection_creation_on_new_dimension(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = False
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "graphs"
|
||||
|
|
|
|||
|
|
@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
|
||||
|
||||
# Verify collection existence is checked on each write
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
# Second write uses cached collection state — no collection_exists check
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
# But upsert should still be called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("test_collection", 384)
|
||||
await processor.ensure_collection("test_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify the collection is cached
|
||||
assert "test_collection" in processor.created_collections
|
||||
assert "test_collection" in processor._known_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
||||
|
|
@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("existing_collection", 384)
|
||||
await processor.ensure_collection("existing_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
|
@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add("cached_collection")
|
||||
processor._known_collections.add("cached_collection")
|
||||
|
||||
processor.ensure_collection("cached_collection", 384)
|
||||
await processor.ensure_collection("cached_collection", 384)
|
||||
|
||||
# Should not check or create - just return
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
|
@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
processor._known_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
|
||||
await processor.delete_collection('test_workspace', 'test_collection')
|
||||
|
||||
|
|
@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||
|
||||
# Verify the cached collection was removed
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||
|
|
|
|||
|
|
@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic:
|
|||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
import asyncio
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.registered_partitions = set()
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
Tests for Cassandra triples storage service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
|
|
@ -24,12 +26,13 @@ class TestCassandraStorageProcessor:
|
|||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||
assert processor.cassandra_username is None
|
||||
assert processor.cassandra_password is None
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters (new cassandra_* names)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
id='custom-storage',
|
||||
|
|
@ -37,11 +40,12 @@ class TestCassandraStorageProcessor:
|
|||
cassandra_username='testuser',
|
||||
cassandra_password='testpass'
|
||||
)
|
||||
|
||||
|
||||
assert processor.cassandra_host == ['cassandra.example.com']
|
||||
assert processor.cassandra_username == 'testuser'
|
||||
assert processor.cassandra_password == 'testpass'
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_partial_auth(self):
|
||||
"""Test processor initialization with only username (no password)"""
|
||||
|
|
@ -92,6 +96,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test table switching logic when authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(
|
||||
|
|
@ -114,7 +119,6 @@ class TestCassandraStorageProcessor:
|
|||
username='testuser',
|
||||
password='testpass'
|
||||
)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -122,6 +126,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test table switching logic when no authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -138,7 +143,6 @@ class TestCassandraStorageProcessor:
|
|||
hosts=['cassandra'], # Updated default
|
||||
keyspace='user2'
|
||||
)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -146,6 +150,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -169,6 +174,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test that triples are properly inserted into Cassandra"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -208,12 +214,12 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
assert mock_tg_instance.async_insert.call_count == 2
|
||||
mock_tg_instance.async_insert.assert_any_call(
|
||||
'collection1', 'subject1', 'predicate1', 'object1',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
mock_tg_instance.async_insert.assert_any_call(
|
||||
'collection1', 'subject2', 'predicate2', 'object2',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
|
@ -224,6 +230,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test behavior when message has no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -236,19 +243,17 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
mock_tg_instance.async_insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
|
||||
async def test_exception_handling_on_connection_failure(self, mock_kg_class):
|
||||
"""Test exception handling during TrustGraph creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
|
@ -256,9 +261,6 @@ class TestCassandraStorageProcessor:
|
|||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
|
|
@ -359,8 +361,6 @@ class TestCassandraStorageProcessor:
|
|||
mock_message1.triples = []
|
||||
|
||||
await processor.store_triples('user1', mock_message1)
|
||||
assert processor.table == 'user1'
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
|
|
@ -368,11 +368,11 @@ class TestCassandraStorageProcessor:
|
|||
mock_message2.triples = []
|
||||
|
||||
await processor.store_triples('user2', mock_message2)
|
||||
assert processor.table == 'user2'
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
# Verify TrustGraph was created twice for different tables
|
||||
# Verify TrustGraph was created twice for different workspaces
|
||||
assert mock_kg_class.call_count == 2
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -380,6 +380,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test storing triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -405,7 +406,7 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('test_workspace', mock_message)
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
mock_tg_instance.async_insert.assert_called_once_with(
|
||||
'test_collection',
|
||||
'subject with spaces & symbols',
|
||||
'predicate:with/colons',
|
||||
|
|
@ -418,29 +419,29 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
|
||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
||||
async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class):
|
||||
"""Test that a failed connection doesn't leave stale cached state"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_good_instance = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Set an initial table
|
||||
processor.table = ('old_user', 'old_collection')
|
||||
|
||||
# Mock TrustGraph to raise exception
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
# First call fails
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples('new_user', mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
# TrustGraph should be set to None though
|
||||
assert processor.tg is None
|
||||
# Second call succeeds — should retry connection, not use stale state
|
||||
mock_kg_class.side_effect = None
|
||||
mock_kg_class.return_value = mock_good_instance
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Connection was attempted twice (failed + succeeded)
|
||||
assert mock_kg_class.call_count == 2
|
||||
|
||||
|
||||
class TestCassandraPerformanceOptimizations:
|
||||
|
|
@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that legacy mode still works with single table"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
|
||||
|
|
@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that optimized mode uses multi-table schema"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
|
||||
|
|
@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that all tables stay consistent during batch writes"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
mock_tg_instance.async_insert.assert_called_once_with(
|
||||
'collection1', 'test_subject', 'test_predicate', 'test_object',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
|
|
|||
|
|
@ -89,7 +89,8 @@ class TestSanitizeName:
|
|||
|
||||
class TestFindCollection:
|
||||
|
||||
def test_finds_matching_collection(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_matching_collection(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_test_workspace_test_col_customers_384"
|
||||
|
|
@ -98,11 +99,12 @@ class TestFindCollection:
|
|||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||
|
||||
assert result == "rows_test_workspace_test_col_customers_384"
|
||||
|
||||
def test_returns_none_when_no_match(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_match(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_other_workspace_other_col_schema_768"
|
||||
|
|
@ -111,14 +113,15 @@ class TestFindCollection:
|
|||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_error(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_error(self):
|
||||
proc = _make_processor()
|
||||
proc.qdrant.get_collections.side_effect = Exception("connection error")
|
||||
|
||||
result = proc.find_collection("workspace", "col", "schema")
|
||||
result = await proc.find_collection("workspace", "col", "schema")
|
||||
assert result is None
|
||||
|
||||
|
||||
|
|
@ -139,7 +142,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_no_collection_returns_empty(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value=None)
|
||||
proc.find_collection = AsyncMock(return_value=None)
|
||||
request = _make_request()
|
||||
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
|
|
@ -148,7 +151,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_successful_query_returns_matches(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
points = [
|
||||
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
|
||||
|
|
@ -172,7 +175,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_index_name_filter_applied(self):
|
||||
"""When index_name is specified, a Qdrant filter should be used."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
|
|
@ -188,7 +191,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_no_index_name_no_filter(self):
|
||||
"""When index_name is empty, no filter should be applied."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
|
|
@ -204,7 +207,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_missing_payload_fields_default(self):
|
||||
"""Points with missing payload fields should use defaults."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
point = MagicMock()
|
||||
point.payload = {} # Empty payload
|
||||
|
|
@ -225,7 +228,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_qdrant_error_propagates(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
proc.qdrant.query_points.side_effect = Exception("qdrant down")
|
||||
|
||||
request = _make_request()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue