mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-09 15:22:38 +02:00
Structured data 2 (#645)
* Structured data refactor - multi-index tables, remove need for manual mods to the Cassandra tables * Tech spec updated to track implementation
This commit is contained in:
parent
5ffad92345
commit
1809c1f56d
87 changed files with 5233 additions and 3235 deletions
380
tests/unit/test_embeddings/test_row_embeddings_processor.py
Normal file
380
tests/unit/test_embeddings/test_row_embeddings_processor.py
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
"""
|
||||
Unit tests for trustgraph.embeddings.row_embeddings.embeddings
|
||||
Tests the Stage 1 processor that computes embeddings for row index fields.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test row embeddings processor functionality"""
|
||||
|
||||
async def test_processor_initialization(self):
|
||||
"""Test basic processor initialization"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-row-embeddings'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
assert hasattr(processor, 'schemas')
|
||||
assert processor.schemas == {}
|
||||
assert processor.batch_size == 10 # default
|
||||
|
||||
async def test_processor_initialization_with_custom_batch_size(self):
|
||||
"""Test processor initialization with custom batch size"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-row-embeddings',
|
||||
'batch_size': 25
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
assert processor.batch_size == 25
|
||||
|
||||
async def test_get_index_names_single_index(self):
|
||||
"""Test getting index names with single indexed field"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
schema = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
Field(name='email', type='text', indexed=False),
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
# Should include primary key and indexed field
|
||||
assert 'id' in index_names
|
||||
assert 'name' in index_names
|
||||
assert 'email' not in index_names
|
||||
|
||||
async def test_get_index_names_no_indexes(self):
|
||||
"""Test getting index names when no fields are indexed"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
schema = RowSchema(
|
||||
name='logs',
|
||||
description='Log records',
|
||||
fields=[
|
||||
Field(name='timestamp', type='text'),
|
||||
Field(name='message', type='text'),
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
assert index_names == []
|
||||
|
||||
async def test_build_index_value_single_field(self):
|
||||
"""Test building index value for single field"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
value_map = {
|
||||
'id': 'CUST001',
|
||||
'name': 'John Doe',
|
||||
'email': 'john@example.com'
|
||||
}
|
||||
|
||||
result = processor.build_index_value(value_map, 'name')
|
||||
|
||||
assert result == ['John Doe']
|
||||
|
||||
async def test_build_index_value_composite_index(self):
|
||||
"""Test building index value for composite index"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
value_map = {
|
||||
'first_name': 'John',
|
||||
'last_name': 'Doe',
|
||||
'city': 'New York'
|
||||
}
|
||||
|
||||
result = processor.build_index_value(value_map, 'first_name, last_name')
|
||||
|
||||
assert result == ['John', 'Doe']
|
||||
|
||||
async def test_build_index_value_missing_field(self):
|
||||
"""Test building index value when field is missing"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
value_map = {
|
||||
'name': 'John Doe'
|
||||
}
|
||||
|
||||
result = processor.build_index_value(value_map, 'missing_field')
|
||||
|
||||
assert result == ['']
|
||||
|
||||
async def test_build_text_for_embedding_single_value(self):
|
||||
"""Test building text representation for single value"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
result = processor.build_text_for_embedding(['John Doe'])
|
||||
|
||||
assert result == 'John Doe'
|
||||
|
||||
async def test_build_text_for_embedding_multiple_values(self):
|
||||
"""Test building text representation for multiple values"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
result = processor.build_text_for_embedding(['John', 'Doe', 'NYC'])
|
||||
|
||||
assert result == 'John Doe NYC'
|
||||
|
||||
async def test_on_schema_config_loads_schemas(self):
|
||||
"""Test that schema configuration is loaded correctly"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
import json
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor',
|
||||
'config_type': 'schema'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
schema_def = {
|
||||
'name': 'customers',
|
||||
'description': 'Customer records',
|
||||
'fields': [
|
||||
{'name': 'id', 'type': 'text', 'primary_key': True},
|
||||
{'name': 'name', 'type': 'text', 'indexed': True},
|
||||
{'name': 'email', 'type': 'text'}
|
||||
]
|
||||
}
|
||||
|
||||
config_data = {
|
||||
'schema': {
|
||||
'customers': json.dumps(schema_def)
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
|
||||
assert 'customers' in processor.schemas
|
||||
assert processor.schemas['customers'].name == 'customers'
|
||||
assert len(processor.schemas['customers'].fields) == 3
|
||||
|
||||
async def test_on_schema_config_handles_missing_type(self):
|
||||
"""Test that missing schema type is handled gracefully"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor',
|
||||
'config_type': 'schema'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
config_data = {
|
||||
'other_type': {}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
|
||||
assert processor.schemas == {}
|
||||
|
||||
async def test_on_message_drops_unknown_collection(self):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import ExtractedObject
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
# No collections registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'unknown_user'
|
||||
metadata.collection = 'unknown_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name='customers',
|
||||
values=[{'id': '123', 'name': 'Test'}]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = obj
|
||||
|
||||
mock_flow = MagicMock()
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Flow should not be called for output
|
||||
mock_flow.assert_not_called()
|
||||
|
||||
async def test_on_message_drops_unknown_schema(self):
|
||||
"""Test that messages for unknown schemas are dropped"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import ExtractedObject
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
# No schemas registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name='unknown_schema',
|
||||
values=[{'id': '123', 'name': 'Test'}]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = obj
|
||||
|
||||
mock_flow = MagicMock()
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Flow should not be called for output
|
||||
mock_flow.assert_not_called()
|
||||
|
||||
async def test_on_message_processes_embeddings(self):
|
||||
"""Test processing a message and computing embeddings"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import ExtractedObject, RowSchema, Field
|
||||
import json
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor',
|
||||
'config_type': 'schema'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
# Set up schema
|
||||
processor.schemas['customers'] = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name='customers',
|
||||
values=[
|
||||
{'id': 'CUST001', 'name': 'John Doe'},
|
||||
{'id': 'CUST002', 'name': 'Jane Smith'}
|
||||
]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = obj
|
||||
|
||||
# Mock the flow
|
||||
mock_embeddings_request = AsyncMock()
|
||||
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow_factory(name):
|
||||
if name == 'embeddings-request':
|
||||
return mock_embeddings_request
|
||||
elif name == 'output':
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
mock_flow = MagicMock(side_effect=flow_factory)
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Should have called embed for each unique text
|
||||
# 4 values: CUST001, John Doe, CUST002, Jane Smith
|
||||
assert mock_embeddings_request.embed.call_count == 4
|
||||
|
||||
# Should have sent output
|
||||
mock_output.send.assert_called()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Unit tests for objects import dispatcher.
|
||||
Unit tests for rows import dispatcher.
|
||||
|
||||
Tests the business logic of objects import dispatcher
|
||||
Tests the business logic of rows import dispatcher
|
||||
while mocking the Publisher and websocket components.
|
||||
"""
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ import asyncio
|
|||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
|
||||
from trustgraph.gateway.dispatch.rows_import import RowsImport
|
||||
from trustgraph.schema import Metadata, ExtractedObject
|
||||
|
||||
|
||||
|
|
@ -92,16 +92,16 @@ def minimal_objects_message():
|
|||
}
|
||||
|
||||
|
||||
class TestObjectsImportInitialization:
|
||||
"""Test ObjectsImport initialization."""
|
||||
class TestRowsImportInitialization:
|
||||
"""Test RowsImport initialization."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport creates Publisher with correct parameters."""
|
||||
"""Test that RowsImport creates Publisher with correct parameters."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -116,28 +116,28 @@ class TestObjectsImportInitialization:
|
|||
)
|
||||
|
||||
# Verify instance variables are set correctly
|
||||
assert objects_import.ws == mock_websocket
|
||||
assert objects_import.running == mock_running
|
||||
assert objects_import.publisher == mock_publisher_instance
|
||||
assert rows_import.ws == mock_websocket
|
||||
assert rows_import.running == mock_running
|
||||
assert rows_import.publisher == mock_publisher_instance
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport stores all required references."""
|
||||
objects_import = ObjectsImport(
|
||||
"""Test that RowsImport stores all required references."""
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="objects-queue"
|
||||
)
|
||||
|
||||
assert objects_import.ws is mock_websocket
|
||||
assert objects_import.running is mock_running
|
||||
assert rows_import.ws is mock_websocket
|
||||
assert rows_import.running is mock_running
|
||||
|
||||
|
||||
class TestObjectsImportLifecycle:
|
||||
"""Test ObjectsImport lifecycle methods."""
|
||||
class TestRowsImportLifecycle:
|
||||
"""Test RowsImport lifecycle methods."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that start() calls publisher.start()."""
|
||||
|
|
@ -145,18 +145,18 @@ class TestObjectsImportLifecycle:
|
|||
mock_publisher_instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.start()
|
||||
await rows_import.start()
|
||||
|
||||
mock_publisher_instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that destroy() properly stops publisher and closes websocket."""
|
||||
|
|
@ -164,21 +164,21 @@ class TestObjectsImportLifecycle:
|
|||
mock_publisher_instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.destroy()
|
||||
await rows_import.destroy()
|
||||
|
||||
# Verify sequence of operations
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_publisher_instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
|
||||
"""Test that destroy() handles None websocket gracefully."""
|
||||
|
|
@ -186,7 +186,7 @@ class TestObjectsImportLifecycle:
|
|||
mock_publisher_instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -194,16 +194,16 @@ class TestObjectsImportLifecycle:
|
|||
)
|
||||
|
||||
# Should not raise exception
|
||||
await objects_import.destroy()
|
||||
await rows_import.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_publisher_instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestObjectsImportMessageProcessing:
|
||||
"""Test ObjectsImport message processing."""
|
||||
class TestRowsImportMessageProcessing:
|
||||
"""Test RowsImport message processing."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() processes complete message correctly."""
|
||||
|
|
@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing:
|
|||
assert sent_object.metadata.collection == "testcollection"
|
||||
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
|
||||
"""Test that receive() handles message with minimal required fields."""
|
||||
|
|
@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = minimal_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing:
|
|||
assert sent_object.source_span == "" # Default value
|
||||
assert len(sent_object.metadata.metadata) == 0 # Default empty list
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() uses appropriate default values for optional fields."""
|
||||
|
|
@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = message_data
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Get the sent object and verify defaults
|
||||
sent_object = mock_publisher_instance.send.call_args[0][1]
|
||||
|
|
@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing:
|
|||
assert sent_object.source_span == ""
|
||||
|
||||
|
||||
class TestObjectsImportRunMethod:
|
||||
"""Test ObjectsImport run method."""
|
||||
class TestRowsImportRunMethod:
|
||||
"""Test RowsImport run method."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that run() loops while running.get() returns True."""
|
||||
|
|
@ -331,14 +331,14 @@ class TestObjectsImportRunMethod:
|
|||
# Set up running state to return True twice, then False
|
||||
mock_running.get.side_effect = [True, True, False]
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.run()
|
||||
await rows_import.run()
|
||||
|
||||
# Verify sleep was called twice (for the two True iterations)
|
||||
assert mock_sleep.call_count == 2
|
||||
|
|
@ -348,10 +348,10 @@ class TestObjectsImportRunMethod:
|
|||
mock_websocket.close.assert_called_once()
|
||||
|
||||
# Verify websocket was set to None
|
||||
assert objects_import.ws is None
|
||||
assert rows_import.ws is None
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
|
||||
"""Test that run() handles None websocket gracefully."""
|
||||
|
|
@ -360,7 +360,7 @@ class TestObjectsImportRunMethod:
|
|||
|
||||
mock_running.get.return_value = False # Exit immediately
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -368,14 +368,14 @@ class TestObjectsImportRunMethod:
|
|||
)
|
||||
|
||||
# Should not raise exception
|
||||
await objects_import.run()
|
||||
await rows_import.run()
|
||||
|
||||
# Verify websocket remains None
|
||||
assert objects_import.ws is None
|
||||
assert rows_import.ws is None
|
||||
|
||||
|
||||
class TestObjectsImportBatchProcessing:
|
||||
"""Test ObjectsImport batch processing functionality."""
|
||||
class TestRowsImportBatchProcessing:
|
||||
"""Test RowsImport batch processing functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def batch_objects_message(self):
|
||||
|
|
@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing:
|
|||
"source_span": "Multiple people found in document"
|
||||
}
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
|
||||
"""Test that receive() processes batch message correctly."""
|
||||
|
|
@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = batch_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing:
|
|||
assert sent_object.confidence == 0.85
|
||||
assert sent_object.source_span == "Multiple people found in document"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() handles empty batch correctly."""
|
||||
|
|
@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_batch_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Should still send the message
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing:
|
|||
assert len(sent_object.values) == 0
|
||||
|
||||
|
||||
class TestObjectsImportErrorHandling:
|
||||
"""Test error handling in ObjectsImport."""
|
||||
class TestRowsImportErrorHandling:
|
||||
"""Test error handling in RowsImport."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() propagates publisher send errors."""
|
||||
|
|
@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling:
|
|||
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling:
|
|||
mock_msg.json.return_value = sample_objects_message
|
||||
|
||||
with pytest.raises(Exception, match="Publisher error"):
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() handles malformed JSON appropriately."""
|
||||
mock_publisher_class.return_value = Mock()
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling:
|
|||
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
|
@ -76,7 +76,7 @@ def cities_schema():
|
|||
def validator():
|
||||
"""Create a mock processor with just the validation method"""
|
||||
from unittest.mock import MagicMock
|
||||
from trustgraph.extract.kg.objects.processor import Processor
|
||||
from trustgraph.extract.kg.rows.processor import Processor
|
||||
|
||||
# Create a mock processor
|
||||
mock_processor = MagicMock()
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ class TestFlowClient:
|
|||
expected_methods = [
|
||||
'text_completion', 'agent', 'graph_rag', 'document_rag',
|
||||
'graph_embeddings_query', 'embeddings', 'prompt',
|
||||
'triples_query', 'objects_query'
|
||||
'triples_query', 'rows_query'
|
||||
]
|
||||
|
||||
for method in expected_methods:
|
||||
|
|
@ -216,7 +216,7 @@ class TestSocketClient:
|
|||
expected_methods = [
|
||||
'agent', 'text_completion', 'graph_rag', 'document_rag',
|
||||
'prompt', 'graph_embeddings_query', 'embeddings',
|
||||
'triples_query', 'objects_query', 'mcp_tool'
|
||||
'triples_query', 'rows_query', 'mcp_tool'
|
||||
]
|
||||
|
||||
for method in expected_methods:
|
||||
|
|
@ -243,7 +243,7 @@ class TestBulkClient:
|
|||
'import_graph_embeddings',
|
||||
'import_document_embeddings',
|
||||
'import_entity_contexts',
|
||||
'import_objects'
|
||||
'import_rows'
|
||||
]
|
||||
|
||||
for method in import_methods:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
"""
|
||||
Unit tests for Cassandra Objects GraphQL Query Processor
|
||||
Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation)
|
||||
|
||||
Tests the business logic of the GraphQL query processor including:
|
||||
- GraphQL schema generation from RowSchema
|
||||
- Query execution and validation
|
||||
- CQL translation logic
|
||||
- Schema configuration handling
|
||||
- Query execution using unified rows table
|
||||
- Name sanitization
|
||||
- GraphQL query execution
|
||||
- Message processing logic
|
||||
"""
|
||||
|
||||
|
|
@ -12,119 +13,91 @@ import pytest
|
|||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
import strawberry
|
||||
from strawberry import Schema
|
||||
|
||||
from trustgraph.query.objects.cassandra.service import Processor
|
||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from trustgraph.query.rows.cassandra.service import Processor
|
||||
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
|
||||
class TestObjectsGraphQLQueryLogic:
|
||||
"""Test business logic without external dependencies"""
|
||||
|
||||
def test_get_python_type_mapping(self):
|
||||
"""Test schema field type conversion to Python types"""
|
||||
processor = MagicMock()
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
|
||||
# Basic type mappings
|
||||
assert processor.get_python_type("string") == str
|
||||
assert processor.get_python_type("integer") == int
|
||||
assert processor.get_python_type("float") == float
|
||||
assert processor.get_python_type("boolean") == bool
|
||||
assert processor.get_python_type("timestamp") == str
|
||||
assert processor.get_python_type("date") == str
|
||||
assert processor.get_python_type("time") == str
|
||||
assert processor.get_python_type("uuid") == str
|
||||
|
||||
# Unknown type defaults to str
|
||||
assert processor.get_python_type("unknown_type") == str
|
||||
|
||||
def test_create_graphql_type_basic_fields(self):
|
||||
"""Test GraphQL type creation for basic field types"""
|
||||
processor = MagicMock()
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="test_table",
|
||||
description="Test table",
|
||||
fields=[
|
||||
Field(
|
||||
name="id",
|
||||
type="string",
|
||||
primary=True,
|
||||
required=True,
|
||||
description="Primary key"
|
||||
),
|
||||
Field(
|
||||
name="name",
|
||||
type="string",
|
||||
required=True,
|
||||
description="Name field"
|
||||
),
|
||||
Field(
|
||||
name="age",
|
||||
type="integer",
|
||||
required=False,
|
||||
description="Optional age"
|
||||
),
|
||||
Field(
|
||||
name="active",
|
||||
type="boolean",
|
||||
required=False,
|
||||
description="Status flag"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Create GraphQL type
|
||||
graphql_type = processor.create_graphql_type("test_table", schema)
|
||||
|
||||
# Verify type was created
|
||||
assert graphql_type is not None
|
||||
assert hasattr(graphql_type, '__name__')
|
||||
assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower()
|
||||
class TestRowsGraphQLQueryLogic:
|
||||
"""Test business logic for unified table query implementation"""
|
||||
|
||||
def test_sanitize_name_cassandra_compatibility(self):
|
||||
"""Test name sanitization for Cassandra field names"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test field name sanitization (matches storage processor)
|
||||
|
||||
# Test field name sanitization (uses r_ prefix like storage processor)
|
||||
assert processor.sanitize_name("simple_field") == "simple_field"
|
||||
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
|
||||
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
|
||||
assert processor.sanitize_name("123_field") == "o_123_field"
|
||||
assert processor.sanitize_name("123_field") == "r_123_field"
|
||||
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
|
||||
assert processor.sanitize_name("special!@#chars") == "special___chars"
|
||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||
assert processor.sanitize_name("CamelCase") == "camelcase"
|
||||
|
||||
def test_sanitize_table_name(self):
|
||||
"""Test table name sanitization (always gets o_ prefix)"""
|
||||
def test_get_index_names(self):
|
||||
"""Test extraction of index names from schema"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
|
||||
# Table names always get o_ prefix
|
||||
assert processor.sanitize_table("simple_table") == "o_simple_table"
|
||||
assert processor.sanitize_table("Table-Name") == "o_table_name"
|
||||
assert processor.sanitize_table("123table") == "o_123table"
|
||||
assert processor.sanitize_table("") == "o_"
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string"), # Not indexed
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
assert "id" in index_names
|
||||
assert "category" in index_names
|
||||
assert "status" in index_names
|
||||
assert "name" not in index_names
|
||||
assert len(index_names) == 3
|
||||
|
||||
def test_find_matching_index_exact_match(self):
|
||||
"""Test finding matching index for exact match query"""
|
||||
processor = MagicMock()
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string") # Not indexed
|
||||
]
|
||||
)
|
||||
|
||||
# Filter on indexed field should return match
|
||||
filters = {"category": "electronics"}
|
||||
result = processor.find_matching_index(schema, filters)
|
||||
assert result is not None
|
||||
assert result[0] == "category"
|
||||
assert result[1] == ["electronics"]
|
||||
|
||||
# Filter on non-indexed field should return None
|
||||
filters = {"name": "test"}
|
||||
result = processor.find_matching_index(schema, filters)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configuration"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.graphql_types = {}
|
||||
processor.graphql_schema = None
|
||||
processor.config_key = "schema" # Set the config key
|
||||
processor.generate_graphql_schema = AsyncMock()
|
||||
processor.config_key = "schema"
|
||||
processor.schema_builder = MagicMock()
|
||||
processor.schema_builder.clear = MagicMock()
|
||||
processor.schema_builder.add_schema = MagicMock()
|
||||
processor.schema_builder.build = MagicMock(return_value=MagicMock())
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Create test config
|
||||
schema_config = {
|
||||
"schema": {
|
||||
|
|
@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic:
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Process config
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer" in processor.schemas
|
||||
schema = processor.schemas["customer"]
|
||||
assert schema.name == "customer"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
||||
# Verify fields
|
||||
id_field = next(f for f in schema.fields if f.name == "id")
|
||||
assert id_field.primary is True
|
||||
# The field should have been created correctly from JSON
|
||||
# Let's test what we can verify - that the field has the right attributes
|
||||
assert hasattr(id_field, 'required') # Has the required attribute
|
||||
assert hasattr(id_field, 'primary') # Has the primary attribute
|
||||
|
||||
|
||||
email_field = next(f for f in schema.fields if f.name == "email")
|
||||
assert email_field.indexed is True
|
||||
|
||||
|
||||
status_field = next(f for f in schema.fields if f.name == "status")
|
||||
assert status_field.enum_values == ["active", "inactive"]
|
||||
|
||||
# Verify GraphQL schema regeneration was called
|
||||
processor.generate_graphql_schema.assert_called_once()
|
||||
|
||||
def test_cql_query_building_basic(self):
|
||||
"""Test basic CQL query construction"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to capture the query
|
||||
mock_result = []
|
||||
processor.session.execute.return_value = mock_result
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="test_table",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string", indexed=True),
|
||||
Field(name="status", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
# Test query building
|
||||
asyncio = pytest.importorskip("asyncio")
|
||||
|
||||
async def run_test():
|
||||
await processor.query_cassandra(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="test_table",
|
||||
row_schema=schema,
|
||||
filters={"name": "John", "invalid_filter": "ignored"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Run the async test
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(run_test())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Verify Cassandra connection and query execution
|
||||
processor.connect_cassandra.assert_called_once()
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
# Verify the query structure (can't easily test exact query without complex mocking)
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0] # First positional argument is the query
|
||||
params = call_args[0][1] # Second positional argument is parameters
|
||||
|
||||
# Basic query structure checks
|
||||
assert "SELECT * FROM test_user.o_test_table" in query
|
||||
assert "WHERE" in query
|
||||
assert "collection = %s" in query
|
||||
assert "LIMIT 10" in query
|
||||
|
||||
# Parameters should include collection and name filter
|
||||
assert "test_collection" in params
|
||||
assert "John" in params
|
||||
# Verify schema builder was called
|
||||
processor.schema_builder.add_schema.assert_called_once()
|
||||
processor.schema_builder.build.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_context_handling(self):
|
||||
|
|
@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic:
|
|||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Mock schema execution
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
|
||||
mock_result.errors = None
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
|
|
@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic:
|
|||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify schema.execute was called with correct context
|
||||
processor.graphql_schema.execute.assert_called_once()
|
||||
call_args = processor.graphql_schema.execute.call_args
|
||||
|
||||
|
||||
# Verify context was passed
|
||||
context = call_args[1]['context_value'] # keyword argument
|
||||
context = call_args[1]['context_value']
|
||||
assert context["processor"] == processor
|
||||
assert context["user"] == "test_user"
|
||||
assert context["collection"] == "test_collection"
|
||||
|
||||
|
||||
# Verify result structure
|
||||
assert "data" in result
|
||||
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
|
||||
|
|
@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic:
|
|||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Create a simple object to simulate GraphQL error instead of MagicMock
|
||||
|
||||
# Create a simple object to simulate GraphQL error
|
||||
class MockError:
|
||||
def __init__(self, message, path, extensions):
|
||||
self.message = message
|
||||
self.path = path
|
||||
self.extensions = extensions
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
mock_error = MockError(
|
||||
message="Field 'invalid_field' doesn't exist",
|
||||
path=["customers", "0", "invalid_field"],
|
||||
extensions={"code": "FIELD_NOT_FOUND"}
|
||||
)
|
||||
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = None
|
||||
mock_result.errors = [mock_error]
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query='{ customers { invalid_field } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify error handling
|
||||
assert "errors" in result
|
||||
assert len(result["errors"]) == 1
|
||||
|
||||
|
||||
error = result["errors"][0]
|
||||
assert error["message"] == "Field 'invalid_field' doesn't exist"
|
||||
assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path
|
||||
assert error["path"] == ["customers", "0", "invalid_field"]
|
||||
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
|
||||
|
||||
def test_schema_generation_basic_structure(self):
|
||||
"""Test basic GraphQL schema generation structure"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"customer": RowSchema(
|
||||
name="customer",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.graphql_types = {}
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Test individual type creation (avoiding the full schema generation which has annotation issues)
|
||||
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
|
||||
processor.graphql_types["customer"] = graphql_type
|
||||
|
||||
# Verify type was created
|
||||
assert len(processor.graphql_types) == 1
|
||||
assert "customer" in processor.graphql_types
|
||||
assert processor.graphql_types["customer"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_processing_success(self):
|
||||
"""Test successful message processing flow"""
|
||||
processor = MagicMock()
|
||||
processor.execute_graphql_query = AsyncMock()
|
||||
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Mock successful query result
|
||||
processor.execute_graphql_query.return_value = {
|
||||
"data": {"customers": [{"id": "1", "name": "John"}]},
|
||||
"errors": [],
|
||||
"extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String())
|
||||
"extensions": {}
|
||||
}
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = ObjectsQueryRequest(
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
collection="test_collection",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None
|
||||
)
|
||||
mock_msg.value.return_value = mock_request
|
||||
mock_msg.properties.return_value = {"id": "test-123"}
|
||||
|
||||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
||||
# Process message
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
||||
# Verify query was executed
|
||||
processor.execute_graphql_query.assert_called_once_with(
|
||||
query='{ customers { id name } }',
|
||||
|
|
@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic:
|
|||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify response was sent
|
||||
mock_response_flow.send.assert_called_once()
|
||||
response_call = mock_response_flow.send.call_args[0][0]
|
||||
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response_call, ObjectsQueryResponse)
|
||||
assert isinstance(response_call, RowsQueryResponse)
|
||||
assert response_call.error is None
|
||||
assert '"customers"' in response_call.data # JSON encoded
|
||||
assert len(response_call.errors) == 0
|
||||
|
|
@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic:
|
|||
processor = MagicMock()
|
||||
processor.execute_graphql_query = AsyncMock()
|
||||
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Mock query execution error
|
||||
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = ObjectsQueryRequest(
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ invalid_query }',
|
||||
|
|
@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic:
|
|||
)
|
||||
mock_msg.value.return_value = mock_request
|
||||
mock_msg.properties.return_value = {"id": "test-456"}
|
||||
|
||||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
||||
# Process message
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
||||
# Verify error response was sent
|
||||
mock_response_flow.send.assert_called_once()
|
||||
response_call = mock_response_flow.send.call_args[0][0]
|
||||
|
||||
|
||||
# Verify error response structure
|
||||
assert isinstance(response_call, ObjectsQueryResponse)
|
||||
assert isinstance(response_call, RowsQueryResponse)
|
||||
assert response_call.error is not None
|
||||
assert response_call.error.type == "objects-query-error"
|
||||
assert response_call.error.type == "rows-query-error"
|
||||
assert "No schema available" in response_call.error.message
|
||||
assert response_call.data is None
|
||||
|
||||
|
||||
class TestCQLQueryGeneration:
|
||||
"""Test CQL query generation logic in isolation"""
|
||||
|
||||
def test_partition_key_inclusion(self):
|
||||
"""Test that collection is always included in queries"""
|
||||
class TestUnifiedTableQueries:
|
||||
"""Test queries against the unified rows table"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_index_match(self):
|
||||
"""Test query execution with matching index"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
|
||||
# Mock the query building (simplified version)
|
||||
keyspace = processor.sanitize_name("test_user")
|
||||
table = processor.sanitize_table("test_table")
|
||||
|
||||
query = f"SELECT * FROM {keyspace}.{table}"
|
||||
where_clauses = ["collection = %s"]
|
||||
|
||||
assert "collection = %s" in where_clauses
|
||||
assert keyspace == "test_user"
|
||||
assert table == "o_test_table"
|
||||
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to return test data
|
||||
mock_row = MagicMock()
|
||||
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
|
||||
processor.session.execute.return_value = [mock_row]
|
||||
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
# Query with filter on indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
filters={"category": "electronics"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify Cassandra was connected and queried
|
||||
processor.connect_cassandra.assert_called_once()
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
# Verify query structure - should query unified rows table
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
params = call_args[0][1]
|
||||
|
||||
assert "SELECT data, source FROM test_user.rows" in query
|
||||
assert "collection = %s" in query
|
||||
assert "schema_name = %s" in query
|
||||
assert "index_name = %s" in query
|
||||
assert "index_value = %s" in query
|
||||
|
||||
assert params[0] == "test_collection"
|
||||
assert params[1] == "products"
|
||||
assert params[2] == "category"
|
||||
assert params[3] == ["electronics"]
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "123"
|
||||
assert results[0]["category"] == "electronics"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_without_index_match(self):
|
||||
"""Test query execution without matching index (scan mode)"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to return test data
|
||||
mock_row1 = MagicMock()
|
||||
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
|
||||
mock_row2 = MagicMock()
|
||||
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
|
||||
processor.session.execute.return_value = [mock_row1, mock_row2]
|
||||
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string"), # Not indexed
|
||||
Field(name="price", type="string") # Not indexed
|
||||
]
|
||||
)
|
||||
|
||||
# Query with filter on non-indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
filters={"name": "Product A"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Query should use ALLOW FILTERING for scan
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
|
||||
assert "ALLOW FILTERING" in query
|
||||
|
||||
# Should post-filter results
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "Product A"
|
||||
|
||||
|
||||
class TestFilterMatching:
|
||||
"""Test filter matching logic"""
|
||||
|
||||
def test_matches_filters_exact_match(self):
|
||||
"""Test exact match filter"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
|
||||
|
||||
row = {"status": "active", "name": "test"}
|
||||
assert processor._matches_filters(row, {"status": "active"}, schema) is True
|
||||
assert processor._matches_filters(row, {"status": "inactive"}, schema) is False
|
||||
|
||||
def test_matches_filters_comparison_operators(self):
|
||||
"""Test comparison operators in filters"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="price", type="float")])
|
||||
|
||||
row = {"price": "100.0"}
|
||||
|
||||
# Greater than
|
||||
assert processor._matches_filters(row, {"price_gt": 50}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_gt": 150}, schema) is False
|
||||
|
||||
# Less than
|
||||
assert processor._matches_filters(row, {"price_lt": 150}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_lt": 50}, schema) is False
|
||||
|
||||
# Greater than or equal
|
||||
assert processor._matches_filters(row, {"price_gte": 100}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_gte": 101}, schema) is False
|
||||
|
||||
# Less than or equal
|
||||
assert processor._matches_filters(row, {"price_lte": 100}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_lte": 99}, schema) is False
|
||||
|
||||
def test_matches_filters_contains(self):
|
||||
"""Test contains filter"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="description", type="string")])
|
||||
|
||||
row = {"description": "A great product for everyone"}
|
||||
|
||||
assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True
|
||||
assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False
|
||||
|
||||
def test_matches_filters_in_list(self):
|
||||
"""Test in-list filter"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
|
||||
|
||||
row = {"status": "active"}
|
||||
|
||||
assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True
|
||||
assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False
|
||||
|
||||
|
||||
class TestIndexedFieldFiltering:
|
||||
"""Test that only indexed or primary key fields can be directly filtered"""
|
||||
|
||||
def test_indexed_field_filtering(self):
|
||||
"""Test that only indexed or primary key fields can be filtered"""
|
||||
# Create schema with mixed field types
|
||||
schema = RowSchema(
|
||||
name="test",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="indexed_field", type="string", indexed=True),
|
||||
Field(name="indexed_field", type="string", indexed=True),
|
||||
Field(name="normal_field", type="string", indexed=False),
|
||||
Field(name="another_field", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
filters = {
|
||||
"id": "test123", # Primary key - should be included
|
||||
"indexed_field": "value", # Indexed - should be included
|
||||
"normal_field": "ignored", # Not indexed - should be ignored
|
||||
"another_field": "also_ignored" # Not indexed - should be ignored
|
||||
}
|
||||
|
||||
|
||||
# Simulate the filtering logic from the processor
|
||||
valid_filters = []
|
||||
for field_name, value in filters.items():
|
||||
|
|
@ -492,7 +531,7 @@ class TestCQLQueryGeneration:
|
|||
schema_field = next((f for f in schema.fields if f.name == field_name), None)
|
||||
if schema_field and (schema_field.indexed or schema_field.primary):
|
||||
valid_filters.append((field_name, value))
|
||||
|
||||
|
||||
# Only id and indexed_field should be included
|
||||
assert len(valid_filters) == 2
|
||||
field_names = [f[0] for f in valid_filters]
|
||||
|
|
@ -500,52 +539,3 @@ class TestCQLQueryGeneration:
|
|||
assert "indexed_field" in field_names
|
||||
assert "normal_field" not in field_names
|
||||
assert "another_field" not in field_names
|
||||
|
||||
|
||||
class TestGraphQLSchemaGeneration:
|
||||
"""Test GraphQL schema generation in detail"""
|
||||
|
||||
def test_field_type_annotations(self):
|
||||
"""Test that GraphQL types have correct field annotations"""
|
||||
processor = MagicMock()
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Create schema with various field types
|
||||
schema = RowSchema(
|
||||
name="test",
|
||||
fields=[
|
||||
Field(name="id", type="string", required=True, primary=True),
|
||||
Field(name="count", type="integer", required=True),
|
||||
Field(name="price", type="float", required=False),
|
||||
Field(name="active", type="boolean", required=False),
|
||||
Field(name="optional_text", type="string", required=False)
|
||||
]
|
||||
)
|
||||
|
||||
# Create GraphQL type
|
||||
graphql_type = processor.create_graphql_type("test", schema)
|
||||
|
||||
# Verify type was created successfully
|
||||
assert graphql_type is not None
|
||||
|
||||
def test_basic_type_creation(self):
|
||||
"""Test that GraphQL types are created correctly"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"customer": RowSchema(
|
||||
name="customer",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
}
|
||||
processor.graphql_types = {}
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Create GraphQL type directly
|
||||
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
|
||||
processor.graphql_types["customer"] = graphql_type
|
||||
|
||||
# Verify customer type was created
|
||||
assert "customer" in processor.graphql_types
|
||||
assert processor.graphql_types["customer"] is not None
|
||||
|
|
@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
from trustgraph.schema import (
|
||||
StructuredQueryRequest, StructuredQueryResponse,
|
||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||
ObjectsQueryRequest, ObjectsQueryResponse,
|
||||
RowsQueryRequest, RowsQueryResponse,
|
||||
Error, GraphQLError
|
||||
)
|
||||
from trustgraph.retrieval.structured_query.service import Processor
|
||||
|
|
@ -68,7 +68,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
|
||||
# Mock objects query service response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
|
||||
errors=None,
|
||||
|
|
@ -86,7 +86,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -108,7 +108,7 @@ class TestStructuredQueryProcessor:
|
|||
# Verify objects query service was called correctly
|
||||
mock_objects_client.request.assert_called_once()
|
||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
||||
assert isinstance(objects_call_args, RowsQueryRequest)
|
||||
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
|
||||
assert objects_call_args.variables == {"state": "NY"}
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
|
|
@ -224,7 +224,7 @@ class TestStructuredQueryProcessor:
|
|||
assert response.error is not None
|
||||
assert "empty GraphQL query" in response.error.message
|
||||
|
||||
async def test_objects_query_service_error(self, processor):
|
||||
async def test_rows_query_service_error(self, processor):
|
||||
"""Test handling of objects query service errors"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
|
|
@ -250,7 +250,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
|
||||
# Mock objects query service error
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=Error(type="graphql-execution-error", message="Table 'customers' not found"),
|
||||
data=None,
|
||||
errors=None,
|
||||
|
|
@ -267,7 +267,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -284,7 +284,7 @@ class TestStructuredQueryProcessor:
|
|||
response = response_call[0][0]
|
||||
|
||||
assert response.error is not None
|
||||
assert "Objects query service error" in response.error.message
|
||||
assert "Rows query service error" in response.error.message
|
||||
assert "Table 'customers' not found" in response.error.message
|
||||
|
||||
async def test_graphql_errors_handling(self, processor):
|
||||
|
|
@ -321,7 +321,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
]
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=None,
|
||||
errors=graphql_errors,
|
||||
|
|
@ -338,7 +338,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -400,7 +400,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
|
||||
# Mock objects response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}',
|
||||
errors=None
|
||||
|
|
@ -416,7 +416,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -464,7 +464,7 @@ class TestStructuredQueryProcessor:
|
|||
confidence=0.9
|
||||
)
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=None, # Null data
|
||||
errors=None,
|
||||
|
|
@ -481,7 +481,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
||||
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
||||
from trustgraph.storage.knowledge.store import Processor as KgStore
|
||||
|
||||
|
|
@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration:
|
|||
assert processor.cassandra_password is None
|
||||
|
||||
|
||||
class TestObjectsWriterConfiguration:
|
||||
class TestRowsWriterConfiguration:
|
||||
"""Test Cassandra configuration in objects writer processor."""
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
def test_environment_variable_configuration(self, mock_cluster):
|
||||
"""Test processor picks up configuration from environment variables."""
|
||||
env_vars = {
|
||||
|
|
@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
|
||||
assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2']
|
||||
assert processor.cassandra_username == 'obj-env-user'
|
||||
assert processor.cassandra_password == 'obj-env-pass'
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
def test_cassandra_connection_with_hosts_list(self, mock_cluster):
|
||||
"""Test that Cassandra connection uses hosts list correctly."""
|
||||
env_vars = {
|
||||
|
|
@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify cluster was called with hosts list
|
||||
|
|
@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration:
|
|||
assert 'contact_points' in call_args.kwargs
|
||||
assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3']
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_authentication_configuration(self, mock_auth_provider, mock_cluster):
|
||||
"""Test authentication is configured when credentials are provided."""
|
||||
env_vars = {
|
||||
|
|
@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify auth provider was created with correct credentials
|
||||
|
|
@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling:
|
|||
def test_objects_writer_add_args(self):
|
||||
"""Test that objects writer adds standard Cassandra arguments."""
|
||||
import argparse
|
||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
||||
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
ObjectsWriter.add_args(parser)
|
||||
RowsWriter.add_args(parser)
|
||||
|
||||
# Parse empty args to check that arguments exist
|
||||
args = parser.parse_args([])
|
||||
|
|
|
|||
|
|
@ -1,533 +0,0 @@
|
|||
"""
|
||||
Unit tests for Cassandra Object Storage Processor
|
||||
|
||||
Tests the business logic of the object storage processor including:
|
||||
- Schema configuration handling
|
||||
- Type conversions
|
||||
- Name sanitization
|
||||
- Table structure generation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
from trustgraph.storage.objects.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
class TestObjectsCassandraStorageLogic:
|
||||
"""Test business logic without FlowProcessor dependencies"""
|
||||
|
||||
def test_sanitize_name(self):
|
||||
"""Test name sanitization for Cassandra compatibility"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test various name patterns (back to original logic)
|
||||
assert processor.sanitize_name("simple_name") == "simple_name"
|
||||
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
||||
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
|
||||
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
|
||||
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
|
||||
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
|
||||
|
||||
def test_get_cassandra_type(self):
|
||||
"""Test field type conversion to Cassandra types"""
|
||||
processor = MagicMock()
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
|
||||
# Basic type mappings
|
||||
assert processor.get_cassandra_type("string") == "text"
|
||||
assert processor.get_cassandra_type("boolean") == "boolean"
|
||||
assert processor.get_cassandra_type("timestamp") == "timestamp"
|
||||
assert processor.get_cassandra_type("uuid") == "uuid"
|
||||
|
||||
# Integer types with size hints
|
||||
assert processor.get_cassandra_type("integer", size=2) == "int"
|
||||
assert processor.get_cassandra_type("integer", size=8) == "bigint"
|
||||
|
||||
# Float types with size hints
|
||||
assert processor.get_cassandra_type("float", size=2) == "float"
|
||||
assert processor.get_cassandra_type("float", size=8) == "double"
|
||||
|
||||
# Unknown type defaults to text
|
||||
assert processor.get_cassandra_type("unknown_type") == "text"
|
||||
|
||||
def test_convert_value(self):
|
||||
"""Test value conversion for different field types"""
|
||||
processor = MagicMock()
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
|
||||
# Integer conversions
|
||||
assert processor.convert_value("123", "integer") == 123
|
||||
assert processor.convert_value(123.5, "integer") == 123
|
||||
assert processor.convert_value(None, "integer") is None
|
||||
|
||||
# Float conversions
|
||||
assert processor.convert_value("123.45", "float") == 123.45
|
||||
assert processor.convert_value(123, "float") == 123.0
|
||||
|
||||
# Boolean conversions
|
||||
assert processor.convert_value("true", "boolean") is True
|
||||
assert processor.convert_value("false", "boolean") is False
|
||||
assert processor.convert_value("1", "boolean") is True
|
||||
assert processor.convert_value("0", "boolean") is False
|
||||
assert processor.convert_value("yes", "boolean") is True
|
||||
assert processor.convert_value("no", "boolean") is False
|
||||
|
||||
# String conversions
|
||||
assert processor.convert_value(123, "string") == "123"
|
||||
assert processor.convert_value(True, "string") == "True"
|
||||
|
||||
def test_table_creation_cql_generation(self):
|
||||
"""Test CQL generation for table creation"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="customer_records",
|
||||
description="Test customer schema",
|
||||
fields=[
|
||||
Field(
|
||||
name="customer_id",
|
||||
type="string",
|
||||
size=50,
|
||||
primary=True,
|
||||
required=True,
|
||||
indexed=False
|
||||
),
|
||||
Field(
|
||||
name="email",
|
||||
type="string",
|
||||
size=100,
|
||||
required=True,
|
||||
indexed=True
|
||||
),
|
||||
Field(
|
||||
name="age",
|
||||
type="integer",
|
||||
size=4,
|
||||
required=False,
|
||||
indexed=False
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "customer_records", schema)
|
||||
|
||||
# Verify keyspace was ensured (check that it was added to known_keyspaces)
|
||||
assert "test_user" in processor.known_keyspaces
|
||||
|
||||
# Check the CQL that was executed (first call should be table creation)
|
||||
all_calls = processor.session.execute.call_args_list
|
||||
table_creation_cql = all_calls[0][0][0] # First call
|
||||
|
||||
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
|
||||
assert "collection text" in table_creation_cql
|
||||
assert "customer_id text" in table_creation_cql
|
||||
assert "email text" in table_creation_cql
|
||||
assert "age int" in table_creation_cql
|
||||
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
|
||||
|
||||
def test_table_creation_without_primary_key(self):
|
||||
"""Test table creation when no primary key is defined"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create schema without primary key
|
||||
schema = RowSchema(
|
||||
name="events",
|
||||
description="Event log",
|
||||
fields=[
|
||||
Field(name="event_type", type="string", size=50),
|
||||
Field(name="timestamp", type="timestamp", size=0)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "events", schema)
|
||||
|
||||
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
|
||||
executed_cql = processor.session.execute.call_args[0][0]
|
||||
assert "synthetic_id uuid" in executed_cql
|
||||
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer data",
|
||||
"fields": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "string",
|
||||
"primary_key": True,
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "balance",
|
||||
"type": "float",
|
||||
"size": 8
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
# Check field properties
|
||||
id_field = schema.fields[0]
|
||||
assert id_field.name == "id"
|
||||
assert id_field.type == "string"
|
||||
assert id_field.primary is True
|
||||
# Note: Field.required always returns False due to Pulsar schema limitations
|
||||
# The actual required value is tracked during schema parsing
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_logic(self):
|
||||
"""Test the logic for processing ExtractedObject"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="integer", size=4)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values=[{"id": "123", "value": "456"}],
|
||||
confidence=0.9,
|
||||
source_span="test source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
|
||||
|
||||
# Verify insert was executed (keyspace normal, table with o_ prefix)
|
||||
processor.session.execute.assert_called_once()
|
||||
insert_cql = processor.session.execute.call_args[0][0]
|
||||
values = processor.session.execute.call_args[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_test_schema" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
assert values[0] == "test_collection" # collection value
|
||||
assert values[1] == "123" # id value (from values[0])
|
||||
assert values[2] == 456 # converted integer value (from values[0])
|
||||
|
||||
def test_secondary_index_creation(self):
|
||||
"""Test that secondary indexes are created for indexed fields"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
if keyspace not in processor.known_tables:
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create schema with indexed field
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
description="Product catalog",
|
||||
fields=[
|
||||
Field(name="product_id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=30, indexed=True),
|
||||
Field(name="price", type="float", size=8, indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "products", schema)
|
||||
|
||||
# Should have 3 calls: create table + 2 indexes
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check index creation calls (table has o_ prefix, fields don't)
|
||||
calls = processor.session.execute.call_args_list
|
||||
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
|
||||
assert len(index_calls) == 2
|
||||
assert any("o_products_category_idx" in call for call in index_calls)
|
||||
assert any("o_products_price_idx" in call for call in index_calls)
|
||||
|
||||
|
||||
class TestObjectsCassandraStorageBatchLogic:
|
||||
"""Test batch processing logic in Cassandra storage"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing_logic(self):
|
||||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
description="Test batch schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="value", type="integer", size=4)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_schema",
|
||||
values=[
|
||||
{"id": "001", "name": "First", "value": "100"},
|
||||
{"id": "002", "name": "Second", "value": "200"},
|
||||
{"id": "003", "name": "Third", "value": "300"}
|
||||
],
|
||||
confidence=0.95,
|
||||
source_span="batch source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
# Process batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured once
|
||||
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
|
||||
|
||||
# Verify 3 separate insert calls (one per batch item)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check each insert call
|
||||
calls = processor.session.execute.call_args_list
|
||||
for i, call in enumerate(calls):
|
||||
insert_cql = call[0][0]
|
||||
values = call[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
|
||||
# Check values for each batch item
|
||||
assert values[0] == "batch_collection" # collection
|
||||
assert values[1] == f"00{i+1}" # id from batch item i
|
||||
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
|
||||
assert values[3] == (i+1) * 100 # converted integer value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing_logic(self):
|
||||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create empty batch object
|
||||
empty_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="empty-001",
|
||||
user="test_user",
|
||||
collection="empty_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="empty_schema",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="empty source"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
# Process empty batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once()
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_item_batch_processing_logic(self):
|
||||
"""Test processing of single-item batch (backward compatibility)"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"single_schema": RowSchema(
|
||||
name="single_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="data", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create single-item batch object (backward compatibility case)
|
||||
single_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="single-001",
|
||||
user="test_user",
|
||||
collection="single_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="single_schema",
|
||||
values=[{"id": "single-1", "data": "single data"}], # Array with one item
|
||||
confidence=0.8,
|
||||
source_span="single source"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = single_batch_obj
|
||||
|
||||
# Process single-item batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once()
|
||||
|
||||
# Verify exactly one insert call
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
insert_cql = processor.session.execute.call_args[0][0]
|
||||
values = processor.session.execute.call_args[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_single_schema" in insert_cql
|
||||
assert values[0] == "single_collection" # collection
|
||||
assert values[1] == "single-1" # id value
|
||||
assert values[2] == "single data" # data value
|
||||
|
||||
def test_batch_value_conversion_logic(self):
|
||||
"""Test value conversion works correctly for batch items"""
|
||||
processor = MagicMock()
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
|
||||
# Test various conversion scenarios that would occur in batch processing
|
||||
test_cases = [
|
||||
# Integer conversions for batch items
|
||||
("123", "integer", 123),
|
||||
("456", "integer", 456),
|
||||
("789", "integer", 789),
|
||||
# Float conversions for batch items
|
||||
("12.5", "float", 12.5),
|
||||
("34.7", "float", 34.7),
|
||||
# Boolean conversions for batch items
|
||||
("true", "boolean", True),
|
||||
("false", "boolean", False),
|
||||
("1", "boolean", True),
|
||||
("0", "boolean", False),
|
||||
# String conversions for batch items
|
||||
(123, "string", "123"),
|
||||
(45.6, "string", "45.6"),
|
||||
]
|
||||
|
||||
for input_val, field_type, expected_output in test_cases:
|
||||
result = processor.convert_value(input_val, field_type)
|
||||
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"
|
||||
435
tests/unit/test_storage/test_row_embeddings_qdrant_storage.py
Normal file
435
tests/unit/test_storage/test_row_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,435 @@
|
|||
"""
|
||||
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
|
||||
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant row embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_processor_initialization_basic(self, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
mock_qdrant_client.assert_called_once_with(
|
||||
url='http://localhost:6333', api_key='test-api-key'
|
||||
)
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
mock_qdrant_client.assert_called_once_with(
|
||||
url='http://localhost:6333', api_key=None
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_sanitize_name(self, mock_qdrant_client):
|
||||
"""Test name sanitization for Qdrant collections"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Test basic sanitization
|
||||
assert processor.sanitize_name("simple") == "simple"
|
||||
assert processor.sanitize_name("with-dash") == "with_dash"
|
||||
assert processor.sanitize_name("with.dot") == "with_dot"
|
||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||
|
||||
# Test numeric prefix handling
|
||||
assert processor.sanitize_name("123start") == "r_123start"
|
||||
assert processor.sanitize_name("_underscore") == "r__underscore"
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_get_collection_name(self, mock_qdrant_client):
|
||||
"""Test Qdrant collection name generation"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
collection_name = processor.get_collection_name(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="customer_data",
|
||||
dimension=384
|
||||
)
|
||||
|
||||
assert collection_name == "rows_test_user_test_collection_customer_data_384"
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
|
||||
"""Test that ensure_collection creates a new collection when needed"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("test_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify the collection is cached
|
||||
assert "test_collection" in processor.created_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
||||
"""Test that ensure_collection skips creation when collection exists"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("existing_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_uses_cache(self, mock_qdrant_client):
|
||||
"""Test that ensure_collection uses cache for previously created collections"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add("cached_collection")
|
||||
|
||||
processor.ensure_collection("cached_collection", 384)
|
||||
|
||||
# Should not check or create - just return
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||
async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test processing basic row embeddings message"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = 'test-uuid-123'
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
# Create embeddings message
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='customer_id',
|
||||
index_value=['CUST001'],
|
||||
text='CUST001',
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='customers',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
# Mock message wrapper
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['index_name'] == 'customer_id'
|
||||
assert point.payload['index_value'] == ['CUST001']
|
||||
assert point.payload['text'] == 'CUST001'
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test processing embeddings with multiple vectors"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = 'test-uuid'
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
# Embedding with multiple vectors
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='name',
|
||||
index_value=['John Doe'],
|
||||
text='John Doe',
|
||||
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='people',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
|
||||
"""Test that embeddings with no vectors are skipped"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
# Embedding with no vectors
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='id',
|
||||
index_value=['123'],
|
||||
text='123',
|
||||
vectors=[] # Empty vectors
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='items',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should not call upsert for empty vectors
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
# No collections registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'unknown_user'
|
||||
metadata.collection = 'unknown_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='id',
|
||||
index_value=['123'],
|
||||
text='123',
|
||||
vectors=[[0.1, 0.2]]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='items',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should not call upsert for unknown collection
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection(self, mock_qdrant_client):
|
||||
"""Test deleting all collections for a user/collection"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
|
||||
# Mock collections list
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
|
||||
mock_coll3 = MagicMock()
|
||||
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
|
||||
mock_qdrant_instance.get_collections.return_value = mock_collections
|
||||
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
|
||||
|
||||
await processor.delete_collection('test_user', 'test_collection')
|
||||
|
||||
# Should delete only the matching collections
|
||||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||
|
||||
# Verify the cached collection was removed
|
||||
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||
"""Test deleting collections for a specific schema"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2]
|
||||
mock_qdrant_instance.get_collections.return_value = mock_collections
|
||||
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
await processor.delete_collection_schema(
|
||||
'test_user', 'test_collection', 'customers'
|
||||
)
|
||||
|
||||
# Should only delete the customers schema collection
|
||||
mock_qdrant_instance.delete_collection.assert_called_once()
|
||||
call_args = mock_qdrant_instance.delete_collection.call_args[0]
|
||||
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
474
tests/unit/test_storage/test_rows_cassandra_storage.py
Normal file
474
tests/unit/test_storage/test_rows_cassandra_storage.py
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
"""
|
||||
Unit tests for Cassandra Row Storage Processor (Unified Table Implementation)
|
||||
|
||||
Tests the business logic of the row storage processor including:
|
||||
- Schema configuration handling
|
||||
- Name sanitization
|
||||
- Unified table structure
|
||||
- Index management
|
||||
- Row storage with multi-index support
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
from trustgraph.storage.rows.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
class TestRowsCassandraStorageLogic:
|
||||
"""Test business logic for unified table implementation"""
|
||||
|
||||
def test_sanitize_name(self):
|
||||
"""Test name sanitization for Cassandra compatibility"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test various name patterns
|
||||
assert processor.sanitize_name("simple_name") == "simple_name"
|
||||
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
||||
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
|
||||
assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number"
|
||||
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
|
||||
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
|
||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||
assert processor.sanitize_name("CamelCase") == "camelcase"
|
||||
assert processor.sanitize_name("_underscore_start") == "r__underscore_start"
|
||||
|
||||
def test_get_index_names(self):
|
||||
"""Test extraction of index names from schema"""
|
||||
processor = MagicMock()
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
|
||||
# Schema with primary and indexed fields
|
||||
schema = RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string"), # Not indexed
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
# Should include primary key and indexed fields
|
||||
assert "id" in index_names
|
||||
assert "category" in index_names
|
||||
assert "status" in index_names
|
||||
assert "name" not in index_names # Not indexed
|
||||
assert len(index_names) == 3
|
||||
|
||||
def test_get_index_names_no_indexes(self):
|
||||
"""Test schema with no indexed fields"""
|
||||
processor = MagicMock()
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(
|
||||
name="no_index_schema",
|
||||
fields=[
|
||||
Field(name="data1", type="string"),
|
||||
Field(name="data2", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
assert len(index_names) == 0
|
||||
|
||||
def test_build_index_value(self):
|
||||
"""Test building index values from row data"""
|
||||
processor = MagicMock()
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
|
||||
value_map = {"id": "123", "category": "electronics", "name": "Widget"}
|
||||
|
||||
# Single field index
|
||||
result = processor.build_index_value(value_map, "id")
|
||||
assert result == ["123"]
|
||||
|
||||
result = processor.build_index_value(value_map, "category")
|
||||
assert result == ["electronics"]
|
||||
|
||||
# Missing field returns empty string
|
||||
result = processor.build_index_value(value_map, "missing")
|
||||
assert result == [""]
|
||||
|
||||
def test_build_index_value_composite(self):
|
||||
"""Test building composite index values"""
|
||||
processor = MagicMock()
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
|
||||
value_map = {"region": "us-west", "category": "electronics", "id": "123"}
|
||||
|
||||
# Composite index (comma-separated field names)
|
||||
result = processor.build_index_value(value_map, "region,category")
|
||||
assert result == ["us-west", "electronics"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.registered_partitions = set()
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer data",
|
||||
"fields": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "string",
|
||||
"primary_key": True,
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "category",
|
||||
"type": "string",
|
||||
"indexed": True
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
# Check field properties
|
||||
id_field = schema.fields[0]
|
||||
assert id_field.name == "id"
|
||||
assert id_field.type == "string"
|
||||
assert id_field.primary is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_stores_data_map(self):
|
||||
"""Test that row processing stores data as map<text, text>"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values=[{"id": "123", "value": "test_data"}],
|
||||
confidence=0.9,
|
||||
source_span="test source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify insert was executed
|
||||
processor.session.execute.assert_called()
|
||||
insert_call = processor.session.execute.call_args
|
||||
insert_cql = insert_call[0][0]
|
||||
values = insert_call[0][1]
|
||||
|
||||
# Verify using unified rows table
|
||||
assert "INSERT INTO test_user.rows" in insert_cql
|
||||
|
||||
# Values should be: (collection, schema_name, index_name, index_value, data, source)
|
||||
assert values[0] == "test_collection" # collection
|
||||
assert values[1] == "test_schema" # schema_name
|
||||
assert values[2] == "id" # index_name (primary key field)
|
||||
assert values[3] == ["123"] # index_value as list
|
||||
assert values[4] == {"id": "123", "value": "test_data"} # data map
|
||||
assert values[5] == "" # source
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_multiple_indexes(self):
|
||||
"""Test that row is written once per indexed field"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"multi_index_schema": RowSchema(
|
||||
name="multi_index_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="multi_index_schema",
|
||||
values=[{"id": "123", "category": "electronics", "status": "active"}],
|
||||
confidence=0.9,
|
||||
source_span=""
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 inserts (one per indexed field: id, category, status)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check that different index_names were used
|
||||
index_names_used = set()
|
||||
for call in processor.session.execute.call_args_list:
|
||||
values = call[0][1]
|
||||
index_names_used.add(values[2]) # index_name is 3rd value
|
||||
|
||||
assert index_names_used == {"id", "category", "status"}
|
||||
|
||||
|
||||
class TestRowsCassandraStorageBatchLogic:
|
||||
"""Test batch processing logic for unified table implementation"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing(self):
|
||||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_schema",
|
||||
values=[
|
||||
{"id": "001", "name": "First"},
|
||||
{"id": "002", "name": "Second"},
|
||||
{"id": "003", "name": "Third"}
|
||||
],
|
||||
confidence=0.95,
|
||||
source_span=""
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 inserts (one per row, one index per row since only primary key)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check each insert has different id
|
||||
ids_inserted = set()
|
||||
for call in processor.session.execute.call_args_list:
|
||||
values = call[0][1]
|
||||
ids_inserted.add(tuple(values[3])) # index_value is 4th value
|
||||
|
||||
assert ids_inserted == {("001",), ("002",), ("003",)}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing(self):
|
||||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create empty batch object
|
||||
empty_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="empty-001",
|
||||
user="test_user",
|
||||
collection="empty_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="empty_schema",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span=""
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
|
||||
class TestUnifiedTableStructure:
|
||||
"""Test the unified rows table structure"""
|
||||
|
||||
def test_ensure_tables_creates_unified_structure(self):
|
||||
"""Test that ensure_tables creates the unified rows table"""
|
||||
processor = MagicMock()
|
||||
processor.known_keyspaces = {"test_user"}
|
||||
processor.tables_initialized = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.ensure_keyspace = MagicMock()
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
|
||||
processor.ensure_tables("test_user")
|
||||
|
||||
# Should have 2 calls: create rows table + create row_partitions table
|
||||
assert processor.session.execute.call_count == 2
|
||||
|
||||
# Check rows table creation
|
||||
rows_cql = processor.session.execute.call_args_list[0][0][0]
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql
|
||||
assert "collection text" in rows_cql
|
||||
assert "schema_name text" in rows_cql
|
||||
assert "index_name text" in rows_cql
|
||||
assert "index_value frozen<list<text>>" in rows_cql
|
||||
assert "data map<text, text>" in rows_cql
|
||||
assert "source text" in rows_cql
|
||||
assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql
|
||||
|
||||
# Check row_partitions table creation
|
||||
partitions_cql = processor.session.execute.call_args_list[1][0][0]
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql
|
||||
assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql
|
||||
|
||||
# Verify keyspace added to initialized set
|
||||
assert "test_user" in processor.tables_initialized
|
||||
|
||||
def test_ensure_tables_idempotent(self):
|
||||
"""Test that ensure_tables is idempotent"""
|
||||
processor = MagicMock()
|
||||
processor.tables_initialized = {"test_user"} # Already initialized
|
||||
processor.session = MagicMock()
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
|
||||
processor.ensure_tables("test_user")
|
||||
|
||||
# Should not execute any CQL since already initialized
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
|
||||
class TestPartitionRegistration:
|
||||
"""Test partition registration for tracking what's stored"""
|
||||
|
||||
def test_register_partitions(self):
|
||||
"""Test registering partitions for a collection/schema pair"""
|
||||
processor = MagicMock()
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
|
||||
# Should have 2 inserts (one per index: id, category)
|
||||
assert processor.session.execute.call_count == 2
|
||||
|
||||
# Verify cache was updated
|
||||
assert ("test_collection", "test_schema") in processor.registered_partitions
|
||||
|
||||
def test_register_partitions_idempotent(self):
|
||||
"""Test that partition registration is idempotent"""
|
||||
processor = MagicMock()
|
||||
processor.registered_partitions = {("test_collection", "test_schema")} # Already registered
|
||||
processor.session = MagicMock()
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
|
||||
# Should not execute any CQL since already registered
|
||||
processor.session.execute.assert_not_called()
|
||||
|
|
@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert hasattr(processor, 'client')
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert len(processor.safety_settings) == 4 # 4 safety categories
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
# Verify Google AI Studio client was called with correct API key
|
||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False)
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.client == mock_genai_client
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue