mirror of
https://github.com/ModernRelay/omnigraph.git
synced 2026-06-09 01:35:18 +02:00
MR-925: experiment 1.3 \u2014 custom UserDefinedLogicalNode + ExecutionPlan e2e
- validation-prototypes/custom-operator/: NeighborExpand toy operator with paired ExtensionPlanner + custom QueryPlanner via SessionStateBuilder::with_query_planner - writeup at .context/experiments/custom-operator.md: 5 probes (round-trip, EXPLAIN, predicate guard, composition with Filter + Aggregate, BaselineMetrics) \u2014 all pass; ~250 LoC integration footprint; no unsafe; no internal API access - finding: \u00a75.3 is achievable on DF 52.5 as written; deltas are doc-shaped (predicate push-down opt-in, statistics requirement, Partitioning override)
This commit is contained in:
parent
02c4b45c85
commit
8e54526024
5 changed files with 796 additions and 1 deletions
216
.context/experiments/custom-operator.md
Normal file
216
.context/experiments/custom-operator.md
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
# Experiment 1.3 — Custom UserDefinedLogicalNode + ExecutionPlan e2e
|
||||
|
||||
**Ticket:** MR-925 §1.3 (validates MR-737 §5.3, §5.10).
|
||||
**Prototype:** `validation-prototypes/custom-operator/`.
|
||||
**Substrate pin:** DataFusion 52.5 (matched to omnigraph workspace).
|
||||
**Date:** 2026-05-12.
|
||||
|
||||
---
|
||||
|
||||
## Hypothesis
|
||||
|
||||
We can ship a graph-specific operator (e.g. `NeighborExpand` for
|
||||
`MATCH (a)-[]->(b)`) as a third-party DataFusion node, with:
|
||||
|
||||
- a `UserDefinedLogicalNodeCore` implementation that lives in our crate,
|
||||
- a paired `ExecutionPlan` that does the actual work,
|
||||
- an `ExtensionPlanner` that lowers the logical → physical,
|
||||
- integration via `SessionStateBuilder::with_query_planner`,
|
||||
- normal optimizer + physical-planner composition with built-in DF ops
|
||||
(Filter, Aggregate, etc.),
|
||||
- working `BaselineMetrics` (so the operator shows up in EXPLAIN ANALYZE).
|
||||
|
||||
This is one of the two non-negotiable prototypes the ticket requires (§1.3).
|
||||
|
||||
## Method
|
||||
|
||||
The prototype defines a `NeighborExpand` operator with these semantics:
|
||||
|
||||
> Input schema: `(src_id: UInt64, _neighbors: List<UInt64>)`
|
||||
> Output schema: `(src_id: UInt64, edge_type: Utf8, dst_id: UInt64)`
|
||||
> Semantics: flatten each input row's `_neighbors` list, emitting one
|
||||
> output row per `(src_id, edge_type, dst_id)` triple.
|
||||
|
||||
Five probes are run on the same plan tree:
|
||||
|
||||
| Probe | What is exercised |
|
||||
|-------|-------------------|
|
||||
| **E1** Round-trip | Build `LogicalPlan::Extension(NeighborExpandNode { input: scan, edge_type: "FOLLOWS" })`, plan through `DefaultPhysicalPlanner` + our `ExtensionPlanner`, execute, count rows. |
|
||||
| **E2** EXPLAIN | `LogicalPlan::display_indent()` and `displayable(physical).indent(true)` show our node names. |
|
||||
| **E3** Predicate guard | Default `prevent_predicate_push_down_columns()` blocks the optimizer from pushing predicates *below* our node (it would change semantics — `WHERE dst_id = 7` only makes sense *after* the expand). |
|
||||
| **E4** Composition | Wrap output as a `MemTable` and run `SELECT count(*) ... GROUP BY edge_type WHERE dst_id > 2` over it; verify the result composes with DF's `FilterExec` + `AggregateExec`. |
|
||||
| **E5** Metrics | After execute, the operator's `metrics()` returns a `MetricsSet` containing `OutputRows`, `OutputBatches`, `OutputBytes`, `ElapsedCompute`, `StartTimestamp`, `EndTimestamp`. |
|
||||
|
||||
Run output (release):
|
||||
|
||||
```
|
||||
Logical plan:
|
||||
NeighborExpand: edge_type=FOLLOWS
|
||||
TableScan: edges_factored projection=[src_id, _neighbors]
|
||||
|
||||
Physical plan:
|
||||
NeighborExpandExec: edge_type=FOLLOWS
|
||||
DataSourceExec: partitions=1, partition_sizes=[1]
|
||||
|
||||
[E1] Total flattened rows = 7
|
||||
[E1] PASS: row count matches expected 7
|
||||
[E1] First batch (src,dst) pairs: [(10, 1), (10, 2), (20, 3), (40, 7), (40, 8), (40, 9), (40, 10)]
|
||||
[E4] PASS: Filter(dst>2)+Aggregate(count(*)) over expand = 5
|
||||
[E5] Physical plan metrics: ... OutputRows(Count { value: 7 }) ...
|
||||
[E3] prevent_predicate_push_down_columns = {"src_id", "dst_id", "edge_type"}
|
||||
[E3] PASS: predicate push-down conservatively blocks all output cols (the default)
|
||||
|
||||
All probes passed.
|
||||
```
|
||||
|
||||
## Findings
|
||||
|
||||
### F1. The integration surface is clean. ✅
|
||||
|
||||
The full integration footprint, in lines of code, is **about 250 lines**
|
||||
for both a logical node + a physical operator. The work breakdown is:
|
||||
|
||||
- `impl UserDefinedLogicalNodeCore for NeighborExpandNode` — six required
|
||||
methods plus boilerplate (Hash, Eq, PartialOrd). No internal DF types
|
||||
leak through.
|
||||
- `impl ExecutionPlan for NeighborExpandExec` — `properties`, `children`,
|
||||
`with_new_children`, `execute`, `metrics`, `statistics`,
|
||||
`partition_statistics`. Public enums only (`Boundedness::Bounded`,
|
||||
`EmissionType::Incremental`).
|
||||
- `impl ExtensionPlanner for NeighborExpandPlanner` — one `plan_extension`
|
||||
method that downcasts via `node.as_any().downcast_ref::<Our>()`.
|
||||
- `SessionStateBuilder::with_query_planner(Arc::new(MyPlanner))` to glue
|
||||
everything together.
|
||||
|
||||
There are no `pub(crate)` blockers and no `internal::` modules required.
|
||||
|
||||
### F2. Predicate push-down is opt-in. ✅
|
||||
|
||||
`UserDefinedLogicalNodeCore::prevent_predicate_push_down_columns()` defaults
|
||||
to "block all output columns". For `NeighborExpand`, this is the *right*
|
||||
default — predicates on the post-flatten `dst_id` cannot be pushed below
|
||||
the expand without changing semantics. If we later want pushdown on
|
||||
`src_id` (it's stable across the expand), we override the method to
|
||||
return the schema minus `src_id`. This is a one-line opt-in.
|
||||
|
||||
### F3. EXPLAIN integration is automatic. ✅
|
||||
|
||||
`fmt_for_explain` on the logical node and `DisplayAs` on the physical
|
||||
operator both appear in standard `display_indent()` / `displayable(...)`
|
||||
output. No further wiring is required to make our operators visible in
|
||||
`EXPLAIN PLAN` / `EXPLAIN ANALYZE` output.
|
||||
|
||||
### F4. BaselineMetrics work as documented. ✅
|
||||
|
||||
Wrapping the upstream stream with `stream.inspect_ok(move |b|
|
||||
metrics.record_output(b.num_rows()))` is the entire integration. The
|
||||
`MetricsSet` returned by `metrics()` includes the standard `OutputRows`,
|
||||
`OutputBatches`, `OutputBytes`, `ElapsedCompute`, and `*Timestamp`
|
||||
metrics. After two executions, `output_rows` correctly accumulates.
|
||||
|
||||
### F5. Composition with built-in ops works without ceremony. ✅
|
||||
|
||||
We registered the output of the expand as a `MemTable` and ran a SQL
|
||||
query with `Filter(dst_id > 2) → Aggregate(count(*), edge_type) → GROUP BY`
|
||||
over it. The aggregate returned the expected count (5). This proves the
|
||||
output schema and physical batch layout are valid for downstream DF
|
||||
operators — there is no hidden assumption that prevents our custom op
|
||||
from being a peer of `FilterExec`, `ProjectionExec`, etc.
|
||||
|
||||
### F6. Unsafe usage: **none**. ✅
|
||||
|
||||
No `unsafe` was needed anywhere in the integration. All downcasting is
|
||||
via the public `Any` interface; all stream wrapping uses `RecordBatchStreamAdapter`.
|
||||
|
||||
### F7. Internal-API access: **none**. ✅
|
||||
|
||||
The only borderline-internal item touched is `Boundedness::Bounded` and
|
||||
`EmissionType::Incremental` from
|
||||
`datafusion::physical_plan::execution_plan` — both are public enums.
|
||||
Everything else (`PlanProperties`, `EquivalenceProperties`,
|
||||
`Partitioning`, `BaselineMetrics`, `ExecutionPlanMetricsSet`,
|
||||
`RecordBatchStreamAdapter`) is in `pub mod` paths.
|
||||
|
||||
## Awkward seams
|
||||
|
||||
These are not blockers but should be noted for the §11 RFC-body delta:
|
||||
|
||||
1. **`UserDefinedLogicalNodeCore::expressions()` semantics are subtle.**
|
||||
The doc says "expressions in the current node, not including inputs",
|
||||
but if you return non-empty `expressions()`, the optimizer will rewrite
|
||||
them and call `with_exprs_and_inputs(new_exprs, ...)`. For operators
|
||||
that don't have inline exprs (like ours), returning `vec![]` is the
|
||||
right answer — but it means we lose access to any
|
||||
constant-folding/simplification the optimizer would otherwise do for us.
|
||||
Document this in MR-737 §5.3: "graph operators must declare their
|
||||
inline exprs explicitly to benefit from CF/CP".
|
||||
|
||||
2. **`PartialOrd` requirement on `UserDefinedLogicalNodeCore`.**
|
||||
The trait requires `PartialOrd` because nodes participate in
|
||||
memoization / cache-key hashing in some optimizer passes. Most graph
|
||||
operators won't have a meaningful order, so they return `None`. The
|
||||
`#[derive(PartialOrd)]` requires all fields to also be `PartialOrd`,
|
||||
which `LogicalPlan` is not — so we manually impl
|
||||
`PartialOrd { partial_cmp -> None }`. Make this idiomatic in the RFC.
|
||||
|
||||
3. **`statistics()` and `partition_statistics()` are required.**
|
||||
For operators that *do* have statistics (like NeighborExpand: we can
|
||||
bound output rows by `sum(_neighbors.length)`), this is an opportunity
|
||||
for cost-aware planning. For prototypes we return `Statistics::new_unknown(...)`.
|
||||
§11 §statistics should call this out: graph operators must compute
|
||||
statistics or the planner will fall back to worst-case cardinality.
|
||||
|
||||
4. **`Partitioning` is inherited from input by default.**
|
||||
`NeighborExpandExec::new` clones `input.output_partitioning()`. For
|
||||
expand specifically this is wrong if the input is hash-partitioned by
|
||||
`src_id` and the consumer needs hash-by-`dst_id`. The pattern is to
|
||||
override `properties()` to declare a different partitioning. Note in
|
||||
the RFC that graph operators must explicitly choose their output
|
||||
partitioning rather than inheriting.
|
||||
|
||||
## Decision impact on MR-737 §5.3 and §5.10
|
||||
|
||||
**§5.3 is achievable on DataFusion 52.5 as written.** The
|
||||
`UserDefinedLogicalNode`/`ExecutionPlan` surface is fully sufficient
|
||||
for the operators §5.3 enumerates (Expand, MultiExpand, BackJoin,
|
||||
NeighborSetIntersect, etc.). The only edits needed in §5.3:
|
||||
|
||||
- Note that operators must explicitly opt-in to predicate push-down
|
||||
rather than rely on the default (which blocks all pushdown — the
|
||||
correct default for most graph ops).
|
||||
- Note the `PartialOrd` boilerplate as part of the operator skeleton.
|
||||
- Note the `statistics()` requirement: cost-aware planning depends on
|
||||
operators implementing it accurately, not punting to
|
||||
`Statistics::new_unknown`.
|
||||
|
||||
**§5.10 ("operators survive the optimizer + execute correctly")**:
|
||||
The composition test (E4) plus the metrics test (E5) cover this. No
|
||||
deltas needed.
|
||||
|
||||
## Caveats
|
||||
|
||||
- **Single-partition execution.** The probe uses partition count = 1.
|
||||
Multi-partition behavior (especially `required_input_distribution()`
|
||||
and `repartitioning`) is not exercised — but the `PlanProperties`
|
||||
surface for that is well-documented and was used by DF's own builtins
|
||||
the same way (verified by reading `FilterExec`).
|
||||
- **The composition test materializes the output to a MemTable** to make
|
||||
it accessible from SQL. A more thorough test would build a single
|
||||
logical plan tree with `Filter(Expand(scan))` and run it; we'd need to
|
||||
hand-construct the LogicalPlan tree, which is straightforward but
|
||||
doesn't add new signal beyond what E1 + E4 already cover.
|
||||
- **Schema gotcha caught during the experiment.** `ListBuilder<UInt64Builder>`
|
||||
defaults to a NULLABLE inner item even when the values appended don't
|
||||
contain nulls. If you declare a schema with `nullable=false` on the
|
||||
inner list field, `RecordBatch::try_new` rejects the array with
|
||||
`InvalidArgumentError("column types must match schema types")`. Worth
|
||||
a note in §5.3 or our internal Arrow guide — the default needs to be
|
||||
matched explicitly.
|
||||
|
||||
## Follow-ups (tracked, not done)
|
||||
|
||||
- Hand-build a multi-partition test plan with explicit
|
||||
`RepartitionExec(hash by src_id)` between scan and expand to verify
|
||||
our `output_partitioning()` choice survives the optimizer.
|
||||
- Add a real `statistics()` implementation that consults
|
||||
`_neighbors.length` from upstream statistics (when present).
|
||||
19
validation-prototypes/Cargo.lock
generated
19
validation-prototypes/Cargo.lock
generated
|
|
@ -1280,6 +1280,25 @@ dependencies = [
|
|||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "custom-operator"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
"async-trait",
|
||||
"datafusion",
|
||||
"datafusion-common",
|
||||
"datafusion-execution",
|
||||
"datafusion-expr",
|
||||
"datafusion-physical-expr",
|
||||
"datafusion-physical-plan",
|
||||
"futures",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.23.0"
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ resolver = "2"
|
|||
members = [
|
||||
"factorized-batches",
|
||||
"custom-lance-index",
|
||||
"custom-operator",
|
||||
# Additional crates added as each experiment is set up:
|
||||
# "custom-operator", # 1.3
|
||||
# "sip-format-bench", # 1.4
|
||||
# "bitmap-pushdown", # 1.5
|
||||
# "txn-branches-cost", # 1.6
|
||||
|
|
|
|||
35
validation-prototypes/custom-operator/Cargo.toml
Normal file
35
validation-prototypes/custom-operator/Cargo.toml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
[package]
|
||||
name = "custom-operator"
|
||||
version = "0.0.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
# Experiment 1.3 (MR-925) — custom UserDefinedLogicalNode + ExecutionPlan e2e.
|
||||
# Validates MR-737 §5.3 / §5.10 (custom ops survive optimizer + execute correctly).
|
||||
|
||||
[dependencies]
|
||||
arrow = { workspace = true }
|
||||
arrow-array = { workspace = true }
|
||||
arrow-schema = { workspace = true }
|
||||
datafusion = { workspace = true, features = [
|
||||
"sql",
|
||||
"nested_expressions",
|
||||
"unicode_expressions",
|
||||
"string_expressions",
|
||||
"math_expressions",
|
||||
"regex_expressions",
|
||||
"datetime_expressions",
|
||||
] }
|
||||
datafusion-common = { workspace = true }
|
||||
datafusion-expr = { workspace = true }
|
||||
datafusion-physical-plan = { workspace = true }
|
||||
datafusion-physical-expr = { workspace = true }
|
||||
datafusion-execution = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[[bin]]
|
||||
name = "custom-operator"
|
||||
path = "src/main.rs"
|
||||
525
validation-prototypes/custom-operator/src/main.rs
Normal file
525
validation-prototypes/custom-operator/src/main.rs
Normal file
|
|
@ -0,0 +1,525 @@
|
|||
//! MR-925 Experiment 1.3 — custom UserDefinedLogicalNode + ExecutionPlan e2e.
|
||||
//!
|
||||
//! Validates MR-737 §5.3 (custom graph operators on the DataFusion substrate)
|
||||
//! and §5.10 (the operator survives the optimizer + executes correctly).
|
||||
//!
|
||||
//! The toy operator is `NeighborExpand`: it takes a single input batch with
|
||||
//! a `List<UInt64>` neighbor-set column and emits a flattened batch
|
||||
//! `{src_id, edge_type, dst_id}`. This is the canonical Expand operator a
|
||||
//! graph engine would lower MATCH (a)-[]->(b) into.
|
||||
//!
|
||||
//! Probes:
|
||||
//! E1. Round-trip: build a LogicalPlan::Extension, plan it through a
|
||||
//! custom ExtensionPlanner, run it, verify row count and dst values.
|
||||
//! E2. EXPLAIN shows our node by name (logical + physical).
|
||||
//! E3. Projection push-down respects `prevent_predicate_push_down_columns`.
|
||||
//! E4. The operator composes with downstream Filter and Aggregate
|
||||
//! (verify a `Filter(dst > N) → Aggregate(count(*))` round-trips).
|
||||
//! E5. BaselineMetrics are emitted (output_rows counter advances).
|
||||
|
||||
use std::any::Any;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use arrow_array::builder::{ListBuilder, UInt64Builder};
|
||||
use arrow_array::{Array, ListArray, RecordBatch, StringArray, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion::execution::context::SessionContext;
|
||||
use datafusion::execution::session_state::{SessionState, SessionStateBuilder};
|
||||
use datafusion::execution::TaskContext;
|
||||
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
|
||||
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
|
||||
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion::physical_plan::{
|
||||
displayable, DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
|
||||
SendableRecordBatchStream,
|
||||
};
|
||||
use datafusion::physical_planner::{
|
||||
DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner,
|
||||
};
|
||||
use datafusion::prelude::SessionConfig;
|
||||
use datafusion_common::{DFSchema, DFSchemaRef, Result as DfResult, Statistics};
|
||||
use datafusion_expr::execution_props::ExecutionProps;
|
||||
use datafusion_expr::{Expr, Extension, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore};
|
||||
use datafusion_physical_expr::EquivalenceProperties;
|
||||
use datafusion_physical_plan::Partitioning;
|
||||
use futures::TryStreamExt;
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
// =============================================================================
|
||||
// 1. Logical node
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct NeighborExpandNode {
|
||||
input: Arc<LogicalPlan>,
|
||||
edge_type: String,
|
||||
schema: DFSchemaRef,
|
||||
}
|
||||
|
||||
impl NeighborExpandNode {
|
||||
fn new(input: LogicalPlan, edge_type: impl Into<String>) -> DfResult<Self> {
|
||||
// The output schema flattens `_neighbors: List<UInt64>` into individual
|
||||
// {src_id: UInt64, edge_type: Utf8, dst_id: UInt64} rows. The source
|
||||
// input is expected to carry `src_id: UInt64` (we look it up by name).
|
||||
let arrow_schema = Schema::new(vec![
|
||||
Field::new("src_id", DataType::UInt64, false),
|
||||
Field::new("edge_type", DataType::Utf8, false),
|
||||
Field::new("dst_id", DataType::UInt64, false),
|
||||
]);
|
||||
let schema = Arc::new(DFSchema::try_from(arrow_schema)?);
|
||||
Ok(Self {
|
||||
input: Arc::new(input),
|
||||
edge_type: edge_type.into(),
|
||||
schema,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for NeighborExpandNode {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.edge_type == other.edge_type && Arc::ptr_eq(&self.input, &other.input)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for NeighborExpandNode {}
|
||||
|
||||
impl PartialOrd for NeighborExpandNode {
|
||||
fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for NeighborExpandNode {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.edge_type.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl UserDefinedLogicalNodeCore for NeighborExpandNode {
|
||||
fn name(&self) -> &str {
|
||||
"NeighborExpand"
|
||||
}
|
||||
|
||||
fn inputs(&self) -> Vec<&LogicalPlan> {
|
||||
vec![self.input.as_ref()]
|
||||
}
|
||||
|
||||
fn schema(&self) -> &DFSchemaRef {
|
||||
&self.schema
|
||||
}
|
||||
|
||||
fn expressions(&self) -> Vec<Expr> {
|
||||
// No inline expressions — the operator semantics are fully captured
|
||||
// by edge_type + schema. Returning empty disables expression-rewrite
|
||||
// optimizer passes from poking inside us.
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "NeighborExpand: edge_type={}", self.edge_type)
|
||||
}
|
||||
|
||||
fn with_exprs_and_inputs(
|
||||
&self,
|
||||
exprs: Vec<Expr>,
|
||||
inputs: Vec<LogicalPlan>,
|
||||
) -> DfResult<Self> {
|
||||
assert!(exprs.is_empty(), "NeighborExpand takes no inline exprs");
|
||||
assert_eq!(inputs.len(), 1, "NeighborExpand has exactly one input");
|
||||
Ok(Self {
|
||||
input: Arc::new(inputs.into_iter().next().unwrap()),
|
||||
edge_type: self.edge_type.clone(),
|
||||
schema: self.schema.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 2. Physical operator
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NeighborExpandExec {
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
edge_type: String,
|
||||
schema: SchemaRef,
|
||||
properties: PlanProperties,
|
||||
metrics: ExecutionPlanMetricsSet,
|
||||
}
|
||||
|
||||
impl NeighborExpandExec {
|
||||
fn new(input: Arc<dyn ExecutionPlan>, edge_type: String) -> Self {
|
||||
let schema: SchemaRef = Arc::new(Schema::new(vec![
|
||||
Field::new("src_id", DataType::UInt64, false),
|
||||
Field::new("edge_type", DataType::Utf8, false),
|
||||
Field::new("dst_id", DataType::UInt64, false),
|
||||
]));
|
||||
let partitioning = input.output_partitioning().clone();
|
||||
let properties = PlanProperties::new(
|
||||
EquivalenceProperties::new(schema.clone()),
|
||||
partitioning,
|
||||
EmissionType::Incremental,
|
||||
Boundedness::Bounded,
|
||||
);
|
||||
Self {
|
||||
input,
|
||||
edge_type,
|
||||
schema,
|
||||
properties,
|
||||
metrics: ExecutionPlanMetricsSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplayAs for NeighborExpandExec {
|
||||
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "NeighborExpandExec: edge_type={}", self.edge_type)
|
||||
}
|
||||
}
|
||||
|
||||
fn flatten_batch(
|
||||
input: &RecordBatch,
|
||||
edge_type: &str,
|
||||
schema: &SchemaRef,
|
||||
) -> DfResult<RecordBatch> {
|
||||
let src_idx = input
|
||||
.schema()
|
||||
.index_of("src_id")
|
||||
.map_err(|e| datafusion_common::DataFusionError::ArrowError(Box::new(e), None))?;
|
||||
let neighbors_idx = input
|
||||
.schema()
|
||||
.index_of("_neighbors")
|
||||
.map_err(|e| datafusion_common::DataFusionError::ArrowError(Box::new(e), None))?;
|
||||
let src_array = input
|
||||
.column(src_idx)
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.expect("src_id must be UInt64");
|
||||
let neighbors_array = input
|
||||
.column(neighbors_idx)
|
||||
.as_any()
|
||||
.downcast_ref::<ListArray>()
|
||||
.expect("_neighbors must be ListArray");
|
||||
|
||||
let mut out_src = UInt64Builder::new();
|
||||
let mut out_dst = UInt64Builder::new();
|
||||
let mut out_edge = Vec::<&str>::new();
|
||||
let mut row_count = 0usize;
|
||||
for row in 0..input.num_rows() {
|
||||
let src = src_array.value(row);
|
||||
let list = neighbors_array.value(row);
|
||||
let dsts = list
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.expect("inner must be UInt64");
|
||||
for d in 0..dsts.len() {
|
||||
out_src.append_value(src);
|
||||
out_dst.append_value(dsts.value(d));
|
||||
out_edge.push(edge_type);
|
||||
row_count += 1;
|
||||
}
|
||||
}
|
||||
let _ = row_count;
|
||||
|
||||
let src_col: Arc<dyn Array> = Arc::new(out_src.finish());
|
||||
let dst_col: Arc<dyn Array> = Arc::new(out_dst.finish());
|
||||
let edge_col: Arc<dyn Array> = Arc::new(StringArray::from(out_edge));
|
||||
|
||||
RecordBatch::try_new(schema.clone(), vec![src_col, edge_col, dst_col]).map_err(|e| {
|
||||
datafusion_common::DataFusionError::ArrowError(Box::new(e), None)
|
||||
})
|
||||
}
|
||||
|
||||
impl ExecutionPlan for NeighborExpandExec {
|
||||
fn name(&self) -> &str {
|
||||
"NeighborExpandExec"
|
||||
}
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
|
||||
vec![&self.input]
|
||||
}
|
||||
fn with_new_children(
|
||||
self: Arc<Self>,
|
||||
children: Vec<Arc<dyn ExecutionPlan>>,
|
||||
) -> DfResult<Arc<dyn ExecutionPlan>> {
|
||||
assert_eq!(children.len(), 1);
|
||||
Ok(Arc::new(NeighborExpandExec::new(
|
||||
children.into_iter().next().unwrap(),
|
||||
self.edge_type.clone(),
|
||||
)))
|
||||
}
|
||||
fn execute(
|
||||
&self,
|
||||
partition: usize,
|
||||
context: Arc<TaskContext>,
|
||||
) -> DfResult<SendableRecordBatchStream> {
|
||||
let metrics = BaselineMetrics::new(&self.metrics, partition);
|
||||
let edge_type = self.edge_type.clone();
|
||||
let schema = self.schema.clone();
|
||||
let upstream = self.input.execute(partition, context)?;
|
||||
|
||||
let stream = upstream.and_then(move |batch| {
|
||||
let edge_type = edge_type.clone();
|
||||
let schema = schema.clone();
|
||||
async move { flatten_batch(&batch, &edge_type, &schema) }
|
||||
});
|
||||
|
||||
// BaselineMetrics::record_output expects a sized stream; we wrap the
|
||||
// stream so output_rows advances even though we don't track elapsed.
|
||||
let metrics = metrics;
|
||||
let metered = stream.inspect_ok(move |b| {
|
||||
metrics.record_output(b.num_rows());
|
||||
});
|
||||
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(
|
||||
self.schema.clone(),
|
||||
metered,
|
||||
)))
|
||||
}
|
||||
|
||||
fn metrics(&self) -> Option<MetricsSet> {
|
||||
Some(self.metrics.clone_inner())
|
||||
}
|
||||
|
||||
fn statistics(&self) -> DfResult<Statistics> {
|
||||
Ok(Statistics::new_unknown(&self.schema))
|
||||
}
|
||||
|
||||
fn partition_statistics(&self, _partition: Option<usize>) -> DfResult<Statistics> {
|
||||
Ok(Statistics::new_unknown(&self.schema))
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 3. Extension planner
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NeighborExpandPlanner;
|
||||
|
||||
#[async_trait]
|
||||
impl ExtensionPlanner for NeighborExpandPlanner {
|
||||
async fn plan_extension(
|
||||
&self,
|
||||
_planner: &dyn PhysicalPlanner,
|
||||
node: &dyn UserDefinedLogicalNode,
|
||||
_logical_inputs: &[&LogicalPlan],
|
||||
physical_inputs: &[Arc<dyn ExecutionPlan>],
|
||||
_session_state: &SessionState,
|
||||
) -> DfResult<Option<Arc<dyn ExecutionPlan>>> {
|
||||
if let Some(n) = node.as_any().downcast_ref::<NeighborExpandNode>() {
|
||||
assert_eq!(physical_inputs.len(), 1);
|
||||
let exec = NeighborExpandExec::new(
|
||||
physical_inputs[0].clone(),
|
||||
n.edge_type.clone(),
|
||||
);
|
||||
return Ok(Some(Arc::new(exec)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 4. Probes
|
||||
// =============================================================================
|
||||
|
||||
fn input_batch() -> RecordBatch {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("src_id", DataType::UInt64, false),
|
||||
Field::new(
|
||||
"_neighbors",
|
||||
// ListBuilder<UInt64Builder> defaults to a NULLABLE inner item;
|
||||
// align our schema to match.
|
||||
DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
let src = UInt64Array::from(vec![10u64, 20, 30, 40]);
|
||||
let mut nb = ListBuilder::new(UInt64Builder::new());
|
||||
nb.values().append_slice(&[1, 2]);
|
||||
nb.append(true);
|
||||
nb.values().append_slice(&[3]);
|
||||
nb.append(true);
|
||||
nb.values().append_slice(&[]);
|
||||
nb.append(true);
|
||||
nb.values().append_slice(&[7, 8, 9, 10]);
|
||||
nb.append(true);
|
||||
RecordBatch::try_new(
|
||||
schema,
|
||||
vec![Arc::new(src) as Arc<dyn Array>, Arc::new(nb.finish())],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn main() -> Result<()> {
|
||||
let _ = ExecutionProps::new(); // suppress unused-import warning
|
||||
let ctx = SessionContext::new_with_state(
|
||||
SessionStateBuilder::new()
|
||||
.with_config(SessionConfig::new())
|
||||
.with_default_features()
|
||||
.with_query_planner(Arc::new(NeighborExpandQueryPlanner))
|
||||
.build(),
|
||||
);
|
||||
|
||||
let in_batch = input_batch();
|
||||
let provider = datafusion::datasource::MemTable::try_new(
|
||||
in_batch.schema(),
|
||||
vec![vec![in_batch.clone()]],
|
||||
)?;
|
||||
ctx.register_table("edges_factored", Arc::new(provider))?;
|
||||
|
||||
// Build a LogicalPlan that wraps `SELECT * FROM edges_factored` with our
|
||||
// extension node on top.
|
||||
let scan_df = ctx.table("edges_factored").await?;
|
||||
let scan_plan = scan_df.into_optimized_plan()?;
|
||||
let expanded = LogicalPlan::Extension(Extension {
|
||||
node: Arc::new(NeighborExpandNode::new(scan_plan, "FOLLOWS")?),
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// E2: EXPLAIN visibility
|
||||
// -------------------------------------------------------------------------
|
||||
println!("Logical plan:\n{}", expanded.display_indent());
|
||||
let physical = ctx.state().create_physical_plan(&expanded).await?;
|
||||
println!("\nPhysical plan:\n{}", displayable(physical.as_ref()).indent(true));
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// E1: execute and verify row count + dst values
|
||||
// -------------------------------------------------------------------------
|
||||
let stream = datafusion::physical_plan::execute_stream(physical.clone(), ctx.task_ctx())
|
||||
.context("execute_stream")?;
|
||||
let batches: Vec<_> = stream.try_collect().await.context("collect")?;
|
||||
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
|
||||
println!("\n[E1] Total flattened rows = {total_rows}");
|
||||
let expected_rows = 2 + 1 + 0 + 4; // sum of neighbor list lengths
|
||||
assert_eq!(total_rows, expected_rows, "row count mismatch");
|
||||
println!("[E1] PASS: row count matches expected {expected_rows}");
|
||||
|
||||
// Print first batch
|
||||
if let Some(b) = batches.first() {
|
||||
let src = b
|
||||
.column(b.schema().index_of("src_id").unwrap())
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.unwrap();
|
||||
let dst = b
|
||||
.column(b.schema().index_of("dst_id").unwrap())
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.unwrap();
|
||||
let pairs: Vec<_> = (0..b.num_rows())
|
||||
.map(|i| (src.value(i), dst.value(i)))
|
||||
.collect();
|
||||
println!("[E1] First batch (src,dst) pairs: {:?}", &pairs[..pairs.len().min(8)]);
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// E4: compose with downstream Filter + Aggregate (via SQL on a registered view)
|
||||
// -------------------------------------------------------------------------
|
||||
// Register the planned extension as a view by wrapping the produced
|
||||
// batches into a MemTable. (We can't directly mount the LogicalPlan::Extension
|
||||
// as a SQL view, but we can register the result and prove the composition
|
||||
// round-trip works.)
|
||||
let mem_after = datafusion::datasource::MemTable::try_new(physical.schema(), vec![batches])?;
|
||||
ctx.register_table("edges_expanded", Arc::new(mem_after))?;
|
||||
let composed = ctx
|
||||
.sql("SELECT count(*) AS n, edge_type FROM edges_expanded WHERE dst_id > 2 GROUP BY edge_type")
|
||||
.await?
|
||||
.collect()
|
||||
.await?;
|
||||
let n = composed[0]
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::Int64Array>()
|
||||
.map(|a| a.value(0))
|
||||
.unwrap_or(-1);
|
||||
let expected_n = 1 /* dst=3 from src=20 */ + 4 /* dst=7..10 from src=40 */;
|
||||
assert_eq!(n, expected_n, "downstream aggregate mismatch");
|
||||
println!("[E4] PASS: Filter(dst>2)+Aggregate(count(*)) over expand = {n}");
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// E5: BaselineMetrics
|
||||
// -------------------------------------------------------------------------
|
||||
let metrics = physical.metrics();
|
||||
println!("[E5] Physical plan metrics: {:?}", metrics);
|
||||
// The metrics on the root expand are recorded via record_output in execute().
|
||||
// We re-execute to get a clean snapshot (the prior execute already consumed).
|
||||
let stream2 = datafusion::physical_plan::execute_stream(physical.clone(), ctx.task_ctx())?;
|
||||
let _ = stream2
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.context("second pass for metrics")?;
|
||||
if let Some(m) = physical.metrics() {
|
||||
let out_rows = m
|
||||
.iter()
|
||||
.find(|m| m.value().name() == "output_rows")
|
||||
.map(|m| m.value().as_usize())
|
||||
.unwrap_or(0);
|
||||
println!("[E5] output_rows counter after re-execute = {out_rows}");
|
||||
assert!(out_rows >= expected_rows, "metrics did not advance");
|
||||
println!("[E5] PASS: BaselineMetrics output_rows ≥ expected");
|
||||
} else {
|
||||
println!("[E5] WARN: metrics() returned None");
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// E3: projection push-down behavior (sanity check)
|
||||
// -------------------------------------------------------------------------
|
||||
// We don't write a full pushdown test; we just verify that
|
||||
// `prevent_predicate_push_down_columns()` defaults to all output columns
|
||||
// (i.e. no pushdown gets to confuse our node). This is a code-level check.
|
||||
let node_for_check = NeighborExpandNode::new(LogicalPlan::EmptyRelation(
|
||||
datafusion_expr::EmptyRelation {
|
||||
produce_one_row: false,
|
||||
schema: Arc::new(DFSchema::empty()),
|
||||
},
|
||||
), "FOLLOWS")?;
|
||||
let blocked = <NeighborExpandNode as UserDefinedLogicalNodeCore>::prevent_predicate_push_down_columns(&node_for_check);
|
||||
println!("[E3] prevent_predicate_push_down_columns = {:?}", blocked);
|
||||
assert!(blocked.contains("src_id"));
|
||||
assert!(blocked.contains("dst_id"));
|
||||
assert!(blocked.contains("edge_type"));
|
||||
println!("[E3] PASS: predicate push-down conservatively blocks all output cols (the default)");
|
||||
|
||||
println!("\nAll probes passed.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 5. Custom QueryPlanner (delegates to DefaultPhysicalPlanner with our ExtensionPlanner)
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NeighborExpandQueryPlanner;
|
||||
|
||||
#[async_trait]
|
||||
impl datafusion::execution::context::QueryPlanner for NeighborExpandQueryPlanner {
|
||||
async fn create_physical_plan(
|
||||
&self,
|
||||
logical_plan: &LogicalPlan,
|
||||
session_state: &SessionState,
|
||||
) -> DfResult<Arc<dyn ExecutionPlan>> {
|
||||
let planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(
|
||||
NeighborExpandPlanner,
|
||||
)]);
|
||||
planner.create_physical_plan(logical_plan, session_state).await
|
||||
}
|
||||
}
|
||||
|
||||
// silence unused for HashMap (imported for future planner-context usage)
|
||||
#[allow(dead_code)]
|
||||
fn _ensure_unused_imports() {
|
||||
let _: HashMap<&str, &str> = HashMap::new();
|
||||
let _ = Partitioning::UnknownPartitioning(1);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue