diff --git a/tests/contract/test_objects_cassandra_contracts.py b/tests/contract/test_objects_cassandra_contracts.py index 85f6aedc..3966a3fc 100644 --- a/tests/contract/test_objects_cassandra_contracts.py +++ b/tests/contract/test_objects_cassandra_contracts.py @@ -30,11 +30,11 @@ class TestObjectsCassandraContracts: test_object = ExtractedObject( metadata=test_metadata, schema_name="customer_records", - values={ + values=[{ "customer_id": "CUST123", "name": "Test Customer", "email": "test@example.com" - }, + }], confidence=0.95, source_span="Customer data from document..." ) @@ -54,7 +54,7 @@ class TestObjectsCassandraContracts: # Verify types assert isinstance(test_object.schema_name, str) - assert isinstance(test_object.values, dict) + assert isinstance(test_object.values, list) assert isinstance(test_object.confidence, float) assert isinstance(test_object.source_span, str) @@ -200,7 +200,7 @@ class TestObjectsCassandraContracts: metadata=[] ), schema_name="test_schema", - values={"field1": "value1", "field2": "123"}, + values=[{"field1": "value1", "field2": "123"}], confidence=0.85, source_span="Test span" ) @@ -292,7 +292,7 @@ class TestObjectsCassandraContracts: metadata=[{"key": "value"}] ), schema_name="table789", # -> table name - values={"field": "value"}, + values=[{"field": "value"}], confidence=0.9, source_span="Source" ) @@ -303,4 +303,215 @@ class TestObjectsCassandraContracts: # - metadata.collection -> Part of primary key assert test_obj.metadata.user # Required for keyspace assert test_obj.schema_name # Required for table - assert test_obj.metadata.collection # Required for partition key \ No newline at end of file + assert test_obj.metadata.collection # Required for partition key + + +@pytest.mark.contract +class TestObjectsCassandraContractsBatch: + """Contract tests for Cassandra object storage batch processing""" + + def test_extracted_object_batch_input_contract(self): + """Test that batched ExtractedObject schema matches expected input format""" + # Create test object with multiple values in batch + test_metadata = Metadata( + id="batch-doc-001", + user="test_user", + collection="test_collection", + metadata=[] + ) + + batch_object = ExtractedObject( + metadata=test_metadata, + schema_name="customer_records", + values=[ + { + "customer_id": "CUST123", + "name": "Test Customer 1", + "email": "test1@example.com" + }, + { + "customer_id": "CUST124", + "name": "Test Customer 2", + "email": "test2@example.com" + }, + { + "customer_id": "CUST125", + "name": "Test Customer 3", + "email": "test3@example.com" + } + ], + confidence=0.88, + source_span="Multiple customer data from document..." + ) + + # Verify batch structure + assert hasattr(batch_object, 'values') + assert isinstance(batch_object.values, list) + assert len(batch_object.values) == 3 + + # Verify each batch item is a dict + for i, batch_item in enumerate(batch_object.values): + assert isinstance(batch_item, dict) + assert "customer_id" in batch_item + assert "name" in batch_item + assert "email" in batch_item + assert batch_item["customer_id"] == f"CUST12{3+i}" + assert f"Test Customer {i+1}" in batch_item["name"] + + def test_extracted_object_empty_batch_contract(self): + """Test empty batch ExtractedObject contract""" + test_metadata = Metadata( + id="empty-batch-001", + user="test_user", + collection="test_collection", + metadata=[] + ) + + empty_batch_object = ExtractedObject( + metadata=test_metadata, + schema_name="empty_schema", + values=[], # Empty batch + confidence=1.0, + source_span="No objects found in document" + ) + + # Verify empty batch structure + assert hasattr(empty_batch_object, 'values') + assert isinstance(empty_batch_object.values, list) + assert len(empty_batch_object.values) == 0 + assert empty_batch_object.confidence == 1.0 + + def test_extracted_object_single_item_batch_contract(self): + """Test single-item batch (backward compatibility) contract""" + test_metadata = Metadata( + id="single-batch-001", + user="test_user", + collection="test_collection", + metadata=[] + ) + + single_batch_object = ExtractedObject( + metadata=test_metadata, + schema_name="customer_records", + values=[{ # Array with single item for backward compatibility + "customer_id": "CUST999", + "name": "Single Customer", + "email": "single@example.com" + }], + confidence=0.95, + source_span="Single customer data from document..." + ) + + # Verify single-item batch structure + assert isinstance(single_batch_object.values, list) + assert len(single_batch_object.values) == 1 + assert isinstance(single_batch_object.values[0], dict) + assert single_batch_object.values[0]["customer_id"] == "CUST999" + + def test_extracted_object_batch_serialization_contract(self): + """Test that batched ExtractedObject can be serialized/deserialized correctly""" + # Create batch object + original = ExtractedObject( + metadata=Metadata( + id="batch-serial-001", + user="test_user", + collection="test_coll", + metadata=[] + ), + schema_name="test_schema", + values=[ + {"field1": "value1", "field2": "123"}, + {"field1": "value2", "field2": "456"}, + {"field1": "value3", "field2": "789"} + ], + confidence=0.92, + source_span="Batch test span" + ) + + # Test serialization using schema + schema = AvroSchema(ExtractedObject) + + # Encode and decode + encoded = schema.encode(original) + decoded = schema.decode(encoded) + + # Verify round-trip for batch + assert decoded.metadata.id == original.metadata.id + assert decoded.metadata.user == original.metadata.user + assert decoded.metadata.collection == original.metadata.collection + assert decoded.schema_name == original.schema_name + assert len(decoded.values) == len(original.values) + assert len(decoded.values) == 3 + + # Verify each batch item + for i in range(3): + assert decoded.values[i] == original.values[i] + assert decoded.values[i]["field1"] == f"value{i+1}" + assert decoded.values[i]["field2"] == f"{123 + i*333}" + + assert decoded.confidence == original.confidence + assert decoded.source_span == original.source_span + + def test_batch_processing_field_validation_contract(self): + """Test that batch processing validates field consistency""" + # All batch items should have consistent field structure + # This is a contract that the application should enforce + + # Valid batch - all items have same fields + valid_batch_values = [ + {"id": "1", "name": "Item 1", "value": "100"}, + {"id": "2", "name": "Item 2", "value": "200"}, + {"id": "3", "name": "Item 3", "value": "300"} + ] + + # Each item has the same field structure + field_sets = [set(item.keys()) for item in valid_batch_values] + assert all(fields == field_sets[0] for fields in field_sets), "All batch items should have consistent fields" + + # Invalid batch - inconsistent fields (this would be caught by application logic) + invalid_batch_values = [ + {"id": "1", "name": "Item 1", "value": "100"}, + {"id": "2", "name": "Item 2"}, # Missing 'value' field + {"id": "3", "name": "Item 3", "value": "300", "extra": "field"} # Extra field + ] + + # Demonstrate the inconsistency + invalid_field_sets = [set(item.keys()) for item in invalid_batch_values] + assert not all(fields == invalid_field_sets[0] for fields in invalid_field_sets), "Invalid batch should have inconsistent fields" + + def test_batch_storage_partition_key_contract(self): + """Test that batch objects maintain partition key consistency""" + # In Cassandra storage, all objects in a batch should: + # 1. Belong to the same collection (partition key component) + # 2. Have unique primary keys within the batch + # 3. Be stored in the same keyspace (user) + + test_metadata = Metadata( + id="partition-test-001", + user="consistent_user", # Same keyspace + collection="consistent_collection", # Same partition + metadata=[] + ) + + batch_object = ExtractedObject( + metadata=test_metadata, + schema_name="partition_test", + values=[ + {"id": "pk1", "data": "data1"}, # Unique primary key + {"id": "pk2", "data": "data2"}, # Unique primary key + {"id": "pk3", "data": "data3"} # Unique primary key + ], + confidence=0.95, + source_span="Partition consistency test" + ) + + # Verify consistency contract + assert batch_object.metadata.user # Must have user for keyspace + assert batch_object.metadata.collection # Must have collection for partition key + + # Verify unique primary keys in batch + primary_keys = [item["id"] for item in batch_object.values] + assert len(primary_keys) == len(set(primary_keys)), "Primary keys must be unique within batch" + + # All batch items will be stored in same keyspace and partition + # This is enforced by the metadata.user and metadata.collection being shared \ No newline at end of file diff --git a/tests/contract/test_structured_data_contracts.py b/tests/contract/test_structured_data_contracts.py index af8e70df..91707d4d 100644 --- a/tests/contract/test_structured_data_contracts.py +++ b/tests/contract/test_structured_data_contracts.py @@ -128,18 +128,77 @@ class TestStructuredDataSchemaContracts: obj = ExtractedObject( metadata=metadata, schema_name="customer_records", - values={"id": "123", "name": "John Doe", "email": "john@example.com"}, + values=[{"id": "123", "name": "John Doe", "email": "john@example.com"}], confidence=0.95, source_span="John Doe (john@example.com) customer ID 123" ) # Assert assert obj.schema_name == "customer_records" - assert obj.values["name"] == "John Doe" + assert obj.values[0]["name"] == "John Doe" assert obj.confidence == 0.95 assert len(obj.source_span) > 0 assert obj.metadata.id == "extracted-obj-001" + def test_extracted_object_batch_contract(self): + """Test ExtractedObject schema contract for batched values""" + # Arrange + metadata = Metadata( + id="extracted-batch-001", + user="test_user", + collection="test_collection", + metadata=[] + ) + + # Act - create object with multiple values + obj = ExtractedObject( + metadata=metadata, + schema_name="customer_records", + values=[ + {"id": "123", "name": "John Doe", "email": "john@example.com"}, + {"id": "124", "name": "Jane Smith", "email": "jane@example.com"}, + {"id": "125", "name": "Bob Johnson", "email": "bob@example.com"} + ], + confidence=0.85, + source_span="Multiple customers found in document" + ) + + # Assert + assert obj.schema_name == "customer_records" + assert len(obj.values) == 3 + assert obj.values[0]["name"] == "John Doe" + assert obj.values[1]["name"] == "Jane Smith" + assert obj.values[2]["name"] == "Bob Johnson" + assert obj.values[0]["id"] == "123" + assert obj.values[1]["id"] == "124" + assert obj.values[2]["id"] == "125" + assert obj.confidence == 0.85 + assert "Multiple customers" in obj.source_span + + def test_extracted_object_empty_batch_contract(self): + """Test ExtractedObject schema contract for empty values array""" + # Arrange + metadata = Metadata( + id="extracted-empty-001", + user="test_user", + collection="test_collection", + metadata=[] + ) + + # Act - create object with empty values array + obj = ExtractedObject( + metadata=metadata, + schema_name="empty_schema", + values=[], + confidence=1.0, + source_span="No objects found" + ) + + # Assert + assert obj.schema_name == "empty_schema" + assert len(obj.values) == 0 + assert obj.confidence == 1.0 + @pytest.mark.contract class TestStructuredQueryServiceContracts: @@ -273,7 +332,7 @@ class TestStructuredDataSerializationContracts: object_data = { "metadata": metadata, "schema_name": "test_schema", - "values": {"field1": "value1"}, + "values": [{"field1": "value1"}], "confidence": 0.8, "source_span": "test span" } @@ -314,4 +373,38 @@ class TestStructuredDataSerializationContracts: "data": '{"customers": [{"id": "1", "name": "John"}]}', "errors": [] } - assert serialize_deserialize_test(StructuredQueryResponse, response_data) \ No newline at end of file + assert serialize_deserialize_test(StructuredQueryResponse, response_data) + + def test_extracted_object_batch_serialization(self): + """Test ExtractedObject batch serialization contract""" + # Arrange + metadata = Metadata(id="test", user="user", collection="col", metadata=[]) + batch_object_data = { + "metadata": metadata, + "schema_name": "test_schema", + "values": [ + {"field1": "value1", "field2": "value2"}, + {"field1": "value3", "field2": "value4"}, + {"field1": "value5", "field2": "value6"} + ], + "confidence": 0.9, + "source_span": "batch test span" + } + + # Act & Assert + assert serialize_deserialize_test(ExtractedObject, batch_object_data) + + def test_extracted_object_empty_batch_serialization(self): + """Test ExtractedObject empty batch serialization contract""" + # Arrange + metadata = Metadata(id="test", user="user", collection="col", metadata=[]) + empty_batch_data = { + "metadata": metadata, + "schema_name": "test_schema", + "values": [], + "confidence": 1.0, + "source_span": "empty batch" + } + + # Act & Assert + assert serialize_deserialize_test(ExtractedObject, empty_batch_data) \ No newline at end of file diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py index b54b559a..7b2245ce 100644 --- a/tests/integration/test_object_extraction_integration.py +++ b/tests/integration/test_object_extraction_integration.py @@ -270,9 +270,9 @@ class TestObjectExtractionServiceIntegration: assert len(customer_calls) == 1 customer_obj = customer_calls[0] - assert customer_obj.values["customer_id"] == "CUST001" - assert customer_obj.values["name"] == "John Smith" - assert customer_obj.values["email"] == "john.smith@email.com" + assert customer_obj.values[0]["customer_id"] == "CUST001" + assert customer_obj.values[0]["name"] == "John Smith" + assert customer_obj.values[0]["email"] == "john.smith@email.com" assert customer_obj.confidence > 0.5 @pytest.mark.asyncio @@ -335,10 +335,10 @@ class TestObjectExtractionServiceIntegration: assert len(product_calls) == 1 product_obj = product_calls[0] - assert product_obj.values["product_id"] == "PROD001" - assert product_obj.values["name"] == "Gaming Laptop" - assert product_obj.values["price"] == "1299.99" - assert product_obj.values["category"] == "electronics" + assert product_obj.values[0]["product_id"] == "PROD001" + assert product_obj.values[0]["name"] == "Gaming Laptop" + assert product_obj.values[0]["price"] == "1299.99" + assert product_obj.values[0]["category"] == "electronics" @pytest.mark.asyncio async def test_concurrent_extraction_integration(self, integration_config, mock_integrated_flow): diff --git a/tests/integration/test_objects_cassandra_integration.py b/tests/integration/test_objects_cassandra_integration.py index ff161d04..4ce86f74 100644 --- a/tests/integration/test_objects_cassandra_integration.py +++ b/tests/integration/test_objects_cassandra_integration.py @@ -95,12 +95,12 @@ class TestObjectsCassandraIntegration: metadata=[] ), schema_name="customer_records", - values={ + values=[{ "customer_id": "CUST001", "name": "John Doe", "email": "john@example.com", "age": "30" - }, + }], confidence=0.95, source_span="Customer: John Doe..." ) @@ -183,7 +183,7 @@ class TestObjectsCassandraIntegration: product_obj = ExtractedObject( metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]), schema_name="products", - values={"product_id": "P001", "name": "Widget", "price": "19.99"}, + values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}], confidence=0.9, source_span="Product..." ) @@ -191,7 +191,7 @@ class TestObjectsCassandraIntegration: order_obj = ExtractedObject( metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]), schema_name="orders", - values={"order_id": "O001", "customer_id": "C001", "total": "59.97"}, + values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}], confidence=0.85, source_span="Order..." ) @@ -229,7 +229,7 @@ class TestObjectsCassandraIntegration: test_obj = ExtractedObject( metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), schema_name="test_schema", - values={"id": "123"}, # missing required_field + values=[{"id": "123"}], # missing required_field confidence=0.8, source_span="Test" ) @@ -265,7 +265,7 @@ class TestObjectsCassandraIntegration: test_obj = ExtractedObject( metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]), schema_name="events", - values={"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}, + values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}], confidence=1.0, source_span="Event" ) @@ -334,7 +334,7 @@ class TestObjectsCassandraIntegration: test_obj = ExtractedObject( metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), schema_name="test", - values={"id": "123"}, + values=[{"id": "123"}], confidence=0.9, source_span="Test" ) @@ -364,7 +364,7 @@ class TestObjectsCassandraIntegration: obj = ExtractedObject( metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]), schema_name="data", - values={"id": f"ID-{coll}"}, + values=[{"id": f"ID-{coll}"}], confidence=0.9, source_span="Data" ) @@ -381,4 +381,170 @@ class TestObjectsCassandraIntegration: # Check each insert has the correct collection for i, call in enumerate(insert_calls): values = call[0][1] - assert collections[i] in values \ No newline at end of file + assert collections[i] in values + + @pytest.mark.asyncio + async def test_batch_object_processing(self, processor_with_mocks): + """Test processing objects with batched values""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): + # Configure schema + config = { + "schema": { + "batch_customers": json.dumps({ + "name": "batch_customers", + "description": "Customer batch data", + "fields": [ + {"name": "customer_id", "type": "string", "primary_key": True}, + {"name": "name", "type": "string", "required": True}, + {"name": "email", "type": "string", "indexed": True} + ] + }) + } + } + + await processor.on_schema_config(config, version=1) + + # Process batch object with multiple values + batch_obj = ExtractedObject( + metadata=Metadata( + id="batch-001", + user="test_user", + collection="batch_import", + metadata=[] + ), + schema_name="batch_customers", + values=[ + { + "customer_id": "CUST001", + "name": "John Doe", + "email": "john@example.com" + }, + { + "customer_id": "CUST002", + "name": "Jane Smith", + "email": "jane@example.com" + }, + { + "customer_id": "CUST003", + "name": "Bob Johnson", + "email": "bob@example.com" + } + ], + confidence=0.92, + source_span="Multiple customers extracted from document" + ) + + msg = MagicMock() + msg.value.return_value = batch_obj + + await processor.on_object(msg, None, None) + + # Verify table creation + table_calls = [call for call in mock_session.execute.call_args_list + if "CREATE TABLE" in str(call)] + assert len(table_calls) == 1 + assert "o_batch_customers" in str(table_calls[0]) + + # Verify multiple inserts for batch values + insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call)] + # Should have 3 separate inserts for the 3 objects in the batch + assert len(insert_calls) == 3 + + # Check each insert has correct data + for i, call in enumerate(insert_calls): + values = call[0][1] + assert "batch_import" in values # collection + assert f"CUST00{i+1}" in values # customer_id + if i == 0: + assert "John Doe" in values + assert "john@example.com" in values + elif i == 1: + assert "Jane Smith" in values + assert "jane@example.com" in values + elif i == 2: + assert "Bob Johnson" in values + assert "bob@example.com" in values + + @pytest.mark.asyncio + async def test_empty_batch_processing(self, processor_with_mocks): + """Test processing objects with empty values array""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): + processor.schemas["empty_test"] = RowSchema( + name="empty_test", + fields=[Field(name="id", type="string", size=50, primary=True)] + ) + + # Process empty batch object + empty_obj = ExtractedObject( + metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]), + schema_name="empty_test", + values=[], # Empty batch + confidence=1.0, + source_span="No objects found" + ) + + msg = MagicMock() + msg.value.return_value = empty_obj + + await processor.on_object(msg, None, None) + + # Should still create table + table_calls = [call for call in mock_session.execute.call_args_list + if "CREATE TABLE" in str(call)] + assert len(table_calls) == 1 + + # Should not create any insert statements for empty batch + insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call)] + assert len(insert_calls) == 0 + + @pytest.mark.asyncio + async def test_mixed_single_and_batch_objects(self, processor_with_mocks): + """Test processing mix of single and batch objects""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): + processor.schemas["mixed_test"] = RowSchema( + name="mixed_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="data", type="string", size=100) + ] + ) + + # Single object (backward compatibility) + single_obj = ExtractedObject( + metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]), + schema_name="mixed_test", + values=[{"id": "single-1", "data": "single data"}], # Array with single item + confidence=0.9, + source_span="Single object" + ) + + # Batch object + batch_obj = ExtractedObject( + metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]), + schema_name="mixed_test", + values=[ + {"id": "batch-1", "data": "batch data 1"}, + {"id": "batch-2", "data": "batch data 2"} + ], + confidence=0.85, + source_span="Batch objects" + ) + + # Process both + for obj in [single_obj, batch_obj]: + msg = MagicMock() + msg.value.return_value = obj + await processor.on_object(msg, None, None) + + # Should have 3 total inserts (1 + 2) + insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call)] + assert len(insert_calls) == 3 \ No newline at end of file diff --git a/tests/unit/test_gateway/test_objects_import_dispatcher.py b/tests/unit/test_gateway/test_objects_import_dispatcher.py index 24ba7c0f..ed9e8faa 100644 --- a/tests/unit/test_gateway/test_objects_import_dispatcher.py +++ b/tests/unit/test_gateway/test_objects_import_dispatcher.py @@ -66,11 +66,11 @@ def sample_objects_message(): "collection": "testcollection" }, "schema_name": "person", - "values": { + "values": [{ "name": "John Doe", "age": "30", "city": "New York" - }, + }], "confidence": 0.95, "source_span": "John Doe, age 30, lives in New York" } @@ -86,9 +86,9 @@ def minimal_objects_message(): "collection": "testcollection" }, "schema_name": "simple_schema", - "values": { + "values": [{ "field1": "value1" - } + }] } @@ -235,8 +235,8 @@ class TestObjectsImportMessageProcessing: sent_object = call_args[0][1] assert isinstance(sent_object, ExtractedObject) assert sent_object.schema_name == "person" - assert sent_object.values["name"] == "John Doe" - assert sent_object.values["age"] == "30" + assert sent_object.values[0]["name"] == "John Doe" + assert sent_object.values[0]["age"] == "30" assert sent_object.confidence == 0.95 assert sent_object.source_span == "John Doe, age 30, lives in New York" @@ -274,7 +274,7 @@ class TestObjectsImportMessageProcessing: sent_object = mock_publisher_instance.send.call_args[0][1] assert isinstance(sent_object, ExtractedObject) assert sent_object.schema_name == "simple_schema" - assert sent_object.values["field1"] == "value1" + assert sent_object.values[0]["field1"] == "value1" assert sent_object.confidence == 1.0 # Default value assert sent_object.source_span == "" # Default value assert len(sent_object.metadata.metadata) == 0 # Default empty list @@ -302,7 +302,7 @@ class TestObjectsImportMessageProcessing: "collection": "testcollection" }, "schema_name": "test_schema", - "values": {"key": "value"} + "values": [{"key": "value"}] # No confidence or source_span } @@ -374,6 +374,134 @@ class TestObjectsImportRunMethod: assert objects_import.ws is None +class TestObjectsImportBatchProcessing: + """Test ObjectsImport batch processing functionality.""" + + @pytest.fixture + def batch_objects_message(self): + """Sample batch objects message data.""" + return { + "metadata": { + "id": "batch-001", + "metadata": [ + { + "s": {"v": "batch-001", "e": False}, + "p": {"v": "source", "e": False}, + "o": {"v": "test", "e": False} + } + ], + "user": "testuser", + "collection": "testcollection" + }, + "schema_name": "person", + "values": [ + { + "name": "John Doe", + "age": "30", + "city": "New York" + }, + { + "name": "Jane Smith", + "age": "25", + "city": "Boston" + }, + { + "name": "Bob Johnson", + "age": "45", + "city": "Chicago" + } + ], + "confidence": 0.85, + "source_span": "Multiple people found in document" + } + + @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @pytest.mark.asyncio + async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, batch_objects_message): + """Test that receive() processes batch message correctly.""" + mock_publisher_instance = Mock() + mock_publisher_instance.send = AsyncMock() + mock_publisher_class.return_value = mock_publisher_instance + + objects_import = ObjectsImport( + ws=mock_websocket, + running=mock_running, + pulsar_client=mock_pulsar_client, + queue="test-queue" + ) + + # Create mock message + mock_msg = Mock() + mock_msg.json.return_value = batch_objects_message + + await objects_import.receive(mock_msg) + + # Verify publisher.send was called + mock_publisher_instance.send.assert_called_once() + + # Get the call arguments + call_args = mock_publisher_instance.send.call_args + assert call_args[0][0] is None # First argument should be None + + # Check the ExtractedObject that was sent + sent_object = call_args[0][1] + assert isinstance(sent_object, ExtractedObject) + assert sent_object.schema_name == "person" + + # Check that all batch values are present + assert len(sent_object.values) == 3 + assert sent_object.values[0]["name"] == "John Doe" + assert sent_object.values[0]["age"] == "30" + assert sent_object.values[0]["city"] == "New York" + + assert sent_object.values[1]["name"] == "Jane Smith" + assert sent_object.values[1]["age"] == "25" + assert sent_object.values[1]["city"] == "Boston" + + assert sent_object.values[2]["name"] == "Bob Johnson" + assert sent_object.values[2]["age"] == "45" + assert sent_object.values[2]["city"] == "Chicago" + + assert sent_object.confidence == 0.85 + assert sent_object.source_span == "Multiple people found in document" + + @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @pytest.mark.asyncio + async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): + """Test that receive() handles empty batch correctly.""" + mock_publisher_instance = Mock() + mock_publisher_instance.send = AsyncMock() + mock_publisher_class.return_value = mock_publisher_instance + + objects_import = ObjectsImport( + ws=mock_websocket, + running=mock_running, + pulsar_client=mock_pulsar_client, + queue="test-queue" + ) + + # Message with empty values array + empty_batch_message = { + "metadata": { + "id": "empty-batch-001", + "user": "testuser", + "collection": "testcollection" + }, + "schema_name": "empty_schema", + "values": [] + } + + mock_msg = Mock() + mock_msg.json.return_value = empty_batch_message + + await objects_import.receive(mock_msg) + + # Should still send the message + mock_publisher_instance.send.assert_called_once() + sent_object = mock_publisher_instance.send.call_args[0][1] + assert len(sent_object.values) == 0 + + class TestObjectsImportErrorHandling: """Test error handling in ObjectsImport.""" diff --git a/tests/unit/test_knowledge_graph/test_object_extraction_logic.py b/tests/unit/test_knowledge_graph/test_object_extraction_logic.py index 3a1ff3ae..525f595d 100644 --- a/tests/unit/test_knowledge_graph/test_object_extraction_logic.py +++ b/tests/unit/test_knowledge_graph/test_object_extraction_logic.py @@ -317,12 +317,12 @@ class TestObjectExtractionBusinessLogic: metadata=[] ) - values = { + values = [{ "customer_id": "CUST001", "name": "John Doe", "email": "john@example.com", "status": "active" - } + }] # Act extracted_obj = ExtractedObject( @@ -335,7 +335,7 @@ class TestObjectExtractionBusinessLogic: # Assert assert extracted_obj.schema_name == "customer_records" - assert extracted_obj.values["customer_id"] == "CUST001" + assert extracted_obj.values[0]["customer_id"] == "CUST001" assert extracted_obj.confidence == 0.95 assert "John Doe" in extracted_obj.source_span assert extracted_obj.metadata.user == "test_user" diff --git a/tests/unit/test_storage/test_objects_cassandra_storage.py b/tests/unit/test_storage/test_objects_cassandra_storage.py index 7a928e51..2b250c35 100644 --- a/tests/unit/test_storage/test_objects_cassandra_storage.py +++ b/tests/unit/test_storage/test_objects_cassandra_storage.py @@ -261,7 +261,7 @@ class TestObjectsCassandraStorageLogic: metadata=[] ), schema_name="test_schema", - values={"id": "123", "value": "456"}, + values=[{"id": "123", "value": "456"}], confidence=0.9, source_span="test source" ) @@ -284,8 +284,8 @@ class TestObjectsCassandraStorageLogic: assert "INSERT INTO test_user.o_test_schema" in insert_cql assert "collection" in insert_cql assert values[0] == "test_collection" # collection value - assert values[1] == "123" # id value - assert values[2] == 456 # converted integer value + assert values[1] == "123" # id value (from values[0]) + assert values[2] == 456 # converted integer value (from values[0]) def test_secondary_index_creation(self): """Test that secondary indexes are created for indexed fields""" @@ -325,4 +325,201 @@ class TestObjectsCassandraStorageLogic: index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]] assert len(index_calls) == 2 assert any("o_products_category_idx" in call for call in index_calls) - assert any("o_products_price_idx" in call for call in index_calls) \ No newline at end of file + assert any("o_products_price_idx" in call for call in index_calls) + + +class TestObjectsCassandraStorageBatchLogic: + """Test batch processing logic in Cassandra storage""" + + @pytest.mark.asyncio + async def test_batch_object_processing_logic(self): + """Test processing of batch ExtractedObjects""" + processor = MagicMock() + processor.schemas = { + "batch_schema": RowSchema( + name="batch_schema", + description="Test batch schema", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="name", type="string", size=100), + Field(name="value", type="integer", size=4) + ] + ) + } + processor.ensure_table = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) + processor.convert_value = Processor.convert_value.__get__(processor, Processor) + processor.session = MagicMock() + processor.on_object = Processor.on_object.__get__(processor, Processor) + + # Create batch object with multiple values + batch_obj = ExtractedObject( + metadata=Metadata( + id="batch-001", + user="test_user", + collection="batch_collection", + metadata=[] + ), + schema_name="batch_schema", + values=[ + {"id": "001", "name": "First", "value": "100"}, + {"id": "002", "name": "Second", "value": "200"}, + {"id": "003", "name": "Third", "value": "300"} + ], + confidence=0.95, + source_span="batch source" + ) + + # Create mock message + msg = MagicMock() + msg.value.return_value = batch_obj + + # Process batch object + await processor.on_object(msg, None, None) + + # Verify table was ensured once + processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"]) + + # Verify 3 separate insert calls (one per batch item) + assert processor.session.execute.call_count == 3 + + # Check each insert call + calls = processor.session.execute.call_args_list + for i, call in enumerate(calls): + insert_cql = call[0][0] + values = call[0][1] + + assert "INSERT INTO test_user.o_batch_schema" in insert_cql + assert "collection" in insert_cql + + # Check values for each batch item + assert values[0] == "batch_collection" # collection + assert values[1] == f"00{i+1}" # id from batch item i + assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name + assert values[3] == (i+1) * 100 # converted integer value + + @pytest.mark.asyncio + async def test_empty_batch_processing_logic(self): + """Test processing of empty batch ExtractedObjects""" + processor = MagicMock() + processor.schemas = { + "empty_schema": RowSchema( + name="empty_schema", + fields=[Field(name="id", type="string", size=50, primary=True)] + ) + } + processor.ensure_table = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) + processor.convert_value = Processor.convert_value.__get__(processor, Processor) + processor.session = MagicMock() + processor.on_object = Processor.on_object.__get__(processor, Processor) + + # Create empty batch object + empty_batch_obj = ExtractedObject( + metadata=Metadata( + id="empty-001", + user="test_user", + collection="empty_collection", + metadata=[] + ), + schema_name="empty_schema", + values=[], # Empty batch + confidence=1.0, + source_span="empty source" + ) + + msg = MagicMock() + msg.value.return_value = empty_batch_obj + + # Process empty batch object + await processor.on_object(msg, None, None) + + # Verify table was ensured + processor.ensure_table.assert_called_once() + + # Verify no insert calls for empty batch + processor.session.execute.assert_not_called() + + @pytest.mark.asyncio + async def test_single_item_batch_processing_logic(self): + """Test processing of single-item batch (backward compatibility)""" + processor = MagicMock() + processor.schemas = { + "single_schema": RowSchema( + name="single_schema", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="data", type="string", size=100) + ] + ) + } + processor.ensure_table = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) + processor.convert_value = Processor.convert_value.__get__(processor, Processor) + processor.session = MagicMock() + processor.on_object = Processor.on_object.__get__(processor, Processor) + + # Create single-item batch object (backward compatibility case) + single_batch_obj = ExtractedObject( + metadata=Metadata( + id="single-001", + user="test_user", + collection="single_collection", + metadata=[] + ), + schema_name="single_schema", + values=[{"id": "single-1", "data": "single data"}], # Array with one item + confidence=0.8, + source_span="single source" + ) + + msg = MagicMock() + msg.value.return_value = single_batch_obj + + # Process single-item batch object + await processor.on_object(msg, None, None) + + # Verify table was ensured + processor.ensure_table.assert_called_once() + + # Verify exactly one insert call + processor.session.execute.assert_called_once() + + insert_cql = processor.session.execute.call_args[0][0] + values = processor.session.execute.call_args[0][1] + + assert "INSERT INTO test_user.o_single_schema" in insert_cql + assert values[0] == "single_collection" # collection + assert values[1] == "single-1" # id value + assert values[2] == "single data" # data value + + def test_batch_value_conversion_logic(self): + """Test value conversion works correctly for batch items""" + processor = MagicMock() + processor.convert_value = Processor.convert_value.__get__(processor, Processor) + + # Test various conversion scenarios that would occur in batch processing + test_cases = [ + # Integer conversions for batch items + ("123", "integer", 123), + ("456", "integer", 456), + ("789", "integer", 789), + # Float conversions for batch items + ("12.5", "float", 12.5), + ("34.7", "float", 34.7), + # Boolean conversions for batch items + ("true", "boolean", True), + ("false", "boolean", False), + ("1", "boolean", True), + ("0", "boolean", False), + # String conversions for batch items + (123, "string", "123"), + (45.6, "string", "45.6"), + ] + + for input_val, field_type, expected_output in test_cases: + result = processor.convert_value(input_val, field_type) + assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}" \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/knowledge/object.py b/trustgraph-base/trustgraph/schema/knowledge/object.py index 1929edc0..537eb95e 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/object.py +++ b/trustgraph-base/trustgraph/schema/knowledge/object.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Map, Double +from pulsar.schema import Record, String, Map, Double, Array from ..core.metadata import Metadata from ..core.topic import topic @@ -10,7 +10,7 @@ from ..core.topic import topic class ExtractedObject(Record): metadata = Metadata() schema_name = String() # Which schema this object belongs to - values = Map(String()) # Field name -> value + values = Array(Map(String())) # Array of objects, each object is field name -> value confidence = Double() source_span = String() # Text span where object was found diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py index 3c88e346..5363dcc5 100644 --- a/trustgraph-cli/trustgraph/cli/load_structured_data.py +++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py @@ -804,7 +804,18 @@ def load_structured_data( print(f"Target schema: {schema_name}") print(f"Sample record:") if processed_records: - print(json.dumps(processed_records[0], indent=2)) + # Show what the batched format will look like + sample_batch = processed_records[:min(3, len(processed_records))] + batch_values = [record["values"] for record in sample_batch] + first_record = processed_records[0] + batched_sample = { + "metadata": first_record["metadata"], + "schema_name": first_record["schema_name"], + "values": batch_values, + "confidence": first_record["confidence"], + "source_span": first_record["source_span"] + } + print(json.dumps(batched_sample, indent=2)) return # Import to TrustGraph using objects import endpoint via WebSocket @@ -828,10 +839,26 @@ def load_structured_data( async with connect(objects_url) as ws: imported_count = 0 - for record in processed_records: - # Send individual ExtractedObject records - await ws.send(json.dumps(record)) - imported_count += 1 + # Process records in batches + for i in range(0, len(processed_records), batch_size): + batch_records = processed_records[i:i + batch_size] + + # Extract values from each record in the batch + batch_values = [record["values"] for record in batch_records] + + # Create batched ExtractedObject message using first record as template + first_record = batch_records[0] + batched_record = { + "metadata": first_record["metadata"], + "schema_name": first_record["schema_name"], + "values": batch_values, # Array of value dictionaries + "confidence": first_record["confidence"], + "source_span": first_record["source_span"] + } + + # Send batched ExtractedObject + await ws.send(json.dumps(batched_record)) + imported_count += len(batch_records) if imported_count % 100 == 0: logger.info(f"Imported {imported_count}/{len(processed_records)} records...") diff --git a/trustgraph-flow/trustgraph/extract/kg/objects/processor.py b/trustgraph-flow/trustgraph/extract/kg/objects/processor.py index 2d4f5255..e925a349 100644 --- a/trustgraph-flow/trustgraph/extract/kg/objects/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/objects/processor.py @@ -256,31 +256,34 @@ class Processor(FlowProcessor): flow ) - # Emit each extracted object - for obj in objects: + # Emit extracted objects as a batch if any were found + if objects: # Calculate confidence (could be enhanced with actual confidence from prompt) confidence = 0.8 # Default confidence - # Convert all values to strings for Pulsar compatibility - string_values = convert_values_to_strings(obj) + # Convert all objects' values to strings for Pulsar compatibility + batch_values = [] + for obj in objects: + string_values = convert_values_to_strings(obj) + batch_values.append(string_values) - # Create ExtractedObject + # Create ExtractedObject with batched values extracted = ExtractedObject( metadata=Metadata( - id=f"{v.metadata.id}:{schema_name}:{hash(str(obj))}", + id=f"{v.metadata.id}:{schema_name}", metadata=[], user=v.metadata.user, collection=v.metadata.collection, ), schema_name=schema_name, - values=string_values, + values=batch_values, # Array of objects confidence=confidence, source_span=chunk_text[:100] # First 100 chars as source reference ) await flow("output").send(extracted) - logger.debug(f"Emitted extracted object for schema {schema_name}") + logger.debug(f"Emitted batch of {len(objects)} objects for schema {schema_name}") except Exception as e: logger.error(f"Object extraction exception: {e}", exc_info=True) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py index f7891f96..bc0c1b85 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py @@ -44,6 +44,12 @@ class ObjectsImport: data = msg.json() + # Handle both single object and array of objects for backward compatibility + values_data = data["values"] + if not isinstance(values_data, list): + # Single object - wrap in array + values_data = [values_data] + elt = ExtractedObject( metadata=Metadata( id=data["metadata"]["id"], @@ -52,7 +58,7 @@ class ObjectsImport: collection=data["metadata"]["collection"], ), schema_name=data["schema_name"], - values=data["values"], + values=values_data, confidence=data.get("confidence", 1.0), source_span=data.get("source_span", ""), ) diff --git a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py b/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py index 8cc75318..269171e4 100644 --- a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py @@ -311,7 +311,7 @@ class Processor(FlowProcessor): """Process incoming ExtractedObject and store in Cassandra""" obj = msg.value() - logger.info(f"Storing object for schema {obj.schema_name} from {obj.metadata.id}") + logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}") # Get schema definition schema = self.schemas.get(obj.schema_name) @@ -328,59 +328,67 @@ class Processor(FlowProcessor): safe_keyspace = self.sanitize_name(keyspace) safe_table = self.sanitize_table(table_name) - # Build column names and values - columns = ["collection"] - values = [obj.metadata.collection] - placeholders = ["%s"] - - # Check if we need a synthetic ID - has_primary_key = any(field.primary for field in schema.fields) - if not has_primary_key: - import uuid - columns.append("synthetic_id") - values.append(uuid.uuid4()) - placeholders.append("%s") - - # Process fields - for field in schema.fields: - safe_field_name = self.sanitize_name(field.name) - raw_value = obj.values.get(field.name) + # Process each object in the batch + for obj_index, value_map in enumerate(obj.values): + # Build column names and values for this object + columns = ["collection"] + values = [obj.metadata.collection] + placeholders = ["%s"] - # Handle required fields - if field.required and raw_value is None: - logger.warning(f"Required field {field.name} is missing in object") - # Continue anyway - Cassandra doesn't enforce NOT NULL + # Check if we need a synthetic ID + has_primary_key = any(field.primary for field in schema.fields) + if not has_primary_key: + import uuid + columns.append("synthetic_id") + values.append(uuid.uuid4()) + placeholders.append("%s") - # Check if primary key field is NULL - if field.primary and raw_value is None: - logger.error(f"Primary key field {field.name} cannot be NULL - skipping object") - return + # Process fields for this object + skip_object = False + for field in schema.fields: + safe_field_name = self.sanitize_name(field.name) + raw_value = value_map.get(field.name) + + # Handle required fields + if field.required and raw_value is None: + logger.warning(f"Required field {field.name} is missing in object {obj_index}") + # Continue anyway - Cassandra doesn't enforce NOT NULL + + # Check if primary key field is NULL + if field.primary and raw_value is None: + logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}") + skip_object = True + break + + # Convert value to appropriate type + converted_value = self.convert_value(raw_value, field.type) + + columns.append(safe_field_name) + values.append(converted_value) + placeholders.append("%s") - # Convert value to appropriate type - converted_value = self.convert_value(raw_value, field.type) + # Skip this object if primary key validation failed + if skip_object: + continue - columns.append(safe_field_name) - values.append(converted_value) - placeholders.append("%s") - - # Build and execute insert query - insert_cql = f""" - INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)}) - VALUES ({', '.join(placeholders)}) - """ - - # Debug: Show data being inserted - logger.debug(f"Storing {obj.schema_name}: {dict(zip(columns, values))}") - - if len(columns) != len(values) or len(columns) != len(placeholders): - raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}") - - try: - # Convert to tuple - Cassandra driver requires tuple for parameters - self.session.execute(insert_cql, tuple(values)) - except Exception as e: - logger.error(f"Failed to insert object: {e}", exc_info=True) - raise + # Build and execute insert query for this object + insert_cql = f""" + INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)}) + VALUES ({', '.join(placeholders)}) + """ + + # Debug: Show data being inserted + logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}") + + if len(columns) != len(values) or len(columns) != len(placeholders): + raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}") + + try: + # Convert to tuple - Cassandra driver requires tuple for parameters + self.session.execute(insert_cql, tuple(values)) + except Exception as e: + logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True) + raise def close(self): """Clean up Cassandra connections"""