From cdfbccbfdc57597c2bdad5c9079a1c73cc55f21c Mon Sep 17 00:00:00 2001 From: Ragnor Comerford Date: Fri, 1 May 2026 10:42:21 +0200 Subject: [PATCH] MR-794 step 2: scaffold MutationStaging accumulator + scan_with_pending MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the scaffolding for the in-memory staged-write rewire — no behavior change yet: * New crates/omnigraph/src/exec/staging.rs with MutationStaging, PendingTable, PendingMode, StagedTablePath, plus the end-of-query finalize() that issues one stage_* + commit_staged per pending table (Merge mode dedupes by id, last-write-wins). * TableStore::scan_with_pending and count_rows_with_pending helpers — Lance scan committed + DataFusion MemTable scan pending, concat. Sidesteps the Scanner::with_fragments filter-pushdown limitation documented on scan_with_staged. * Add datafusion = "52" to workspace + omnigraph-engine deps for MemTable (transitively pulled by Lance already). Engine code still uses the legacy MutationStaging shape; the rewire lands in subsequent commits. Co-Authored-By: Claude Opus 4.7 (1M context) --- Cargo.lock | 1 + Cargo.toml | 1 + crates/omnigraph/Cargo.toml | 1 + crates/omnigraph/src/exec/mod.rs | 1 + crates/omnigraph/src/exec/staging.rs | 391 +++++++++++++++++++++++++++ crates/omnigraph/src/table_store.rs | 140 ++++++++++ 6 files changed, 535 insertions(+) create mode 100644 crates/omnigraph/src/exec/staging.rs diff --git a/Cargo.lock b/Cargo.lock index 9eafcf9..77efe8f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4647,6 +4647,7 @@ dependencies = [ "async-trait", "base64", "chrono", + "datafusion", "fail", "futures", "lance", diff --git a/Cargo.toml b/Cargo.toml index 1766b07..2878bf7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ arrow-select = "57" arrow-cast = { version = "57", features = ["prettyprint"] } arrow-ord = "57" +datafusion = { version = "52", default-features = false } datafusion-physical-plan = "52" datafusion-physical-expr = "52" datafusion-execution = "52" diff --git a/crates/omnigraph/Cargo.toml b/crates/omnigraph/Cargo.toml index 0e1bae9..afa608b 100644 --- a/crates/omnigraph/Cargo.toml +++ b/crates/omnigraph/Cargo.toml @@ -19,6 +19,7 @@ failpoints = ["dep:fail", "fail/failpoints"] omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.3.1" } lance = { workspace = true } lance-datafusion = { workspace = true } +datafusion = { workspace = true } lance-file = { workspace = true } lance-index = { workspace = true } lance-linalg = { workspace = true } diff --git a/crates/omnigraph/src/exec/mod.rs b/crates/omnigraph/src/exec/mod.rs index 8171461..33a7e41 100644 --- a/crates/omnigraph/src/exec/mod.rs +++ b/crates/omnigraph/src/exec/mod.rs @@ -46,3 +46,4 @@ mod merge; mod mutation; mod projection; mod query; +pub(crate) mod staging; diff --git a/crates/omnigraph/src/exec/staging.rs b/crates/omnigraph/src/exec/staging.rs new file mode 100644 index 0000000..2b2c193 --- /dev/null +++ b/crates/omnigraph/src/exec/staging.rs @@ -0,0 +1,391 @@ +//! Per-query staging accumulator for direct-publish writes (MR-794 step 2+). +//! +//! `MutationStaging` accumulates per-table input batches in memory during a +//! `mutate_as` or `load` query, then at end-of-query commits each touched +//! table via Lance's distributed-write API (one `stage_*` + `commit_staged` +//! per table) and returns the publisher inputs (`SubTableUpdate` list + +//! `expected_table_versions`). +//! +//! Read-your-writes within the same query is satisfied by the in-memory +//! pending batches (see `pending_batches`) — read sites union the committed +//! Lance scan with the pending Arrow batches via DataFusion `MemTable` (see +//! `crate::table_store::TableStore::scan_with_pending`). +//! +//! This module is shared by the engine's mutation path (`exec/mutation.rs`) +//! and the bulk loader (`loader/mod.rs`); both feed insert/update batches +//! into `pending` and route end-of-query commits through `finalize`. +//! Deletes follow the inline-commit path and are recorded via +//! `record_inline` (parse-time D₂ rule prevents mixed insert/delete in a +//! single query, so no flushing is required). + +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use arrow_array::{Array, RecordBatch, StringArray, UInt32Array}; +use arrow_schema::SchemaRef; + +use crate::db::SubTableUpdate; +use crate::error::{OmniError, Result}; + +/// Whether the per-table accumulator should commit via `stage_append` +/// (no @key inserts, edge inserts) or `stage_merge_insert` (any @key insert +/// or update). Once set to `Merge` for a table within a query, subsequent +/// inserts on that table are rolled into the same merge — a `WhenNotMatched +/// = InsertAll` merge is correct for both cases. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum PendingMode { + Append, + Merge, +} + +/// Per-table accumulator. Each insert/update op pushes a `RecordBatch` into +/// `batches`; at end-of-query the accumulated batches concat into a single +/// stage call. +#[derive(Debug)] +pub(crate) struct PendingTable { + pub(crate) schema: SchemaRef, + pub(crate) mode: PendingMode, + pub(crate) batches: Vec, +} + +impl PendingTable { + fn new(schema: SchemaRef, mode: PendingMode) -> Self { + Self { + schema, + mode, + batches: Vec::new(), + } + } + + fn total_rows(&self) -> usize { + self.batches.iter().map(|b| b.num_rows()).sum() + } +} + +/// Stable per-table identifiers captured on first touch and reused at +/// finalize time. Avoids re-resolving the dataset path / branch. +#[derive(Debug, Clone)] +pub(crate) struct StagedTablePath { + pub(crate) full_path: String, + pub(crate) table_branch: Option, +} + +/// Per-query staging state. +/// +/// Replaces the post-MR-771 inline-commit `MutationStaging.latest` map with +/// an in-memory accumulator that defers all Lance HEAD advances to +/// end-of-query. After this rewire the bug class "Lance HEAD drifts ahead +/// of `__manifest`" is unreachable in `mutate_as` and `load` for inserts +/// and updates by construction. +#[derive(Default)] +pub(crate) struct MutationStaging { + /// Pre-write manifest version per table — the publisher's CAS fence at + /// end-of-query. + pub(crate) expected_versions: HashMap, + /// Per-table identifiers captured on first touch. + pub(crate) paths: HashMap, + /// In-memory accumulated batches per table (insert/update path). + pub(crate) pending: HashMap, + /// Inline-committed updates from delete-touching ops (D₂ guarantees no + /// pending batches exist on a delete-touched table). + pub(crate) inline_committed: HashMap, +} + +impl MutationStaging { + /// Capture pre-write metadata on first touch of a table. Subsequent + /// touches are no-ops (paths and `expected_version` are stable for the + /// lifetime of one query). + pub(crate) fn ensure_path( + &mut self, + table_key: &str, + full_path: String, + table_branch: Option, + expected_version: u64, + ) { + self.paths.entry(table_key.to_string()).or_insert(StagedTablePath { + full_path, + table_branch, + }); + self.expected_versions + .entry(table_key.to_string()) + .or_insert(expected_version); + } + + /// Append a batch to the per-table accumulator. + /// + /// `mode` is asserted-consistent with prior pushes for the same table: + /// `Append`+`Append` stays Append; any `Merge` upgrades the table to + /// Merge (e.g. an `update Person` after `insert Knows from='X' to='Y'` + /// when both produce content on `node:Person`). Once Merge is set, + /// subsequent appends roll into the merge stream — `WhenNotMatched = + /// InsertAll` correctly inserts append-shaped rows. + pub(crate) fn append_batch( + &mut self, + table_key: &str, + schema: SchemaRef, + mode: PendingMode, + batch: RecordBatch, + ) -> Result<()> { + if batch.num_rows() == 0 { + // No-op — staging is purely additive; an empty batch should not + // be appended. + return Ok(()); + } + let entry = self + .pending + .entry(table_key.to_string()) + .or_insert_with(|| PendingTable::new(schema.clone(), mode)); + // Upgrade Append -> Merge if any op needs merge semantics. + if mode == PendingMode::Merge { + entry.mode = PendingMode::Merge; + } + entry.batches.push(batch); + Ok(()) + } + + /// Record a delete that already inline-committed at the Lance layer. + pub(crate) fn record_inline(&mut self, update: SubTableUpdate) { + self.inline_committed.insert(update.table_key.clone(), update); + } + + /// Read-your-writes accessor: the accumulated pending batches for + /// `table_key`, or `&[]` if none. + pub(crate) fn pending_batches(&self, table_key: &str) -> &[RecordBatch] { + self.pending + .get(table_key) + .map(|p| p.batches.as_slice()) + .unwrap_or(&[]) + } + + /// Schema of the accumulated batches for `table_key`, or `None` if no + /// op has touched the table. Used by `scan_with_pending` to construct + /// the in-memory `MemTable`. + pub(crate) fn pending_schema(&self, table_key: &str) -> Option { + self.pending.get(table_key).map(|p| p.schema.clone()) + } + + /// `true` if neither pending nor inline_committed has any state — the + /// query made no observable writes. + pub(crate) fn is_empty(&self) -> bool { + self.pending.is_empty() && self.inline_committed.is_empty() + } + + /// Total count of pending rows across all tables. Used by tests and + /// (eventually) memory-budget enforcement. + #[allow(dead_code)] + pub(crate) fn pending_row_count(&self) -> usize { + self.pending.values().map(|p| p.total_rows()).sum() + } + + /// End-of-query: for each pending table, concat batches and commit via + /// `stage_append` or `stage_merge_insert` followed by `commit_staged`. + /// Merge with inline-committed entries. Return `(updates, + /// expected_versions)` for `commit_updates_on_branch_with_expected`. + /// + /// Sequential per-table — no cross-table dependency, but a parallel + /// version is a perf optimization for multi-table writes (loader with + /// many node + edge types). v1 ships sequential; the fan-out can land + /// in a follow-up. + pub(crate) async fn finalize( + self, + db: &crate::db::Omnigraph, + _branch: Option<&str>, + ) -> Result<(Vec, HashMap)> { + let MutationStaging { + expected_versions, + paths, + pending, + inline_committed, + } = self; + + let mut updates: Vec = + inline_committed.into_values().collect(); + + for (table_key, table) in pending { + let path = paths.get(&table_key).ok_or_else(|| { + OmniError::manifest_internal(format!( + "MutationStaging::finalize: missing path for table '{}'", + table_key + )) + })?; + let expected = *expected_versions.get(&table_key).ok_or_else(|| { + OmniError::manifest_internal(format!( + "MutationStaging::finalize: missing expected version for table '{}'", + table_key + )) + })?; + + // Reopen at the pre-write version. Lance HEAD has not advanced + // since `ensure_path` captured it — no prior op committed to + // this dataset. + let ds = db + .reopen_for_mutation( + &table_key, + &path.full_path, + path.table_branch.as_deref(), + expected, + ) + .await?; + + if table.batches.is_empty() { + continue; + } + + // For Merge mode, dedupe accumulated batches by `id`, keeping + // the LAST occurrence (last-write-wins for the query). This + // is required because Lance's `MergeInsertBuilder` produces + // arbitrary results on duplicate keys in the source. Append + // mode is exempt because no-key node and edge inserts use + // ULID-generated ids that are unique within a query. + let combined = match table.mode { + PendingMode::Merge => { + dedupe_merge_batches_by_id(&table.schema, table.batches)? + } + PendingMode::Append => { + if table.batches.len() == 1 { + table.batches.into_iter().next().unwrap() + } else { + arrow_select::concat::concat_batches( + &table.schema, + &table.batches, + ) + .map_err(|e| OmniError::Lance(e.to_string()))? + } + } + }; + + // Commit via Lance's two-phase write: stage produces + // uncommitted fragments + transaction; commit advances HEAD. + let staged = match table.mode { + PendingMode::Append => { + db.table_store().stage_append(&ds, combined, &[]).await? + } + PendingMode::Merge => { + db.table_store() + .stage_merge_insert( + ds.clone(), + combined, + vec!["id".to_string()], + lance::dataset::WhenMatched::UpdateAll, + lance::dataset::WhenNotMatched::InsertAll, + ) + .await? + } + }; + let new_ds = db + .table_store() + .commit_staged(Arc::new(ds), staged.transaction) + .await?; + let state = db + .table_store() + .table_state(&path.full_path, &new_ds) + .await?; + updates.push(SubTableUpdate { + table_key: table_key.clone(), + table_version: state.version, + table_branch: path.table_branch.clone(), + row_count: state.row_count, + version_metadata: state.version_metadata, + }); + } + + Ok((updates, expected_versions)) + } +} + +/// Walk `batches` in reverse, tracking seen `id` values; for each row +/// whose id we have NOT seen yet, mark it as a keeper. After the walk, +/// take the kept rows in forward (input) order and concat into one batch. +/// +/// Result: a deduped batch where each `id` appears at most once, with +/// the LAST occurrence's column values. Required by `stage_merge_insert`, +/// which needs unique source keys (Lance's `MergeInsertBuilder` produces +/// arbitrary results on duplicates). +/// +/// `batches` must be non-empty and all share `schema` (caller enforces). +fn dedupe_merge_batches_by_id( + schema: &SchemaRef, + batches: Vec, +) -> Result { + if batches.is_empty() { + return Err(OmniError::manifest_internal( + "dedupe_merge_batches_by_id: batches is empty".to_string(), + )); + } + + // Walk in reverse, tracking seen ids. For each row whose id we + // haven't seen yet, record (batch_idx, row_idx) for the kept set. + let mut seen: HashSet = HashSet::new(); + let mut keep: Vec> = vec![Vec::new(); batches.len()]; + let mut any_duplicates = false; + + for (b_idx, batch) in batches.iter().enumerate().rev() { + let id_col = batch + .column_by_name("id") + .ok_or_else(|| { + OmniError::manifest_internal( + "dedupe_merge_batches_by_id: batch has no 'id' column".to_string(), + ) + })? + .as_any() + .downcast_ref::() + .ok_or_else(|| { + OmniError::manifest_internal( + "dedupe_merge_batches_by_id: 'id' column is not Utf8".to_string(), + ) + })?; + for r_idx in (0..batch.num_rows()).rev() { + if !id_col.is_valid(r_idx) { + // NULL ids — keep all (NULL != NULL in Lance/SQL semantics). + keep[b_idx].push(r_idx as u32); + continue; + } + let id = id_col.value(r_idx); + if seen.insert(id.to_string()) { + keep[b_idx].push(r_idx as u32); + } else { + any_duplicates = true; + } + } + // We pushed in reverse-row order; flip to forward order so the + // emitted batch reflects insertion order. + keep[b_idx].reverse(); + } + + // Fast path: no duplicates → simple concat. + if !any_duplicates { + if batches.len() == 1 { + return Ok(batches.into_iter().next().unwrap()); + } + return arrow_select::concat::concat_batches(schema, &batches) + .map_err(|e| OmniError::Lance(e.to_string())); + } + + // Slow path: build per-batch slices via `take`, then concat. + let mut sliced: Vec = Vec::with_capacity(batches.len()); + for (b_idx, idxs) in keep.into_iter().enumerate() { + if idxs.is_empty() { + continue; + } + let take_array = UInt32Array::from(idxs); + let columns: Vec> = batches[b_idx] + .columns() + .iter() + .map(|col| arrow_select::take::take(col, &take_array, None)) + .collect::>() + .map_err(|e| OmniError::Lance(e.to_string()))?; + let new_batch = RecordBatch::try_new(batches[b_idx].schema(), columns) + .map_err(|e| OmniError::Lance(e.to_string()))?; + sliced.push(new_batch); + } + if sliced.is_empty() { + return Err(OmniError::manifest_internal( + "dedupe_merge_batches_by_id: all rows were dropped (unexpected)".to_string(), + )); + } + if sliced.len() == 1 { + return Ok(sliced.into_iter().next().unwrap()); + } + arrow_select::concat::concat_batches(schema, &sliced) + .map_err(|e| OmniError::Lance(e.to_string())) +} diff --git a/crates/omnigraph/src/table_store.rs b/crates/omnigraph/src/table_store.rs index c9ff90a..52deb2f 100644 --- a/crates/omnigraph/src/table_store.rs +++ b/crates/omnigraph/src/table_store.rs @@ -821,6 +821,103 @@ impl TableStore { .map_err(|e| OmniError::Lance(e.to_string())) } + /// Scan committed via Lance + apply the same filter to in-memory + /// pending batches via DataFusion `MemTable`, concat the two result + /// streams. The replacement for `scan_with_staged` in engine code: + /// MR-794 step 2+ accumulates input batches in memory and unions + /// them with the committed snapshot at read time, sidestepping the + /// `Scanner::with_fragments` filter-pushdown limitation documented + /// on `scan_with_staged`. + /// + /// `committed_ds` should be opened at the pre-mutation + /// `expected_version` (the same version captured in `MutationStaging::expected_versions` + /// at first touch of the table). `pending_batches` are the per-table + /// accumulator's batches in their input shape. `pending_schema` is + /// the schema of the accumulated batches; passing `None` falls back + /// to the schema of the first pending batch. + /// + /// `filter` is the Lance / DataFusion SQL predicate. It is applied + /// to both sides — Lance pushes it down on the committed side; the + /// pending side runs it through a fresh DataFusion `SessionContext` + /// with the batches registered as a `MemTable` named `pending`. + /// + /// When `pending_batches` is empty this delegates to the regular + /// scan path. + pub async fn scan_with_pending( + &self, + committed_ds: &Dataset, + pending_batches: &[RecordBatch], + pending_schema: Option, + projection: Option<&[&str]>, + filter: Option<&str>, + ) -> Result> { + let mut out = self.scan(committed_ds, projection, filter, None).await?; + if pending_batches.is_empty() { + return Ok(out); + } + let pending = scan_pending_batches( + pending_batches, + pending_schema, + projection, + filter, + ) + .await?; + out.extend(pending); + Ok(out) + } + + /// `count_rows` variant that respects in-memory pending batches via + /// DataFusion `MemTable`. Used by edge-cardinality validation that + /// needs to see staged edges before commit. + /// + /// Cheaper than `scan_with_pending` for the count case because we + /// don't materialize columns on the pending side. + pub async fn count_rows_with_pending( + &self, + committed_ds: &Dataset, + pending_batches: &[RecordBatch], + pending_schema: Option, + filter: Option, + ) -> Result { + let committed = self.count_rows(committed_ds, filter.clone()).await?; + if pending_batches.is_empty() { + return Ok(committed); + } + // Count via DataFusion: COUNT(*) from pending [WHERE filter]. + let schema = + pending_schema.unwrap_or_else(|| pending_batches[0].schema()); + let ctx = datafusion::execution::context::SessionContext::new(); + let mem = datafusion::datasource::MemTable::try_new( + schema, + vec![pending_batches.to_vec()], + ) + .map_err(|e| OmniError::Lance(e.to_string()))?; + ctx.register_table("pending", Arc::new(mem)) + .map_err(|e| OmniError::Lance(e.to_string()))?; + let where_clause = filter + .map(|f| format!("WHERE {f}")) + .unwrap_or_default(); + let sql = format!("SELECT COUNT(*) AS c FROM pending {where_clause}"); + let df = ctx + .sql(&sql) + .await + .map_err(|e| OmniError::Lance(e.to_string()))?; + let batches = df + .collect() + .await + .map_err(|e| OmniError::Lance(e.to_string()))?; + let pending_count = batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .map(|b| { + use arrow_array::cast::AsArray; + let arr = b.column(0).as_primitive::(); + arr.value(0) as usize + }) + .sum::(); + Ok(committed + pending_count) + } + /// `count_rows` variant that respects staged writes. Used for /// edge-cardinality validation that needs to see staged edges before /// commit. Same `committed - removed + new` composition as @@ -1068,6 +1165,49 @@ fn assign_row_id_meta(fragments: &mut [Fragment], start_row_id: u64) -> Result<( Ok(()) } +/// Apply `projection` and `filter` to in-memory pending batches via a +/// fresh DataFusion `SessionContext`. Used by `scan_with_pending` for +/// the read-your-writes side of MR-794's in-memory accumulator. +/// +/// `pending_batches` must be non-empty (the caller short-circuits on +/// empty). +async fn scan_pending_batches( + pending_batches: &[RecordBatch], + pending_schema: Option, + projection: Option<&[&str]>, + filter: Option<&str>, +) -> Result> { + let schema = pending_schema.unwrap_or_else(|| pending_batches[0].schema()); + let ctx = datafusion::execution::context::SessionContext::new(); + let mem = datafusion::datasource::MemTable::try_new( + schema, + vec![pending_batches.to_vec()], + ) + .map_err(|e| OmniError::Lance(e.to_string()))?; + ctx.register_table("pending", Arc::new(mem)) + .map_err(|e| OmniError::Lance(e.to_string()))?; + + let proj = projection + .map(|cols| { + cols.iter() + .map(|c| format!("\"{}\"", c.replace('"', "\"\""))) + .collect::>() + .join(", ") + }) + .unwrap_or_else(|| "*".to_string()); + let where_clause = filter + .map(|f| format!("WHERE {f}")) + .unwrap_or_default(); + let sql = format!("SELECT {proj} FROM pending {where_clause}"); + let df = ctx + .sql(&sql) + .await + .map_err(|e| OmniError::Lance(e.to_string()))?; + df.collect() + .await + .map_err(|e| OmniError::Lance(e.to_string())) +} + fn combine_committed_with_staged(ds: &Dataset, staged: &[StagedWrite]) -> Vec { let removed: std::collections::HashSet = staged .iter()