mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
parent
ebca467ed8
commit
0b7620bc04
12 changed files with 946 additions and 107 deletions
|
|
@ -30,11 +30,11 @@ class TestObjectsCassandraContracts:
|
|||
test_object = ExtractedObject(
|
||||
metadata=test_metadata,
|
||||
schema_name="customer_records",
|
||||
values={
|
||||
values=[{
|
||||
"customer_id": "CUST123",
|
||||
"name": "Test Customer",
|
||||
"email": "test@example.com"
|
||||
},
|
||||
}],
|
||||
confidence=0.95,
|
||||
source_span="Customer data from document..."
|
||||
)
|
||||
|
|
@ -54,7 +54,7 @@ class TestObjectsCassandraContracts:
|
|||
|
||||
# Verify types
|
||||
assert isinstance(test_object.schema_name, str)
|
||||
assert isinstance(test_object.values, dict)
|
||||
assert isinstance(test_object.values, list)
|
||||
assert isinstance(test_object.confidence, float)
|
||||
assert isinstance(test_object.source_span, str)
|
||||
|
||||
|
|
@ -200,7 +200,7 @@ class TestObjectsCassandraContracts:
|
|||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values={"field1": "value1", "field2": "123"},
|
||||
values=[{"field1": "value1", "field2": "123"}],
|
||||
confidence=0.85,
|
||||
source_span="Test span"
|
||||
)
|
||||
|
|
@ -292,7 +292,7 @@ class TestObjectsCassandraContracts:
|
|||
metadata=[{"key": "value"}]
|
||||
),
|
||||
schema_name="table789", # -> table name
|
||||
values={"field": "value"},
|
||||
values=[{"field": "value"}],
|
||||
confidence=0.9,
|
||||
source_span="Source"
|
||||
)
|
||||
|
|
@ -303,4 +303,215 @@ class TestObjectsCassandraContracts:
|
|||
# - metadata.collection -> Part of primary key
|
||||
assert test_obj.metadata.user # Required for keyspace
|
||||
assert test_obj.schema_name # Required for table
|
||||
assert test_obj.metadata.collection # Required for partition key
|
||||
assert test_obj.metadata.collection # Required for partition key
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestObjectsCassandraContractsBatch:
|
||||
"""Contract tests for Cassandra object storage batch processing"""
|
||||
|
||||
def test_extracted_object_batch_input_contract(self):
|
||||
"""Test that batched ExtractedObject schema matches expected input format"""
|
||||
# Create test object with multiple values in batch
|
||||
test_metadata = Metadata(
|
||||
id="batch-doc-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
batch_object = ExtractedObject(
|
||||
metadata=test_metadata,
|
||||
schema_name="customer_records",
|
||||
values=[
|
||||
{
|
||||
"customer_id": "CUST123",
|
||||
"name": "Test Customer 1",
|
||||
"email": "test1@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST124",
|
||||
"name": "Test Customer 2",
|
||||
"email": "test2@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST125",
|
||||
"name": "Test Customer 3",
|
||||
"email": "test3@example.com"
|
||||
}
|
||||
],
|
||||
confidence=0.88,
|
||||
source_span="Multiple customer data from document..."
|
||||
)
|
||||
|
||||
# Verify batch structure
|
||||
assert hasattr(batch_object, 'values')
|
||||
assert isinstance(batch_object.values, list)
|
||||
assert len(batch_object.values) == 3
|
||||
|
||||
# Verify each batch item is a dict
|
||||
for i, batch_item in enumerate(batch_object.values):
|
||||
assert isinstance(batch_item, dict)
|
||||
assert "customer_id" in batch_item
|
||||
assert "name" in batch_item
|
||||
assert "email" in batch_item
|
||||
assert batch_item["customer_id"] == f"CUST12{3+i}"
|
||||
assert f"Test Customer {i+1}" in batch_item["name"]
|
||||
|
||||
def test_extracted_object_empty_batch_contract(self):
|
||||
"""Test empty batch ExtractedObject contract"""
|
||||
test_metadata = Metadata(
|
||||
id="empty-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
empty_batch_object = ExtractedObject(
|
||||
metadata=test_metadata,
|
||||
schema_name="empty_schema",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="No objects found in document"
|
||||
)
|
||||
|
||||
# Verify empty batch structure
|
||||
assert hasattr(empty_batch_object, 'values')
|
||||
assert isinstance(empty_batch_object.values, list)
|
||||
assert len(empty_batch_object.values) == 0
|
||||
assert empty_batch_object.confidence == 1.0
|
||||
|
||||
def test_extracted_object_single_item_batch_contract(self):
|
||||
"""Test single-item batch (backward compatibility) contract"""
|
||||
test_metadata = Metadata(
|
||||
id="single-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
single_batch_object = ExtractedObject(
|
||||
metadata=test_metadata,
|
||||
schema_name="customer_records",
|
||||
values=[{ # Array with single item for backward compatibility
|
||||
"customer_id": "CUST999",
|
||||
"name": "Single Customer",
|
||||
"email": "single@example.com"
|
||||
}],
|
||||
confidence=0.95,
|
||||
source_span="Single customer data from document..."
|
||||
)
|
||||
|
||||
# Verify single-item batch structure
|
||||
assert isinstance(single_batch_object.values, list)
|
||||
assert len(single_batch_object.values) == 1
|
||||
assert isinstance(single_batch_object.values[0], dict)
|
||||
assert single_batch_object.values[0]["customer_id"] == "CUST999"
|
||||
|
||||
def test_extracted_object_batch_serialization_contract(self):
|
||||
"""Test that batched ExtractedObject can be serialized/deserialized correctly"""
|
||||
# Create batch object
|
||||
original = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-serial-001",
|
||||
user="test_user",
|
||||
collection="test_coll",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values=[
|
||||
{"field1": "value1", "field2": "123"},
|
||||
{"field1": "value2", "field2": "456"},
|
||||
{"field1": "value3", "field2": "789"}
|
||||
],
|
||||
confidence=0.92,
|
||||
source_span="Batch test span"
|
||||
)
|
||||
|
||||
# Test serialization using schema
|
||||
schema = AvroSchema(ExtractedObject)
|
||||
|
||||
# Encode and decode
|
||||
encoded = schema.encode(original)
|
||||
decoded = schema.decode(encoded)
|
||||
|
||||
# Verify round-trip for batch
|
||||
assert decoded.metadata.id == original.metadata.id
|
||||
assert decoded.metadata.user == original.metadata.user
|
||||
assert decoded.metadata.collection == original.metadata.collection
|
||||
assert decoded.schema_name == original.schema_name
|
||||
assert len(decoded.values) == len(original.values)
|
||||
assert len(decoded.values) == 3
|
||||
|
||||
# Verify each batch item
|
||||
for i in range(3):
|
||||
assert decoded.values[i] == original.values[i]
|
||||
assert decoded.values[i]["field1"] == f"value{i+1}"
|
||||
assert decoded.values[i]["field2"] == f"{123 + i*333}"
|
||||
|
||||
assert decoded.confidence == original.confidence
|
||||
assert decoded.source_span == original.source_span
|
||||
|
||||
def test_batch_processing_field_validation_contract(self):
|
||||
"""Test that batch processing validates field consistency"""
|
||||
# All batch items should have consistent field structure
|
||||
# This is a contract that the application should enforce
|
||||
|
||||
# Valid batch - all items have same fields
|
||||
valid_batch_values = [
|
||||
{"id": "1", "name": "Item 1", "value": "100"},
|
||||
{"id": "2", "name": "Item 2", "value": "200"},
|
||||
{"id": "3", "name": "Item 3", "value": "300"}
|
||||
]
|
||||
|
||||
# Each item has the same field structure
|
||||
field_sets = [set(item.keys()) for item in valid_batch_values]
|
||||
assert all(fields == field_sets[0] for fields in field_sets), "All batch items should have consistent fields"
|
||||
|
||||
# Invalid batch - inconsistent fields (this would be caught by application logic)
|
||||
invalid_batch_values = [
|
||||
{"id": "1", "name": "Item 1", "value": "100"},
|
||||
{"id": "2", "name": "Item 2"}, # Missing 'value' field
|
||||
{"id": "3", "name": "Item 3", "value": "300", "extra": "field"} # Extra field
|
||||
]
|
||||
|
||||
# Demonstrate the inconsistency
|
||||
invalid_field_sets = [set(item.keys()) for item in invalid_batch_values]
|
||||
assert not all(fields == invalid_field_sets[0] for fields in invalid_field_sets), "Invalid batch should have inconsistent fields"
|
||||
|
||||
def test_batch_storage_partition_key_contract(self):
|
||||
"""Test that batch objects maintain partition key consistency"""
|
||||
# In Cassandra storage, all objects in a batch should:
|
||||
# 1. Belong to the same collection (partition key component)
|
||||
# 2. Have unique primary keys within the batch
|
||||
# 3. Be stored in the same keyspace (user)
|
||||
|
||||
test_metadata = Metadata(
|
||||
id="partition-test-001",
|
||||
user="consistent_user", # Same keyspace
|
||||
collection="consistent_collection", # Same partition
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
batch_object = ExtractedObject(
|
||||
metadata=test_metadata,
|
||||
schema_name="partition_test",
|
||||
values=[
|
||||
{"id": "pk1", "data": "data1"}, # Unique primary key
|
||||
{"id": "pk2", "data": "data2"}, # Unique primary key
|
||||
{"id": "pk3", "data": "data3"} # Unique primary key
|
||||
],
|
||||
confidence=0.95,
|
||||
source_span="Partition consistency test"
|
||||
)
|
||||
|
||||
# Verify consistency contract
|
||||
assert batch_object.metadata.user # Must have user for keyspace
|
||||
assert batch_object.metadata.collection # Must have collection for partition key
|
||||
|
||||
# Verify unique primary keys in batch
|
||||
primary_keys = [item["id"] for item in batch_object.values]
|
||||
assert len(primary_keys) == len(set(primary_keys)), "Primary keys must be unique within batch"
|
||||
|
||||
# All batch items will be stored in same keyspace and partition
|
||||
# This is enforced by the metadata.user and metadata.collection being shared
|
||||
|
|
@ -128,18 +128,77 @@ class TestStructuredDataSchemaContracts:
|
|||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name="customer_records",
|
||||
values={"id": "123", "name": "John Doe", "email": "john@example.com"},
|
||||
values=[{"id": "123", "name": "John Doe", "email": "john@example.com"}],
|
||||
confidence=0.95,
|
||||
source_span="John Doe (john@example.com) customer ID 123"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert obj.schema_name == "customer_records"
|
||||
assert obj.values["name"] == "John Doe"
|
||||
assert obj.values[0]["name"] == "John Doe"
|
||||
assert obj.confidence == 0.95
|
||||
assert len(obj.source_span) > 0
|
||||
assert obj.metadata.id == "extracted-obj-001"
|
||||
|
||||
def test_extracted_object_batch_contract(self):
|
||||
"""Test ExtractedObject schema contract for batched values"""
|
||||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="extracted-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
# Act - create object with multiple values
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name="customer_records",
|
||||
values=[
|
||||
{"id": "123", "name": "John Doe", "email": "john@example.com"},
|
||||
{"id": "124", "name": "Jane Smith", "email": "jane@example.com"},
|
||||
{"id": "125", "name": "Bob Johnson", "email": "bob@example.com"}
|
||||
],
|
||||
confidence=0.85,
|
||||
source_span="Multiple customers found in document"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert obj.schema_name == "customer_records"
|
||||
assert len(obj.values) == 3
|
||||
assert obj.values[0]["name"] == "John Doe"
|
||||
assert obj.values[1]["name"] == "Jane Smith"
|
||||
assert obj.values[2]["name"] == "Bob Johnson"
|
||||
assert obj.values[0]["id"] == "123"
|
||||
assert obj.values[1]["id"] == "124"
|
||||
assert obj.values[2]["id"] == "125"
|
||||
assert obj.confidence == 0.85
|
||||
assert "Multiple customers" in obj.source_span
|
||||
|
||||
def test_extracted_object_empty_batch_contract(self):
|
||||
"""Test ExtractedObject schema contract for empty values array"""
|
||||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="extracted-empty-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
# Act - create object with empty values array
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name="empty_schema",
|
||||
values=[],
|
||||
confidence=1.0,
|
||||
source_span="No objects found"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert obj.schema_name == "empty_schema"
|
||||
assert len(obj.values) == 0
|
||||
assert obj.confidence == 1.0
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestStructuredQueryServiceContracts:
|
||||
|
|
@ -273,7 +332,7 @@ class TestStructuredDataSerializationContracts:
|
|||
object_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
"values": {"field1": "value1"},
|
||||
"values": [{"field1": "value1"}],
|
||||
"confidence": 0.8,
|
||||
"source_span": "test span"
|
||||
}
|
||||
|
|
@ -314,4 +373,38 @@ class TestStructuredDataSerializationContracts:
|
|||
"data": '{"customers": [{"id": "1", "name": "John"}]}',
|
||||
"errors": []
|
||||
}
|
||||
assert serialize_deserialize_test(StructuredQueryResponse, response_data)
|
||||
assert serialize_deserialize_test(StructuredQueryResponse, response_data)
|
||||
|
||||
def test_extracted_object_batch_serialization(self):
|
||||
"""Test ExtractedObject batch serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
|
||||
batch_object_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
"values": [
|
||||
{"field1": "value1", "field2": "value2"},
|
||||
{"field1": "value3", "field2": "value4"},
|
||||
{"field1": "value5", "field2": "value6"}
|
||||
],
|
||||
"confidence": 0.9,
|
||||
"source_span": "batch test span"
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
assert serialize_deserialize_test(ExtractedObject, batch_object_data)
|
||||
|
||||
def test_extracted_object_empty_batch_serialization(self):
|
||||
"""Test ExtractedObject empty batch serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
|
||||
empty_batch_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
"values": [],
|
||||
"confidence": 1.0,
|
||||
"source_span": "empty batch"
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
assert serialize_deserialize_test(ExtractedObject, empty_batch_data)
|
||||
|
|
@ -270,9 +270,9 @@ class TestObjectExtractionServiceIntegration:
|
|||
assert len(customer_calls) == 1
|
||||
customer_obj = customer_calls[0]
|
||||
|
||||
assert customer_obj.values["customer_id"] == "CUST001"
|
||||
assert customer_obj.values["name"] == "John Smith"
|
||||
assert customer_obj.values["email"] == "john.smith@email.com"
|
||||
assert customer_obj.values[0]["customer_id"] == "CUST001"
|
||||
assert customer_obj.values[0]["name"] == "John Smith"
|
||||
assert customer_obj.values[0]["email"] == "john.smith@email.com"
|
||||
assert customer_obj.confidence > 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -335,10 +335,10 @@ class TestObjectExtractionServiceIntegration:
|
|||
assert len(product_calls) == 1
|
||||
product_obj = product_calls[0]
|
||||
|
||||
assert product_obj.values["product_id"] == "PROD001"
|
||||
assert product_obj.values["name"] == "Gaming Laptop"
|
||||
assert product_obj.values["price"] == "1299.99"
|
||||
assert product_obj.values["category"] == "electronics"
|
||||
assert product_obj.values[0]["product_id"] == "PROD001"
|
||||
assert product_obj.values[0]["name"] == "Gaming Laptop"
|
||||
assert product_obj.values[0]["price"] == "1299.99"
|
||||
assert product_obj.values[0]["category"] == "electronics"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_extraction_integration(self, integration_config, mock_integrated_flow):
|
||||
|
|
|
|||
|
|
@ -95,12 +95,12 @@ class TestObjectsCassandraIntegration:
|
|||
metadata=[]
|
||||
),
|
||||
schema_name="customer_records",
|
||||
values={
|
||||
values=[{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": "30"
|
||||
},
|
||||
}],
|
||||
confidence=0.95,
|
||||
source_span="Customer: John Doe..."
|
||||
)
|
||||
|
|
@ -183,7 +183,7 @@ class TestObjectsCassandraIntegration:
|
|||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||
schema_name="products",
|
||||
values={"product_id": "P001", "name": "Widget", "price": "19.99"},
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
source_span="Product..."
|
||||
)
|
||||
|
|
@ -191,7 +191,7 @@ class TestObjectsCassandraIntegration:
|
|||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||
schema_name="orders",
|
||||
values={"order_id": "O001", "customer_id": "C001", "total": "59.97"},
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
source_span="Order..."
|
||||
)
|
||||
|
|
@ -229,7 +229,7 @@ class TestObjectsCassandraIntegration:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="test_schema",
|
||||
values={"id": "123"}, # missing required_field
|
||||
values=[{"id": "123"}], # missing required_field
|
||||
confidence=0.8,
|
||||
source_span="Test"
|
||||
)
|
||||
|
|
@ -265,7 +265,7 @@ class TestObjectsCassandraIntegration:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
|
||||
schema_name="events",
|
||||
values={"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"},
|
||||
values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}],
|
||||
confidence=1.0,
|
||||
source_span="Event"
|
||||
)
|
||||
|
|
@ -334,7 +334,7 @@ class TestObjectsCassandraIntegration:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="test",
|
||||
values={"id": "123"},
|
||||
values=[{"id": "123"}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
|
@ -364,7 +364,7 @@ class TestObjectsCassandraIntegration:
|
|||
obj = ExtractedObject(
|
||||
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
|
||||
schema_name="data",
|
||||
values={"id": f"ID-{coll}"},
|
||||
values=[{"id": f"ID-{coll}"}],
|
||||
confidence=0.9,
|
||||
source_span="Data"
|
||||
)
|
||||
|
|
@ -381,4 +381,170 @@ class TestObjectsCassandraIntegration:
|
|||
# Check each insert has the correct collection
|
||||
for i, call in enumerate(insert_calls):
|
||||
values = call[0][1]
|
||||
assert collections[i] in values
|
||||
assert collections[i] in values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with batched values"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"batch_customers": json.dumps({
|
||||
"name": "batch_customers",
|
||||
"description": "Customer batch data",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
values=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Smith",
|
||||
"email": "jane@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST003",
|
||||
"name": "Bob Johnson",
|
||||
"email": "bob@example.com"
|
||||
}
|
||||
],
|
||||
confidence=0.92,
|
||||
source_span="Multiple customers extracted from document"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
assert "o_batch_customers" in str(table_calls[0])
|
||||
|
||||
# Verify multiple inserts for batch values
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
# Should have 3 separate inserts for the 3 objects in the batch
|
||||
assert len(insert_calls) == 3
|
||||
|
||||
# Check each insert has correct data
|
||||
for i, call in enumerate(insert_calls):
|
||||
values = call[0][1]
|
||||
assert "batch_import" in values # collection
|
||||
assert f"CUST00{i+1}" in values # customer_id
|
||||
if i == 0:
|
||||
assert "John Doe" in values
|
||||
assert "john@example.com" in values
|
||||
elif i == 1:
|
||||
assert "Jane Smith" in values
|
||||
assert "jane@example.com" in values
|
||||
elif i == 2:
|
||||
assert "Bob Johnson" in values
|
||||
assert "bob@example.com" in values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with empty values array"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="No objects found"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should still create table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
|
||||
# Should not create any insert statements for empty batch
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_single_and_batch_objects(self, processor_with_mocks):
|
||||
"""Test processing mix of single and batch objects"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["mixed_test"] = RowSchema(
|
||||
name="mixed_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="data", type="string", size=100)
|
||||
]
|
||||
)
|
||||
|
||||
# Single object (backward compatibility)
|
||||
single_obj = ExtractedObject(
|
||||
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
|
||||
schema_name="mixed_test",
|
||||
values=[{"id": "single-1", "data": "single data"}], # Array with single item
|
||||
confidence=0.9,
|
||||
source_span="Single object"
|
||||
)
|
||||
|
||||
# Batch object
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
|
||||
schema_name="mixed_test",
|
||||
values=[
|
||||
{"id": "batch-1", "data": "batch data 1"},
|
||||
{"id": "batch-2", "data": "batch data 2"}
|
||||
],
|
||||
confidence=0.85,
|
||||
source_span="Batch objects"
|
||||
)
|
||||
|
||||
# Process both
|
||||
for obj in [single_obj, batch_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 total inserts (1 + 2)
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 3
|
||||
|
|
@ -66,11 +66,11 @@ def sample_objects_message():
|
|||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "person",
|
||||
"values": {
|
||||
"values": [{
|
||||
"name": "John Doe",
|
||||
"age": "30",
|
||||
"city": "New York"
|
||||
},
|
||||
}],
|
||||
"confidence": 0.95,
|
||||
"source_span": "John Doe, age 30, lives in New York"
|
||||
}
|
||||
|
|
@ -86,9 +86,9 @@ def minimal_objects_message():
|
|||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "simple_schema",
|
||||
"values": {
|
||||
"values": [{
|
||||
"field1": "value1"
|
||||
}
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -235,8 +235,8 @@ class TestObjectsImportMessageProcessing:
|
|||
sent_object = call_args[0][1]
|
||||
assert isinstance(sent_object, ExtractedObject)
|
||||
assert sent_object.schema_name == "person"
|
||||
assert sent_object.values["name"] == "John Doe"
|
||||
assert sent_object.values["age"] == "30"
|
||||
assert sent_object.values[0]["name"] == "John Doe"
|
||||
assert sent_object.values[0]["age"] == "30"
|
||||
assert sent_object.confidence == 0.95
|
||||
assert sent_object.source_span == "John Doe, age 30, lives in New York"
|
||||
|
||||
|
|
@ -274,7 +274,7 @@ class TestObjectsImportMessageProcessing:
|
|||
sent_object = mock_publisher_instance.send.call_args[0][1]
|
||||
assert isinstance(sent_object, ExtractedObject)
|
||||
assert sent_object.schema_name == "simple_schema"
|
||||
assert sent_object.values["field1"] == "value1"
|
||||
assert sent_object.values[0]["field1"] == "value1"
|
||||
assert sent_object.confidence == 1.0 # Default value
|
||||
assert sent_object.source_span == "" # Default value
|
||||
assert len(sent_object.metadata.metadata) == 0 # Default empty list
|
||||
|
|
@ -302,7 +302,7 @@ class TestObjectsImportMessageProcessing:
|
|||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "test_schema",
|
||||
"values": {"key": "value"}
|
||||
"values": [{"key": "value"}]
|
||||
# No confidence or source_span
|
||||
}
|
||||
|
||||
|
|
@ -374,6 +374,134 @@ class TestObjectsImportRunMethod:
|
|||
assert objects_import.ws is None
|
||||
|
||||
|
||||
class TestObjectsImportBatchProcessing:
|
||||
"""Test ObjectsImport batch processing functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def batch_objects_message(self):
|
||||
"""Sample batch objects message data."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "batch-001",
|
||||
"metadata": [
|
||||
{
|
||||
"s": {"v": "batch-001", "e": False},
|
||||
"p": {"v": "source", "e": False},
|
||||
"o": {"v": "test", "e": False}
|
||||
}
|
||||
],
|
||||
"user": "testuser",
|
||||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "person",
|
||||
"values": [
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": "30",
|
||||
"city": "New York"
|
||||
},
|
||||
{
|
||||
"name": "Jane Smith",
|
||||
"age": "25",
|
||||
"city": "Boston"
|
||||
},
|
||||
{
|
||||
"name": "Bob Johnson",
|
||||
"age": "45",
|
||||
"city": "Chicago"
|
||||
}
|
||||
],
|
||||
"confidence": 0.85,
|
||||
"source_span": "Multiple people found in document"
|
||||
}
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, batch_objects_message):
|
||||
"""Test that receive() processes batch message correctly."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = batch_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
||||
# Get the call arguments
|
||||
call_args = mock_publisher_instance.send.call_args
|
||||
assert call_args[0][0] is None # First argument should be None
|
||||
|
||||
# Check the ExtractedObject that was sent
|
||||
sent_object = call_args[0][1]
|
||||
assert isinstance(sent_object, ExtractedObject)
|
||||
assert sent_object.schema_name == "person"
|
||||
|
||||
# Check that all batch values are present
|
||||
assert len(sent_object.values) == 3
|
||||
assert sent_object.values[0]["name"] == "John Doe"
|
||||
assert sent_object.values[0]["age"] == "30"
|
||||
assert sent_object.values[0]["city"] == "New York"
|
||||
|
||||
assert sent_object.values[1]["name"] == "Jane Smith"
|
||||
assert sent_object.values[1]["age"] == "25"
|
||||
assert sent_object.values[1]["city"] == "Boston"
|
||||
|
||||
assert sent_object.values[2]["name"] == "Bob Johnson"
|
||||
assert sent_object.values[2]["age"] == "45"
|
||||
assert sent_object.values[2]["city"] == "Chicago"
|
||||
|
||||
assert sent_object.confidence == 0.85
|
||||
assert sent_object.source_span == "Multiple people found in document"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
"""Test that receive() handles empty batch correctly."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
# Message with empty values array
|
||||
empty_batch_message = {
|
||||
"metadata": {
|
||||
"id": "empty-batch-001",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection"
|
||||
},
|
||||
"schema_name": "empty_schema",
|
||||
"values": []
|
||||
}
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_batch_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
|
||||
# Should still send the message
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
sent_object = mock_publisher_instance.send.call_args[0][1]
|
||||
assert len(sent_object.values) == 0
|
||||
|
||||
|
||||
class TestObjectsImportErrorHandling:
|
||||
"""Test error handling in ObjectsImport."""
|
||||
|
||||
|
|
|
|||
|
|
@ -317,12 +317,12 @@ class TestObjectExtractionBusinessLogic:
|
|||
metadata=[]
|
||||
)
|
||||
|
||||
values = {
|
||||
values = [{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"status": "active"
|
||||
}
|
||||
}]
|
||||
|
||||
# Act
|
||||
extracted_obj = ExtractedObject(
|
||||
|
|
@ -335,7 +335,7 @@ class TestObjectExtractionBusinessLogic:
|
|||
|
||||
# Assert
|
||||
assert extracted_obj.schema_name == "customer_records"
|
||||
assert extracted_obj.values["customer_id"] == "CUST001"
|
||||
assert extracted_obj.values[0]["customer_id"] == "CUST001"
|
||||
assert extracted_obj.confidence == 0.95
|
||||
assert "John Doe" in extracted_obj.source_span
|
||||
assert extracted_obj.metadata.user == "test_user"
|
||||
|
|
|
|||
|
|
@ -261,7 +261,7 @@ class TestObjectsCassandraStorageLogic:
|
|||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values={"id": "123", "value": "456"},
|
||||
values=[{"id": "123", "value": "456"}],
|
||||
confidence=0.9,
|
||||
source_span="test source"
|
||||
)
|
||||
|
|
@ -284,8 +284,8 @@ class TestObjectsCassandraStorageLogic:
|
|||
assert "INSERT INTO test_user.o_test_schema" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
assert values[0] == "test_collection" # collection value
|
||||
assert values[1] == "123" # id value
|
||||
assert values[2] == 456 # converted integer value
|
||||
assert values[1] == "123" # id value (from values[0])
|
||||
assert values[2] == 456 # converted integer value (from values[0])
|
||||
|
||||
def test_secondary_index_creation(self):
|
||||
"""Test that secondary indexes are created for indexed fields"""
|
||||
|
|
@ -325,4 +325,201 @@ class TestObjectsCassandraStorageLogic:
|
|||
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
|
||||
assert len(index_calls) == 2
|
||||
assert any("o_products_category_idx" in call for call in index_calls)
|
||||
assert any("o_products_price_idx" in call for call in index_calls)
|
||||
assert any("o_products_price_idx" in call for call in index_calls)
|
||||
|
||||
|
||||
class TestObjectsCassandraStorageBatchLogic:
|
||||
"""Test batch processing logic in Cassandra storage"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing_logic(self):
|
||||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
description="Test batch schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="value", type="integer", size=4)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_schema",
|
||||
values=[
|
||||
{"id": "001", "name": "First", "value": "100"},
|
||||
{"id": "002", "name": "Second", "value": "200"},
|
||||
{"id": "003", "name": "Third", "value": "300"}
|
||||
],
|
||||
confidence=0.95,
|
||||
source_span="batch source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
# Process batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured once
|
||||
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
|
||||
|
||||
# Verify 3 separate insert calls (one per batch item)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check each insert call
|
||||
calls = processor.session.execute.call_args_list
|
||||
for i, call in enumerate(calls):
|
||||
insert_cql = call[0][0]
|
||||
values = call[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
|
||||
# Check values for each batch item
|
||||
assert values[0] == "batch_collection" # collection
|
||||
assert values[1] == f"00{i+1}" # id from batch item i
|
||||
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
|
||||
assert values[3] == (i+1) * 100 # converted integer value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing_logic(self):
|
||||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create empty batch object
|
||||
empty_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="empty-001",
|
||||
user="test_user",
|
||||
collection="empty_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="empty_schema",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="empty source"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
# Process empty batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once()
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_item_batch_processing_logic(self):
|
||||
"""Test processing of single-item batch (backward compatibility)"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"single_schema": RowSchema(
|
||||
name="single_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="data", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create single-item batch object (backward compatibility case)
|
||||
single_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="single-001",
|
||||
user="test_user",
|
||||
collection="single_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="single_schema",
|
||||
values=[{"id": "single-1", "data": "single data"}], # Array with one item
|
||||
confidence=0.8,
|
||||
source_span="single source"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = single_batch_obj
|
||||
|
||||
# Process single-item batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once()
|
||||
|
||||
# Verify exactly one insert call
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
insert_cql = processor.session.execute.call_args[0][0]
|
||||
values = processor.session.execute.call_args[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_single_schema" in insert_cql
|
||||
assert values[0] == "single_collection" # collection
|
||||
assert values[1] == "single-1" # id value
|
||||
assert values[2] == "single data" # data value
|
||||
|
||||
def test_batch_value_conversion_logic(self):
|
||||
"""Test value conversion works correctly for batch items"""
|
||||
processor = MagicMock()
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
|
||||
# Test various conversion scenarios that would occur in batch processing
|
||||
test_cases = [
|
||||
# Integer conversions for batch items
|
||||
("123", "integer", 123),
|
||||
("456", "integer", 456),
|
||||
("789", "integer", 789),
|
||||
# Float conversions for batch items
|
||||
("12.5", "float", 12.5),
|
||||
("34.7", "float", 34.7),
|
||||
# Boolean conversions for batch items
|
||||
("true", "boolean", True),
|
||||
("false", "boolean", False),
|
||||
("1", "boolean", True),
|
||||
("0", "boolean", False),
|
||||
# String conversions for batch items
|
||||
(123, "string", "123"),
|
||||
(45.6, "string", "45.6"),
|
||||
]
|
||||
|
||||
for input_val, field_type, expected_output in test_cases:
|
||||
result = processor.convert_value(input_val, field_type)
|
||||
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, String, Map, Double
|
||||
from pulsar.schema import Record, String, Map, Double, Array
|
||||
|
||||
from ..core.metadata import Metadata
|
||||
from ..core.topic import topic
|
||||
|
|
@ -10,7 +10,7 @@ from ..core.topic import topic
|
|||
class ExtractedObject(Record):
|
||||
metadata = Metadata()
|
||||
schema_name = String() # Which schema this object belongs to
|
||||
values = Map(String()) # Field name -> value
|
||||
values = Array(Map(String())) # Array of objects, each object is field name -> value
|
||||
confidence = Double()
|
||||
source_span = String() # Text span where object was found
|
||||
|
||||
|
|
|
|||
|
|
@ -804,7 +804,18 @@ def load_structured_data(
|
|||
print(f"Target schema: {schema_name}")
|
||||
print(f"Sample record:")
|
||||
if processed_records:
|
||||
print(json.dumps(processed_records[0], indent=2))
|
||||
# Show what the batched format will look like
|
||||
sample_batch = processed_records[:min(3, len(processed_records))]
|
||||
batch_values = [record["values"] for record in sample_batch]
|
||||
first_record = processed_records[0]
|
||||
batched_sample = {
|
||||
"metadata": first_record["metadata"],
|
||||
"schema_name": first_record["schema_name"],
|
||||
"values": batch_values,
|
||||
"confidence": first_record["confidence"],
|
||||
"source_span": first_record["source_span"]
|
||||
}
|
||||
print(json.dumps(batched_sample, indent=2))
|
||||
return
|
||||
|
||||
# Import to TrustGraph using objects import endpoint via WebSocket
|
||||
|
|
@ -828,10 +839,26 @@ def load_structured_data(
|
|||
async with connect(objects_url) as ws:
|
||||
imported_count = 0
|
||||
|
||||
for record in processed_records:
|
||||
# Send individual ExtractedObject records
|
||||
await ws.send(json.dumps(record))
|
||||
imported_count += 1
|
||||
# Process records in batches
|
||||
for i in range(0, len(processed_records), batch_size):
|
||||
batch_records = processed_records[i:i + batch_size]
|
||||
|
||||
# Extract values from each record in the batch
|
||||
batch_values = [record["values"] for record in batch_records]
|
||||
|
||||
# Create batched ExtractedObject message using first record as template
|
||||
first_record = batch_records[0]
|
||||
batched_record = {
|
||||
"metadata": first_record["metadata"],
|
||||
"schema_name": first_record["schema_name"],
|
||||
"values": batch_values, # Array of value dictionaries
|
||||
"confidence": first_record["confidence"],
|
||||
"source_span": first_record["source_span"]
|
||||
}
|
||||
|
||||
# Send batched ExtractedObject
|
||||
await ws.send(json.dumps(batched_record))
|
||||
imported_count += len(batch_records)
|
||||
|
||||
if imported_count % 100 == 0:
|
||||
logger.info(f"Imported {imported_count}/{len(processed_records)} records...")
|
||||
|
|
|
|||
|
|
@ -256,31 +256,34 @@ class Processor(FlowProcessor):
|
|||
flow
|
||||
)
|
||||
|
||||
# Emit each extracted object
|
||||
for obj in objects:
|
||||
# Emit extracted objects as a batch if any were found
|
||||
if objects:
|
||||
|
||||
# Calculate confidence (could be enhanced with actual confidence from prompt)
|
||||
confidence = 0.8 # Default confidence
|
||||
|
||||
# Convert all values to strings for Pulsar compatibility
|
||||
string_values = convert_values_to_strings(obj)
|
||||
# Convert all objects' values to strings for Pulsar compatibility
|
||||
batch_values = []
|
||||
for obj in objects:
|
||||
string_values = convert_values_to_strings(obj)
|
||||
batch_values.append(string_values)
|
||||
|
||||
# Create ExtractedObject
|
||||
# Create ExtractedObject with batched values
|
||||
extracted = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id=f"{v.metadata.id}:{schema_name}:{hash(str(obj))}",
|
||||
id=f"{v.metadata.id}:{schema_name}",
|
||||
metadata=[],
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
schema_name=schema_name,
|
||||
values=string_values,
|
||||
values=batch_values, # Array of objects
|
||||
confidence=confidence,
|
||||
source_span=chunk_text[:100] # First 100 chars as source reference
|
||||
)
|
||||
|
||||
await flow("output").send(extracted)
|
||||
logger.debug(f"Emitted extracted object for schema {schema_name}")
|
||||
logger.debug(f"Emitted batch of {len(objects)} objects for schema {schema_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Object extraction exception: {e}", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -44,6 +44,12 @@ class ObjectsImport:
|
|||
|
||||
data = msg.json()
|
||||
|
||||
# Handle both single object and array of objects for backward compatibility
|
||||
values_data = data["values"]
|
||||
if not isinstance(values_data, list):
|
||||
# Single object - wrap in array
|
||||
values_data = [values_data]
|
||||
|
||||
elt = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
|
|
@ -52,7 +58,7 @@ class ObjectsImport:
|
|||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
schema_name=data["schema_name"],
|
||||
values=data["values"],
|
||||
values=values_data,
|
||||
confidence=data.get("confidence", 1.0),
|
||||
source_span=data.get("source_span", ""),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ class Processor(FlowProcessor):
|
|||
"""Process incoming ExtractedObject and store in Cassandra"""
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(f"Storing object for schema {obj.schema_name} from {obj.metadata.id}")
|
||||
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
|
|
@ -328,59 +328,67 @@ class Processor(FlowProcessor):
|
|||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
safe_table = self.sanitize_table(table_name)
|
||||
|
||||
# Build column names and values
|
||||
columns = ["collection"]
|
||||
values = [obj.metadata.collection]
|
||||
placeholders = ["%s"]
|
||||
|
||||
# Check if we need a synthetic ID
|
||||
has_primary_key = any(field.primary for field in schema.fields)
|
||||
if not has_primary_key:
|
||||
import uuid
|
||||
columns.append("synthetic_id")
|
||||
values.append(uuid.uuid4())
|
||||
placeholders.append("%s")
|
||||
|
||||
# Process fields
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
raw_value = obj.values.get(field.name)
|
||||
# Process each object in the batch
|
||||
for obj_index, value_map in enumerate(obj.values):
|
||||
# Build column names and values for this object
|
||||
columns = ["collection"]
|
||||
values = [obj.metadata.collection]
|
||||
placeholders = ["%s"]
|
||||
|
||||
# Handle required fields
|
||||
if field.required and raw_value is None:
|
||||
logger.warning(f"Required field {field.name} is missing in object")
|
||||
# Continue anyway - Cassandra doesn't enforce NOT NULL
|
||||
# Check if we need a synthetic ID
|
||||
has_primary_key = any(field.primary for field in schema.fields)
|
||||
if not has_primary_key:
|
||||
import uuid
|
||||
columns.append("synthetic_id")
|
||||
values.append(uuid.uuid4())
|
||||
placeholders.append("%s")
|
||||
|
||||
# Check if primary key field is NULL
|
||||
if field.primary and raw_value is None:
|
||||
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object")
|
||||
return
|
||||
# Process fields for this object
|
||||
skip_object = False
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
raw_value = value_map.get(field.name)
|
||||
|
||||
# Handle required fields
|
||||
if field.required and raw_value is None:
|
||||
logger.warning(f"Required field {field.name} is missing in object {obj_index}")
|
||||
# Continue anyway - Cassandra doesn't enforce NOT NULL
|
||||
|
||||
# Check if primary key field is NULL
|
||||
if field.primary and raw_value is None:
|
||||
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
|
||||
skip_object = True
|
||||
break
|
||||
|
||||
# Convert value to appropriate type
|
||||
converted_value = self.convert_value(raw_value, field.type)
|
||||
|
||||
columns.append(safe_field_name)
|
||||
values.append(converted_value)
|
||||
placeholders.append("%s")
|
||||
|
||||
# Convert value to appropriate type
|
||||
converted_value = self.convert_value(raw_value, field.type)
|
||||
# Skip this object if primary key validation failed
|
||||
if skip_object:
|
||||
continue
|
||||
|
||||
columns.append(safe_field_name)
|
||||
values.append(converted_value)
|
||||
placeholders.append("%s")
|
||||
|
||||
# Build and execute insert query
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
|
||||
VALUES ({', '.join(placeholders)})
|
||||
"""
|
||||
|
||||
# Debug: Show data being inserted
|
||||
logger.debug(f"Storing {obj.schema_name}: {dict(zip(columns, values))}")
|
||||
|
||||
if len(columns) != len(values) or len(columns) != len(placeholders):
|
||||
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
|
||||
|
||||
try:
|
||||
# Convert to tuple - Cassandra driver requires tuple for parameters
|
||||
self.session.execute(insert_cql, tuple(values))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert object: {e}", exc_info=True)
|
||||
raise
|
||||
# Build and execute insert query for this object
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
|
||||
VALUES ({', '.join(placeholders)})
|
||||
"""
|
||||
|
||||
# Debug: Show data being inserted
|
||||
logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
|
||||
|
||||
if len(columns) != len(values) or len(columns) != len(placeholders):
|
||||
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
|
||||
|
||||
try:
|
||||
# Convert to tuple - Cassandra driver requires tuple for parameters
|
||||
self.session.execute(insert_cql, tuple(values))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue