Structured data 2 (#645)

* Structured data refactor - multi-index tables, remove need for manual mods to the Cassandra tables

* Tech spec updated to track implementation
This commit is contained in:
cybermaggedon 2026-02-23 15:56:29 +00:00 committed by GitHub
parent 5ffad92345
commit 1809c1f56d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
87 changed files with 5233 additions and 3235 deletions

View file

@ -1,8 +1,8 @@
"""
Contract tests for Cassandra Object Storage
Contract tests for Cassandra Row Storage
These tests verify the message contracts and schema compatibility
for the objects storage processor.
for the rows storage processor.
"""
import pytest
@ -10,12 +10,12 @@ import json
from pulsar.schema import AvroSchema
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.storage.rows.cassandra.write import Processor
@pytest.mark.contract
class TestObjectsCassandraContracts:
"""Contract tests for Cassandra object storage messages"""
class TestRowsCassandraContracts:
"""Contract tests for Cassandra row storage messages"""
def test_extracted_object_input_contract(self):
"""Test that ExtractedObject schema matches expected input format"""
@ -145,50 +145,6 @@ class TestObjectsCassandraContracts:
assert required_field_keys.issubset(field.keys())
assert set(field.keys()).issubset(required_field_keys | optional_field_keys)
def test_cassandra_type_mapping_contract(self):
"""Test that all supported field types have Cassandra mappings"""
processor = Processor.__new__(Processor)
# All field types that should be supported
supported_types = [
("string", "text"),
("integer", "int"), # or bigint based on size
("float", "float"), # or double based on size
("boolean", "boolean"),
("timestamp", "timestamp"),
("date", "date"),
("time", "time"),
("uuid", "uuid")
]
for field_type, expected_cassandra_type in supported_types:
cassandra_type = processor.get_cassandra_type(field_type)
# For integer and float, the exact type depends on size
if field_type in ["integer", "float"]:
assert cassandra_type in ["int", "bigint", "float", "double"]
else:
assert cassandra_type == expected_cassandra_type
def test_value_conversion_contract(self):
"""Test value conversion for all supported types"""
processor = Processor.__new__(Processor)
# Test conversions maintain data integrity
test_cases = [
# (input_value, field_type, expected_output, expected_type)
("123", "integer", 123, int),
("123.45", "float", 123.45, float),
("true", "boolean", True, bool),
("false", "boolean", False, bool),
("test string", "string", "test string", str),
(None, "string", None, type(None)),
]
for input_val, field_type, expected_val, expected_type in test_cases:
result = processor.convert_value(input_val, field_type)
assert result == expected_val
assert isinstance(result, expected_type) or result is None
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
def test_extracted_object_serialization_contract(self):
"""Test that ExtractedObject can be serialized/deserialized correctly"""
@ -222,43 +178,31 @@ class TestObjectsCassandraContracts:
assert decoded.confidence == original.confidence
assert decoded.source_span == original.source_span
def test_cassandra_table_naming_contract(self):
def test_cassandra_name_sanitization_contract(self):
"""Test Cassandra naming conventions and constraints"""
processor = Processor.__new__(Processor)
# Test table naming (always gets o_ prefix)
table_test_names = [
("simple_name", "o_simple_name"),
("Name-With-Dashes", "o_name_with_dashes"),
("name.with.dots", "o_name_with_dots"),
("123_numbers", "o_123_numbers"),
("special!@#chars", "o_special___chars"), # 3 special chars become 3 underscores
("UPPERCASE", "o_uppercase"),
("CamelCase", "o_camelcase"),
("", "o_"), # Edge case - empty string becomes o_
]
for input_name, expected_name in table_test_names:
result = processor.sanitize_table(input_name)
assert result == expected_name
# Verify result is valid Cassandra identifier (starts with letter)
assert result.startswith('o_')
assert result.replace('o_', '').replace('_', '').isalnum() or result == 'o_'
# Test regular name sanitization (only adds o_ prefix if starts with number)
# Test name sanitization for Cassandra identifiers
# - Non-alphanumeric chars (except underscore) become underscores
# - Names starting with non-letter get 'r_' prefix
# - All names converted to lowercase
name_test_cases = [
("simple_name", "simple_name"),
("Name-With-Dashes", "name_with_dashes"),
("name.with.dots", "name_with_dots"),
("123_numbers", "o_123_numbers"), # Only this gets o_ prefix
("123_numbers", "r_123_numbers"), # Gets r_ prefix (starts with number)
("special!@#chars", "special___chars"), # 3 special chars become 3 underscores
("UPPERCASE", "uppercase"),
("CamelCase", "camelcase"),
("_underscore_start", "r__underscore_start"), # Gets r_ prefix (starts with underscore)
]
for input_name, expected_name in name_test_cases:
result = processor.sanitize_name(input_name)
assert result == expected_name
assert result == expected_name, f"Expected {expected_name} but got {result} for input {input_name}"
# Verify result is valid Cassandra identifier (starts with letter)
if result: # Skip empty string case
assert result[0].isalpha(), f"Result {result} should start with a letter"
def test_primary_key_structure_contract(self):
"""Test that primary key structure follows Cassandra best practices"""
@ -308,8 +252,8 @@ class TestObjectsCassandraContracts:
@pytest.mark.contract
class TestObjectsCassandraContractsBatch:
"""Contract tests for Cassandra object storage batch processing"""
class TestRowsCassandraContractsBatch:
"""Contract tests for Cassandra row storage batch processing"""
def test_extracted_object_batch_input_contract(self):
"""Test that batched ExtractedObject schema matches expected input format"""

View file

@ -1,26 +1,26 @@
"""
Contract tests for Objects GraphQL Query Service
Contract tests for Rows GraphQL Query Service
These tests verify the message contracts and schema compatibility
for the objects GraphQL query processor.
for the rows GraphQL query processor.
"""
import pytest
import json
from pulsar.schema import AvroSchema
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
@pytest.mark.contract
class TestObjectsGraphQLQueryContracts:
class TestRowsGraphQLQueryContracts:
"""Contract tests for GraphQL query service messages"""
def test_objects_query_request_contract(self):
"""Test ObjectsQueryRequest schema structure and required fields"""
def test_rows_query_request_contract(self):
"""Test RowsQueryRequest schema structure and required fields"""
# Create test request with all required fields
test_request = ObjectsQueryRequest(
test_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ customers { id name email } }',
@ -49,10 +49,10 @@ class TestObjectsGraphQLQueryContracts:
assert test_request.variables["status"] == "active"
assert test_request.operation_name == "GetCustomers"
def test_objects_query_request_minimal(self):
"""Test ObjectsQueryRequest with minimal required fields"""
def test_rows_query_request_minimal(self):
"""Test RowsQueryRequest with minimal required fields"""
# Create request with only essential fields
minimal_request = ObjectsQueryRequest(
minimal_request = RowsQueryRequest(
user="user",
collection="collection",
query='{ test }',
@ -91,10 +91,10 @@ class TestObjectsGraphQLQueryContracts:
assert test_error.path == ["customers", "0", "nonexistent"]
assert test_error.extensions["code"] == "FIELD_ERROR"
def test_objects_query_response_success_contract(self):
"""Test ObjectsQueryResponse schema for successful queries"""
def test_rows_query_response_success_contract(self):
"""Test RowsQueryResponse schema for successful queries"""
# Create successful response
success_response = ObjectsQueryResponse(
success_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
errors=[],
@ -119,11 +119,11 @@ class TestObjectsGraphQLQueryContracts:
assert len(parsed_data["customers"]) == 1
assert parsed_data["customers"][0]["id"] == "1"
def test_objects_query_response_error_contract(self):
"""Test ObjectsQueryResponse schema for error cases"""
def test_rows_query_response_error_contract(self):
"""Test RowsQueryResponse schema for error cases"""
# Create GraphQL errors - work around Pulsar Array(Record) validation bug
# by creating a response without the problematic errors array first
error_response = ObjectsQueryResponse(
error_response = RowsQueryResponse(
error=None, # System error is None - these are GraphQL errors
data=None, # No data due to errors
errors=[], # Empty errors array to avoid Pulsar bug
@ -160,14 +160,14 @@ class TestObjectsGraphQLQueryContracts:
assert validation_error.path == ["customers", "email"]
assert validation_error.extensions["details"] == "Invalid email format"
def test_objects_query_response_system_error_contract(self):
"""Test ObjectsQueryResponse schema for system errors"""
def test_rows_query_response_system_error_contract(self):
"""Test RowsQueryResponse schema for system errors"""
from trustgraph.schema import Error
# Create system error response
system_error_response = ObjectsQueryResponse(
system_error_response = RowsQueryResponse(
error=Error(
type="objects-query-error",
type="rows-query-error",
message="Failed to connect to Cassandra cluster"
),
data=None,
@ -177,7 +177,7 @@ class TestObjectsGraphQLQueryContracts:
# Verify system error structure
assert system_error_response.error is not None
assert system_error_response.error.type == "objects-query-error"
assert system_error_response.error.type == "rows-query-error"
assert "Cassandra" in system_error_response.error.message
assert system_error_response.data is None
assert len(system_error_response.errors) == 0
@ -186,7 +186,7 @@ class TestObjectsGraphQLQueryContracts:
def test_request_response_serialization_contract(self):
"""Test that request/response can be serialized/deserialized correctly"""
# Create original request
original_request = ObjectsQueryRequest(
original_request = RowsQueryRequest(
user="serialization_test",
collection="test_data",
query='{ orders(limit: 5) { id total customer { name } } }',
@ -195,7 +195,7 @@ class TestObjectsGraphQLQueryContracts:
)
# Test request serialization using Pulsar schema
request_schema = AvroSchema(ObjectsQueryRequest)
request_schema = AvroSchema(RowsQueryRequest)
# Encode and decode request
encoded_request = request_schema.encode(original_request)
@ -209,7 +209,7 @@ class TestObjectsGraphQLQueryContracts:
assert decoded_request.operation_name == original_request.operation_name
# Create original response - work around Pulsar Array(Record) bug
original_response = ObjectsQueryResponse(
original_response = RowsQueryResponse(
error=None,
data='{"orders": []}',
errors=[], # Empty to avoid Pulsar validation bug
@ -224,7 +224,7 @@ class TestObjectsGraphQLQueryContracts:
)
# Test response serialization
response_schema = AvroSchema(ObjectsQueryResponse)
response_schema = AvroSchema(RowsQueryResponse)
# Encode and decode response
encoded_response = response_schema.encode(original_response)
@ -244,7 +244,7 @@ class TestObjectsGraphQLQueryContracts:
def test_graphql_query_format_contract(self):
"""Test supported GraphQL query formats"""
# Test basic query
basic_query = ObjectsQueryRequest(
basic_query = RowsQueryRequest(
user="test", collection="test", query='{ customers { id } }',
variables={}, operation_name=""
)
@ -253,7 +253,7 @@ class TestObjectsGraphQLQueryContracts:
assert basic_query.query.strip().endswith('}')
# Test query with variables
parameterized_query = ObjectsQueryRequest(
parameterized_query = RowsQueryRequest(
user="test", collection="test",
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
variables={"status": "active", "limit": "10"},
@ -265,7 +265,7 @@ class TestObjectsGraphQLQueryContracts:
assert parameterized_query.operation_name == "GetCustomers"
# Test complex nested query
nested_query = ObjectsQueryRequest(
nested_query = RowsQueryRequest(
user="test", collection="test",
query='''
{
@ -296,7 +296,7 @@ class TestObjectsGraphQLQueryContracts:
# Note: Current schema uses Map(String()) which only supports string values
# This test verifies the current contract, though ideally we'd support all JSON types
variables_test = ObjectsQueryRequest(
variables_test = RowsQueryRequest(
user="test", collection="test", query='{ test }',
variables={
"string_var": "test_value",
@ -319,7 +319,7 @@ class TestObjectsGraphQLQueryContracts:
def test_cassandra_context_fields_contract(self):
"""Test that request contains necessary fields for Cassandra operations"""
# Verify request has fields needed for Cassandra keyspace/table targeting
request = ObjectsQueryRequest(
request = RowsQueryRequest(
user="keyspace_name", # Maps to Cassandra keyspace
collection="partition_collection", # Used in partition key
query='{ objects { id } }',
@ -338,7 +338,7 @@ class TestObjectsGraphQLQueryContracts:
def test_graphql_extensions_contract(self):
"""Test GraphQL extensions field format and usage"""
# Extensions should support query metadata
response_with_extensions = ObjectsQueryResponse(
response_with_extensions = RowsQueryResponse(
error=None,
data='{"test": "data"}',
errors=[],
@ -404,7 +404,7 @@ class TestObjectsGraphQLQueryContracts:
'''
# Request to execute specific operation
multi_op_request = ObjectsQueryRequest(
multi_op_request = RowsQueryRequest(
user="test", collection="test",
query=multi_op_query,
variables={},
@ -417,7 +417,7 @@ class TestObjectsGraphQLQueryContracts:
assert "GetOrders" in multi_op_request.query
# Test single operation (operation_name optional)
single_op_request = ObjectsQueryRequest(
single_op_request = RowsQueryRequest(
user="test", collection="test",
query='{ customers { id } }',
variables={}, operation_name=""

View file

@ -12,7 +12,7 @@ from argparse import ArgumentParser
# Import processors that use Cassandra configuration
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
from trustgraph.storage.knowledge.store import Processor as KgStore
@ -55,8 +55,8 @@ class TestEndToEndConfigurationFlow:
assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3']
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster):
"""Test complete flow from environment variables to Cassandra Cluster connection."""
env_vars = {
@ -73,7 +73,7 @@ class TestEndToEndConfigurationFlow:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
# Trigger Cassandra connection
processor.connect_cassandra()
@ -320,7 +320,7 @@ class TestNoBackwardCompatibilityEndToEnd:
class TestMultipleHostsHandling:
"""Test multiple Cassandra hosts handling end-to-end."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_multiple_hosts_passed_to_cluster(self, mock_cluster):
"""Test that multiple hosts are correctly passed to Cassandra cluster."""
env_vars = {
@ -333,7 +333,7 @@ class TestMultipleHostsHandling:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify all hosts were passed to Cluster
@ -386,8 +386,8 @@ class TestMultipleHostsHandling:
class TestAuthenticationFlow:
"""Test authentication configuration flow end-to-end."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster):
"""Test that authentication is enabled when both username and password are provided."""
env_vars = {
@ -402,7 +402,7 @@ class TestAuthenticationFlow:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Auth provider should be created
@ -416,8 +416,8 @@ class TestAuthenticationFlow:
assert 'auth_provider' in call_args.kwargs
assert call_args.kwargs['auth_provider'] == mock_auth_instance
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster):
"""Test that authentication is not used when credentials are missing."""
env_vars = {
@ -429,7 +429,7 @@ class TestAuthenticationFlow:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Auth provider should not be created
@ -439,11 +439,11 @@ class TestAuthenticationFlow:
call_args = mock_cluster.call_args
assert 'auth_provider' not in call_args.kwargs
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster):
"""Test that authentication is not used when only username is provided."""
processor = ObjectsWriter(
processor = RowsWriter(
taskgroup=MagicMock(),
cassandra_host='partial-auth-host',
cassandra_username='partial-user'

View file

@ -11,7 +11,7 @@ import json
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.objects.processor import Processor
from trustgraph.extract.kg.rows.processor import Processor
from trustgraph.schema import (
Chunk, ExtractedObject, Metadata, RowSchema, Field,
PromptRequest, PromptResponse
@ -220,7 +220,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
@ -288,7 +288,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
@ -353,7 +353,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
@ -447,7 +447,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Mock flow with failing prompt service
@ -496,7 +496,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration

View file

@ -1,608 +0,0 @@
"""
Integration tests for Cassandra Object Storage
These tests verify the end-to-end functionality of storing ExtractedObjects
in Cassandra, including table creation, data insertion, and error handling.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
import uuid
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
@pytest.mark.integration
class TestObjectsCassandraIntegration:
"""Integration tests for Cassandra object storage"""
@pytest.fixture
def mock_cassandra_session(self):
"""Mock Cassandra session for integration tests"""
session = MagicMock()
# Track if keyspaces have been created
created_keyspaces = set()
# Mock the execute method to return a valid result for keyspace checks
def execute_mock(query, *args, **kwargs):
result = MagicMock()
query_str = str(query)
# Track keyspace creation
if "CREATE KEYSPACE" in query_str:
# Extract keyspace name from query
import re
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
if match:
created_keyspaces.add(match.group(1))
# For keyspace existence checks
if "system_schema.keyspaces" in query_str:
# Check if this keyspace was created
if args and args[0] in created_keyspaces:
result.one.return_value = MagicMock() # Exists
else:
result.one.return_value = None # Doesn't exist
else:
result.one.return_value = None
return result
session.execute = MagicMock(side_effect=execute_mock)
return session
@pytest.fixture
def mock_cassandra_cluster(self, mock_cassandra_session):
"""Mock Cassandra cluster"""
cluster = MagicMock()
cluster.connect.return_value = mock_cassandra_session
cluster.shutdown = MagicMock()
return cluster
@pytest.fixture
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
"""Create processor with mocked Cassandra dependencies"""
processor = MagicMock()
processor.graph_host = "localhost"
processor.graph_username = None
processor.graph_password = None
processor.config_key = "schema"
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.cluster = None
processor.session = None
# Bind actual methods
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.create_collection = Processor.create_collection.__get__(processor, Processor)
return processor, mock_cassandra_cluster, mock_cassandra_session
@pytest.mark.asyncio
async def test_end_to_end_object_storage(self, processor_with_mocks):
"""Test complete flow from schema config to object storage"""
processor, mock_cluster, mock_session = processor_with_mocks
# Mock Cluster creation
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Step 1: Configure schema
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer information",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True},
{"name": "age", "type": "integer"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
# Step 1.5: Create the collection first (simulate tg-set-collection)
await processor.create_collection("test_user", "import_2024", {})
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
metadata=Metadata(
id="doc-001",
user="test_user",
collection="import_2024",
metadata=[]
),
schema_name="customer_records",
values=[{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"age": "30"
}],
confidence=0.95,
source_span="Customer: John Doe..."
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify Cassandra interactions
assert mock_cluster.connect.called
# Verify keyspace creation
keyspace_calls = [call for call in mock_session.execute.call_args_list
if "CREATE KEYSPACE" in str(call)]
assert len(keyspace_calls) == 1
assert "test_user" in str(keyspace_calls[0])
# Verify table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "o_customer_records" in str(table_calls[0]) # Table gets o_ prefix
assert "collection text" in str(table_calls[0])
assert "PRIMARY KEY ((collection, customer_id))" in str(table_calls[0])
# Verify index creation
index_calls = [call for call in mock_session.execute.call_args_list
if "CREATE INDEX" in str(call)]
assert len(index_calls) == 1
assert "email" in str(index_calls[0])
# Verify data insertion
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 1
insert_call = insert_calls[0]
assert "test_user.o_customer_records" in str(insert_call) # Table gets o_ prefix
# Check inserted values
values = insert_call[0][1]
assert "import_2024" in values # collection
assert "CUST001" in values # customer_id
assert "John Doe" in values # name
assert "john@example.com" in values # email
assert 30 in values # age (converted to int)
@pytest.mark.asyncio
async def test_multi_schema_handling(self, processor_with_mocks):
"""Test handling multiple schemas and objects"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure multiple schemas
config = {
"schema": {
"products": json.dumps({
"name": "products",
"fields": [
{"name": "product_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string"},
{"name": "price", "type": "float"}
]
}),
"orders": json.dumps({
"name": "orders",
"fields": [
{"name": "order_id", "type": "string", "primary_key": True},
{"name": "customer_id", "type": "string"},
{"name": "total", "type": "float"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
# Create collections first
await processor.create_collection("shop", "catalog", {})
await processor.create_collection("shop", "sales", {})
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9,
source_span="Product..."
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85,
source_span="Order..."
)
# Process both objects
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Verify separate tables were created
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 2
assert any("o_products" in str(call) for call in table_calls) # Tables get o_ prefix
assert any("o_orders" in str(call) for call in table_calls) # Tables get o_ prefix
@pytest.mark.asyncio
async def test_missing_required_fields(self, processor_with_mocks):
"""Test handling of objects with missing required fields"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema with required field
processor.schemas["test_schema"] = RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True, required=True),
Field(name="required_field", type="string", size=100, required=True)
]
)
# Create collection first
await processor.create_collection("test", "test", {})
# Create object missing required field
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="test_schema",
values=[{"id": "123"}], # missing required_field
confidence=0.8,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
# Should still process (Cassandra doesn't enforce NOT NULL)
await processor.on_object(msg, None, None)
# Verify insert was attempted
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 1
@pytest.mark.asyncio
async def test_schema_without_primary_key(self, processor_with_mocks):
"""Test handling schemas without defined primary keys"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema without primary key
processor.schemas["events"] = RowSchema(
name="events",
description="Event log",
fields=[
Field(name="event_type", type="string", size=50),
Field(name="timestamp", type="timestamp", size=0)
]
)
# Create collection first
await processor.create_collection("logger", "app_events", {})
# Process object
test_obj = ExtractedObject(
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
schema_name="events",
values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}],
confidence=1.0,
source_span="Event"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify synthetic_id was added
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "synthetic_id uuid" in str(table_calls[0])
# Verify insert includes UUID
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 1
values = insert_calls[0][0][1]
# Check that a UUID was generated (will be in values list)
uuid_found = any(isinstance(v, uuid.UUID) for v in values)
assert uuid_found
@pytest.mark.asyncio
async def test_authentication_handling(self, processor_with_mocks):
"""Test Cassandra authentication"""
processor, mock_cluster, mock_session = processor_with_mocks
processor.cassandra_username = "cassandra_user"
processor.cassandra_password = "cassandra_pass"
with patch('trustgraph.storage.objects.cassandra.write.Cluster') as mock_cluster_class:
with patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') as mock_auth:
mock_cluster_class.return_value = mock_cluster
# Trigger connection
processor.connect_cassandra()
# Verify authentication was configured
mock_auth.assert_called_once_with(
username="cassandra_user",
password="cassandra_pass"
)
mock_cluster_class.assert_called_once()
call_kwargs = mock_cluster_class.call_args[1]
assert 'auth_provider' in call_kwargs
@pytest.mark.asyncio
async def test_error_handling_during_insert(self, processor_with_mocks):
"""Test error handling when insertion fails"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["test"] = RowSchema(
name="test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Make insert fail
mock_result = MagicMock()
mock_result.one.return_value = MagicMock() # Keyspace exists
mock_session.execute.side_effect = [
mock_result, # keyspace existence check succeeds
None, # table creation succeeds
Exception("Connection timeout") # insert fails
]
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="test",
values=[{"id": "123"}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
# Should raise the exception
with pytest.raises(Exception, match="Connection timeout"):
await processor.on_object(msg, None, None)
@pytest.mark.asyncio
async def test_collection_partitioning(self, processor_with_mocks):
"""Test that objects are properly partitioned by collection"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["data"] = RowSchema(
name="data",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Process objects from different collections
collections = ["import_jan", "import_feb", "import_mar"]
# Create all collections first
for coll in collections:
await processor.create_collection("analytics", coll, {})
for coll in collections:
obj = ExtractedObject(
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
schema_name="data",
values=[{"id": f"ID-{coll}"}],
confidence=0.9,
source_span="Data"
)
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Verify all inserts include collection in values
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 3
# Check each insert has the correct collection
for i, call in enumerate(insert_calls):
values = call[0][1]
assert collections[i] in values
@pytest.mark.asyncio
async def test_batch_object_processing(self, processor_with_mocks):
"""Test processing objects with batched values"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema
config = {
"schema": {
"batch_customers": json.dumps({
"name": "batch_customers",
"description": "Customer batch data",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True}
]
})
}
}
await processor.on_schema_config(config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_import",
metadata=[]
),
schema_name="batch_customers",
values=[
{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com"
},
{
"customer_id": "CUST002",
"name": "Jane Smith",
"email": "jane@example.com"
},
{
"customer_id": "CUST003",
"name": "Bob Johnson",
"email": "bob@example.com"
}
],
confidence=0.92,
source_span="Multiple customers extracted from document"
)
# Create collection first
await processor.create_collection("test_user", "batch_import", {})
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Verify table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "o_batch_customers" in str(table_calls[0])
# Verify multiple inserts for batch values
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
# Should have 3 separate inserts for the 3 objects in the batch
assert len(insert_calls) == 3
# Check each insert has correct data
for i, call in enumerate(insert_calls):
values = call[0][1]
assert "batch_import" in values # collection
assert f"CUST00{i+1}" in values # customer_id
if i == 0:
assert "John Doe" in values
assert "john@example.com" in values
elif i == 1:
assert "Jane Smith" in values
assert "jane@example.com" in values
elif i == 2:
assert "Bob Johnson" in values
assert "bob@example.com" in values
@pytest.mark.asyncio
async def test_empty_batch_processing(self, processor_with_mocks):
"""Test processing objects with empty values array"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Create collection first
await processor.create_collection("test", "empty", {})
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
source_span="No objects found"
)
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
# Should still create table
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
# Should not create any insert statements for empty batch
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 0
@pytest.mark.asyncio
async def test_mixed_single_and_batch_objects(self, processor_with_mocks):
"""Test processing mix of single and batch objects"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["mixed_test"] = RowSchema(
name="mixed_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="data", type="string", size=100)
]
)
# Create collection first
await processor.create_collection("test", "mixed", {})
# Single object (backward compatibility)
single_obj = ExtractedObject(
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
schema_name="mixed_test",
values=[{"id": "single-1", "data": "single data"}], # Array with single item
confidence=0.9,
source_span="Single object"
)
# Batch object
batch_obj = ExtractedObject(
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
schema_name="mixed_test",
values=[
{"id": "batch-1", "data": "batch data 1"},
{"id": "batch-2", "data": "batch data 2"}
],
confidence=0.85,
source_span="Batch objects"
)
# Process both
for obj in [single_obj, batch_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Should have 3 total inserts (1 + 2)
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 3

View file

@ -0,0 +1,492 @@
"""
Integration tests for Cassandra Row Storage (Unified Table Implementation)
These tests verify the end-to-end functionality of storing ExtractedObjects
in the unified Cassandra rows table, including table creation, data insertion,
and error handling.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
@pytest.mark.integration
class TestRowsCassandraIntegration:
"""Integration tests for Cassandra row storage with unified table"""
@pytest.fixture
def mock_cassandra_session(self):
"""Mock Cassandra session for integration tests"""
session = MagicMock()
# Track if keyspaces have been created
created_keyspaces = set()
# Mock the execute method to return a valid result for keyspace checks
def execute_mock(query, *args, **kwargs):
result = MagicMock()
query_str = str(query)
# Track keyspace creation
if "CREATE KEYSPACE" in query_str:
import re
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
if match:
created_keyspaces.add(match.group(1))
# For keyspace existence checks
if "system_schema.keyspaces" in query_str:
if args and args[0] in created_keyspaces:
result.one.return_value = MagicMock() # Exists
else:
result.one.return_value = None # Doesn't exist
else:
result.one.return_value = None
return result
session.execute = MagicMock(side_effect=execute_mock)
return session
@pytest.fixture
def mock_cassandra_cluster(self, mock_cassandra_session):
"""Mock Cassandra cluster"""
cluster = MagicMock()
cluster.connect.return_value = mock_cassandra_session
cluster.shutdown = MagicMock()
return cluster
@pytest.fixture
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
"""Create processor with mocked Cassandra dependencies"""
processor = MagicMock()
processor.cassandra_host = ["localhost"]
processor.cassandra_username = None
processor.cassandra_password = None
processor.config_key = "schema"
processor.schemas = {}
processor.known_keyspaces = set()
processor.tables_initialized = set()
processor.registered_partitions = set()
processor.cluster = None
processor.session = None
# Bind actual methods from the new unified table implementation
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.collection_exists = MagicMock(return_value=True)
return processor, mock_cassandra_cluster, mock_cassandra_session
@pytest.mark.asyncio
async def test_end_to_end_object_storage(self, processor_with_mocks):
"""Test complete flow from schema config to object storage"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Step 1: Configure schema
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer information",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True},
{"name": "age", "type": "integer"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
metadata=Metadata(
id="doc-001",
user="test_user",
collection="import_2024",
metadata=[]
),
schema_name="customer_records",
values=[{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"age": "30"
}],
confidence=0.95,
source_span="Customer: John Doe..."
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify Cassandra interactions
assert mock_cluster.connect.called
# Verify keyspace creation
keyspace_calls = [call for call in mock_session.execute.call_args_list
if "CREATE KEYSPACE" in str(call)]
assert len(keyspace_calls) == 1
assert "test_user" in str(keyspace_calls[0])
# Verify unified table creation (rows table, not per-schema table)
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 2 # rows table + row_partitions table
assert any("rows" in str(call) for call in table_calls)
assert any("row_partitions" in str(call) for call in table_calls)
# Verify the rows table has correct structure
rows_table_call = [call for call in table_calls if ".rows" in str(call)][0]
assert "collection text" in str(rows_table_call)
assert "schema_name text" in str(rows_table_call)
assert "index_name text" in str(rows_table_call)
assert "data map<text, text>" in str(rows_table_call)
# Verify data insertion into unified table
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
# Should have 2 data inserts: one for customer_id (primary), one for email (indexed)
assert len(rows_insert_calls) == 2
@pytest.mark.asyncio
async def test_multi_schema_handling(self, processor_with_mocks):
"""Test handling multiple schemas stored in unified table"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Configure multiple schemas
config = {
"schema": {
"products": json.dumps({
"name": "products",
"fields": [
{"name": "product_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string"},
{"name": "price", "type": "float"}
]
}),
"orders": json.dumps({
"name": "orders",
"fields": [
{"name": "order_id", "type": "string", "primary_key": True},
{"name": "customer_id", "type": "string"},
{"name": "total", "type": "float"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9,
source_span="Product..."
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85,
source_span="Order..."
)
# Process both objects
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# All data goes into the same unified rows table
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
# Should only create 2 tables: rows + row_partitions (not per-schema tables)
assert len(table_calls) == 2
# Verify data inserts go to unified rows table
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) > 0
for call in rows_insert_calls:
assert ".rows" in str(call)
@pytest.mark.asyncio
async def test_multi_index_storage(self, processor_with_mocks):
"""Test that rows are stored with multiple indexes"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Schema with multiple indexed fields
processor.schemas["indexed_data"] = RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="indexed_data",
values=[{
"id": "123",
"category": "electronics",
"status": "active",
"description": "A product"
}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Should have 3 data inserts (one per indexed field: id, category, status)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) == 3
# Verify different index names were used
index_names = set()
for call in rows_insert_calls:
values = call[0][1]
index_names.add(values[2]) # index_name is 3rd parameter
assert index_names == {"id", "category", "status"}
@pytest.mark.asyncio
async def test_authentication_handling(self, processor_with_mocks):
"""Test Cassandra authentication"""
processor, mock_cluster, mock_session = processor_with_mocks
processor.cassandra_username = "cassandra_user"
processor.cassandra_password = "cassandra_pass"
with patch('trustgraph.storage.rows.cassandra.write.Cluster') as mock_cluster_class:
with patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') as mock_auth:
mock_cluster_class.return_value = mock_cluster
# Trigger connection
processor.connect_cassandra()
# Verify authentication was configured
mock_auth.assert_called_once_with(
username="cassandra_user",
password="cassandra_pass"
)
mock_cluster_class.assert_called_once()
call_kwargs = mock_cluster_class.call_args[1]
assert 'auth_provider' in call_kwargs
@pytest.mark.asyncio
async def test_batch_object_processing(self, processor_with_mocks):
"""Test processing objects with batched values"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema
config = {
"schema": {
"batch_customers": json.dumps({
"name": "batch_customers",
"description": "Customer batch data",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True}
]
})
}
}
await processor.on_schema_config(config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_import",
metadata=[]
),
schema_name="batch_customers",
values=[
{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com"
},
{
"customer_id": "CUST002",
"name": "Jane Smith",
"email": "jane@example.com"
},
{
"customer_id": "CUST003",
"name": "Bob Johnson",
"email": "bob@example.com"
}
],
confidence=0.92,
source_span="Multiple customers extracted from document"
)
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Verify unified table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 2 # rows + row_partitions
# Each row in batch gets 2 data inserts (customer_id primary + email indexed)
# 3 rows * 2 indexes = 6 data inserts
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) == 6
@pytest.mark.asyncio
async def test_empty_batch_processing(self, processor_with_mocks):
"""Test processing objects with empty values array"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
source_span="No objects found"
)
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
# Should not create any data insert statements for empty batch
# (partition registration may still happen)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) == 0
@pytest.mark.asyncio
async def test_data_stored_as_map(self, processor_with_mocks):
"""Test that data is stored as map<text, text>"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["map_test"] = RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="map_test",
values=[{"id": "123", "name": "Test Item", "count": "42"}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify insert uses map for data
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) >= 1
# Check that data is passed as a dict (will be map in Cassandra)
insert_call = rows_insert_calls[0]
values = insert_call[0][1]
# Values are: (collection, schema_name, index_name, index_value, data, source)
# values[4] should be the data map
data_map = values[4]
assert isinstance(data_map, dict)
assert data_map["id"] == "123"
assert data_map["name"] == "Test Item"
assert data_map["count"] == "42"
@pytest.mark.asyncio
async def test_partition_registration(self, processor_with_mocks):
"""Test that partitions are registered for efficient querying"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["partition_test"] = RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]),
schema_name="partition_test",
values=[{"id": "123", "category": "test"}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify partition registration
partition_inserts = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and "row_partitions" in str(call)]
# Should register partitions for each index (id, category)
assert len(partition_inserts) == 2
# Verify cache was updated
assert ("my_collection", "partition_test") in processor.registered_partitions

View file

@ -1,5 +1,5 @@
"""
Integration tests for Objects GraphQL Query Service
Integration tests for Rows GraphQL Query Service
These tests verify end-to-end functionality including:
- Real Cassandra database operations
@ -24,8 +24,8 @@ except Exception:
DOCKER_AVAILABLE = False
CassandraContainer = None
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata
@ -390,7 +390,7 @@ class TestObjectsGraphQLQueryIntegration:
processor.connect_cassandra()
# Create mock message
request = ObjectsQueryRequest(
request = RowsQueryRequest(
user="msg_test_user",
collection="msg_test_collection",
query='{ customer_objects { customer_id name } }',
@ -415,7 +415,7 @@ class TestObjectsGraphQLQueryIntegration:
# Verify response structure
sent_response = mock_response_producer.send.call_args[0][0]
assert isinstance(sent_response, ObjectsQueryResponse)
assert isinstance(sent_response, RowsQueryResponse)
# Should have no system error (even if no data)
assert sent_response.error is None

View file

@ -2,7 +2,7 @@
Integration tests for Structured Query Service
These tests verify the end-to-end functionality of the structured query service,
testing orchestration between nlp-query and objects-query services.
testing orchestration between nlp-query and rows-query services.
Following the TEST_STRATEGY.md approach for integration testing.
"""
@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.schema import (
StructuredQueryRequest, StructuredQueryResponse,
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
ObjectsQueryRequest, ObjectsQueryResponse,
RowsQueryRequest, RowsQueryResponse,
Error, GraphQLError
)
from trustgraph.retrieval.structured_query.service import Processor
@ -81,7 +81,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock Objects Query Service Response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}',
errors=None,
@ -99,7 +99,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -121,7 +121,7 @@ class TestStructuredQueryServiceIntegration:
# Verify Objects service call
mock_objects_client.request.assert_called_once()
objects_call_args = mock_objects_client.request.call_args[0][0]
assert isinstance(objects_call_args, ObjectsQueryRequest)
assert isinstance(objects_call_args, RowsQueryRequest)
assert "customers" in objects_call_args.query
assert "orders" in objects_call_args.query
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
@ -220,7 +220,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock Objects service failure
objects_error_response = ObjectsQueryResponse(
objects_error_response = RowsQueryResponse(
error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"),
data=None,
errors=None,
@ -237,7 +237,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -255,7 +255,7 @@ class TestStructuredQueryServiceIntegration:
assert response.error is not None
assert response.error.type == "structured-query-error"
assert "Objects query service error" in response.error.message
assert "Rows query service error" in response.error.message
assert "nonexistent_table" in response.error.message
@pytest.mark.asyncio
@ -298,7 +298,7 @@ class TestStructuredQueryServiceIntegration:
)
]
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None, # No data when validation fails
errors=validation_errors,
@ -315,7 +315,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -422,7 +422,7 @@ class TestStructuredQueryServiceIntegration:
]
}
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=json.dumps(complex_data),
errors=None,
@ -443,7 +443,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -503,7 +503,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock empty Objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": []}', # Empty result set
errors=None,
@ -520,7 +520,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -577,7 +577,7 @@ class TestStructuredQueryServiceIntegration:
confidence=0.9
)
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=f'{{"test_{i}": [{{"id": "{i}"}}]}}',
errors=None,
@ -599,7 +599,7 @@ class TestStructuredQueryServiceIntegration:
if service_name == "nlp-query-request":
service_call_count += 1
return nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
service_call_count += 1
return objects_client
elif service_name == "response":
@ -700,7 +700,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock Objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}',
errors=None,
@ -717,7 +717,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response

View file

@ -0,0 +1,380 @@
"""
Unit tests for trustgraph.embeddings.row_embeddings.embeddings
Tests the Stage 1 processor that computes embeddings for row index fields.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
"""Test row embeddings processor functionality"""
async def test_processor_initialization(self):
"""Test basic processor initialization"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-row-embeddings'
}
processor = Processor(**config)
assert hasattr(processor, 'schemas')
assert processor.schemas == {}
assert processor.batch_size == 10 # default
async def test_processor_initialization_with_custom_batch_size(self):
"""Test processor initialization with custom batch size"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-row-embeddings',
'batch_size': 25
}
processor = Processor(**config)
assert processor.batch_size == 25
async def test_get_index_names_single_index(self):
"""Test getting index names with single indexed field"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import RowSchema, Field
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
schema = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
Field(name='email', type='text', indexed=False),
]
)
index_names = processor.get_index_names(schema)
# Should include primary key and indexed field
assert 'id' in index_names
assert 'name' in index_names
assert 'email' not in index_names
async def test_get_index_names_no_indexes(self):
"""Test getting index names when no fields are indexed"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import RowSchema, Field
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
schema = RowSchema(
name='logs',
description='Log records',
fields=[
Field(name='timestamp', type='text'),
Field(name='message', type='text'),
]
)
index_names = processor.get_index_names(schema)
assert index_names == []
async def test_build_index_value_single_field(self):
"""Test building index value for single field"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'id': 'CUST001',
'name': 'John Doe',
'email': 'john@example.com'
}
result = processor.build_index_value(value_map, 'name')
assert result == ['John Doe']
async def test_build_index_value_composite_index(self):
"""Test building index value for composite index"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'first_name': 'John',
'last_name': 'Doe',
'city': 'New York'
}
result = processor.build_index_value(value_map, 'first_name, last_name')
assert result == ['John', 'Doe']
async def test_build_index_value_missing_field(self):
"""Test building index value when field is missing"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'name': 'John Doe'
}
result = processor.build_index_value(value_map, 'missing_field')
assert result == ['']
async def test_build_text_for_embedding_single_value(self):
"""Test building text representation for single value"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
result = processor.build_text_for_embedding(['John Doe'])
assert result == 'John Doe'
async def test_build_text_for_embedding_multiple_values(self):
"""Test building text representation for multiple values"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
result = processor.build_text_for_embedding(['John', 'Doe', 'NYC'])
assert result == 'John Doe NYC'
async def test_on_schema_config_loads_schemas(self):
"""Test that schema configuration is loaded correctly"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
import json
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
schema_def = {
'name': 'customers',
'description': 'Customer records',
'fields': [
{'name': 'id', 'type': 'text', 'primary_key': True},
{'name': 'name', 'type': 'text', 'indexed': True},
{'name': 'email', 'type': 'text'}
]
}
config_data = {
'schema': {
'customers': json.dumps(schema_def)
}
}
await processor.on_schema_config(config_data, 1)
assert 'customers' in processor.schemas
assert processor.schemas['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
config_data = {
'other_type': {}
}
await processor.on_schema_config(config_data, 1)
assert processor.schemas == {}
async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='customers',
values=[{'id': '123', 'name': 'Test'}]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
mock_flow = MagicMock()
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Flow should not be called for output
mock_flow.assert_not_called()
async def test_on_message_drops_unknown_schema(self):
"""Test that messages for unknown schemas are dropped"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# No schemas registered
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='unknown_schema',
values=[{'id': '123', 'name': 'Test'}]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
mock_flow = MagicMock()
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Flow should not be called for output
mock_flow.assert_not_called()
async def test_on_message_processes_embeddings(self):
"""Test processing a message and computing embeddings"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject, RowSchema, Field
import json
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# Set up schema
processor.schemas['customers'] = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='customers',
values=[
{'id': 'CUST001', 'name': 'John Doe'},
{'id': 'CUST002', 'name': 'Jane Smith'}
]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
# Mock the flow
mock_embeddings_request = AsyncMock()
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
mock_output = AsyncMock()
def flow_factory(name):
if name == 'embeddings-request':
return mock_embeddings_request
elif name == 'output':
return mock_output
return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory)
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Should have called embed for each unique text
# 4 values: CUST001, John Doe, CUST002, Jane Smith
assert mock_embeddings_request.embed.call_count == 4
# Should have sent output
mock_output.send.assert_called()
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -1,7 +1,7 @@
"""
Unit tests for objects import dispatcher.
Unit tests for rows import dispatcher.
Tests the business logic of objects import dispatcher
Tests the business logic of rows import dispatcher
while mocking the Publisher and websocket components.
"""
@ -11,7 +11,7 @@ import asyncio
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from aiohttp import web
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
from trustgraph.gateway.dispatch.rows_import import RowsImport
from trustgraph.schema import Metadata, ExtractedObject
@ -92,16 +92,16 @@ def minimal_objects_message():
}
class TestObjectsImportInitialization:
"""Test ObjectsImport initialization."""
class TestRowsImportInitialization:
"""Test RowsImport initialization."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport creates Publisher with correct parameters."""
"""Test that RowsImport creates Publisher with correct parameters."""
mock_publisher_instance = Mock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -116,28 +116,28 @@ class TestObjectsImportInitialization:
)
# Verify instance variables are set correctly
assert objects_import.ws == mock_websocket
assert objects_import.running == mock_running
assert objects_import.publisher == mock_publisher_instance
assert rows_import.ws == mock_websocket
assert rows_import.running == mock_running
assert rows_import.publisher == mock_publisher_instance
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport stores all required references."""
objects_import = ObjectsImport(
"""Test that RowsImport stores all required references."""
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="objects-queue"
)
assert objects_import.ws is mock_websocket
assert objects_import.running is mock_running
assert rows_import.ws is mock_websocket
assert rows_import.running is mock_running
class TestObjectsImportLifecycle:
"""Test ObjectsImport lifecycle methods."""
class TestRowsImportLifecycle:
"""Test RowsImport lifecycle methods."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that start() calls publisher.start()."""
@ -145,18 +145,18 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.start = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.start()
await rows_import.start()
mock_publisher_instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that destroy() properly stops publisher and closes websocket."""
@ -164,21 +164,21 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.destroy()
await rows_import.destroy()
# Verify sequence of operations
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
"""Test that destroy() handles None websocket gracefully."""
@ -186,7 +186,7 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=None, # None websocket
running=mock_running,
backend=mock_backend,
@ -194,16 +194,16 @@ class TestObjectsImportLifecycle:
)
# Should not raise exception
await objects_import.destroy()
await rows_import.destroy()
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
class TestObjectsImportMessageProcessing:
"""Test ObjectsImport message processing."""
class TestRowsImportMessageProcessing:
"""Test RowsImport message processing."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() processes complete message correctly."""
@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = sample_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing:
assert sent_object.metadata.collection == "testcollection"
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
"""Test that receive() handles message with minimal required fields."""
@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = minimal_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing:
assert sent_object.source_span == "" # Default value
assert len(sent_object.metadata.metadata) == 0 # Default empty list
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() uses appropriate default values for optional fields."""
@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = message_data
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Get the sent object and verify defaults
sent_object = mock_publisher_instance.send.call_args[0][1]
@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing:
assert sent_object.source_span == ""
class TestObjectsImportRunMethod:
"""Test ObjectsImport run method."""
class TestRowsImportRunMethod:
"""Test RowsImport run method."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that run() loops while running.get() returns True."""
@ -331,14 +331,14 @@ class TestObjectsImportRunMethod:
# Set up running state to return True twice, then False
mock_running.get.side_effect = [True, True, False]
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.run()
await rows_import.run()
# Verify sleep was called twice (for the two True iterations)
assert mock_sleep.call_count == 2
@ -348,10 +348,10 @@ class TestObjectsImportRunMethod:
mock_websocket.close.assert_called_once()
# Verify websocket was set to None
assert objects_import.ws is None
assert rows_import.ws is None
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
"""Test that run() handles None websocket gracefully."""
@ -360,7 +360,7 @@ class TestObjectsImportRunMethod:
mock_running.get.return_value = False # Exit immediately
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=None, # None websocket
running=mock_running,
backend=mock_backend,
@ -368,14 +368,14 @@ class TestObjectsImportRunMethod:
)
# Should not raise exception
await objects_import.run()
await rows_import.run()
# Verify websocket remains None
assert objects_import.ws is None
assert rows_import.ws is None
class TestObjectsImportBatchProcessing:
"""Test ObjectsImport batch processing functionality."""
class TestRowsImportBatchProcessing:
"""Test RowsImport batch processing functionality."""
@pytest.fixture
def batch_objects_message(self):
@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing:
"source_span": "Multiple people found in document"
}
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
"""Test that receive() processes batch message correctly."""
@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing:
mock_msg = Mock()
mock_msg.json.return_value = batch_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing:
assert sent_object.confidence == 0.85
assert sent_object.source_span == "Multiple people found in document"
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles empty batch correctly."""
@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing:
mock_msg = Mock()
mock_msg.json.return_value = empty_batch_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Should still send the message
mock_publisher_instance.send.assert_called_once()
@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing:
assert len(sent_object.values) == 0
class TestObjectsImportErrorHandling:
"""Test error handling in ObjectsImport."""
class TestRowsImportErrorHandling:
"""Test error handling in RowsImport."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() propagates publisher send errors."""
@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling:
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling:
mock_msg.json.return_value = sample_objects_message
with pytest.raises(Exception, match="Publisher error"):
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles malformed JSON appropriately."""
mock_publisher_class.return_value = Mock()
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling:
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
with pytest.raises(json.JSONDecodeError):
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)

View file

@ -76,7 +76,7 @@ def cities_schema():
def validator():
"""Create a mock processor with just the validation method"""
from unittest.mock import MagicMock
from trustgraph.extract.kg.objects.processor import Processor
from trustgraph.extract.kg.rows.processor import Processor
# Create a mock processor
mock_processor = MagicMock()

View file

@ -167,7 +167,7 @@ class TestFlowClient:
expected_methods = [
'text_completion', 'agent', 'graph_rag', 'document_rag',
'graph_embeddings_query', 'embeddings', 'prompt',
'triples_query', 'objects_query'
'triples_query', 'rows_query'
]
for method in expected_methods:
@ -216,7 +216,7 @@ class TestSocketClient:
expected_methods = [
'agent', 'text_completion', 'graph_rag', 'document_rag',
'prompt', 'graph_embeddings_query', 'embeddings',
'triples_query', 'objects_query', 'mcp_tool'
'triples_query', 'rows_query', 'mcp_tool'
]
for method in expected_methods:
@ -243,7 +243,7 @@ class TestBulkClient:
'import_graph_embeddings',
'import_document_embeddings',
'import_entity_contexts',
'import_objects'
'import_rows'
]
for method in import_methods:

View file

@ -1,10 +1,11 @@
"""
Unit tests for Cassandra Objects GraphQL Query Processor
Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation)
Tests the business logic of the GraphQL query processor including:
- GraphQL schema generation from RowSchema
- Query execution and validation
- CQL translation logic
- Schema configuration handling
- Query execution using unified rows table
- Name sanitization
- GraphQL query execution
- Message processing logic
"""
@ -12,119 +13,91 @@ import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
import strawberry
from strawberry import Schema
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.schema import RowSchema, Field
class TestObjectsGraphQLQueryLogic:
"""Test business logic without external dependencies"""
def test_get_python_type_mapping(self):
"""Test schema field type conversion to Python types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_python_type("string") == str
assert processor.get_python_type("integer") == int
assert processor.get_python_type("float") == float
assert processor.get_python_type("boolean") == bool
assert processor.get_python_type("timestamp") == str
assert processor.get_python_type("date") == str
assert processor.get_python_type("time") == str
assert processor.get_python_type("uuid") == str
# Unknown type defaults to str
assert processor.get_python_type("unknown_type") == str
def test_create_graphql_type_basic_fields(self):
"""Test GraphQL type creation for basic field types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="test_table",
description="Test table",
fields=[
Field(
name="id",
type="string",
primary=True,
required=True,
description="Primary key"
),
Field(
name="name",
type="string",
required=True,
description="Name field"
),
Field(
name="age",
type="integer",
required=False,
description="Optional age"
),
Field(
name="active",
type="boolean",
required=False,
description="Status flag"
)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test_table", schema)
# Verify type was created
assert graphql_type is not None
assert hasattr(graphql_type, '__name__')
assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower()
class TestRowsGraphQLQueryLogic:
"""Test business logic for unified table query implementation"""
def test_sanitize_name_cassandra_compatibility(self):
"""Test name sanitization for Cassandra field names"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test field name sanitization (matches storage processor)
# Test field name sanitization (uses r_ prefix like storage processor)
assert processor.sanitize_name("simple_field") == "simple_field"
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
assert processor.sanitize_name("123_field") == "o_123_field"
assert processor.sanitize_name("123_field") == "r_123_field"
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
assert processor.sanitize_name("special!@#chars") == "special___chars"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
assert processor.sanitize_name("CamelCase") == "camelcase"
def test_sanitize_table_name(self):
"""Test table name sanitization (always gets o_ prefix)"""
def test_get_index_names(self):
"""Test extraction of index names from schema"""
processor = MagicMock()
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Table names always get o_ prefix
assert processor.sanitize_table("simple_table") == "o_simple_table"
assert processor.sanitize_table("Table-Name") == "o_table_name"
assert processor.sanitize_table("123table") == "o_123table"
assert processor.sanitize_table("") == "o_"
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
schema = RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string"), # Not indexed
Field(name="status", type="string", indexed=True)
]
)
index_names = processor.get_index_names(schema)
assert "id" in index_names
assert "category" in index_names
assert "status" in index_names
assert "name" not in index_names
assert len(index_names) == 3
def test_find_matching_index_exact_match(self):
"""Test finding matching index for exact match query"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
schema = RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string") # Not indexed
]
)
# Filter on indexed field should return match
filters = {"category": "electronics"}
result = processor.find_matching_index(schema, filters)
assert result is not None
assert result[0] == "category"
assert result[1] == ["electronics"]
# Filter on non-indexed field should return None
filters = {"name": "test"}
result = processor.find_matching_index(schema, filters)
assert result is None
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.graphql_types = {}
processor.graphql_schema = None
processor.config_key = "schema" # Set the config key
processor.generate_graphql_schema = AsyncMock()
processor.config_key = "schema"
processor.schema_builder = MagicMock()
processor.schema_builder.clear = MagicMock()
processor.schema_builder.add_schema = MagicMock()
processor.schema_builder.build = MagicMock(return_value=MagicMock())
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
schema_config = {
"schema": {
@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic:
})
}
}
# Process config
await processor.on_schema_config(schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
# Verify fields
id_field = next(f for f in schema.fields if f.name == "id")
assert id_field.primary is True
# The field should have been created correctly from JSON
# Let's test what we can verify - that the field has the right attributes
assert hasattr(id_field, 'required') # Has the required attribute
assert hasattr(id_field, 'primary') # Has the primary attribute
email_field = next(f for f in schema.fields if f.name == "email")
assert email_field.indexed is True
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify GraphQL schema regeneration was called
processor.generate_graphql_schema.assert_called_once()
def test_cql_query_building_basic(self):
"""Test basic CQL query construction"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to capture the query
mock_result = []
processor.session.execute.return_value = mock_result
# Create test schema
schema = RowSchema(
name="test_table",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string", indexed=True),
Field(name="status", type="string")
]
)
# Test query building
asyncio = pytest.importorskip("asyncio")
async def run_test():
await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="test_table",
row_schema=schema,
filters={"name": "John", "invalid_filter": "ignored"},
limit=10
)
# Run the async test
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(run_test())
finally:
loop.close()
# Verify Cassandra connection and query execution
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
# Verify the query structure (can't easily test exact query without complex mocking)
call_args = processor.session.execute.call_args
query = call_args[0][0] # First positional argument is the query
params = call_args[0][1] # Second positional argument is parameters
# Basic query structure checks
assert "SELECT * FROM test_user.o_test_table" in query
assert "WHERE" in query
assert "collection = %s" in query
assert "LIMIT 10" in query
# Parameters should include collection and name filter
assert "test_collection" in params
assert "John" in params
# Verify schema builder was called
processor.schema_builder.add_schema.assert_called_once()
processor.schema_builder.build.assert_called_once()
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { id name } }',
variables={},
@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic:
user="test_user",
collection="test_collection"
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value'] # keyword argument
context = call_args[1]['context_value']
assert context["processor"] == processor
assert context["user"] == "test_user"
assert context["collection"] == "test_collection"
# Verify result structure
assert "data" in result
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error instead of MagicMock
# Create a simple object to simulate GraphQL error
class MockError:
def __init__(self, message, path, extensions):
self.message = message
self.path = path
self.extensions = extensions
def __str__(self):
return self.message
mock_error = MockError(
message="Field 'invalid_field' doesn't exist",
path=["customers", "0", "invalid_field"],
extensions={"code": "FIELD_NOT_FOUND"}
)
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
user="test_user",
user="test_user",
collection="test_collection"
)
# Verify error handling
assert "errors" in result
assert len(result["errors"]) == 1
error = result["errors"][0]
assert error["message"] == "Field 'invalid_field' doesn't exist"
assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path
assert error["path"] == ["customers", "0", "invalid_field"]
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
def test_schema_generation_basic_structure(self):
"""Test basic GraphQL schema generation structure"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Test individual type creation (avoiding the full schema generation which has annotation issues)
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify type was created
assert len(processor.graphql_types) == 1
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None
@pytest.mark.asyncio
async def test_message_processing_success(self):
"""Test successful message processing flow"""
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock successful query result
processor.execute_graphql_query.return_value = {
"data": {"customers": [{"id": "1", "name": "John"}]},
"errors": [],
"extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String())
"extensions": {}
}
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
collection="test_collection",
query='{ customers { id name } }',
variables={},
operation_name=None
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-123"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
query='{ customers { id name } }',
@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic:
user="test_user",
collection="test_collection"
)
# Verify response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert isinstance(response_call, RowsQueryResponse)
assert response_call.error is None
assert '"customers"' in response_call.data # JSON encoded
assert len(response_call.errors) == 0
@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock query execution error
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ invalid_query }',
@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic:
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-456"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify error response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify error response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert isinstance(response_call, RowsQueryResponse)
assert response_call.error is not None
assert response_call.error.type == "objects-query-error"
assert response_call.error.type == "rows-query-error"
assert "No schema available" in response_call.error.message
assert response_call.data is None
class TestCQLQueryGeneration:
"""Test CQL query generation logic in isolation"""
def test_partition_key_inclusion(self):
"""Test that collection is always included in queries"""
class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
async def test_query_with_index_match(self):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Mock the query building (simplified version)
keyspace = processor.sanitize_name("test_user")
table = processor.sanitize_table("test_table")
query = f"SELECT * FROM {keyspace}.{table}"
where_clauses = ["collection = %s"]
assert "collection = %s" in where_clauses
assert keyspace == "test_user"
assert table == "o_test_table"
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
processor.session.execute.return_value = [mock_row]
schema = RowSchema(
name="products",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string")
]
)
# Query with filter on indexed field
results = await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="products",
row_schema=schema,
filters={"category": "electronics"},
limit=10
)
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
# Verify query structure - should query unified rows table
call_args = processor.session.execute.call_args
query = call_args[0][0]
params = call_args[0][1]
assert "SELECT data, source FROM test_user.rows" in query
assert "collection = %s" in query
assert "schema_name = %s" in query
assert "index_name = %s" in query
assert "index_value = %s" in query
assert params[0] == "test_collection"
assert params[1] == "products"
assert params[2] == "category"
assert params[3] == ["electronics"]
# Verify results
assert len(results) == 1
assert results[0]["id"] == "123"
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
async def test_query_without_index_match(self):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
mock_row2 = MagicMock()
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
processor.session.execute.return_value = [mock_row1, mock_row2]
schema = RowSchema(
name="products",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string"), # Not indexed
Field(name="price", type="string") # Not indexed
]
)
# Query with filter on non-indexed field
results = await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="products",
row_schema=schema,
filters={"name": "Product A"},
limit=10
)
# Query should use ALLOW FILTERING for scan
call_args = processor.session.execute.call_args
query = call_args[0][0]
assert "ALLOW FILTERING" in query
# Should post-filter results
assert len(results) == 1
assert results[0]["name"] == "Product A"
class TestFilterMatching:
"""Test filter matching logic"""
def test_matches_filters_exact_match(self):
"""Test exact match filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
row = {"status": "active", "name": "test"}
assert processor._matches_filters(row, {"status": "active"}, schema) is True
assert processor._matches_filters(row, {"status": "inactive"}, schema) is False
def test_matches_filters_comparison_operators(self):
"""Test comparison operators in filters"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="price", type="float")])
row = {"price": "100.0"}
# Greater than
assert processor._matches_filters(row, {"price_gt": 50}, schema) is True
assert processor._matches_filters(row, {"price_gt": 150}, schema) is False
# Less than
assert processor._matches_filters(row, {"price_lt": 150}, schema) is True
assert processor._matches_filters(row, {"price_lt": 50}, schema) is False
# Greater than or equal
assert processor._matches_filters(row, {"price_gte": 100}, schema) is True
assert processor._matches_filters(row, {"price_gte": 101}, schema) is False
# Less than or equal
assert processor._matches_filters(row, {"price_lte": 100}, schema) is True
assert processor._matches_filters(row, {"price_lte": 99}, schema) is False
def test_matches_filters_contains(self):
"""Test contains filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="description", type="string")])
row = {"description": "A great product for everyone"}
assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True
assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False
def test_matches_filters_in_list(self):
"""Test in-list filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
row = {"status": "active"}
assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True
assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False
class TestIndexedFieldFiltering:
"""Test that only indexed or primary key fields can be directly filtered"""
def test_indexed_field_filtering(self):
"""Test that only indexed or primary key fields can be filtered"""
# Create schema with mixed field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", primary=True),
Field(name="indexed_field", type="string", indexed=True),
Field(name="indexed_field", type="string", indexed=True),
Field(name="normal_field", type="string", indexed=False),
Field(name="another_field", type="string")
]
)
filters = {
"id": "test123", # Primary key - should be included
"indexed_field": "value", # Indexed - should be included
"normal_field": "ignored", # Not indexed - should be ignored
"another_field": "also_ignored" # Not indexed - should be ignored
}
# Simulate the filtering logic from the processor
valid_filters = []
for field_name, value in filters.items():
@ -492,7 +531,7 @@ class TestCQLQueryGeneration:
schema_field = next((f for f in schema.fields if f.name == field_name), None)
if schema_field and (schema_field.indexed or schema_field.primary):
valid_filters.append((field_name, value))
# Only id and indexed_field should be included
assert len(valid_filters) == 2
field_names = [f[0] for f in valid_filters]
@ -500,52 +539,3 @@ class TestCQLQueryGeneration:
assert "indexed_field" in field_names
assert "normal_field" not in field_names
assert "another_field" not in field_names
class TestGraphQLSchemaGeneration:
"""Test GraphQL schema generation in detail"""
def test_field_type_annotations(self):
"""Test that GraphQL types have correct field annotations"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create schema with various field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", required=True, primary=True),
Field(name="count", type="integer", required=True),
Field(name="price", type="float", required=False),
Field(name="active", type="boolean", required=False),
Field(name="optional_text", type="string", required=False)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test", schema)
# Verify type was created successfully
assert graphql_type is not None
def test_basic_type_creation(self):
"""Test that GraphQL types are created correctly"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[Field(name="id", type="string", primary=True)]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create GraphQL type directly
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify customer type was created
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
StructuredQueryRequest, StructuredQueryResponse,
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
ObjectsQueryRequest, ObjectsQueryResponse,
RowsQueryRequest, RowsQueryResponse,
Error, GraphQLError
)
from trustgraph.retrieval.structured_query.service import Processor
@ -68,7 +68,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects query service response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
errors=None,
@ -86,7 +86,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -108,7 +108,7 @@ class TestStructuredQueryProcessor:
# Verify objects query service was called correctly
mock_objects_client.request.assert_called_once()
objects_call_args = mock_objects_client.request.call_args[0][0]
assert isinstance(objects_call_args, ObjectsQueryRequest)
assert isinstance(objects_call_args, RowsQueryRequest)
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
assert objects_call_args.variables == {"state": "NY"}
assert objects_call_args.user == "trustgraph"
@ -224,7 +224,7 @@ class TestStructuredQueryProcessor:
assert response.error is not None
assert "empty GraphQL query" in response.error.message
async def test_objects_query_service_error(self, processor):
async def test_rows_query_service_error(self, processor):
"""Test handling of objects query service errors"""
# Arrange
request = StructuredQueryRequest(
@ -250,7 +250,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects query service error
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=Error(type="graphql-execution-error", message="Table 'customers' not found"),
data=None,
errors=None,
@ -267,7 +267,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -284,7 +284,7 @@ class TestStructuredQueryProcessor:
response = response_call[0][0]
assert response.error is not None
assert "Objects query service error" in response.error.message
assert "Rows query service error" in response.error.message
assert "Table 'customers' not found" in response.error.message
async def test_graphql_errors_handling(self, processor):
@ -321,7 +321,7 @@ class TestStructuredQueryProcessor:
)
]
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None,
errors=graphql_errors,
@ -338,7 +338,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -400,7 +400,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}',
errors=None
@ -416,7 +416,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -464,7 +464,7 @@ class TestStructuredQueryProcessor:
confidence=0.9
)
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None, # Null data
errors=None,
@ -481,7 +481,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response

View file

@ -10,7 +10,7 @@ import pytest
from unittest.mock import Mock, patch, MagicMock
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
from trustgraph.storage.knowledge.store import Processor as KgStore
@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration:
assert processor.cassandra_password is None
class TestObjectsWriterConfiguration:
class TestRowsWriterConfiguration:
"""Test Cassandra configuration in objects writer processor."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_environment_variable_configuration(self, mock_cluster):
"""Test processor picks up configuration from environment variables."""
env_vars = {
@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2']
assert processor.cassandra_username == 'obj-env-user'
assert processor.cassandra_password == 'obj-env-pass'
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_cassandra_connection_with_hosts_list(self, mock_cluster):
"""Test that Cassandra connection uses hosts list correctly."""
env_vars = {
@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify cluster was called with hosts list
@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration:
assert 'contact_points' in call_args.kwargs
assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3']
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_authentication_configuration(self, mock_auth_provider, mock_cluster):
"""Test authentication is configured when credentials are provided."""
env_vars = {
@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify auth provider was created with correct credentials
@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling:
def test_objects_writer_add_args(self):
"""Test that objects writer adds standard Cassandra arguments."""
import argparse
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
parser = argparse.ArgumentParser()
ObjectsWriter.add_args(parser)
RowsWriter.add_args(parser)
# Parse empty args to check that arguments exist
args = parser.parse_args([])

View file

@ -1,533 +0,0 @@
"""
Unit tests for Cassandra Object Storage Processor
Tests the business logic of the object storage processor including:
- Schema configuration handling
- Type conversions
- Name sanitization
- Table structure generation
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class TestObjectsCassandraStorageLogic:
"""Test business logic without FlowProcessor dependencies"""
def test_sanitize_name(self):
"""Test name sanitization for Cassandra compatibility"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test various name patterns (back to original logic)
assert processor.sanitize_name("simple_name") == "simple_name"
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
def test_get_cassandra_type(self):
"""Test field type conversion to Cassandra types"""
processor = MagicMock()
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_cassandra_type("string") == "text"
assert processor.get_cassandra_type("boolean") == "boolean"
assert processor.get_cassandra_type("timestamp") == "timestamp"
assert processor.get_cassandra_type("uuid") == "uuid"
# Integer types with size hints
assert processor.get_cassandra_type("integer", size=2) == "int"
assert processor.get_cassandra_type("integer", size=8) == "bigint"
# Float types with size hints
assert processor.get_cassandra_type("float", size=2) == "float"
assert processor.get_cassandra_type("float", size=8) == "double"
# Unknown type defaults to text
assert processor.get_cassandra_type("unknown_type") == "text"
def test_convert_value(self):
"""Test value conversion for different field types"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Integer conversions
assert processor.convert_value("123", "integer") == 123
assert processor.convert_value(123.5, "integer") == 123
assert processor.convert_value(None, "integer") is None
# Float conversions
assert processor.convert_value("123.45", "float") == 123.45
assert processor.convert_value(123, "float") == 123.0
# Boolean conversions
assert processor.convert_value("true", "boolean") is True
assert processor.convert_value("false", "boolean") is False
assert processor.convert_value("1", "boolean") is True
assert processor.convert_value("0", "boolean") is False
assert processor.convert_value("yes", "boolean") is True
assert processor.convert_value("no", "boolean") is False
# String conversions
assert processor.convert_value(123, "string") == "123"
assert processor.convert_value(True, "string") == "True"
def test_table_creation_cql_generation(self):
"""Test CQL generation for table creation"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="customer_records",
description="Test customer schema",
fields=[
Field(
name="customer_id",
type="string",
size=50,
primary=True,
required=True,
indexed=False
),
Field(
name="email",
type="string",
size=100,
required=True,
indexed=True
),
Field(
name="age",
type="integer",
size=4,
required=False,
indexed=False
)
]
)
# Call ensure_table
processor.ensure_table("test_user", "customer_records", schema)
# Verify keyspace was ensured (check that it was added to known_keyspaces)
assert "test_user" in processor.known_keyspaces
# Check the CQL that was executed (first call should be table creation)
all_calls = processor.session.execute.call_args_list
table_creation_cql = all_calls[0][0][0] # First call
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
assert "collection text" in table_creation_cql
assert "customer_id text" in table_creation_cql
assert "email text" in table_creation_cql
assert "age int" in table_creation_cql
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
def test_table_creation_without_primary_key(self):
"""Test table creation when no primary key is defined"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema without primary key
schema = RowSchema(
name="events",
description="Event log",
fields=[
Field(name="event_type", type="string", size=50),
Field(name="timestamp", type="timestamp", size=0)
]
)
# Call ensure_table
processor.ensure_table("test_user", "events", schema)
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
executed_cql = processor.session.execute.call_args[0][0]
assert "synthetic_id uuid" in executed_cql
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer data",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True
},
{
"name": "name",
"type": "string",
"required": True
},
{
"name": "balance",
"type": "float",
"size": 8
}
]
})
}
}
# Process configuration
await processor.on_schema_config(config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
# Check field properties
id_field = schema.fields[0]
assert id_field.name == "id"
assert id_field.type == "string"
assert id_field.primary is True
# Note: Field.required always returns False due to Pulsar schema limitations
# The actual required value is tracked during schema parsing
@pytest.mark.asyncio
async def test_object_processing_logic(self):
"""Test the logic for processing ExtractedObject"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="integer", size=4)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="test_schema",
values=[{"id": "123", "value": "456"}],
confidence=0.9,
source_span="test source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
# Verify insert was executed (keyspace normal, table with o_ prefix)
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_test_schema" in insert_cql
assert "collection" in insert_cql
assert values[0] == "test_collection" # collection value
assert values[1] == "123" # id value (from values[0])
assert values[2] == 456 # converted integer value (from values[0])
def test_secondary_index_creation(self):
"""Test that secondary indexes are created for indexed fields"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
if keyspace not in processor.known_tables:
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema with indexed field
schema = RowSchema(
name="products",
description="Product catalog",
fields=[
Field(name="product_id", type="string", size=50, primary=True),
Field(name="category", type="string", size=30, indexed=True),
Field(name="price", type="float", size=8, indexed=True)
]
)
# Call ensure_table
processor.ensure_table("test_user", "products", schema)
# Should have 3 calls: create table + 2 indexes
assert processor.session.execute.call_count == 3
# Check index creation calls (table has o_ prefix, fields don't)
calls = processor.session.execute.call_args_list
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
assert len(index_calls) == 2
assert any("o_products_category_idx" in call for call in index_calls)
assert any("o_products_price_idx" in call for call in index_calls)
class TestObjectsCassandraStorageBatchLogic:
"""Test batch processing logic in Cassandra storage"""
@pytest.mark.asyncio
async def test_batch_object_processing_logic(self):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
description="Test batch schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="value", type="integer", size=4)
]
)
}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
metadata=[]
),
schema_name="batch_schema",
values=[
{"id": "001", "name": "First", "value": "100"},
{"id": "002", "name": "Second", "value": "200"},
{"id": "003", "name": "Third", "value": "300"}
],
confidence=0.95,
source_span="batch source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = batch_obj
# Process batch object
await processor.on_object(msg, None, None)
# Verify table was ensured once
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
# Verify 3 separate insert calls (one per batch item)
assert processor.session.execute.call_count == 3
# Check each insert call
calls = processor.session.execute.call_args_list
for i, call in enumerate(calls):
insert_cql = call[0][0]
values = call[0][1]
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
assert "collection" in insert_cql
# Check values for each batch item
assert values[0] == "batch_collection" # collection
assert values[1] == f"00{i+1}" # id from batch item i
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
assert values[3] == (i+1) * 100 # converted integer value
@pytest.mark.asyncio
async def test_empty_batch_processing_logic(self):
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create empty batch object
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
metadata=[]
),
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span="empty source"
)
msg = MagicMock()
msg.value.return_value = empty_batch_obj
# Process empty batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
@pytest.mark.asyncio
async def test_single_item_batch_processing_logic(self):
"""Test processing of single-item batch (backward compatibility)"""
processor = MagicMock()
processor.schemas = {
"single_schema": RowSchema(
name="single_schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="data", type="string", size=100)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create single-item batch object (backward compatibility case)
single_batch_obj = ExtractedObject(
metadata=Metadata(
id="single-001",
user="test_user",
collection="single_collection",
metadata=[]
),
schema_name="single_schema",
values=[{"id": "single-1", "data": "single data"}], # Array with one item
confidence=0.8,
source_span="single source"
)
msg = MagicMock()
msg.value.return_value = single_batch_obj
# Process single-item batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify exactly one insert call
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_single_schema" in insert_cql
assert values[0] == "single_collection" # collection
assert values[1] == "single-1" # id value
assert values[2] == "single data" # data value
def test_batch_value_conversion_logic(self):
"""Test value conversion works correctly for batch items"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Test various conversion scenarios that would occur in batch processing
test_cases = [
# Integer conversions for batch items
("123", "integer", 123),
("456", "integer", 456),
("789", "integer", 789),
# Float conversions for batch items
("12.5", "float", 12.5),
("34.7", "float", 34.7),
# Boolean conversions for batch items
("true", "boolean", True),
("false", "boolean", False),
("1", "boolean", True),
("0", "boolean", False),
# String conversions for batch items
(123, "string", "123"),
(45.6, "string", "45.6"),
]
for input_val, field_type, expected_output in test_cases:
result = processor.convert_value(input_val, field_type)
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"

View file

@ -0,0 +1,435 @@
"""
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
"""Test Qdrant row embeddings storage functionality"""
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_processor_initialization_basic(self, mock_qdrant_client):
"""Test basic Qdrant processor initialization"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
mock_qdrant_client.assert_called_once_with(
url='http://localhost:6333', api_key='test-api-key'
)
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
"""Test processor initialization with default values"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
mock_qdrant_client.assert_called_once_with(
url='http://localhost:6333', api_key=None
)
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_sanitize_name(self, mock_qdrant_client):
"""Test name sanitization for Qdrant collections"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Test basic sanitization
assert processor.sanitize_name("simple") == "simple"
assert processor.sanitize_name("with-dash") == "with_dash"
assert processor.sanitize_name("with.dot") == "with_dot"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
# Test numeric prefix handling
assert processor.sanitize_name("123start") == "r_123start"
assert processor.sanitize_name("_underscore") == "r__underscore"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_get_collection_name(self, mock_qdrant_client):
"""Test Qdrant collection name generation"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
collection_name = processor.get_collection_name(
user="test_user",
collection="test_collection",
schema_name="customer_data",
dimension=384
)
assert collection_name == "rows_test_user_test_collection_customer_data_384"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
"""Test that ensure_collection creates a new collection when needed"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.ensure_collection("test_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_instance.create_collection.assert_called_once()
# Verify the collection is cached
assert "test_collection" in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
"""Test that ensure_collection skips creation when collection exists"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.ensure_collection("existing_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once()
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_uses_cache(self, mock_qdrant_client):
"""Test that ensure_collection uses cache for previously created collections"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.created_collections.add("cached_collection")
processor.ensure_collection("cached_collection", 384)
# Should not check or create - just return
mock_qdrant_instance.collection_exists.assert_not_called()
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client):
"""Test processing basic row embeddings message"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = 'test-uuid-123'
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# Create embeddings message
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
embedding = RowIndexEmbedding(
index_name='customer_id',
index_value=['CUST001'],
text='CUST001',
vectors=[[0.1, 0.2, 0.3]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='customers',
embeddings=[embedding]
)
# Mock message wrapper
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload['index_name'] == 'customer_id'
assert point.payload['index_value'] == ['CUST001']
assert point.payload['text'] == 'CUST001'
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
"""Test processing embeddings with multiple vectors"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = 'test-uuid'
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with multiple vectors
embedding = RowIndexEmbedding(
index_name='name',
index_value=['John Doe'],
text='John Doe',
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='people',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should be called 3 times (once per vector)
assert mock_qdrant_instance.upsert.call_count == 3
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
"""Test that embeddings with no vectors are skipped"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with no vectors
embedding = RowIndexEmbedding(
index_name='id',
index_value=['123'],
text='123',
vectors=[] # Empty vectors
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='items',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should not call upsert for empty vectors
mock_qdrant_instance.upsert.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client):
"""Test that messages for unknown collections are dropped"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
embedding = RowIndexEmbedding(
index_name='id',
index_value=['123'],
text='123',
vectors=[[0.1, 0.2]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='items',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should not call upsert for unknown collection
mock_qdrant_instance.upsert.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection(self, mock_qdrant_client):
"""Test deleting all collections for a user/collection"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
# Mock collections list
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
mock_coll3 = MagicMock()
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
mock_qdrant_instance.get_collections.return_value = mock_collections
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
await processor.delete_collection('test_user', 'test_collection')
# Should delete only the matching collections
assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client):
"""Test deleting collections for a specific schema"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2]
mock_qdrant_instance.get_collections.return_value = mock_collections
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
await processor.delete_collection_schema(
'test_user', 'test_collection', 'customers'
)
# Should only delete the customers schema collection
mock_qdrant_instance.delete_collection.assert_called_once()
call_args = mock_qdrant_instance.delete_collection.call_args[0]
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,474 @@
"""
Unit tests for Cassandra Row Storage Processor (Unified Table Implementation)
Tests the business logic of the row storage processor including:
- Schema configuration handling
- Name sanitization
- Unified table structure
- Index management
- Row storage with multi-index support
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class TestRowsCassandraStorageLogic:
"""Test business logic for unified table implementation"""
def test_sanitize_name(self):
"""Test name sanitization for Cassandra compatibility"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test various name patterns
assert processor.sanitize_name("simple_name") == "simple_name"
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number"
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
assert processor.sanitize_name("CamelCase") == "camelcase"
assert processor.sanitize_name("_underscore_start") == "r__underscore_start"
def test_get_index_names(self):
"""Test extraction of index names from schema"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
# Schema with primary and indexed fields
schema = RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string"), # Not indexed
Field(name="status", type="string", indexed=True)
]
)
index_names = processor.get_index_names(schema)
# Should include primary key and indexed fields
assert "id" in index_names
assert "category" in index_names
assert "status" in index_names
assert "name" not in index_names # Not indexed
assert len(index_names) == 3
def test_get_index_names_no_indexes(self):
"""Test schema with no indexed fields"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
schema = RowSchema(
name="no_index_schema",
fields=[
Field(name="data1", type="string"),
Field(name="data2", type="string")
]
)
index_names = processor.get_index_names(schema)
assert len(index_names) == 0
def test_build_index_value(self):
"""Test building index values from row data"""
processor = MagicMock()
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
value_map = {"id": "123", "category": "electronics", "name": "Widget"}
# Single field index
result = processor.build_index_value(value_map, "id")
assert result == ["123"]
result = processor.build_index_value(value_map, "category")
assert result == ["electronics"]
# Missing field returns empty string
result = processor.build_index_value(value_map, "missing")
assert result == [""]
def test_build_index_value_composite(self):
"""Test building composite index values"""
processor = MagicMock()
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
value_map = {"region": "us-west", "category": "electronics", "id": "123"}
# Composite index (comma-separated field names)
result = processor.build_index_value(value_map, "region,category")
assert result == ["us-west", "electronics"]
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.registered_partitions = set()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer data",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True
},
{
"name": "name",
"type": "string",
"required": True
},
{
"name": "category",
"type": "string",
"indexed": True
}
]
})
}
}
# Process configuration
await processor.on_schema_config(config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
# Check field properties
id_field = schema.fields[0]
assert id_field.name == "id"
assert id_field.type == "string"
assert id_field.primary is True
@pytest.mark.asyncio
async def test_object_processing_stores_data_map(self):
"""Test that row processing stores data as map<text, text>"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="test_schema",
values=[{"id": "123", "value": "test_data"}],
confidence=0.9,
source_span="test source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
# Verify insert was executed
processor.session.execute.assert_called()
insert_call = processor.session.execute.call_args
insert_cql = insert_call[0][0]
values = insert_call[0][1]
# Verify using unified rows table
assert "INSERT INTO test_user.rows" in insert_cql
# Values should be: (collection, schema_name, index_name, index_value, data, source)
assert values[0] == "test_collection" # collection
assert values[1] == "test_schema" # schema_name
assert values[2] == "id" # index_name (primary key field)
assert values[3] == ["123"] # index_value as list
assert values[4] == {"id": "123", "value": "test_data"} # data map
assert values[5] == "" # source
@pytest.mark.asyncio
async def test_object_processing_multiple_indexes(self):
"""Test that row is written once per indexed field"""
processor = MagicMock()
processor.schemas = {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="multi_index_schema",
values=[{"id": "123", "category": "electronics", "status": "active"}],
confidence=0.9,
source_span=""
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Should have 3 inserts (one per indexed field: id, category, status)
assert processor.session.execute.call_count == 3
# Check that different index_names were used
index_names_used = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
index_names_used.add(values[2]) # index_name is 3rd value
assert index_names_used == {"id", "category", "status"}
class TestRowsCassandraStorageBatchLogic:
"""Test batch processing logic for unified table implementation"""
@pytest.mark.asyncio
async def test_batch_object_processing(self):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
metadata=[]
),
schema_name="batch_schema",
values=[
{"id": "001", "name": "First"},
{"id": "002", "name": "Second"},
{"id": "003", "name": "Third"}
],
confidence=0.95,
source_span=""
)
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Should have 3 inserts (one per row, one index per row since only primary key)
assert processor.session.execute.call_count == 3
# Check each insert has different id
ids_inserted = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
ids_inserted.add(tuple(values[3])) # index_value is 4th value
assert ids_inserted == {("001",), ("002",), ("003",)}
@pytest.mark.asyncio
async def test_empty_batch_processing(self):
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create empty batch object
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
metadata=[]
),
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span=""
)
msg = MagicMock()
msg.value.return_value = empty_batch_obj
await processor.on_object(msg, None, None)
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
class TestUnifiedTableStructure:
"""Test the unified rows table structure"""
def test_ensure_tables_creates_unified_structure(self):
"""Test that ensure_tables creates the unified rows table"""
processor = MagicMock()
processor.known_keyspaces = {"test_user"}
processor.tables_initialized = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.ensure_keyspace = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.ensure_tables("test_user")
# Should have 2 calls: create rows table + create row_partitions table
assert processor.session.execute.call_count == 2
# Check rows table creation
rows_cql = processor.session.execute.call_args_list[0][0][0]
assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql
assert "collection text" in rows_cql
assert "schema_name text" in rows_cql
assert "index_name text" in rows_cql
assert "index_value frozen<list<text>>" in rows_cql
assert "data map<text, text>" in rows_cql
assert "source text" in rows_cql
assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql
# Check row_partitions table creation
partitions_cql = processor.session.execute.call_args_list[1][0][0]
assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql
assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql
# Verify keyspace added to initialized set
assert "test_user" in processor.tables_initialized
def test_ensure_tables_idempotent(self):
"""Test that ensure_tables is idempotent"""
processor = MagicMock()
processor.tables_initialized = {"test_user"} # Already initialized
processor.session = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.ensure_tables("test_user")
# Should not execute any CQL since already initialized
processor.session.execute.assert_not_called()
class TestPartitionRegistration:
"""Test partition registration for tracking what's stored"""
def test_register_partitions(self):
"""Test registering partitions for a collection/schema pair"""
processor = MagicMock()
processor.registered_partitions = set()
processor.session = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
}
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
# Should have 2 inserts (one per index: id, category)
assert processor.session.execute.call_count == 2
# Verify cache was updated
assert ("test_collection", "test_schema") in processor.registered_partitions
def test_register_partitions_idempotent(self):
"""Test that partition registration is idempotent"""
processor = MagicMock()
processor.registered_partitions = {("test_collection", "test_schema")} # Already registered
processor.session = MagicMock()
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
# Should not execute any CQL since already registered
processor.session.execute.assert_not_called()

View file

@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert hasattr(processor, 'client')
assert hasattr(processor, 'safety_settings')
assert len(processor.safety_settings) == 4 # 4 safety categories
mock_genai_class.assert_called_once_with(api_key='test-api-key')
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'gemini-1.5-pro'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_genai_class.assert_called_once_with(api_key='test-api-key')
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
# Assert
# Verify Google AI Studio client was called with correct API key
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False)
# Verify processor has the client
assert processor.client == mock_genai_client