Merge pull request #78 from ModernRelay/devin/1778363660-mr-901-blob-branch-merge

Fix branch merge with blob columns
This commit is contained in:
Ragnor Comerford 2026-05-12 07:31:04 -07:00 committed by GitHub
commit 676c9eab05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 384 additions and 49 deletions

View file

@ -26,12 +26,14 @@ struct StagedMergeResult {
struct CursorRow {
id: String,
signature: String,
dataset: Dataset,
batch: RecordBatch,
row_index: usize,
}
struct OrderedTableCursor {
stream: Option<std::pin::Pin<Box<DatasetRecordBatchStream>>>,
dataset: Option<Dataset>,
current_batch: Option<RecordBatch>,
current_row: usize,
peeked: Option<CursorRow>,
@ -47,14 +49,15 @@ impl OrderedTableCursor {
}
async fn from_dataset(dataset: Option<Dataset>) -> Result<Self> {
let stream = if let Some(ds) = dataset {
let stream = if let Some(ds) = &dataset {
Some(Box::pin(
crate::table_store::TableStore::scan_stream(
&ds,
crate::table_store::TableStore::scan_stream_with(
ds,
None,
None,
Some(vec![ColumnOrdering::asc_nulls_last("id".to_string())]),
false,
true,
|_| Ok(()),
)
.await?,
))
@ -64,6 +67,7 @@ impl OrderedTableCursor {
Ok(Self {
stream,
dataset,
current_batch: None,
current_row: 0,
peeked: None,
@ -90,9 +94,13 @@ impl OrderedTableCursor {
if self.current_row < batch.num_rows() {
let row_index = self.current_row;
self.current_row += 1;
let dataset = self.dataset.clone().ok_or_else(|| {
OmniError::manifest("cursor row missing source dataset".to_string())
})?;
return Ok(Some(CursorRow {
id: row_id_at(batch, row_index)?,
signature: row_signature(batch, row_index)?,
dataset,
batch: batch.clone(),
row_index,
}));
@ -146,13 +154,38 @@ impl StagedTableWriter {
async fn push_row(&mut self, row: &CursorRow) -> Result<()> {
self.row_count += 1;
self.buffered_rows += 1;
self.batches.push(row.batch.slice(row.row_index, 1));
self.batches.push(self.row_batch(row).await?);
if self.buffered_rows >= MERGE_STAGE_BATCH_ROWS {
self.flush().await?;
}
Ok(())
}
async fn row_batch(&self, row: &CursorRow) -> Result<RecordBatch> {
let batch = row.batch.slice(row.row_index, 1);
let has_blob_columns = row
.dataset
.schema()
.fields_pre_order()
.any(|field| field.is_blob());
if has_blob_columns {
return crate::table_store::TableStore::materialize_blob_batch(&row.dataset, batch)
.await;
}
let columns = self
.schema
.fields()
.iter()
.map(|field| {
batch.column_by_name(field.name()).cloned().ok_or_else(|| {
OmniError::Lance(format!("batch missing column '{}'", field.name()))
})
})
.collect::<Result<Vec<_>>>()?;
RecordBatch::try_new(self.schema.clone(), columns)
.map_err(|e| OmniError::Lance(e.to_string()))
}
async fn finish(mut self) -> Result<StagedTable> {
self.flush().await?;
if self.dataset.is_none() {
@ -494,7 +527,10 @@ fn classify_merge_conflict(
fn row_signature(batch: &RecordBatch, row: usize) -> Result<String> {
let mut values = Vec::with_capacity(batch.num_columns());
for column in batch.columns() {
for (field, column) in batch.schema().fields().iter().zip(batch.columns()) {
if field.name().starts_with("_row") {
continue;
}
values.push(
array_value_to_string(column.as_ref(), row)
.map_err(|e| OmniError::Lance(e.to_string()))?,
@ -503,6 +539,10 @@ fn row_signature(batch: &RecordBatch, row: usize) -> Result<String> {
Ok(values.join("\u{1f}"))
}
async fn scan_validation_stream(ds: &Dataset) -> Result<DatasetRecordBatchStream> {
crate::table_store::TableStore::scan_stream_with(ds, None, None, None, false, |_| Ok(())).await
}
async fn validate_merge_candidates(
db: &Omnigraph,
source_snapshot: &Snapshot,
@ -520,8 +560,7 @@ async fn validate_merge_candidates(
if let Some(ds) =
candidate_dataset(source_snapshot, target_snapshot, candidates, &table_key).await?
{
let mut stream =
crate::table_store::TableStore::scan_stream(&ds, None, None, None, false).await?;
let mut stream = scan_validation_stream(&ds).await?;
while let Some(batch) = stream
.try_next()
.await
@ -568,8 +607,7 @@ async fn validate_merge_candidates(
if let Some(ds) =
candidate_dataset(source_snapshot, target_snapshot, candidates, &table_key).await?
{
let mut stream =
crate::table_store::TableStore::scan_stream(&ds, None, None, None, false).await?;
let mut stream = scan_validation_stream(&ds).await?;
while let Some(batch) = stream
.try_next()
.await
@ -929,7 +967,7 @@ async fn publish_rewritten_merge_table(
if let Some(delta) = &staged.delta_staged {
let batches: Vec<RecordBatch> = target_db
.table_store()
.scan_batches(&delta.dataset)
.scan_batches_for_rewrite(&delta.dataset)
.await?
.into_iter()
.filter(|batch| batch.num_rows() > 0)
@ -1269,7 +1307,8 @@ impl Omnigraph {
.filter(|table_key| {
matches!(
candidates.get(*table_key),
Some(CandidateTableState::RewriteMerged(_)) | Some(CandidateTableState::AdoptSourceState)
Some(CandidateTableState::RewriteMerged(_))
| Some(CandidateTableState::AdoptSourceState)
)
})
.map(|table_key| (table_key.clone(), active_branch_for_keys.clone()))

View file

@ -1,15 +1,18 @@
use arrow_array::{RecordBatch, UInt64Array};
use arrow_array::{
Array, ArrayRef, RecordBatch, StringArray, StructArray, UInt8Array, UInt32Array, UInt64Array,
};
use arrow_schema::SchemaRef;
use arrow_select::concat::concat_batches;
use futures::TryStreamExt;
use lance::Dataset;
use lance::blob::BlobArrayBuilder;
use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner};
use lance::dataset::transaction::{Operation, Transaction, TransactionBuilder};
use lance::dataset::{
CommitBuilder, InsertBuilder, MergeInsertBuilder, WhenMatched, WhenNotMatched, WriteMode,
WriteParams,
};
use lance::datatypes::BlobHandling;
use lance::datatypes::BlobKind;
use lance::index::scalar::IndexDetails;
use lance_file::version::LanceFileVersion;
use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams};
@ -272,17 +275,189 @@ impl TableStore {
return self.scan_batches(ds).await;
}
let mut scanner = ds.scan();
scanner.blob_handling(BlobHandling::AllBinary);
scanner
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
let batches = Self::scan_stream(ds, None, None, None, true)
.await?
.try_collect::<Vec<RecordBatch>>()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut materialized = Vec::with_capacity(batches.len());
for batch in batches {
materialized.push(Self::materialize_blob_batch(ds, batch).await?);
}
Ok(materialized)
}
pub(crate) async fn materialize_blob_batch(
ds: &Dataset,
batch: RecordBatch,
) -> Result<RecordBatch> {
let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
if !has_blob_columns {
return Ok(batch);
}
let row_ids = batch
.column_by_name("_rowid")
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
.ok_or_else(|| {
OmniError::Lance("expected _rowid column when materializing blobs".to_string())
})?
.values()
.iter()
.copied()
.collect::<Vec<_>>();
let schema: SchemaRef = Arc::new(ds.schema().into());
let mut columns = Vec::with_capacity(schema.fields().len());
for field in schema.fields() {
let lance_field = lance::datatypes::Field::try_from(field.as_ref())
.map_err(|e| OmniError::Lance(e.to_string()))?;
let column = batch.column_by_name(field.name()).ok_or_else(|| {
OmniError::Lance(format!("batch missing column '{}'", field.name()))
})?;
if lance_field.is_blob() {
let descriptions =
column
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
OmniError::Lance(format!(
"expected blob descriptions for '{}'",
field.name()
))
})?;
columns.push(
Self::rebuild_blob_column(ds, field.name(), descriptions, &row_ids).await?,
);
} else {
columns.push(column.clone());
}
}
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
}
async fn rebuild_blob_column(
ds: &Dataset,
column_name: &str,
descriptions: &StructArray,
row_ids: &[u64],
) -> Result<ArrayRef> {
let mut builder = BlobArrayBuilder::new(row_ids.len());
let mut non_null_row_ids = Vec::new();
let mut row_has_blob = Vec::with_capacity(row_ids.len());
for row in 0..row_ids.len() {
let is_null = Self::blob_description_is_null(descriptions, row)?;
row_has_blob.push(!is_null);
if !is_null {
non_null_row_ids.push(row_ids[row]);
}
}
let blob_files = if non_null_row_ids.is_empty() {
Vec::new()
} else {
Arc::new(ds.clone())
.take_blobs(&non_null_row_ids, column_name)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
};
let mut files = blob_files.into_iter();
for has_blob in row_has_blob {
if !has_blob {
builder
.push_null()
.map_err(|e| OmniError::Lance(e.to_string()))?;
continue;
}
let blob = files.next().ok_or_else(|| {
OmniError::Lance(format!(
"blob rewrite for '{}' lost alignment with source rows",
column_name
))
})?;
builder
.push_bytes(
blob.read()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?,
)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if files.next().is_some() {
return Err(OmniError::Lance(format!(
"blob rewrite for '{}' produced extra source blobs",
column_name
)));
}
builder
.finish()
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn blob_description_is_null(descriptions: &StructArray, row: usize) -> Result<bool> {
if descriptions.is_null(row) {
return Ok(true);
}
let position = descriptions
.column_by_name("position")
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
.ok_or_else(|| {
OmniError::Lance(format!(
"unrecognized blob description schema {:?}: missing UInt64 position field",
descriptions.fields()
))
})?;
let size = descriptions
.column_by_name("size")
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
.ok_or_else(|| {
OmniError::Lance(format!(
"unrecognized blob description schema {:?}: missing UInt64 size field",
descriptions.fields()
))
})?;
let Some(kind_column) = descriptions.column_by_name("kind") else {
return Ok(position.is_null(row) || size.is_null(row));
};
let kind = if let Some(kind) = kind_column.as_any().downcast_ref::<UInt8Array>() {
if kind.is_null(row) {
return Ok(true);
}
kind.value(row)
} else if let Some(kind) = kind_column.as_any().downcast_ref::<UInt32Array>() {
if kind.is_null(row) {
return Ok(true);
}
kind.value(row) as u8
} else {
return Err(OmniError::Lance(format!(
"unrecognized blob description schema {:?}: kind field must be UInt8 or UInt32",
descriptions.fields()
)));
};
let kind = BlobKind::try_from(kind).map_err(|e| OmniError::Lance(e.to_string()))?;
if kind != BlobKind::Inline {
return Ok(false);
}
let blob_uri = descriptions
.column_by_name("blob_uri")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.and_then(|arr| (!arr.is_null(row)).then(|| arr.value(row)));
Ok((position.is_null(row) || position.value(row) == 0)
&& (size.is_null(row) || size.value(row) == 0)
&& blob_uri.unwrap_or("").is_empty())
}
pub async fn scan_stream(
ds: &Dataset,
projection: Option<&[&str]>,
@ -772,11 +947,7 @@ impl TableStore {
///
/// MR-793 Phase 2: introduces this for the schema_apply rewrite path.
/// Lance API verified in `.context/mr-793-design.md` Appendix A.1.
pub async fn stage_overwrite(
&self,
ds: &Dataset,
batch: RecordBatch,
) -> Result<StagedWrite> {
pub async fn stage_overwrite(&self, ds: &Dataset, batch: RecordBatch) -> Result<StagedWrite> {
if batch.num_rows() == 0 {
return Err(OmniError::manifest_internal(
"stage_overwrite called with empty batch".to_string(),
@ -821,8 +992,7 @@ impl TableStore {
// read-your-writes via scan_with_staged, list every committed
// fragment in removed_fragment_ids so the post-stage view shows
// ONLY the staged fragments.
let removed_fragment_ids: Vec<u64> =
ds.manifest.fragments.iter().map(|f| f.id).collect();
let removed_fragment_ids: Vec<u64> = ds.manifest.fragments.iter().map(|f| f.id).collect();
Ok(StagedWrite {
transaction,
new_fragments,
@ -859,9 +1029,7 @@ impl TableStore {
.replace(true)
.execute_uncommitted()
.await
.map_err(|e| {
OmniError::Lance(format!("stage_create_btree_index: {}", e))
})?;
.map_err(|e| OmniError::Lance(format!("stage_create_btree_index: {}", e)))?;
let removed_indices: Vec<IndexMetadata> = ds
.load_indices()
.await
@ -900,9 +1068,7 @@ impl TableStore {
.replace(true)
.execute_uncommitted()
.await
.map_err(|e| {
OmniError::Lance(format!("stage_create_inverted_index: {}", e))
})?;
.map_err(|e| OmniError::Lance(format!("stage_create_inverted_index: {}", e)))?;
let removed_indices: Vec<IndexMetadata> = ds
.load_indices()
.await
@ -1074,13 +1240,8 @@ impl TableStore {
None => committed,
};
let pending = scan_pending_batches(
pending_batches,
pending_schema,
projection,
filter,
)
.await?;
let pending =
scan_pending_batches(pending_batches, pending_schema, projection, filter).await?;
let mut out = committed;
out.extend(pending);
@ -1438,11 +1599,8 @@ async fn scan_pending_batches(
) -> Result<Vec<RecordBatch>> {
let schema = pending_schema.unwrap_or_else(|| pending_batches[0].schema());
let ctx = datafusion::execution::context::SessionContext::new();
let mem = datafusion::datasource::MemTable::try_new(
schema,
vec![pending_batches.to_vec()],
)
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mem = datafusion::datasource::MemTable::try_new(schema, vec![pending_batches.to_vec()])
.map_err(|e| OmniError::Lance(e.to_string()))?;
ctx.register_table("pending", Arc::new(mem))
.map_err(|e| OmniError::Lance(e.to_string()))?;
@ -1454,9 +1612,7 @@ async fn scan_pending_batches(
.join(", ")
})
.unwrap_or_else(|| "*".to_string());
let where_clause = filter
.map(|f| format!("WHERE {f}"))
.unwrap_or_default();
let where_clause = filter.map(|f| format!("WHERE {f}")).unwrap_or_default();
let sql = format!("SELECT {proj} FROM pending {where_clause}");
let df = ctx
.sql(&sql)

View file

@ -60,6 +60,24 @@ query add_employment($person: String, $company: String) {
}
"#;
const BLOB_SCHEMA: &str = r#"
node Document {
title: String @key
content: Blob?
note: String?
}
"#;
const BLOB_MUTATIONS: &str = r#"
query insert_doc($title: String, $content: Blob, $note: String) {
insert Document { title: $title, content: $content, note: $note }
}
query update_doc_note($title: String, $note: String) {
update Document set { note: $note } where title = $title
}
"#;
async fn init_search_db(dir: &tempfile::TempDir) -> Omnigraph {
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, SEARCH_SCHEMA).await.unwrap();
@ -297,6 +315,128 @@ async fn branch_merge_updates_main_traversal() {
assert_eq!(merged.num_rows(), 3);
}
#[tokio::test]
async fn branch_merge_with_blob_columns_preserves_blob_data() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = Omnigraph::init(uri, BLOB_SCHEMA).await.unwrap();
load_jsonl(
&mut main,
concat!(
"{\"type\":\"Document\",\"data\":{\"title\":\"seed\",\"content\":\"base64:U2VlZA==\",\"note\":\"original\"}}\n",
"{\"type\":\"Document\",\"data\":{\"title\":\"main-doc\",\"content\":\"base64:TWFpbg==\",\"note\":\"main\"}}",
),
LoadMode::Overwrite,
)
.await
.unwrap();
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
mutate_main(
&mut main,
BLOB_MUTATIONS,
"update_doc_note",
&params(&[("$title", "main-doc"), ("$note", "updated on main")]),
)
.await
.unwrap();
mutate_branch(
&mut feature,
"feature",
BLOB_MUTATIONS,
"insert_doc",
&params(&[
("$title", "readme"),
("$content", "base64:SGVsbG8="),
("$note", "branch insert"),
]),
)
.await
.unwrap();
mutate_branch(
&mut feature,
"feature",
BLOB_MUTATIONS,
"update_doc_note",
&params(&[("$title", "seed"), ("$note", "updated on branch")]),
)
.await
.unwrap();
let outcome = main.branch_merge("feature", "main").await.unwrap();
assert_eq!(outcome, MergeOutcome::Merged);
let readme = main
.read_blob("Document", "readme", "content")
.await
.unwrap();
let readme_bytes = readme.read().await.unwrap();
assert_eq!(&readme_bytes[..], b"Hello");
let seed = main.read_blob("Document", "seed", "content").await.unwrap();
let seed_bytes = seed.read().await.unwrap();
assert_eq!(&seed_bytes[..], b"Seed");
let main_doc = main
.read_blob("Document", "main-doc", "content")
.await
.unwrap();
let main_doc_bytes = main_doc.read().await.unwrap();
assert_eq!(&main_doc_bytes[..], b"Main");
}
#[tokio::test]
async fn branch_merge_with_external_blob_uri_materializes_payload() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let external_dir = tempfile::tempdir().unwrap();
let external_path = external_dir.path().join("external.txt");
fs::write(&external_path, b"External").unwrap();
let external_uri = format!("file://{}", external_path.display());
let mut main = Omnigraph::init(uri, BLOB_SCHEMA).await.unwrap();
load_jsonl(&mut main, "", LoadMode::Overwrite)
.await
.unwrap();
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
load_jsonl(
&mut main,
"{\"type\":\"Document\",\"data\":{\"title\":\"main-doc\",\"content\":\"base64:TWFpbg==\",\"note\":\"main\"}}",
LoadMode::Append,
)
.await
.unwrap();
let external_data = serde_json::json!({
"type": "Document",
"data": {
"title": "external",
"content": external_uri,
"note": "branch insert",
}
})
.to_string();
feature
.load("feature", &external_data, LoadMode::Append)
.await
.unwrap();
let outcome = main.branch_merge("feature", "main").await.unwrap();
assert_eq!(outcome, MergeOutcome::Merged);
let external = main
.read_blob("Document", "external", "content")
.await
.unwrap();
let external_bytes = external.read().await.unwrap();
assert_eq!(&external_bytes[..], b"External");
}
#[tokio::test]
async fn branch_merge_applies_node_insert_to_main() {
let dir = tempfile::tempdir().unwrap();