Object batching (#499)

* Object batching

* Update tests
This commit is contained in:
cybermaggedon 2025-09-05 15:59:06 +01:00 committed by GitHub
parent ebca467ed8
commit 0b7620bc04
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 946 additions and 107 deletions

View file

@ -30,11 +30,11 @@ class TestObjectsCassandraContracts:
test_object = ExtractedObject( test_object = ExtractedObject(
metadata=test_metadata, metadata=test_metadata,
schema_name="customer_records", schema_name="customer_records",
values={ values=[{
"customer_id": "CUST123", "customer_id": "CUST123",
"name": "Test Customer", "name": "Test Customer",
"email": "test@example.com" "email": "test@example.com"
}, }],
confidence=0.95, confidence=0.95,
source_span="Customer data from document..." source_span="Customer data from document..."
) )
@ -54,7 +54,7 @@ class TestObjectsCassandraContracts:
# Verify types # Verify types
assert isinstance(test_object.schema_name, str) assert isinstance(test_object.schema_name, str)
assert isinstance(test_object.values, dict) assert isinstance(test_object.values, list)
assert isinstance(test_object.confidence, float) assert isinstance(test_object.confidence, float)
assert isinstance(test_object.source_span, str) assert isinstance(test_object.source_span, str)
@ -200,7 +200,7 @@ class TestObjectsCassandraContracts:
metadata=[] metadata=[]
), ),
schema_name="test_schema", schema_name="test_schema",
values={"field1": "value1", "field2": "123"}, values=[{"field1": "value1", "field2": "123"}],
confidence=0.85, confidence=0.85,
source_span="Test span" source_span="Test span"
) )
@ -292,7 +292,7 @@ class TestObjectsCassandraContracts:
metadata=[{"key": "value"}] metadata=[{"key": "value"}]
), ),
schema_name="table789", # -> table name schema_name="table789", # -> table name
values={"field": "value"}, values=[{"field": "value"}],
confidence=0.9, confidence=0.9,
source_span="Source" source_span="Source"
) )
@ -304,3 +304,214 @@ class TestObjectsCassandraContracts:
assert test_obj.metadata.user # Required for keyspace assert test_obj.metadata.user # Required for keyspace
assert test_obj.schema_name # Required for table assert test_obj.schema_name # Required for table
assert test_obj.metadata.collection # Required for partition key assert test_obj.metadata.collection # Required for partition key
@pytest.mark.contract
class TestObjectsCassandraContractsBatch:
"""Contract tests for Cassandra object storage batch processing"""
def test_extracted_object_batch_input_contract(self):
"""Test that batched ExtractedObject schema matches expected input format"""
# Create test object with multiple values in batch
test_metadata = Metadata(
id="batch-doc-001",
user="test_user",
collection="test_collection",
metadata=[]
)
batch_object = ExtractedObject(
metadata=test_metadata,
schema_name="customer_records",
values=[
{
"customer_id": "CUST123",
"name": "Test Customer 1",
"email": "test1@example.com"
},
{
"customer_id": "CUST124",
"name": "Test Customer 2",
"email": "test2@example.com"
},
{
"customer_id": "CUST125",
"name": "Test Customer 3",
"email": "test3@example.com"
}
],
confidence=0.88,
source_span="Multiple customer data from document..."
)
# Verify batch structure
assert hasattr(batch_object, 'values')
assert isinstance(batch_object.values, list)
assert len(batch_object.values) == 3
# Verify each batch item is a dict
for i, batch_item in enumerate(batch_object.values):
assert isinstance(batch_item, dict)
assert "customer_id" in batch_item
assert "name" in batch_item
assert "email" in batch_item
assert batch_item["customer_id"] == f"CUST12{3+i}"
assert f"Test Customer {i+1}" in batch_item["name"]
def test_extracted_object_empty_batch_contract(self):
"""Test empty batch ExtractedObject contract"""
test_metadata = Metadata(
id="empty-batch-001",
user="test_user",
collection="test_collection",
metadata=[]
)
empty_batch_object = ExtractedObject(
metadata=test_metadata,
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span="No objects found in document"
)
# Verify empty batch structure
assert hasattr(empty_batch_object, 'values')
assert isinstance(empty_batch_object.values, list)
assert len(empty_batch_object.values) == 0
assert empty_batch_object.confidence == 1.0
def test_extracted_object_single_item_batch_contract(self):
"""Test single-item batch (backward compatibility) contract"""
test_metadata = Metadata(
id="single-batch-001",
user="test_user",
collection="test_collection",
metadata=[]
)
single_batch_object = ExtractedObject(
metadata=test_metadata,
schema_name="customer_records",
values=[{ # Array with single item for backward compatibility
"customer_id": "CUST999",
"name": "Single Customer",
"email": "single@example.com"
}],
confidence=0.95,
source_span="Single customer data from document..."
)
# Verify single-item batch structure
assert isinstance(single_batch_object.values, list)
assert len(single_batch_object.values) == 1
assert isinstance(single_batch_object.values[0], dict)
assert single_batch_object.values[0]["customer_id"] == "CUST999"
def test_extracted_object_batch_serialization_contract(self):
"""Test that batched ExtractedObject can be serialized/deserialized correctly"""
# Create batch object
original = ExtractedObject(
metadata=Metadata(
id="batch-serial-001",
user="test_user",
collection="test_coll",
metadata=[]
),
schema_name="test_schema",
values=[
{"field1": "value1", "field2": "123"},
{"field1": "value2", "field2": "456"},
{"field1": "value3", "field2": "789"}
],
confidence=0.92,
source_span="Batch test span"
)
# Test serialization using schema
schema = AvroSchema(ExtractedObject)
# Encode and decode
encoded = schema.encode(original)
decoded = schema.decode(encoded)
# Verify round-trip for batch
assert decoded.metadata.id == original.metadata.id
assert decoded.metadata.user == original.metadata.user
assert decoded.metadata.collection == original.metadata.collection
assert decoded.schema_name == original.schema_name
assert len(decoded.values) == len(original.values)
assert len(decoded.values) == 3
# Verify each batch item
for i in range(3):
assert decoded.values[i] == original.values[i]
assert decoded.values[i]["field1"] == f"value{i+1}"
assert decoded.values[i]["field2"] == f"{123 + i*333}"
assert decoded.confidence == original.confidence
assert decoded.source_span == original.source_span
def test_batch_processing_field_validation_contract(self):
"""Test that batch processing validates field consistency"""
# All batch items should have consistent field structure
# This is a contract that the application should enforce
# Valid batch - all items have same fields
valid_batch_values = [
{"id": "1", "name": "Item 1", "value": "100"},
{"id": "2", "name": "Item 2", "value": "200"},
{"id": "3", "name": "Item 3", "value": "300"}
]
# Each item has the same field structure
field_sets = [set(item.keys()) for item in valid_batch_values]
assert all(fields == field_sets[0] for fields in field_sets), "All batch items should have consistent fields"
# Invalid batch - inconsistent fields (this would be caught by application logic)
invalid_batch_values = [
{"id": "1", "name": "Item 1", "value": "100"},
{"id": "2", "name": "Item 2"}, # Missing 'value' field
{"id": "3", "name": "Item 3", "value": "300", "extra": "field"} # Extra field
]
# Demonstrate the inconsistency
invalid_field_sets = [set(item.keys()) for item in invalid_batch_values]
assert not all(fields == invalid_field_sets[0] for fields in invalid_field_sets), "Invalid batch should have inconsistent fields"
def test_batch_storage_partition_key_contract(self):
"""Test that batch objects maintain partition key consistency"""
# In Cassandra storage, all objects in a batch should:
# 1. Belong to the same collection (partition key component)
# 2. Have unique primary keys within the batch
# 3. Be stored in the same keyspace (user)
test_metadata = Metadata(
id="partition-test-001",
user="consistent_user", # Same keyspace
collection="consistent_collection", # Same partition
metadata=[]
)
batch_object = ExtractedObject(
metadata=test_metadata,
schema_name="partition_test",
values=[
{"id": "pk1", "data": "data1"}, # Unique primary key
{"id": "pk2", "data": "data2"}, # Unique primary key
{"id": "pk3", "data": "data3"} # Unique primary key
],
confidence=0.95,
source_span="Partition consistency test"
)
# Verify consistency contract
assert batch_object.metadata.user # Must have user for keyspace
assert batch_object.metadata.collection # Must have collection for partition key
# Verify unique primary keys in batch
primary_keys = [item["id"] for item in batch_object.values]
assert len(primary_keys) == len(set(primary_keys)), "Primary keys must be unique within batch"
# All batch items will be stored in same keyspace and partition
# This is enforced by the metadata.user and metadata.collection being shared

View file

@ -128,18 +128,77 @@ class TestStructuredDataSchemaContracts:
obj = ExtractedObject( obj = ExtractedObject(
metadata=metadata, metadata=metadata,
schema_name="customer_records", schema_name="customer_records",
values={"id": "123", "name": "John Doe", "email": "john@example.com"}, values=[{"id": "123", "name": "John Doe", "email": "john@example.com"}],
confidence=0.95, confidence=0.95,
source_span="John Doe (john@example.com) customer ID 123" source_span="John Doe (john@example.com) customer ID 123"
) )
# Assert # Assert
assert obj.schema_name == "customer_records" assert obj.schema_name == "customer_records"
assert obj.values["name"] == "John Doe" assert obj.values[0]["name"] == "John Doe"
assert obj.confidence == 0.95 assert obj.confidence == 0.95
assert len(obj.source_span) > 0 assert len(obj.source_span) > 0
assert obj.metadata.id == "extracted-obj-001" assert obj.metadata.id == "extracted-obj-001"
def test_extracted_object_batch_contract(self):
"""Test ExtractedObject schema contract for batched values"""
# Arrange
metadata = Metadata(
id="extracted-batch-001",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act - create object with multiple values
obj = ExtractedObject(
metadata=metadata,
schema_name="customer_records",
values=[
{"id": "123", "name": "John Doe", "email": "john@example.com"},
{"id": "124", "name": "Jane Smith", "email": "jane@example.com"},
{"id": "125", "name": "Bob Johnson", "email": "bob@example.com"}
],
confidence=0.85,
source_span="Multiple customers found in document"
)
# Assert
assert obj.schema_name == "customer_records"
assert len(obj.values) == 3
assert obj.values[0]["name"] == "John Doe"
assert obj.values[1]["name"] == "Jane Smith"
assert obj.values[2]["name"] == "Bob Johnson"
assert obj.values[0]["id"] == "123"
assert obj.values[1]["id"] == "124"
assert obj.values[2]["id"] == "125"
assert obj.confidence == 0.85
assert "Multiple customers" in obj.source_span
def test_extracted_object_empty_batch_contract(self):
"""Test ExtractedObject schema contract for empty values array"""
# Arrange
metadata = Metadata(
id="extracted-empty-001",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act - create object with empty values array
obj = ExtractedObject(
metadata=metadata,
schema_name="empty_schema",
values=[],
confidence=1.0,
source_span="No objects found"
)
# Assert
assert obj.schema_name == "empty_schema"
assert len(obj.values) == 0
assert obj.confidence == 1.0
@pytest.mark.contract @pytest.mark.contract
class TestStructuredQueryServiceContracts: class TestStructuredQueryServiceContracts:
@ -273,7 +332,7 @@ class TestStructuredDataSerializationContracts:
object_data = { object_data = {
"metadata": metadata, "metadata": metadata,
"schema_name": "test_schema", "schema_name": "test_schema",
"values": {"field1": "value1"}, "values": [{"field1": "value1"}],
"confidence": 0.8, "confidence": 0.8,
"source_span": "test span" "source_span": "test span"
} }
@ -315,3 +374,37 @@ class TestStructuredDataSerializationContracts:
"errors": [] "errors": []
} }
assert serialize_deserialize_test(StructuredQueryResponse, response_data) assert serialize_deserialize_test(StructuredQueryResponse, response_data)
def test_extracted_object_batch_serialization(self):
"""Test ExtractedObject batch serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
batch_object_data = {
"metadata": metadata,
"schema_name": "test_schema",
"values": [
{"field1": "value1", "field2": "value2"},
{"field1": "value3", "field2": "value4"},
{"field1": "value5", "field2": "value6"}
],
"confidence": 0.9,
"source_span": "batch test span"
}
# Act & Assert
assert serialize_deserialize_test(ExtractedObject, batch_object_data)
def test_extracted_object_empty_batch_serialization(self):
"""Test ExtractedObject empty batch serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
empty_batch_data = {
"metadata": metadata,
"schema_name": "test_schema",
"values": [],
"confidence": 1.0,
"source_span": "empty batch"
}
# Act & Assert
assert serialize_deserialize_test(ExtractedObject, empty_batch_data)

View file

@ -270,9 +270,9 @@ class TestObjectExtractionServiceIntegration:
assert len(customer_calls) == 1 assert len(customer_calls) == 1
customer_obj = customer_calls[0] customer_obj = customer_calls[0]
assert customer_obj.values["customer_id"] == "CUST001" assert customer_obj.values[0]["customer_id"] == "CUST001"
assert customer_obj.values["name"] == "John Smith" assert customer_obj.values[0]["name"] == "John Smith"
assert customer_obj.values["email"] == "john.smith@email.com" assert customer_obj.values[0]["email"] == "john.smith@email.com"
assert customer_obj.confidence > 0.5 assert customer_obj.confidence > 0.5
@pytest.mark.asyncio @pytest.mark.asyncio
@ -335,10 +335,10 @@ class TestObjectExtractionServiceIntegration:
assert len(product_calls) == 1 assert len(product_calls) == 1
product_obj = product_calls[0] product_obj = product_calls[0]
assert product_obj.values["product_id"] == "PROD001" assert product_obj.values[0]["product_id"] == "PROD001"
assert product_obj.values["name"] == "Gaming Laptop" assert product_obj.values[0]["name"] == "Gaming Laptop"
assert product_obj.values["price"] == "1299.99" assert product_obj.values[0]["price"] == "1299.99"
assert product_obj.values["category"] == "electronics" assert product_obj.values[0]["category"] == "electronics"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_extraction_integration(self, integration_config, mock_integrated_flow): async def test_concurrent_extraction_integration(self, integration_config, mock_integrated_flow):

View file

@ -95,12 +95,12 @@ class TestObjectsCassandraIntegration:
metadata=[] metadata=[]
), ),
schema_name="customer_records", schema_name="customer_records",
values={ values=[{
"customer_id": "CUST001", "customer_id": "CUST001",
"name": "John Doe", "name": "John Doe",
"email": "john@example.com", "email": "john@example.com",
"age": "30" "age": "30"
}, }],
confidence=0.95, confidence=0.95,
source_span="Customer: John Doe..." source_span="Customer: John Doe..."
) )
@ -183,7 +183,7 @@ class TestObjectsCassandraIntegration:
product_obj = ExtractedObject( product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]), metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
schema_name="products", schema_name="products",
values={"product_id": "P001", "name": "Widget", "price": "19.99"}, values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9, confidence=0.9,
source_span="Product..." source_span="Product..."
) )
@ -191,7 +191,7 @@ class TestObjectsCassandraIntegration:
order_obj = ExtractedObject( order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]), metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
schema_name="orders", schema_name="orders",
values={"order_id": "O001", "customer_id": "C001", "total": "59.97"}, values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85, confidence=0.85,
source_span="Order..." source_span="Order..."
) )
@ -229,7 +229,7 @@ class TestObjectsCassandraIntegration:
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="test_schema", schema_name="test_schema",
values={"id": "123"}, # missing required_field values=[{"id": "123"}], # missing required_field
confidence=0.8, confidence=0.8,
source_span="Test" source_span="Test"
) )
@ -265,7 +265,7 @@ class TestObjectsCassandraIntegration:
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]), metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
schema_name="events", schema_name="events",
values={"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}, values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}],
confidence=1.0, confidence=1.0,
source_span="Event" source_span="Event"
) )
@ -334,7 +334,7 @@ class TestObjectsCassandraIntegration:
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="test", schema_name="test",
values={"id": "123"}, values=[{"id": "123"}],
confidence=0.9, confidence=0.9,
source_span="Test" source_span="Test"
) )
@ -364,7 +364,7 @@ class TestObjectsCassandraIntegration:
obj = ExtractedObject( obj = ExtractedObject(
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]), metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
schema_name="data", schema_name="data",
values={"id": f"ID-{coll}"}, values=[{"id": f"ID-{coll}"}],
confidence=0.9, confidence=0.9,
source_span="Data" source_span="Data"
) )
@ -382,3 +382,169 @@ class TestObjectsCassandraIntegration:
for i, call in enumerate(insert_calls): for i, call in enumerate(insert_calls):
values = call[0][1] values = call[0][1]
assert collections[i] in values assert collections[i] in values
@pytest.mark.asyncio
async def test_batch_object_processing(self, processor_with_mocks):
"""Test processing objects with batched values"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema
config = {
"schema": {
"batch_customers": json.dumps({
"name": "batch_customers",
"description": "Customer batch data",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True}
]
})
}
}
await processor.on_schema_config(config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_import",
metadata=[]
),
schema_name="batch_customers",
values=[
{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com"
},
{
"customer_id": "CUST002",
"name": "Jane Smith",
"email": "jane@example.com"
},
{
"customer_id": "CUST003",
"name": "Bob Johnson",
"email": "bob@example.com"
}
],
confidence=0.92,
source_span="Multiple customers extracted from document"
)
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Verify table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "o_batch_customers" in str(table_calls[0])
# Verify multiple inserts for batch values
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
# Should have 3 separate inserts for the 3 objects in the batch
assert len(insert_calls) == 3
# Check each insert has correct data
for i, call in enumerate(insert_calls):
values = call[0][1]
assert "batch_import" in values # collection
assert f"CUST00{i+1}" in values # customer_id
if i == 0:
assert "John Doe" in values
assert "john@example.com" in values
elif i == 1:
assert "Jane Smith" in values
assert "jane@example.com" in values
elif i == 2:
assert "Bob Johnson" in values
assert "bob@example.com" in values
@pytest.mark.asyncio
async def test_empty_batch_processing(self, processor_with_mocks):
"""Test processing objects with empty values array"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
source_span="No objects found"
)
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
# Should still create table
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
# Should not create any insert statements for empty batch
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 0
@pytest.mark.asyncio
async def test_mixed_single_and_batch_objects(self, processor_with_mocks):
"""Test processing mix of single and batch objects"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["mixed_test"] = RowSchema(
name="mixed_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="data", type="string", size=100)
]
)
# Single object (backward compatibility)
single_obj = ExtractedObject(
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
schema_name="mixed_test",
values=[{"id": "single-1", "data": "single data"}], # Array with single item
confidence=0.9,
source_span="Single object"
)
# Batch object
batch_obj = ExtractedObject(
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
schema_name="mixed_test",
values=[
{"id": "batch-1", "data": "batch data 1"},
{"id": "batch-2", "data": "batch data 2"}
],
confidence=0.85,
source_span="Batch objects"
)
# Process both
for obj in [single_obj, batch_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Should have 3 total inserts (1 + 2)
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 3

View file

@ -66,11 +66,11 @@ def sample_objects_message():
"collection": "testcollection" "collection": "testcollection"
}, },
"schema_name": "person", "schema_name": "person",
"values": { "values": [{
"name": "John Doe", "name": "John Doe",
"age": "30", "age": "30",
"city": "New York" "city": "New York"
}, }],
"confidence": 0.95, "confidence": 0.95,
"source_span": "John Doe, age 30, lives in New York" "source_span": "John Doe, age 30, lives in New York"
} }
@ -86,9 +86,9 @@ def minimal_objects_message():
"collection": "testcollection" "collection": "testcollection"
}, },
"schema_name": "simple_schema", "schema_name": "simple_schema",
"values": { "values": [{
"field1": "value1" "field1": "value1"
} }]
} }
@ -235,8 +235,8 @@ class TestObjectsImportMessageProcessing:
sent_object = call_args[0][1] sent_object = call_args[0][1]
assert isinstance(sent_object, ExtractedObject) assert isinstance(sent_object, ExtractedObject)
assert sent_object.schema_name == "person" assert sent_object.schema_name == "person"
assert sent_object.values["name"] == "John Doe" assert sent_object.values[0]["name"] == "John Doe"
assert sent_object.values["age"] == "30" assert sent_object.values[0]["age"] == "30"
assert sent_object.confidence == 0.95 assert sent_object.confidence == 0.95
assert sent_object.source_span == "John Doe, age 30, lives in New York" assert sent_object.source_span == "John Doe, age 30, lives in New York"
@ -274,7 +274,7 @@ class TestObjectsImportMessageProcessing:
sent_object = mock_publisher_instance.send.call_args[0][1] sent_object = mock_publisher_instance.send.call_args[0][1]
assert isinstance(sent_object, ExtractedObject) assert isinstance(sent_object, ExtractedObject)
assert sent_object.schema_name == "simple_schema" assert sent_object.schema_name == "simple_schema"
assert sent_object.values["field1"] == "value1" assert sent_object.values[0]["field1"] == "value1"
assert sent_object.confidence == 1.0 # Default value assert sent_object.confidence == 1.0 # Default value
assert sent_object.source_span == "" # Default value assert sent_object.source_span == "" # Default value
assert len(sent_object.metadata.metadata) == 0 # Default empty list assert len(sent_object.metadata.metadata) == 0 # Default empty list
@ -302,7 +302,7 @@ class TestObjectsImportMessageProcessing:
"collection": "testcollection" "collection": "testcollection"
}, },
"schema_name": "test_schema", "schema_name": "test_schema",
"values": {"key": "value"} "values": [{"key": "value"}]
# No confidence or source_span # No confidence or source_span
} }
@ -374,6 +374,134 @@ class TestObjectsImportRunMethod:
assert objects_import.ws is None assert objects_import.ws is None
class TestObjectsImportBatchProcessing:
"""Test ObjectsImport batch processing functionality."""
@pytest.fixture
def batch_objects_message(self):
"""Sample batch objects message data."""
return {
"metadata": {
"id": "batch-001",
"metadata": [
{
"s": {"v": "batch-001", "e": False},
"p": {"v": "source", "e": False},
"o": {"v": "test", "e": False}
}
],
"user": "testuser",
"collection": "testcollection"
},
"schema_name": "person",
"values": [
{
"name": "John Doe",
"age": "30",
"city": "New York"
},
{
"name": "Jane Smith",
"age": "25",
"city": "Boston"
},
{
"name": "Bob Johnson",
"age": "45",
"city": "Chicago"
}
],
"confidence": 0.85,
"source_span": "Multiple people found in document"
}
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, batch_objects_message):
"""Test that receive() processes batch message correctly."""
mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
# Create mock message
mock_msg = Mock()
mock_msg.json.return_value = batch_objects_message
await objects_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
# Get the call arguments
call_args = mock_publisher_instance.send.call_args
assert call_args[0][0] is None # First argument should be None
# Check the ExtractedObject that was sent
sent_object = call_args[0][1]
assert isinstance(sent_object, ExtractedObject)
assert sent_object.schema_name == "person"
# Check that all batch values are present
assert len(sent_object.values) == 3
assert sent_object.values[0]["name"] == "John Doe"
assert sent_object.values[0]["age"] == "30"
assert sent_object.values[0]["city"] == "New York"
assert sent_object.values[1]["name"] == "Jane Smith"
assert sent_object.values[1]["age"] == "25"
assert sent_object.values[1]["city"] == "Boston"
assert sent_object.values[2]["name"] == "Bob Johnson"
assert sent_object.values[2]["age"] == "45"
assert sent_object.values[2]["city"] == "Chicago"
assert sent_object.confidence == 0.85
assert sent_object.source_span == "Multiple people found in document"
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
"""Test that receive() handles empty batch correctly."""
mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
ws=mock_websocket,
running=mock_running,
pulsar_client=mock_pulsar_client,
queue="test-queue"
)
# Message with empty values array
empty_batch_message = {
"metadata": {
"id": "empty-batch-001",
"user": "testuser",
"collection": "testcollection"
},
"schema_name": "empty_schema",
"values": []
}
mock_msg = Mock()
mock_msg.json.return_value = empty_batch_message
await objects_import.receive(mock_msg)
# Should still send the message
mock_publisher_instance.send.assert_called_once()
sent_object = mock_publisher_instance.send.call_args[0][1]
assert len(sent_object.values) == 0
class TestObjectsImportErrorHandling: class TestObjectsImportErrorHandling:
"""Test error handling in ObjectsImport.""" """Test error handling in ObjectsImport."""

View file

@ -317,12 +317,12 @@ class TestObjectExtractionBusinessLogic:
metadata=[] metadata=[]
) )
values = { values = [{
"customer_id": "CUST001", "customer_id": "CUST001",
"name": "John Doe", "name": "John Doe",
"email": "john@example.com", "email": "john@example.com",
"status": "active" "status": "active"
} }]
# Act # Act
extracted_obj = ExtractedObject( extracted_obj = ExtractedObject(
@ -335,7 +335,7 @@ class TestObjectExtractionBusinessLogic:
# Assert # Assert
assert extracted_obj.schema_name == "customer_records" assert extracted_obj.schema_name == "customer_records"
assert extracted_obj.values["customer_id"] == "CUST001" assert extracted_obj.values[0]["customer_id"] == "CUST001"
assert extracted_obj.confidence == 0.95 assert extracted_obj.confidence == 0.95
assert "John Doe" in extracted_obj.source_span assert "John Doe" in extracted_obj.source_span
assert extracted_obj.metadata.user == "test_user" assert extracted_obj.metadata.user == "test_user"

View file

@ -261,7 +261,7 @@ class TestObjectsCassandraStorageLogic:
metadata=[] metadata=[]
), ),
schema_name="test_schema", schema_name="test_schema",
values={"id": "123", "value": "456"}, values=[{"id": "123", "value": "456"}],
confidence=0.9, confidence=0.9,
source_span="test source" source_span="test source"
) )
@ -284,8 +284,8 @@ class TestObjectsCassandraStorageLogic:
assert "INSERT INTO test_user.o_test_schema" in insert_cql assert "INSERT INTO test_user.o_test_schema" in insert_cql
assert "collection" in insert_cql assert "collection" in insert_cql
assert values[0] == "test_collection" # collection value assert values[0] == "test_collection" # collection value
assert values[1] == "123" # id value assert values[1] == "123" # id value (from values[0])
assert values[2] == 456 # converted integer value assert values[2] == 456 # converted integer value (from values[0])
def test_secondary_index_creation(self): def test_secondary_index_creation(self):
"""Test that secondary indexes are created for indexed fields""" """Test that secondary indexes are created for indexed fields"""
@ -326,3 +326,200 @@ class TestObjectsCassandraStorageLogic:
assert len(index_calls) == 2 assert len(index_calls) == 2
assert any("o_products_category_idx" in call for call in index_calls) assert any("o_products_category_idx" in call for call in index_calls)
assert any("o_products_price_idx" in call for call in index_calls) assert any("o_products_price_idx" in call for call in index_calls)
class TestObjectsCassandraStorageBatchLogic:
"""Test batch processing logic in Cassandra storage"""
@pytest.mark.asyncio
async def test_batch_object_processing_logic(self):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
description="Test batch schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="value", type="integer", size=4)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
metadata=[]
),
schema_name="batch_schema",
values=[
{"id": "001", "name": "First", "value": "100"},
{"id": "002", "name": "Second", "value": "200"},
{"id": "003", "name": "Third", "value": "300"}
],
confidence=0.95,
source_span="batch source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = batch_obj
# Process batch object
await processor.on_object(msg, None, None)
# Verify table was ensured once
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
# Verify 3 separate insert calls (one per batch item)
assert processor.session.execute.call_count == 3
# Check each insert call
calls = processor.session.execute.call_args_list
for i, call in enumerate(calls):
insert_cql = call[0][0]
values = call[0][1]
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
assert "collection" in insert_cql
# Check values for each batch item
assert values[0] == "batch_collection" # collection
assert values[1] == f"00{i+1}" # id from batch item i
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
assert values[3] == (i+1) * 100 # converted integer value
@pytest.mark.asyncio
async def test_empty_batch_processing_logic(self):
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create empty batch object
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
metadata=[]
),
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span="empty source"
)
msg = MagicMock()
msg.value.return_value = empty_batch_obj
# Process empty batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
@pytest.mark.asyncio
async def test_single_item_batch_processing_logic(self):
"""Test processing of single-item batch (backward compatibility)"""
processor = MagicMock()
processor.schemas = {
"single_schema": RowSchema(
name="single_schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="data", type="string", size=100)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create single-item batch object (backward compatibility case)
single_batch_obj = ExtractedObject(
metadata=Metadata(
id="single-001",
user="test_user",
collection="single_collection",
metadata=[]
),
schema_name="single_schema",
values=[{"id": "single-1", "data": "single data"}], # Array with one item
confidence=0.8,
source_span="single source"
)
msg = MagicMock()
msg.value.return_value = single_batch_obj
# Process single-item batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify exactly one insert call
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_single_schema" in insert_cql
assert values[0] == "single_collection" # collection
assert values[1] == "single-1" # id value
assert values[2] == "single data" # data value
def test_batch_value_conversion_logic(self):
"""Test value conversion works correctly for batch items"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Test various conversion scenarios that would occur in batch processing
test_cases = [
# Integer conversions for batch items
("123", "integer", 123),
("456", "integer", 456),
("789", "integer", 789),
# Float conversions for batch items
("12.5", "float", 12.5),
("34.7", "float", 34.7),
# Boolean conversions for batch items
("true", "boolean", True),
("false", "boolean", False),
("1", "boolean", True),
("0", "boolean", False),
# String conversions for batch items
(123, "string", "123"),
(45.6, "string", "45.6"),
]
for input_val, field_type, expected_output in test_cases:
result = processor.convert_value(input_val, field_type)
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, String, Map, Double from pulsar.schema import Record, String, Map, Double, Array
from ..core.metadata import Metadata from ..core.metadata import Metadata
from ..core.topic import topic from ..core.topic import topic
@ -10,7 +10,7 @@ from ..core.topic import topic
class ExtractedObject(Record): class ExtractedObject(Record):
metadata = Metadata() metadata = Metadata()
schema_name = String() # Which schema this object belongs to schema_name = String() # Which schema this object belongs to
values = Map(String()) # Field name -> value values = Array(Map(String())) # Array of objects, each object is field name -> value
confidence = Double() confidence = Double()
source_span = String() # Text span where object was found source_span = String() # Text span where object was found

View file

@ -804,7 +804,18 @@ def load_structured_data(
print(f"Target schema: {schema_name}") print(f"Target schema: {schema_name}")
print(f"Sample record:") print(f"Sample record:")
if processed_records: if processed_records:
print(json.dumps(processed_records[0], indent=2)) # Show what the batched format will look like
sample_batch = processed_records[:min(3, len(processed_records))]
batch_values = [record["values"] for record in sample_batch]
first_record = processed_records[0]
batched_sample = {
"metadata": first_record["metadata"],
"schema_name": first_record["schema_name"],
"values": batch_values,
"confidence": first_record["confidence"],
"source_span": first_record["source_span"]
}
print(json.dumps(batched_sample, indent=2))
return return
# Import to TrustGraph using objects import endpoint via WebSocket # Import to TrustGraph using objects import endpoint via WebSocket
@ -828,10 +839,26 @@ def load_structured_data(
async with connect(objects_url) as ws: async with connect(objects_url) as ws:
imported_count = 0 imported_count = 0
for record in processed_records: # Process records in batches
# Send individual ExtractedObject records for i in range(0, len(processed_records), batch_size):
await ws.send(json.dumps(record)) batch_records = processed_records[i:i + batch_size]
imported_count += 1
# Extract values from each record in the batch
batch_values = [record["values"] for record in batch_records]
# Create batched ExtractedObject message using first record as template
first_record = batch_records[0]
batched_record = {
"metadata": first_record["metadata"],
"schema_name": first_record["schema_name"],
"values": batch_values, # Array of value dictionaries
"confidence": first_record["confidence"],
"source_span": first_record["source_span"]
}
# Send batched ExtractedObject
await ws.send(json.dumps(batched_record))
imported_count += len(batch_records)
if imported_count % 100 == 0: if imported_count % 100 == 0:
logger.info(f"Imported {imported_count}/{len(processed_records)} records...") logger.info(f"Imported {imported_count}/{len(processed_records)} records...")

View file

@ -256,31 +256,34 @@ class Processor(FlowProcessor):
flow flow
) )
# Emit each extracted object # Emit extracted objects as a batch if any were found
for obj in objects: if objects:
# Calculate confidence (could be enhanced with actual confidence from prompt) # Calculate confidence (could be enhanced with actual confidence from prompt)
confidence = 0.8 # Default confidence confidence = 0.8 # Default confidence
# Convert all values to strings for Pulsar compatibility # Convert all objects' values to strings for Pulsar compatibility
string_values = convert_values_to_strings(obj) batch_values = []
for obj in objects:
string_values = convert_values_to_strings(obj)
batch_values.append(string_values)
# Create ExtractedObject # Create ExtractedObject with batched values
extracted = ExtractedObject( extracted = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id=f"{v.metadata.id}:{schema_name}:{hash(str(obj))}", id=f"{v.metadata.id}:{schema_name}",
metadata=[], metadata=[],
user=v.metadata.user, user=v.metadata.user,
collection=v.metadata.collection, collection=v.metadata.collection,
), ),
schema_name=schema_name, schema_name=schema_name,
values=string_values, values=batch_values, # Array of objects
confidence=confidence, confidence=confidence,
source_span=chunk_text[:100] # First 100 chars as source reference source_span=chunk_text[:100] # First 100 chars as source reference
) )
await flow("output").send(extracted) await flow("output").send(extracted)
logger.debug(f"Emitted extracted object for schema {schema_name}") logger.debug(f"Emitted batch of {len(objects)} objects for schema {schema_name}")
except Exception as e: except Exception as e:
logger.error(f"Object extraction exception: {e}", exc_info=True) logger.error(f"Object extraction exception: {e}", exc_info=True)

View file

@ -44,6 +44,12 @@ class ObjectsImport:
data = msg.json() data = msg.json()
# Handle both single object and array of objects for backward compatibility
values_data = data["values"]
if not isinstance(values_data, list):
# Single object - wrap in array
values_data = [values_data]
elt = ExtractedObject( elt = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id=data["metadata"]["id"], id=data["metadata"]["id"],
@ -52,7 +58,7 @@ class ObjectsImport:
collection=data["metadata"]["collection"], collection=data["metadata"]["collection"],
), ),
schema_name=data["schema_name"], schema_name=data["schema_name"],
values=data["values"], values=values_data,
confidence=data.get("confidence", 1.0), confidence=data.get("confidence", 1.0),
source_span=data.get("source_span", ""), source_span=data.get("source_span", ""),
) )

View file

@ -311,7 +311,7 @@ class Processor(FlowProcessor):
"""Process incoming ExtractedObject and store in Cassandra""" """Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value() obj = msg.value()
logger.info(f"Storing object for schema {obj.schema_name} from {obj.metadata.id}") logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
# Get schema definition # Get schema definition
schema = self.schemas.get(obj.schema_name) schema = self.schemas.get(obj.schema_name)
@ -328,59 +328,67 @@ class Processor(FlowProcessor):
safe_keyspace = self.sanitize_name(keyspace) safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name) safe_table = self.sanitize_table(table_name)
# Build column names and values # Process each object in the batch
columns = ["collection"] for obj_index, value_map in enumerate(obj.values):
values = [obj.metadata.collection] # Build column names and values for this object
placeholders = ["%s"] columns = ["collection"]
values = [obj.metadata.collection]
placeholders = ["%s"]
# Check if we need a synthetic ID # Check if we need a synthetic ID
has_primary_key = any(field.primary for field in schema.fields) has_primary_key = any(field.primary for field in schema.fields)
if not has_primary_key: if not has_primary_key:
import uuid import uuid
columns.append("synthetic_id") columns.append("synthetic_id")
values.append(uuid.uuid4()) values.append(uuid.uuid4())
placeholders.append("%s") placeholders.append("%s")
# Process fields # Process fields for this object
for field in schema.fields: skip_object = False
safe_field_name = self.sanitize_name(field.name) for field in schema.fields:
raw_value = obj.values.get(field.name) safe_field_name = self.sanitize_name(field.name)
raw_value = value_map.get(field.name)
# Handle required fields # Handle required fields
if field.required and raw_value is None: if field.required and raw_value is None:
logger.warning(f"Required field {field.name} is missing in object") logger.warning(f"Required field {field.name} is missing in object {obj_index}")
# Continue anyway - Cassandra doesn't enforce NOT NULL # Continue anyway - Cassandra doesn't enforce NOT NULL
# Check if primary key field is NULL # Check if primary key field is NULL
if field.primary and raw_value is None: if field.primary and raw_value is None:
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object") logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
return skip_object = True
break
# Convert value to appropriate type # Convert value to appropriate type
converted_value = self.convert_value(raw_value, field.type) converted_value = self.convert_value(raw_value, field.type)
columns.append(safe_field_name) columns.append(safe_field_name)
values.append(converted_value) values.append(converted_value)
placeholders.append("%s") placeholders.append("%s")
# Build and execute insert query # Skip this object if primary key validation failed
insert_cql = f""" if skip_object:
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)}) continue
VALUES ({', '.join(placeholders)})
"""
# Debug: Show data being inserted # Build and execute insert query for this object
logger.debug(f"Storing {obj.schema_name}: {dict(zip(columns, values))}") insert_cql = f"""
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
"""
if len(columns) != len(values) or len(columns) != len(placeholders): # Debug: Show data being inserted
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}") logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
try: if len(columns) != len(values) or len(columns) != len(placeholders):
# Convert to tuple - Cassandra driver requires tuple for parameters raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
self.session.execute(insert_cql, tuple(values))
except Exception as e: try:
logger.error(f"Failed to insert object: {e}", exc_info=True) # Convert to tuple - Cassandra driver requires tuple for parameters
raise self.session.execute(insert_cql, tuple(values))
except Exception as e:
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
raise
def close(self): def close(self):
"""Clean up Cassandra connections""" """Clean up Cassandra connections"""