From 351610d18c105862fc8702a8164e381c30dd8a4d Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 12 Apr 2026 20:59:13 +0000 Subject: [PATCH] Implement aggregate execution with wide-batch model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add runtime support for aggregate functions (count, sum, avg, min, max) with GROUP BY semantics, built on a single wide RecordBatch that eliminates correlation tracking by construction. Execution engine (exec/query.rs): - Replace HashMap with Option where columns are prefixed as . - NodeScan prefixes columns and cross-joins with existing batch - Expand collects (src_row, dst_id) pairs, takes wide batch rows, appends prefixed destination columns via hconcat - Filter applies single mask to entire wide batch - AntiJoin: fast-path returns BooleanArray mask; slow-path slices one row for inner pipeline execution Projection engine (exec/projection.rs): - aggregate_return groups rows by non-aggregate key columns using length-prefixed string encoding, computes per-group aggregates - SUM accumulates into f64 to avoid integer overflow - MIN/MAX support both numeric and string types - Empty input returns count=0, others=null Compiler (typecheck.rs): - T8: split MIN/MAX from SUM/AVG — allow string arguments - T9: non-aggregate expressions in aggregate queries must be property accesses or variables - SUM type inference returns Float64 (matching runtime) Tests: 8 new integration tests covering grouped count, global count, sum/avg/min/max per company, aggregate+order+limit, string min/max, multi-hop aggregates, and edge cases. https://claude.ai/code/session_019o5NRyYomgETFyd7hpiLey --- .../omnigraph-compiler/src/query/typecheck.rs | 42 +- crates/omnigraph/src/exec/mod.rs | 2 +- crates/omnigraph/src/exec/projection.rs | 378 +++++++++++++++--- crates/omnigraph/src/exec/query.rs | 261 +++++++----- crates/omnigraph/tests/aggregation.rs | 302 ++++++++++++++ crates/omnigraph/tests/fixtures/test.gq | 35 ++ 6 files changed, 868 insertions(+), 152 deletions(-) create mode 100644 crates/omnigraph/tests/aggregation.rs diff --git a/crates/omnigraph-compiler/src/query/typecheck.rs b/crates/omnigraph-compiler/src/query/typecheck.rs index 6010661..5364897 100644 --- a/crates/omnigraph-compiler/src/query/typecheck.rs +++ b/crates/omnigraph-compiler/src/query/typecheck.rs @@ -189,6 +189,29 @@ fn typecheck_read_query(catalog: &Catalog, query: &QueryDecl) -> Result {} + _ => { + return Err(NanoError::Type( + "T9: non-aggregate expressions in an aggregate query must be \ + property accesses or variables" + .to_string(), + )); + } + } + } + } + } + Ok(ctx) } @@ -1298,9 +1321,9 @@ fn resolve_expr_type( Expr::Aggregate { func, arg } => { let arg_type = resolve_expr_type(catalog, arg, ctx, params)?; - // T8: sum/avg/min/max require numeric + // T8: sum/avg require numeric; min/max require numeric or string match func { - AggFunc::Sum | AggFunc::Avg | AggFunc::Min | AggFunc::Max => { + AggFunc::Sum | AggFunc::Avg => { if let ResolvedType::Scalar(s) = &arg_type && (s.list || !s.scalar.is_numeric()) { @@ -1311,6 +1334,17 @@ fn resolve_expr_type( ))); } } + AggFunc::Min | AggFunc::Max => { + if let ResolvedType::Scalar(s) = &arg_type + && (s.list || (!s.scalar.is_numeric() && s.scalar != ScalarType::String)) + { + return Err(NanoError::Type(format!( + "T8: {} requires numeric or string type, got {}", + func, + s.display_name() + ))); + } + } _ => {} // count works on any type } @@ -1340,8 +1374,8 @@ fn infer_projection_field( Expr::Aggregate { func, arg } => { let (data_type, nullable) = match func { AggFunc::Count => (DataType::Int64, true), - AggFunc::Avg => (DataType::Float64, true), - _ => { + AggFunc::Avg | AggFunc::Sum => (DataType::Float64, true), + AggFunc::Min | AggFunc::Max => { let resolved = resolve_expr_type(catalog, arg, ctx, params)?; let (data_type, _) = resolved_type_to_field_shape(catalog, &resolved)?; (data_type, true) diff --git a/crates/omnigraph/src/exec/mod.rs b/crates/omnigraph/src/exec/mod.rs index 59059e3..8171461 100644 --- a/crates/omnigraph/src/exec/mod.rs +++ b/crates/omnigraph/src/exec/mod.rs @@ -25,7 +25,7 @@ use omnigraph_compiler::ir::{ }; use omnigraph_compiler::lower_mutation_query; use omnigraph_compiler::lower_query; -use omnigraph_compiler::query::ast::{CompOp, Literal, NOW_PARAM_NAME}; +use omnigraph_compiler::query::ast::{AggFunc, CompOp, Literal, NOW_PARAM_NAME}; use omnigraph_compiler::query::typecheck::{CheckedQuery, typecheck_query, typecheck_query_decl}; use omnigraph_compiler::result::{MutationResult, QueryResult}; use omnigraph_compiler::types::Direction; diff --git a/crates/omnigraph/src/exec/projection.rs b/crates/omnigraph/src/exec/projection.rs index e889207..73d5250 100644 --- a/crates/omnigraph/src/exec/projection.rs +++ b/crates/omnigraph/src/exec/projection.rs @@ -1,25 +1,14 @@ use super::*; pub(super) fn apply_filter( - bindings: &mut HashMap, + batch: &mut RecordBatch, filter: &IRFilter, params: &ParamMap, ) -> Result<()> { - // Find which binding this filter applies to - let var_name = match &filter.left { - IRExpr::PropAccess { variable, .. } => variable.clone(), - _ => return Ok(()), // Can't determine variable - }; - - let batch = bindings.get(&var_name).ok_or_else(|| { - OmniError::manifest(format!("filter references unbound variable '{}'", var_name)) - })?; - let mask = evaluate_filter(batch, filter, params)?; let filtered = arrow_select::filter::filter_record_batch(batch, &mask) .map_err(|e| OmniError::Lance(e.to_string()))?; - - bindings.insert(var_name, filtered); + *batch = filtered; Ok(()) } @@ -59,12 +48,13 @@ fn evaluate_filter( Ok(result) } -/// Evaluate an IR expression against a batch, producing an array. +/// Evaluate an IR expression against a wide batch, producing an array. fn evaluate_expr(batch: &RecordBatch, expr: &IRExpr, params: &ParamMap) -> Result { match expr { - IRExpr::PropAccess { property, .. } => { - batch.column_by_name(property).cloned().ok_or_else(|| { - OmniError::manifest(format!("column '{}' not found in batch", property)) + IRExpr::PropAccess { variable, property } => { + let col_name = format!("{}.{}", variable, property); + batch.column_by_name(&col_name).cloned().ok_or_else(|| { + OmniError::manifest(format!("column '{}' not found in wide batch", col_name)) }) } IRExpr::Literal(lit) => literal_to_array(lit, batch.num_rows()), @@ -307,7 +297,7 @@ fn literal_scalar_type(lit: &Literal) -> Result { /// Project return expressions into a result batch. pub(super) fn project_return( - bindings: &HashMap, + wide_batch: &RecordBatch, projections: &[IRProjection], params: &ParamMap, ) -> Result { @@ -317,11 +307,19 @@ pub(super) fn project_return( )); } + // Route to aggregate path if any projection contains an aggregate + let has_aggregates = projections + .iter() + .any(|p| matches!(&p.expr, IRExpr::Aggregate { .. })); + if has_aggregates { + return aggregate_return(wide_batch, projections, params); + } + let mut fields = Vec::with_capacity(projections.len()); let mut columns: Vec = Vec::with_capacity(projections.len()); for proj in projections { - let (name, col) = evaluate_projection(bindings, &proj.expr, params)?; + let (name, col) = evaluate_projection(wide_batch, &proj.expr, params)?; let field_name = proj.alias.as_deref().unwrap_or(&name); fields.push(Field::new( field_name, @@ -335,42 +333,41 @@ pub(super) fn project_return( RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string())) } -/// Evaluate a single projection expression. +/// Evaluate a single projection expression against a wide batch. fn evaluate_projection( - bindings: &HashMap, + wide_batch: &RecordBatch, expr: &IRExpr, params: &ParamMap, ) -> Result<(String, ArrayRef)> { match expr { IRExpr::PropAccess { variable, property } => { - let batch = bindings.get(variable).ok_or_else(|| { + let col_name = format!("{}.{}", variable, property); + let col = wide_batch.column_by_name(&col_name).ok_or_else(|| { OmniError::manifest(format!( - "projection references unbound variable '{}'", - variable + "column '{}' not found in wide batch", + col_name )) })?; - let col = batch.column_by_name(property).ok_or_else(|| { - OmniError::manifest(format!( - "column '{}' not found in binding '{}'", - property, variable - )) - })?; - Ok((format!("{}.{}", variable, property), col.clone())) + Ok((col_name, col.clone())) } IRExpr::Literal(lit) => { - // Get row count from first binding - let num_rows = bindings.values().next().map(|b| b.num_rows()).unwrap_or(0); - let arr = literal_to_array(lit, num_rows)?; + let arr = literal_to_array(lit, wide_batch.num_rows())?; Ok(("literal".to_string(), arr)) } IRExpr::Param(name) => { let lit = params .get(name) .ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?; - let num_rows = bindings.values().next().map(|b| b.num_rows()).unwrap_or(0); - let arr = literal_to_array(lit, num_rows)?; + let arr = literal_to_array(lit, wide_batch.num_rows())?; Ok((name.clone(), arr)) } + IRExpr::Variable(name) => { + let col_name = format!("{}.id", name); + let col = wide_batch.column_by_name(&col_name).ok_or_else(|| { + OmniError::manifest(format!("column '{}' not found in wide batch", col_name)) + })?; + Ok((name.clone(), col.clone())) + } _ => Err(OmniError::manifest(format!( "unsupported projection expression: {:?}", expr @@ -378,11 +375,12 @@ fn evaluate_projection( } } -/// Apply ordering to a batch. +/// Apply ordering to a batch. `source` is used to resolve PropAccess columns +/// (typically the wide batch, or the result batch itself for aggregates). pub(super) fn apply_ordering( batch: RecordBatch, orderings: &[IROrdering], - bindings: &HashMap, + source: &RecordBatch, _params: &ParamMap, ) -> Result { use arrow_ord::sort::{SortColumn, lexsort_to_indices}; @@ -392,16 +390,11 @@ pub(super) fn apply_ordering( for ordering in orderings { let col = match &ordering.expr { IRExpr::PropAccess { variable, property } => { - let binding = bindings.get(variable).ok_or_else(|| { - OmniError::manifest(format!( - "ordering references unbound variable '{}'", - variable - )) - })?; - binding - .column_by_name(property) + let col_name = format!("{}.{}", variable, property); + source + .column_by_name(&col_name) .ok_or_else(|| { - OmniError::manifest(format!("column '{}' not found for ordering", property)) + OmniError::manifest(format!("column '{}' not found for ordering", col_name)) })? .clone() } @@ -442,3 +435,294 @@ pub(super) fn apply_ordering( RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string())) } + +// ─── Aggregate execution ─────────────────────────────────────────────────── + +/// Project return expressions that contain aggregates. +fn aggregate_return( + wide: &RecordBatch, + projections: &[IRProjection], + params: &ParamMap, +) -> Result { + let num_rows = wide.num_rows(); + + struct GroupKey { + proj_idx: usize, + name: String, + column: ArrayRef, + } + struct AggProj { + proj_idx: usize, + name: String, + func: AggFunc, + arg_column: ArrayRef, + } + + let mut group_keys: Vec = Vec::new(); + let mut agg_projs: Vec = Vec::new(); + + for (i, proj) in projections.iter().enumerate() { + match &proj.expr { + IRExpr::Aggregate { func, arg } => { + let (name, col) = evaluate_projection(wide, arg, params)?; + let alias = proj.alias.as_deref().unwrap_or(&name); + agg_projs.push(AggProj { + proj_idx: i, + name: alias.to_string(), + func: *func, + arg_column: col, + }); + } + _ => { + let (name, col) = evaluate_projection(wide, &proj.expr, params)?; + let alias = proj.alias.as_deref().unwrap_or(&name); + group_keys.push(GroupKey { + proj_idx: i, + name: alias.to_string(), + column: col, + }); + } + } + } + + // Handle empty input: return a single row with count=0, others=null + if num_rows == 0 && group_keys.is_empty() { + return build_empty_aggregate_result(projections); + } + + // Build group assignments + let mut group_map: HashMap = HashMap::new(); + let mut group_indices: Vec> = Vec::new(); + let group_cols: Vec<&ArrayRef> = group_keys.iter().map(|gk| &gk.column).collect(); + + if group_keys.is_empty() { + group_indices.push((0..num_rows).collect()); + } else { + for row in 0..num_rows { + let key = build_group_key(&group_cols, row); + let group_idx = match group_map.get(&key) { + Some(&idx) => idx, + None => { + let idx = group_indices.len(); + group_map.insert(key, idx); + group_indices.push(Vec::new()); + idx + } + }; + group_indices[group_idx].push(row); + } + } + + let num_groups = group_indices.len(); + let mut result_columns: Vec<(usize, String, ArrayRef)> = + Vec::with_capacity(projections.len()); + + for gk in &group_keys { + let first_row_indices: Vec = + group_indices.iter().map(|rows| rows[0] as u32).collect(); + let take_idx = UInt32Array::from(first_row_indices); + let col = arrow_select::take::take(gk.column.as_ref(), &take_idx, None) + .map_err(|e| OmniError::Lance(e.to_string()))?; + result_columns.push((gk.proj_idx, gk.name.clone(), col)); + } + + for ap in &agg_projs { + let col = compute_aggregate(ap.func, &ap.arg_column, &group_indices, num_groups)?; + result_columns.push((ap.proj_idx, ap.name.clone(), col)); + } + + result_columns.sort_by_key(|(idx, _, _)| *idx); + + let fields: Vec = result_columns + .iter() + .map(|(_, name, col)| Field::new(name, col.data_type().clone(), true)) + .collect(); + let columns: Vec = result_columns.into_iter().map(|(_, _, col)| col).collect(); + + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string())) +} + +/// Build a string key for grouping using length-prefixed encoding. +fn build_group_key(group_columns: &[&ArrayRef], row: usize) -> String { + let mut key = String::new(); + for col in group_columns { + if col.is_null(row) { + key.push('N'); + } else { + let val = arrow_cast::display::array_value_to_string(col, row).unwrap_or_default(); + key.push('L'); + key.push_str(&val.len().to_string()); + key.push(':'); + key.push_str(&val); + } + } + key +} + +fn compute_aggregate( + func: AggFunc, + arg: &ArrayRef, + group_indices: &[Vec], + num_groups: usize, +) -> Result { + match func { + AggFunc::Count => { + let mut builder = Int64Builder::with_capacity(num_groups); + for group in group_indices { + let count = group.iter().filter(|&&i| !arg.is_null(i)).count(); + builder.append_value(count as i64); + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + AggFunc::Sum => compute_sum(arg, group_indices, num_groups), + AggFunc::Avg => compute_avg(arg, group_indices, num_groups), + AggFunc::Min => compute_min_max(arg, group_indices, num_groups, true), + AggFunc::Max => compute_min_max(arg, group_indices, num_groups, false), + } +} + +fn compute_sum(arg: &ArrayRef, group_indices: &[Vec], num_groups: usize) -> Result { + macro_rules! sum_numeric { + ($arr_type:ty, $arg:expr, $dt:expr) => {{ + let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| { + OmniError::manifest(format!("sum: expected {:?}, got {:?}", $dt, $arg.data_type())) + })?; + let mut builder = Float64Builder::with_capacity(num_groups); + for group in group_indices { + let mut sum = None; + for &i in group { + if !arr.is_null(i) { + *sum.get_or_insert(0.0f64) += arr.value(i) as f64; + } + } + match sum { + Some(v) => builder.append_value(v), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + }}; + } + match arg.data_type() { + dt @ DataType::Int32 => sum_numeric!(Int32Array, arg, dt), + dt @ DataType::Int64 => sum_numeric!(Int64Array, arg, dt), + dt @ DataType::UInt32 => sum_numeric!(UInt32Array, arg, dt), + dt @ DataType::UInt64 => sum_numeric!(UInt64Array, arg, dt), + dt @ DataType::Float32 => sum_numeric!(Float32Array, arg, dt), + dt @ DataType::Float64 => sum_numeric!(Float64Array, arg, dt), + dt => Err(OmniError::manifest(format!("sum: unsupported type {:?}", dt))), + } +} + +fn compute_avg(arg: &ArrayRef, group_indices: &[Vec], num_groups: usize) -> Result { + macro_rules! avg_typed { + ($arr_type:ty, $arg:expr) => {{ + let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| { + OmniError::manifest(format!("avg: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type())) + })?; + let mut builder = Float64Builder::with_capacity(num_groups); + for group in group_indices { + let mut sum = 0.0f64; + let mut count = 0usize; + for &i in group { + if !arr.is_null(i) { sum += arr.value(i) as f64; count += 1; } + } + if count > 0 { builder.append_value(sum / count as f64); } else { builder.append_null(); } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + }}; + } + match arg.data_type() { + DataType::Int32 => avg_typed!(Int32Array, arg), + DataType::Int64 => avg_typed!(Int64Array, arg), + DataType::UInt32 => avg_typed!(UInt32Array, arg), + DataType::UInt64 => avg_typed!(UInt64Array, arg), + DataType::Float32 => avg_typed!(Float32Array, arg), + DataType::Float64 => avg_typed!(Float64Array, arg), + dt => Err(OmniError::manifest(format!("avg: unsupported type {:?}", dt))), + } +} + +fn compute_min_max(arg: &ArrayRef, group_indices: &[Vec], num_groups: usize, is_min: bool) -> Result { + macro_rules! minmax_typed { + ($arr_type:ty, $builder_type:ty, $arg:expr, $is_min:expr) => {{ + let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| { + OmniError::manifest(format!("min/max: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type())) + })?; + let mut builder = <$builder_type>::with_capacity(num_groups); + for group in group_indices { + let mut result = None; + for &i in group { + if !arr.is_null(i) { + let v = arr.value(i); + result = Some(match result { + None => v, + Some(cur) => if $is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } }, + }); + } + } + match result { Some(v) => builder.append_value(v), None => builder.append_null() } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + }}; + } + match arg.data_type() { + DataType::Int32 => minmax_typed!(Int32Array, Int32Builder, arg, is_min), + DataType::Int64 => minmax_typed!(Int64Array, Int64Builder, arg, is_min), + DataType::UInt32 => minmax_typed!(UInt32Array, UInt32Builder, arg, is_min), + DataType::UInt64 => minmax_typed!(UInt64Array, UInt64Builder, arg, is_min), + DataType::Float32 => minmax_typed!(Float32Array, Float32Builder, arg, is_min), + DataType::Float64 => minmax_typed!(Float64Array, Float64Builder, arg, is_min), + DataType::Utf8 => { + let arr = arg.as_any().downcast_ref::().ok_or_else(|| { + OmniError::manifest(format!("min/max: expected Utf8, got {:?}", arg.data_type())) + })?; + let mut builder = StringBuilder::with_capacity(num_groups, num_groups * 16); + for group in group_indices { + let mut result: Option<&str> = None; + for &i in group { + if !arr.is_null(i) { + let v = arr.value(i); + result = Some(match result { + None => v, + Some(cur) => if is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } }, + }); + } + } + match result { Some(v) => builder.append_value(v), None => builder.append_null() } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + dt => Err(OmniError::manifest(format!("min/max: unsupported type {:?}", dt))), + } +} + +/// Build a single-row result for an aggregate query with zero input rows and no group keys. +fn build_empty_aggregate_result(projections: &[IRProjection]) -> Result { + let mut fields = Vec::with_capacity(projections.len()); + let mut columns: Vec = Vec::with_capacity(projections.len()); + + for proj in projections { + let name = proj.alias.as_deref().unwrap_or("?"); + match &proj.expr { + IRExpr::Aggregate { func, .. } => match func { + AggFunc::Count => { + fields.push(Field::new(name, DataType::Int64, true)); + columns.push(Arc::new(Int64Array::from(vec![0i64])) as ArrayRef); + } + _ => { + fields.push(Field::new(name, DataType::Float64, true)); + columns.push(Arc::new(Float64Array::from(vec![None as Option])) as ArrayRef); + } + }, + _ => { + fields.push(Field::new(name, DataType::Utf8, true)); + columns.push(Arc::new(StringArray::from(vec![None as Option<&str>])) as ArrayRef); + } + } + } + + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string())) +} diff --git a/crates/omnigraph/src/exec/query.rs b/crates/omnigraph/src/exec/query.rs index 5eb1bb1..dc95a53 100644 --- a/crates/omnigraph/src/exec/query.rs +++ b/crates/omnigraph/src/exec/query.rs @@ -358,25 +358,21 @@ pub async fn execute_query( return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await; } - let mut bindings: HashMap = HashMap::new(); - - execute_pipeline( - &ir.pipeline, - params, - snapshot, - graph_index, - catalog, - &mut bindings, - &search_mode, - ) - .await?; + let mut wide: Option = None; + execute_pipeline(&ir.pipeline, params, snapshot, graph_index, catalog, &mut wide, &search_mode).await?; + let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty()))); // Project return expressions - let mut result_batch = project_return(&bindings, &ir.return_exprs, params)?; + let has_aggregates = ir.return_exprs.iter().any(|p| matches!(&p.expr, IRExpr::Aggregate { .. })); + let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?; // Apply ordering (skip if search mode already ordered the results) if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) { - result_batch = apply_ordering(result_batch, &ir.order_by, &bindings, params)?; + result_batch = if has_aggregates { + apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)? + } else { + apply_ordering(result_batch, &ir.order_by, &wide_batch, params)? + }; } // Apply limit @@ -403,27 +399,27 @@ async fn execute_rrf_query( rrf: &RrfMode, ) -> Result { // Execute primary search - let mut primary_bindings: HashMap = HashMap::new(); + let mut primary_wide: Option = None; execute_pipeline( &ir.pipeline, params, snapshot, graph_index, catalog, - &mut primary_bindings, + &mut primary_wide, &rrf.primary, ) .await?; // Execute secondary search - let mut secondary_bindings: HashMap = HashMap::new(); + let mut secondary_wide: Option = None; execute_pipeline( &ir.pipeline, params, snapshot, graph_index, catalog, - &mut secondary_bindings, + &mut secondary_wide, &rrf.secondary, ) .await?; @@ -438,13 +434,13 @@ async fn execute_rrf_query( .or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str())) .ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?; - let primary_batch = primary_bindings.get(primary_var).ok_or_else(|| { + let primary_batch = primary_wide.as_ref().ok_or_else(|| { OmniError::manifest(format!( "rrf primary variable '{}' not in bindings", primary_var )) })?; - let secondary_batch = secondary_bindings.get(primary_var).ok_or_else(|| { + let secondary_batch = secondary_wide.as_ref().ok_or_else(|| { OmniError::manifest(format!( "rrf secondary variable '{}' not in bindings", primary_var @@ -452,8 +448,9 @@ async fn execute_rrf_query( })?; // Build ID → rank maps - let primary_ids = extract_id_column(primary_batch)?; - let secondary_ids = extract_id_column(secondary_batch)?; + let id_col_name = format!("{}.id", primary_var); + let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?; + let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?; let mut primary_rank: HashMap = HashMap::new(); for (i, id) in primary_ids.iter().enumerate() { @@ -510,24 +507,21 @@ async fn execute_rrf_query( // Reconstruct a combined batch for the binding in winning order let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?; - // Replace the binding and project - let mut fused_bindings = primary_bindings; - fused_bindings.insert(primary_var.to_string(), fused_batch); - - let result_batch = project_return(&fused_bindings, &ir.return_exprs, params)?; + // Project directly from fused batch + let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?; // Already ordered by RRF score + already limited Ok(QueryResult::new(result_batch.schema(), vec![result_batch])) } -fn extract_id_column(batch: &RecordBatch) -> Result> { +fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result> { let col = batch - .column_by_name("id") - .ok_or_else(|| OmniError::manifest("batch missing 'id' column for RRF".to_string()))?; + .column_by_name(col_name) + .ok_or_else(|| OmniError::manifest(format!("batch missing '{}' column for RRF", col_name)))?; let ids = col .as_any() .downcast_ref::() - .ok_or_else(|| OmniError::manifest("'id' column is not Utf8".to_string()))?; + .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?; Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect()) } @@ -585,7 +579,7 @@ fn execute_pipeline<'a>( snapshot: &'a Snapshot, graph_index: Option<&'a GraphIndex>, catalog: &'a Catalog, - bindings: &'a mut HashMap, + wide: &'a mut Option, search_mode: &'a SearchMode, ) -> std::pin::Pin> + Send + 'a>> { Box::pin(async move { @@ -632,10 +626,16 @@ fn execute_pipeline<'a>( search_mode, ) .await?; - bindings.insert(variable.clone(), batch); + let prefixed = prefix_batch(&batch, variable)?; + *wide = Some(match wide.take() { + None => prefixed, + Some(existing) => cross_join_batches(&existing, &prefixed)?, + }); } IROp::Filter(filter) => { - apply_filter(bindings, filter, params)?; + if let Some(batch) = wide.as_mut() { + apply_filter(batch, filter, params)?; + } } IROp::Expand { src_var, @@ -649,17 +649,20 @@ fn execute_pipeline<'a>( let gi = graph_index.ok_or_else(|| { OmniError::manifest("graph index required for traversal".to_string()) })?; - let batch = execute_expand( - bindings, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction, - dst_type, *min_hops, *max_hops, - ) - .await?; - bindings.insert(dst_var.clone(), batch); + if let Some(batch) = wide.as_mut() { + execute_expand( + batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction, + dst_type, *min_hops, *max_hops, + ) + .await?; + } } IROp::AntiJoin { outer_var, inner } => { let gi = graph_index; - execute_anti_join(bindings, inner, params, snapshot, gi, catalog, outer_var) - .await?; + if let Some(batch) = wide.as_mut() { + execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var) + .await?; + } } } } @@ -669,28 +672,26 @@ fn execute_pipeline<'a>( /// Execute a graph traversal (Expand). async fn execute_expand( - bindings: &HashMap, + wide: &mut RecordBatch, graph_index: &GraphIndex, snapshot: &Snapshot, catalog: &Catalog, src_var: &str, - _dst_var: &str, + dst_var: &str, edge_type: &str, direction: Direction, dst_type: &str, min_hops: u32, max_hops: Option, -) -> Result { - let src_batch = bindings.get(src_var).ok_or_else(|| { - OmniError::manifest(format!("expand references unbound variable '{}'", src_var)) - })?; - - let src_ids = src_batch - .column_by_name("id") - .ok_or_else(|| OmniError::manifest("source batch missing 'id' column".to_string()))? +) -> Result<()> { + let src_id_col_name = format!("{}.id", src_var); + let src_ids = wide + .column_by_name(&src_id_col_name) + .ok_or_else(|| OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name)))? .as_any() .downcast_ref::() - .ok_or_else(|| OmniError::manifest("source 'id' column is not Utf8".to_string()))?; + .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))? + .clone(); // Determine which type index to use for source and destination let edge_def = catalog @@ -720,8 +721,9 @@ async fn execute_expand( let same_type = src_type_name == dst_type_name; - // BFS to collect reachable destination dense IDs - let mut result_dst_ids: Vec = Vec::new(); + // BFS to collect (src_row_idx, dst_id) pairs with per-source dedup + let mut src_indices: Vec = Vec::new(); + let mut dst_id_list: Vec = Vec::new(); for i in 0..src_ids.len() { let src_id = src_ids.value(i); let Some(src_dense) = src_type_idx.to_dense(src_id) else { @@ -749,7 +751,8 @@ async fn execute_expand( if let Some(dst_id) = dst_type_idx.to_id(neighbor) { let dst_id = dst_id.to_string(); if seen_dst_ids.insert(dst_id.clone()) { - result_dst_ids.push(dst_id); + src_indices.push(i as u32); + dst_id_list.push(dst_id); } } } @@ -764,7 +767,36 @@ async fn execute_expand( } // Hydrate destination nodes from the snapshot - hydrate_nodes(snapshot, catalog, dst_type, &result_dst_ids).await + let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list).await?; + + // Build a mapping from dst_id to row index in dst_batch + let dst_batch_id_col = dst_batch + .column_by_name("id") + .ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))? + .as_any() + .downcast_ref::() + .ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?; + let dst_id_to_row: HashMap<&str, usize> = (0..dst_batch_id_col.len()) + .map(|i| (dst_batch_id_col.value(i), i)) + .collect(); + + // Build aligned src/dst index arrays (only for IDs that exist in hydrated batch) + let mut final_src_indices: Vec = Vec::new(); + let mut dst_indices: Vec = Vec::new(); + for (src_idx, dst_id) in src_indices.iter().zip(dst_id_list.iter()) { + if let Some(&dst_row) = dst_id_to_row.get(dst_id.as_str()) { + final_src_indices.push(*src_idx); + dst_indices.push(dst_row as u32); + } + } + + let src_take = UInt32Array::from(final_src_indices); + let dst_take = UInt32Array::from(dst_indices); + let expanded_wide = take_batch(wide, &src_take)?; + let dst_prefixed = prefix_batch(&dst_batch, dst_var)?; + let aligned_dst = take_batch(&dst_prefixed, &dst_take)?; + *wide = hconcat_batches(&expanded_wide, &aligned_dst)?; + Ok(()) } /// Load full node rows for a set of IDs from a snapshot. @@ -829,15 +861,15 @@ async fn hydrate_nodes( Ok(scan_result) } -/// Try bulk anti-join via CSR existence check. Returns Some if the inner +/// Try bulk anti-join via CSR existence check. Returns Some(mask) if the inner /// pipeline is a single Expand from outer_var (the common negation pattern). -fn try_bulk_anti_join( - outer_batch: &RecordBatch, +fn try_bulk_anti_join_mask( + wide: &RecordBatch, inner_pipeline: &[IROp], graph_index: Option<&GraphIndex>, catalog: &Catalog, outer_var: &str, -) -> Option> { +) -> Option { if inner_pipeline.len() != 1 { return None; } @@ -866,8 +898,9 @@ fn try_bulk_anti_join( }?; let type_idx = gi.type_index(src_type_name)?; - let outer_ids = outer_batch - .column_by_name("id")? + let id_col_name = format!("{}.id", outer_var); + let outer_ids = wide + .column_by_name(&id_col_name)? .as_any() .downcast_ref::()?; @@ -881,16 +914,12 @@ fn try_bulk_anti_join( }) .collect(); - let mask = BooleanArray::from(keep_mask); - Some( - arrow_select::filter::filter_record_batch(outer_batch, &mask) - .map_err(|e| OmniError::Lance(e.to_string())), - ) + Some(BooleanArray::from(keep_mask)) } -/// Execute an AntiJoin: remove rows from outer_var where the inner pipeline finds matches. +/// Execute an AntiJoin: remove rows from wide batch where the inner pipeline finds matches. async fn execute_anti_join( - bindings: &mut HashMap, + wide: &mut RecordBatch, inner_pipeline: &[IROp], params: &ParamMap, snapshot: &Snapshot, @@ -898,35 +927,22 @@ async fn execute_anti_join( catalog: &Catalog, outer_var: &str, ) -> Result<()> { - let outer_batch = bindings.get(outer_var).ok_or_else(|| { - OmniError::manifest(format!( - "anti-join references unbound variable '{}'", - outer_var - )) - })?; - // Fast path: bulk CSR existence check (O(N), zero Lance I/O) - if let Some(result) = - try_bulk_anti_join(outer_batch, inner_pipeline, graph_index, catalog, outer_var) + if let Some(mask) = + try_bulk_anti_join_mask(wide, inner_pipeline, graph_index, catalog, outer_var) { - bindings.insert(outer_var.to_string(), result?); + *wide = arrow_select::filter::filter_record_batch(wide, &mask) + .map_err(|e| OmniError::Lance(e.to_string()))?; return Ok(()); } // Slow path: per-row inner pipeline execution - let outer_ids = outer_batch - .column_by_name("id") - .ok_or_else(|| OmniError::manifest("outer batch missing 'id' column".to_string()))? - .as_any() - .downcast_ref::() - .ok_or_else(|| OmniError::manifest("outer 'id' column is not Utf8".to_string()))?; + let num_rows = wide.num_rows(); + let mut keep_mask = vec![true; num_rows]; - let mut keep_mask = vec![true; outer_batch.num_rows()]; - - for i in 0..outer_ids.len() { - let single_row = outer_batch.slice(i, 1); - let mut inner_bindings: HashMap = HashMap::new(); - inner_bindings.insert(outer_var.to_string(), single_row); + for i in 0..num_rows { + let single_row = wide.slice(i, 1); + let mut inner_wide: Option = Some(single_row); let no_search = SearchMode::default(); execute_pipeline( @@ -935,15 +951,15 @@ async fn execute_anti_join( snapshot, graph_index, catalog, - &mut inner_bindings, + &mut inner_wide, &no_search, ) .await?; - let has_match = inner_bindings - .iter() - .filter(|(k, _)| *k != outer_var) - .any(|(_, batch)| batch.num_rows() > 0); + let has_match = inner_wide + .as_ref() + .map(|batch| batch.num_rows() > 0) + .unwrap_or(false); if has_match { keep_mask[i] = false; @@ -951,10 +967,8 @@ async fn execute_anti_join( } let mask = BooleanArray::from(keep_mask); - let filtered = arrow_select::filter::filter_record_batch(outer_batch, &mask) + *wide = arrow_select::filter::filter_record_batch(wide, &mask) .map_err(|e| OmniError::Lance(e.to_string()))?; - - bindings.insert(outer_var.to_string(), filtered); Ok(()) } @@ -1220,3 +1234,50 @@ pub(super) fn literal_to_sql(lit: &Literal) -> String { Literal::List(_) => "NULL".to_string(), // Not supported in SQL pushdown } } + +fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result { + let fields: Vec = batch.schema().fields().iter().map(|f| { + Field::new(format!("{}.{}", variable, f.name()), f.data_type().clone(), f.is_nullable()) + }).collect(); + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, batch.columns().to_vec()).map_err(|e| OmniError::Lance(e.to_string())) +} + +fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result { + let n = left.num_rows(); + let m = right.num_rows(); + if n == 0 || m == 0 { + let mut fields: Vec = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect(); + fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone())); + return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields)))); + } + let left_indices: Vec = (0..n as u32).flat_map(|i| std::iter::repeat(i).take(m)).collect(); + let right_indices: Vec = (0..n).flat_map(|_| 0..m as u32).collect(); + let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?; + let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?; + hconcat_batches(&left_expanded, &right_expanded) +} + +fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result { + let mut fields: Vec = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect(); + if cfg!(debug_assertions) { + let left_schema = left.schema(); + let left_names: HashSet<&str> = left_schema.fields().iter().map(|f| f.name().as_str()).collect(); + let right_schema = right.schema(); + for f in right_schema.fields() { + debug_assert!(!left_names.contains(f.name().as_str()), "hconcat_batches: duplicate column '{}'", f.name()); + } + } + fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone())); + let mut columns: Vec = left.columns().to_vec(); + columns.extend(right.columns().to_vec()); + RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).map_err(|e| OmniError::Lance(e.to_string())) +} + +fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result { + let columns: Vec = batch.columns().iter() + .map(|col| arrow_select::take::take(col.as_ref(), indices, None)) + .collect::, _>>() + .map_err(|e| OmniError::Lance(e.to_string()))?; + RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string())) +} diff --git a/crates/omnigraph/tests/aggregation.rs b/crates/omnigraph/tests/aggregation.rs new file mode 100644 index 0000000..5bf30ee --- /dev/null +++ b/crates/omnigraph/tests/aggregation.rs @@ -0,0 +1,302 @@ +mod helpers; + +use arrow_array::{Array, Float64Array, Int32Array, Int64Array, StringArray}; + +use omnigraph::db::Omnigraph; +use omnigraph::loader::{LoadMode, load_jsonl}; +use omnigraph_compiler::ir::ParamMap; + +use helpers::*; + +// ─── Count aggregate with GROUP BY ───────────────────────────────────────── + +#[tokio::test] +async fn friend_counts_grouped_by_person() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Test data: Alice knows Bob, Alice knows Charlie, Bob knows Diana + // So: Alice=2 friends, Bob=1 friend (Charlie & Diana have no outgoing knows → dropped) + let result = query_main(&mut db, TEST_QUERIES, "friend_counts", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 2); + + let names = batch + .column_by_name("p.name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let counts = batch + .column_by_name("friends") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Ordered by friends desc: Alice(2), Bob(1) + assert_eq!(names.value(0), "Alice"); + assert_eq!(counts.value(0), 2); + assert_eq!(names.value(1), "Bob"); + assert_eq!(counts.value(1), 1); +} + +// ─── Global count (no group key) ─────────────────────────────────────────── + +#[tokio::test] +async fn total_people_global_count() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let result = query_main(&mut db, TEST_QUERIES, "total_people", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + + let total = batch + .column_by_name("total") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(total.value(0), 4); +} + +// ─── Sum/Avg/Min/Max aggregates ──────────────────────────────────────────── + +#[tokio::test] +async fn age_stats_per_company() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Test data: Alice(30) worksAt Acme, Bob(25) worksAt Globex + let result = query_main(&mut db, TEST_QUERIES, "age_stats", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 2); + + let names = batch + .column_by_name("c.name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Find the Acme row (order not guaranteed without ORDER BY) + let acme_row = (0..batch.num_rows()) + .find(|&i| names.value(i) == "Acme") + .expect("Acme should be in results"); + let globex_row = (0..batch.num_rows()) + .find(|&i| names.value(i) == "Globex") + .expect("Globex should be in results"); + + let headcount = batch + .column_by_name("headcount") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(headcount.value(acme_row), 1); + assert_eq!(headcount.value(globex_row), 1); + + let avg_age = batch + .column_by_name("avg_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert!((avg_age.value(acme_row) - 30.0).abs() < 0.01); + assert!((avg_age.value(globex_row) - 25.0).abs() < 0.01); + + let youngest = batch + .column_by_name("youngest") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(youngest.value(acme_row), 30); + assert_eq!(youngest.value(globex_row), 25); + + let total_age = batch + .column_by_name("total_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert!((total_age.value(acme_row) - 30.0).abs() < 0.01); + assert!((total_age.value(globex_row) - 25.0).abs() < 0.01); +} + +// ─── Aggregate + order + limit ───────────────────────────────────────────── + +#[tokio::test] +async fn top_connected_person() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Alice has 2 friends (most connected) + let result = query_main(&mut db, TEST_QUERIES, "top_connected", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + + let names = batch + .column_by_name("p.name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let counts = batch + .column_by_name("friends") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(names.value(0), "Alice"); + assert_eq!(counts.value(0), 2); +} + +// ─── Inline aggregate query (not from fixture) ──────────────────────────── + +#[tokio::test] +async fn inline_count_with_no_friends() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Diana has no outgoing Knows edges, so she won't appear in friend_counts. + // But Alice(2) and Bob(1) do. Verify the count matches. + let queries = r#" +query fc() { + match { + $p: Person + $p knows $f + } + return { $p.name, count($f) as friends } + order { friends desc } +} +"#; + let result = query_main(&mut db, queries, "fc", &ParamMap::new()) + .await + .unwrap(); + let batch = result.concat_batches().unwrap(); + // Only Alice and Bob have outgoing Knows edges + assert_eq!(batch.num_rows(), 2); +} + +#[tokio::test] +async fn inline_global_count_empty() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap(); + + // Load only nodes, no edges + let data = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}} +{"type": "Company", "data": {"name": "Acme"}}"#; + load_jsonl(&mut db, data, LoadMode::Overwrite) + .await + .unwrap(); + + // Global count — should return 1 row with count=1 + let queries = r#" +query tc() { + match { $p: Person } + return { count($p) as total } +} +"#; + let result = query_main(&mut db, queries, "tc", &ParamMap::new()) + .await + .unwrap(); + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + let total = batch + .column_by_name("total") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(total.value(0), 1); +} + +#[tokio::test] +async fn inline_min_max_string() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query name_range() { + match { $p: Person } + return { + min($p.name) as first_name + max($p.name) as last_name + } +} +"#; + let result = query_main(&mut db, queries, "name_range", &ParamMap::new()) + .await + .unwrap(); + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + + let first = batch + .column_by_name("first_name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let last = batch + .column_by_name("last_name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(first.value(0), "Alice"); + assert_eq!(last.value(0), "Diana"); +} + +#[tokio::test] +async fn inline_multi_hop_aggregate() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Friends of friends count per person + let queries = r#" +query fof_counts($name: String) { + match { + $p: Person { name: $name } + $p knows $mid + $mid knows $fof + } + return { $p.name, count($fof) as fof_count } +} +"#; + // Alice → Bob, Charlie. Bob → Diana. Charlie → nobody. + // So Alice's fof = [Diana] → count = 1 + let result = query_main( + &mut db, + queries, + "fof_counts", + ¶ms(&[("$name", "Alice")]), + ) + .await + .unwrap(); + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + + let count = batch + .column_by_name("fof_count") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count.value(0), 1); +} diff --git a/crates/omnigraph/tests/fixtures/test.gq b/crates/omnigraph/tests/fixtures/test.gq index daf03ed..08944fc 100644 --- a/crates/omnigraph/tests/fixtures/test.gq +++ b/crates/omnigraph/tests/fixtures/test.gq @@ -76,3 +76,38 @@ query top_by_age() { order { $p.age desc } limit 2 } + +// Aggregation: global count (no group key) +query total_people() { + match { + $p: Person + } + return { count($p) as total } +} + +// Aggregation: sum/avg/min/max of age per company +query age_stats() { + match { + $p: Person + $p worksAt $c + } + return { + $c.name + sum($p.age) as total_age + avg($p.age) as avg_age + min($p.age) as youngest + max($p.age) as oldest + count($p) as headcount + } +} + +// Aggregation: most connected person +query top_connected() { + match { + $p: Person + $p knows $f + } + return { $p.name, count($f) as friends } + order { friends desc } + limit 1 +}