Implement aggregate execution with wide-batch model

Add runtime support for aggregate functions (count, sum, avg, min, max)
with GROUP BY semantics, built on a single wide RecordBatch that
eliminates correlation tracking by construction.

Execution engine (exec/query.rs):
- Replace HashMap<String, RecordBatch> with Option<RecordBatch> where
  columns are prefixed as <variable>.<property>
- NodeScan prefixes columns and cross-joins with existing batch
- Expand collects (src_row, dst_id) pairs, takes wide batch rows,
  appends prefixed destination columns via hconcat
- Filter applies single mask to entire wide batch
- AntiJoin: fast-path returns BooleanArray mask; slow-path slices
  one row for inner pipeline execution

Projection engine (exec/projection.rs):
- aggregate_return groups rows by non-aggregate key columns using
  length-prefixed string encoding, computes per-group aggregates
- SUM accumulates into f64 to avoid integer overflow
- MIN/MAX support both numeric and string types
- Empty input returns count=0, others=null

Compiler (typecheck.rs):
- T8: split MIN/MAX from SUM/AVG — allow string arguments
- T9: non-aggregate expressions in aggregate queries must be
  property accesses or variables
- SUM type inference returns Float64 (matching runtime)

Tests: 8 new integration tests covering grouped count, global count,
sum/avg/min/max per company, aggregate+order+limit, string min/max,
multi-hop aggregates, and edge cases.

https://claude.ai/code/session_019o5NRyYomgETFyd7hpiLey
This commit is contained in:
Claude 2026-04-12 20:59:13 +00:00
parent 34cc2ccf3a
commit 351610d18c
No known key found for this signature in database
6 changed files with 868 additions and 152 deletions

View file

@ -189,6 +189,29 @@ fn typecheck_read_query(catalog: &Catalog, query: &QueryDecl) -> Result<TypeCont
));
}
// T9: If any return expression is an aggregate, non-aggregate expressions
// must be valid group-by keys (PropAccess or Variable).
let has_agg = query
.return_clause
.iter()
.any(|p| matches!(p.expr, Expr::Aggregate { .. }));
if has_agg {
for proj in &query.return_clause {
if !matches!(proj.expr, Expr::Aggregate { .. }) {
match &proj.expr {
Expr::PropAccess { .. } | Expr::Variable(_) => {}
_ => {
return Err(NanoError::Type(
"T9: non-aggregate expressions in an aggregate query must be \
property accesses or variables"
.to_string(),
));
}
}
}
}
}
Ok(ctx)
}
@ -1298,9 +1321,9 @@ fn resolve_expr_type(
Expr::Aggregate { func, arg } => {
let arg_type = resolve_expr_type(catalog, arg, ctx, params)?;
// T8: sum/avg/min/max require numeric
// T8: sum/avg require numeric; min/max require numeric or string
match func {
AggFunc::Sum | AggFunc::Avg | AggFunc::Min | AggFunc::Max => {
AggFunc::Sum | AggFunc::Avg => {
if let ResolvedType::Scalar(s) = &arg_type
&& (s.list || !s.scalar.is_numeric())
{
@ -1311,6 +1334,17 @@ fn resolve_expr_type(
)));
}
}
AggFunc::Min | AggFunc::Max => {
if let ResolvedType::Scalar(s) = &arg_type
&& (s.list || (!s.scalar.is_numeric() && s.scalar != ScalarType::String))
{
return Err(NanoError::Type(format!(
"T8: {} requires numeric or string type, got {}",
func,
s.display_name()
)));
}
}
_ => {} // count works on any type
}
@ -1340,8 +1374,8 @@ fn infer_projection_field(
Expr::Aggregate { func, arg } => {
let (data_type, nullable) = match func {
AggFunc::Count => (DataType::Int64, true),
AggFunc::Avg => (DataType::Float64, true),
_ => {
AggFunc::Avg | AggFunc::Sum => (DataType::Float64, true),
AggFunc::Min | AggFunc::Max => {
let resolved = resolve_expr_type(catalog, arg, ctx, params)?;
let (data_type, _) = resolved_type_to_field_shape(catalog, &resolved)?;
(data_type, true)

View file

@ -25,7 +25,7 @@ use omnigraph_compiler::ir::{
};
use omnigraph_compiler::lower_mutation_query;
use omnigraph_compiler::lower_query;
use omnigraph_compiler::query::ast::{CompOp, Literal, NOW_PARAM_NAME};
use omnigraph_compiler::query::ast::{AggFunc, CompOp, Literal, NOW_PARAM_NAME};
use omnigraph_compiler::query::typecheck::{CheckedQuery, typecheck_query, typecheck_query_decl};
use omnigraph_compiler::result::{MutationResult, QueryResult};
use omnigraph_compiler::types::Direction;

View file

@ -1,25 +1,14 @@
use super::*;
pub(super) fn apply_filter(
bindings: &mut HashMap<String, RecordBatch>,
batch: &mut RecordBatch,
filter: &IRFilter,
params: &ParamMap,
) -> Result<()> {
// Find which binding this filter applies to
let var_name = match &filter.left {
IRExpr::PropAccess { variable, .. } => variable.clone(),
_ => return Ok(()), // Can't determine variable
};
let batch = bindings.get(&var_name).ok_or_else(|| {
OmniError::manifest(format!("filter references unbound variable '{}'", var_name))
})?;
let mask = evaluate_filter(batch, filter, params)?;
let filtered = arrow_select::filter::filter_record_batch(batch, &mask)
.map_err(|e| OmniError::Lance(e.to_string()))?;
bindings.insert(var_name, filtered);
*batch = filtered;
Ok(())
}
@ -59,12 +48,13 @@ fn evaluate_filter(
Ok(result)
}
/// Evaluate an IR expression against a batch, producing an array.
/// Evaluate an IR expression against a wide batch, producing an array.
fn evaluate_expr(batch: &RecordBatch, expr: &IRExpr, params: &ParamMap) -> Result<ArrayRef> {
match expr {
IRExpr::PropAccess { property, .. } => {
batch.column_by_name(property).cloned().ok_or_else(|| {
OmniError::manifest(format!("column '{}' not found in batch", property))
IRExpr::PropAccess { variable, property } => {
let col_name = format!("{}.{}", variable, property);
batch.column_by_name(&col_name).cloned().ok_or_else(|| {
OmniError::manifest(format!("column '{}' not found in wide batch", col_name))
})
}
IRExpr::Literal(lit) => literal_to_array(lit, batch.num_rows()),
@ -307,7 +297,7 @@ fn literal_scalar_type(lit: &Literal) -> Result<ScalarType> {
/// Project return expressions into a result batch.
pub(super) fn project_return(
bindings: &HashMap<String, RecordBatch>,
wide_batch: &RecordBatch,
projections: &[IRProjection],
params: &ParamMap,
) -> Result<RecordBatch> {
@ -317,11 +307,19 @@ pub(super) fn project_return(
));
}
// Route to aggregate path if any projection contains an aggregate
let has_aggregates = projections
.iter()
.any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
if has_aggregates {
return aggregate_return(wide_batch, projections, params);
}
let mut fields = Vec::with_capacity(projections.len());
let mut columns: Vec<ArrayRef> = Vec::with_capacity(projections.len());
for proj in projections {
let (name, col) = evaluate_projection(bindings, &proj.expr, params)?;
let (name, col) = evaluate_projection(wide_batch, &proj.expr, params)?;
let field_name = proj.alias.as_deref().unwrap_or(&name);
fields.push(Field::new(
field_name,
@ -335,42 +333,41 @@ pub(super) fn project_return(
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
}
/// Evaluate a single projection expression.
/// Evaluate a single projection expression against a wide batch.
fn evaluate_projection(
bindings: &HashMap<String, RecordBatch>,
wide_batch: &RecordBatch,
expr: &IRExpr,
params: &ParamMap,
) -> Result<(String, ArrayRef)> {
match expr {
IRExpr::PropAccess { variable, property } => {
let batch = bindings.get(variable).ok_or_else(|| {
let col_name = format!("{}.{}", variable, property);
let col = wide_batch.column_by_name(&col_name).ok_or_else(|| {
OmniError::manifest(format!(
"projection references unbound variable '{}'",
variable
"column '{}' not found in wide batch",
col_name
))
})?;
let col = batch.column_by_name(property).ok_or_else(|| {
OmniError::manifest(format!(
"column '{}' not found in binding '{}'",
property, variable
))
})?;
Ok((format!("{}.{}", variable, property), col.clone()))
Ok((col_name, col.clone()))
}
IRExpr::Literal(lit) => {
// Get row count from first binding
let num_rows = bindings.values().next().map(|b| b.num_rows()).unwrap_or(0);
let arr = literal_to_array(lit, num_rows)?;
let arr = literal_to_array(lit, wide_batch.num_rows())?;
Ok(("literal".to_string(), arr))
}
IRExpr::Param(name) => {
let lit = params
.get(name)
.ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?;
let num_rows = bindings.values().next().map(|b| b.num_rows()).unwrap_or(0);
let arr = literal_to_array(lit, num_rows)?;
let arr = literal_to_array(lit, wide_batch.num_rows())?;
Ok((name.clone(), arr))
}
IRExpr::Variable(name) => {
let col_name = format!("{}.id", name);
let col = wide_batch.column_by_name(&col_name).ok_or_else(|| {
OmniError::manifest(format!("column '{}' not found in wide batch", col_name))
})?;
Ok((name.clone(), col.clone()))
}
_ => Err(OmniError::manifest(format!(
"unsupported projection expression: {:?}",
expr
@ -378,11 +375,12 @@ fn evaluate_projection(
}
}
/// Apply ordering to a batch.
/// Apply ordering to a batch. `source` is used to resolve PropAccess columns
/// (typically the wide batch, or the result batch itself for aggregates).
pub(super) fn apply_ordering(
batch: RecordBatch,
orderings: &[IROrdering],
bindings: &HashMap<String, RecordBatch>,
source: &RecordBatch,
_params: &ParamMap,
) -> Result<RecordBatch> {
use arrow_ord::sort::{SortColumn, lexsort_to_indices};
@ -392,16 +390,11 @@ pub(super) fn apply_ordering(
for ordering in orderings {
let col = match &ordering.expr {
IRExpr::PropAccess { variable, property } => {
let binding = bindings.get(variable).ok_or_else(|| {
OmniError::manifest(format!(
"ordering references unbound variable '{}'",
variable
))
})?;
binding
.column_by_name(property)
let col_name = format!("{}.{}", variable, property);
source
.column_by_name(&col_name)
.ok_or_else(|| {
OmniError::manifest(format!("column '{}' not found for ordering", property))
OmniError::manifest(format!("column '{}' not found for ordering", col_name))
})?
.clone()
}
@ -442,3 +435,294 @@ pub(super) fn apply_ordering(
RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
}
// ─── Aggregate execution ───────────────────────────────────────────────────
/// Project return expressions that contain aggregates.
fn aggregate_return(
wide: &RecordBatch,
projections: &[IRProjection],
params: &ParamMap,
) -> Result<RecordBatch> {
let num_rows = wide.num_rows();
struct GroupKey {
proj_idx: usize,
name: String,
column: ArrayRef,
}
struct AggProj {
proj_idx: usize,
name: String,
func: AggFunc,
arg_column: ArrayRef,
}
let mut group_keys: Vec<GroupKey> = Vec::new();
let mut agg_projs: Vec<AggProj> = Vec::new();
for (i, proj) in projections.iter().enumerate() {
match &proj.expr {
IRExpr::Aggregate { func, arg } => {
let (name, col) = evaluate_projection(wide, arg, params)?;
let alias = proj.alias.as_deref().unwrap_or(&name);
agg_projs.push(AggProj {
proj_idx: i,
name: alias.to_string(),
func: *func,
arg_column: col,
});
}
_ => {
let (name, col) = evaluate_projection(wide, &proj.expr, params)?;
let alias = proj.alias.as_deref().unwrap_or(&name);
group_keys.push(GroupKey {
proj_idx: i,
name: alias.to_string(),
column: col,
});
}
}
}
// Handle empty input: return a single row with count=0, others=null
if num_rows == 0 && group_keys.is_empty() {
return build_empty_aggregate_result(projections);
}
// Build group assignments
let mut group_map: HashMap<String, usize> = HashMap::new();
let mut group_indices: Vec<Vec<usize>> = Vec::new();
let group_cols: Vec<&ArrayRef> = group_keys.iter().map(|gk| &gk.column).collect();
if group_keys.is_empty() {
group_indices.push((0..num_rows).collect());
} else {
for row in 0..num_rows {
let key = build_group_key(&group_cols, row);
let group_idx = match group_map.get(&key) {
Some(&idx) => idx,
None => {
let idx = group_indices.len();
group_map.insert(key, idx);
group_indices.push(Vec::new());
idx
}
};
group_indices[group_idx].push(row);
}
}
let num_groups = group_indices.len();
let mut result_columns: Vec<(usize, String, ArrayRef)> =
Vec::with_capacity(projections.len());
for gk in &group_keys {
let first_row_indices: Vec<u32> =
group_indices.iter().map(|rows| rows[0] as u32).collect();
let take_idx = UInt32Array::from(first_row_indices);
let col = arrow_select::take::take(gk.column.as_ref(), &take_idx, None)
.map_err(|e| OmniError::Lance(e.to_string()))?;
result_columns.push((gk.proj_idx, gk.name.clone(), col));
}
for ap in &agg_projs {
let col = compute_aggregate(ap.func, &ap.arg_column, &group_indices, num_groups)?;
result_columns.push((ap.proj_idx, ap.name.clone(), col));
}
result_columns.sort_by_key(|(idx, _, _)| *idx);
let fields: Vec<Field> = result_columns
.iter()
.map(|(_, name, col)| Field::new(name, col.data_type().clone(), true))
.collect();
let columns: Vec<ArrayRef> = result_columns.into_iter().map(|(_, _, col)| col).collect();
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
}
/// Build a string key for grouping using length-prefixed encoding.
fn build_group_key(group_columns: &[&ArrayRef], row: usize) -> String {
let mut key = String::new();
for col in group_columns {
if col.is_null(row) {
key.push('N');
} else {
let val = arrow_cast::display::array_value_to_string(col, row).unwrap_or_default();
key.push('L');
key.push_str(&val.len().to_string());
key.push(':');
key.push_str(&val);
}
}
key
}
fn compute_aggregate(
func: AggFunc,
arg: &ArrayRef,
group_indices: &[Vec<usize>],
num_groups: usize,
) -> Result<ArrayRef> {
match func {
AggFunc::Count => {
let mut builder = Int64Builder::with_capacity(num_groups);
for group in group_indices {
let count = group.iter().filter(|&&i| !arg.is_null(i)).count();
builder.append_value(count as i64);
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
AggFunc::Sum => compute_sum(arg, group_indices, num_groups),
AggFunc::Avg => compute_avg(arg, group_indices, num_groups),
AggFunc::Min => compute_min_max(arg, group_indices, num_groups, true),
AggFunc::Max => compute_min_max(arg, group_indices, num_groups, false),
}
}
fn compute_sum(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize) -> Result<ArrayRef> {
macro_rules! sum_numeric {
($arr_type:ty, $arg:expr, $dt:expr) => {{
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
OmniError::manifest(format!("sum: expected {:?}, got {:?}", $dt, $arg.data_type()))
})?;
let mut builder = Float64Builder::with_capacity(num_groups);
for group in group_indices {
let mut sum = None;
for &i in group {
if !arr.is_null(i) {
*sum.get_or_insert(0.0f64) += arr.value(i) as f64;
}
}
match sum {
Some(v) => builder.append_value(v),
None => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}};
}
match arg.data_type() {
dt @ DataType::Int32 => sum_numeric!(Int32Array, arg, dt),
dt @ DataType::Int64 => sum_numeric!(Int64Array, arg, dt),
dt @ DataType::UInt32 => sum_numeric!(UInt32Array, arg, dt),
dt @ DataType::UInt64 => sum_numeric!(UInt64Array, arg, dt),
dt @ DataType::Float32 => sum_numeric!(Float32Array, arg, dt),
dt @ DataType::Float64 => sum_numeric!(Float64Array, arg, dt),
dt => Err(OmniError::manifest(format!("sum: unsupported type {:?}", dt))),
}
}
fn compute_avg(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize) -> Result<ArrayRef> {
macro_rules! avg_typed {
($arr_type:ty, $arg:expr) => {{
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
OmniError::manifest(format!("avg: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type()))
})?;
let mut builder = Float64Builder::with_capacity(num_groups);
for group in group_indices {
let mut sum = 0.0f64;
let mut count = 0usize;
for &i in group {
if !arr.is_null(i) { sum += arr.value(i) as f64; count += 1; }
}
if count > 0 { builder.append_value(sum / count as f64); } else { builder.append_null(); }
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}};
}
match arg.data_type() {
DataType::Int32 => avg_typed!(Int32Array, arg),
DataType::Int64 => avg_typed!(Int64Array, arg),
DataType::UInt32 => avg_typed!(UInt32Array, arg),
DataType::UInt64 => avg_typed!(UInt64Array, arg),
DataType::Float32 => avg_typed!(Float32Array, arg),
DataType::Float64 => avg_typed!(Float64Array, arg),
dt => Err(OmniError::manifest(format!("avg: unsupported type {:?}", dt))),
}
}
fn compute_min_max(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize, is_min: bool) -> Result<ArrayRef> {
macro_rules! minmax_typed {
($arr_type:ty, $builder_type:ty, $arg:expr, $is_min:expr) => {{
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
OmniError::manifest(format!("min/max: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type()))
})?;
let mut builder = <$builder_type>::with_capacity(num_groups);
for group in group_indices {
let mut result = None;
for &i in group {
if !arr.is_null(i) {
let v = arr.value(i);
result = Some(match result {
None => v,
Some(cur) => if $is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } },
});
}
}
match result { Some(v) => builder.append_value(v), None => builder.append_null() }
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}};
}
match arg.data_type() {
DataType::Int32 => minmax_typed!(Int32Array, Int32Builder, arg, is_min),
DataType::Int64 => minmax_typed!(Int64Array, Int64Builder, arg, is_min),
DataType::UInt32 => minmax_typed!(UInt32Array, UInt32Builder, arg, is_min),
DataType::UInt64 => minmax_typed!(UInt64Array, UInt64Builder, arg, is_min),
DataType::Float32 => minmax_typed!(Float32Array, Float32Builder, arg, is_min),
DataType::Float64 => minmax_typed!(Float64Array, Float64Builder, arg, is_min),
DataType::Utf8 => {
let arr = arg.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
OmniError::manifest(format!("min/max: expected Utf8, got {:?}", arg.data_type()))
})?;
let mut builder = StringBuilder::with_capacity(num_groups, num_groups * 16);
for group in group_indices {
let mut result: Option<&str> = None;
for &i in group {
if !arr.is_null(i) {
let v = arr.value(i);
result = Some(match result {
None => v,
Some(cur) => if is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } },
});
}
}
match result { Some(v) => builder.append_value(v), None => builder.append_null() }
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
dt => Err(OmniError::manifest(format!("min/max: unsupported type {:?}", dt))),
}
}
/// Build a single-row result for an aggregate query with zero input rows and no group keys.
fn build_empty_aggregate_result(projections: &[IRProjection]) -> Result<RecordBatch> {
let mut fields = Vec::with_capacity(projections.len());
let mut columns: Vec<ArrayRef> = Vec::with_capacity(projections.len());
for proj in projections {
let name = proj.alias.as_deref().unwrap_or("?");
match &proj.expr {
IRExpr::Aggregate { func, .. } => match func {
AggFunc::Count => {
fields.push(Field::new(name, DataType::Int64, true));
columns.push(Arc::new(Int64Array::from(vec![0i64])) as ArrayRef);
}
_ => {
fields.push(Field::new(name, DataType::Float64, true));
columns.push(Arc::new(Float64Array::from(vec![None as Option<f64>])) as ArrayRef);
}
},
_ => {
fields.push(Field::new(name, DataType::Utf8, true));
columns.push(Arc::new(StringArray::from(vec![None as Option<&str>])) as ArrayRef);
}
}
}
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
}

View file

@ -358,25 +358,21 @@ pub async fn execute_query(
return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
}
let mut bindings: HashMap<String, RecordBatch> = HashMap::new();
execute_pipeline(
&ir.pipeline,
params,
snapshot,
graph_index,
catalog,
&mut bindings,
&search_mode,
)
.await?;
let mut wide: Option<RecordBatch> = None;
execute_pipeline(&ir.pipeline, params, snapshot, graph_index, catalog, &mut wide, &search_mode).await?;
let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
// Project return expressions
let mut result_batch = project_return(&bindings, &ir.return_exprs, params)?;
let has_aggregates = ir.return_exprs.iter().any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
// Apply ordering (skip if search mode already ordered the results)
if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
result_batch = apply_ordering(result_batch, &ir.order_by, &bindings, params)?;
result_batch = if has_aggregates {
apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
} else {
apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
};
}
// Apply limit
@ -403,27 +399,27 @@ async fn execute_rrf_query(
rrf: &RrfMode,
) -> Result<QueryResult> {
// Execute primary search
let mut primary_bindings: HashMap<String, RecordBatch> = HashMap::new();
let mut primary_wide: Option<RecordBatch> = None;
execute_pipeline(
&ir.pipeline,
params,
snapshot,
graph_index,
catalog,
&mut primary_bindings,
&mut primary_wide,
&rrf.primary,
)
.await?;
// Execute secondary search
let mut secondary_bindings: HashMap<String, RecordBatch> = HashMap::new();
let mut secondary_wide: Option<RecordBatch> = None;
execute_pipeline(
&ir.pipeline,
params,
snapshot,
graph_index,
catalog,
&mut secondary_bindings,
&mut secondary_wide,
&rrf.secondary,
)
.await?;
@ -438,13 +434,13 @@ async fn execute_rrf_query(
.or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
.ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
let primary_batch = primary_bindings.get(primary_var).ok_or_else(|| {
let primary_batch = primary_wide.as_ref().ok_or_else(|| {
OmniError::manifest(format!(
"rrf primary variable '{}' not in bindings",
primary_var
))
})?;
let secondary_batch = secondary_bindings.get(primary_var).ok_or_else(|| {
let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
OmniError::manifest(format!(
"rrf secondary variable '{}' not in bindings",
primary_var
@ -452,8 +448,9 @@ async fn execute_rrf_query(
})?;
// Build ID → rank maps
let primary_ids = extract_id_column(primary_batch)?;
let secondary_ids = extract_id_column(secondary_batch)?;
let id_col_name = format!("{}.id", primary_var);
let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
let mut primary_rank: HashMap<String, usize> = HashMap::new();
for (i, id) in primary_ids.iter().enumerate() {
@ -510,24 +507,21 @@ async fn execute_rrf_query(
// Reconstruct a combined batch for the binding in winning order
let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
// Replace the binding and project
let mut fused_bindings = primary_bindings;
fused_bindings.insert(primary_var.to_string(), fused_batch);
let result_batch = project_return(&fused_bindings, &ir.return_exprs, params)?;
// Project directly from fused batch
let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
// Already ordered by RRF score + already limited
Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
}
fn extract_id_column(batch: &RecordBatch) -> Result<Vec<String>> {
fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
let col = batch
.column_by_name("id")
.ok_or_else(|| OmniError::manifest("batch missing 'id' column for RRF".to_string()))?;
.column_by_name(col_name)
.ok_or_else(|| OmniError::manifest(format!("batch missing '{}' column for RRF", col_name)))?;
let ids = col
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest("'id' column is not Utf8".to_string()))?;
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
}
@ -585,7 +579,7 @@ fn execute_pipeline<'a>(
snapshot: &'a Snapshot,
graph_index: Option<&'a GraphIndex>,
catalog: &'a Catalog,
bindings: &'a mut HashMap<String, RecordBatch>,
wide: &'a mut Option<RecordBatch>,
search_mode: &'a SearchMode,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
@ -632,10 +626,16 @@ fn execute_pipeline<'a>(
search_mode,
)
.await?;
bindings.insert(variable.clone(), batch);
let prefixed = prefix_batch(&batch, variable)?;
*wide = Some(match wide.take() {
None => prefixed,
Some(existing) => cross_join_batches(&existing, &prefixed)?,
});
}
IROp::Filter(filter) => {
apply_filter(bindings, filter, params)?;
if let Some(batch) = wide.as_mut() {
apply_filter(batch, filter, params)?;
}
}
IROp::Expand {
src_var,
@ -649,17 +649,20 @@ fn execute_pipeline<'a>(
let gi = graph_index.ok_or_else(|| {
OmniError::manifest("graph index required for traversal".to_string())
})?;
let batch = execute_expand(
bindings, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
dst_type, *min_hops, *max_hops,
)
.await?;
bindings.insert(dst_var.clone(), batch);
if let Some(batch) = wide.as_mut() {
execute_expand(
batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
dst_type, *min_hops, *max_hops,
)
.await?;
}
}
IROp::AntiJoin { outer_var, inner } => {
let gi = graph_index;
execute_anti_join(bindings, inner, params, snapshot, gi, catalog, outer_var)
.await?;
if let Some(batch) = wide.as_mut() {
execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
.await?;
}
}
}
}
@ -669,28 +672,26 @@ fn execute_pipeline<'a>(
/// Execute a graph traversal (Expand).
async fn execute_expand(
bindings: &HashMap<String, RecordBatch>,
wide: &mut RecordBatch,
graph_index: &GraphIndex,
snapshot: &Snapshot,
catalog: &Catalog,
src_var: &str,
_dst_var: &str,
dst_var: &str,
edge_type: &str,
direction: Direction,
dst_type: &str,
min_hops: u32,
max_hops: Option<u32>,
) -> Result<RecordBatch> {
let src_batch = bindings.get(src_var).ok_or_else(|| {
OmniError::manifest(format!("expand references unbound variable '{}'", src_var))
})?;
let src_ids = src_batch
.column_by_name("id")
.ok_or_else(|| OmniError::manifest("source batch missing 'id' column".to_string()))?
) -> Result<()> {
let src_id_col_name = format!("{}.id", src_var);
let src_ids = wide
.column_by_name(&src_id_col_name)
.ok_or_else(|| OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name)))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest("source 'id' column is not Utf8".to_string()))?;
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
.clone();
// Determine which type index to use for source and destination
let edge_def = catalog
@ -720,8 +721,9 @@ async fn execute_expand(
let same_type = src_type_name == dst_type_name;
// BFS to collect reachable destination dense IDs
let mut result_dst_ids: Vec<String> = Vec::new();
// BFS to collect (src_row_idx, dst_id) pairs with per-source dedup
let mut src_indices: Vec<u32> = Vec::new();
let mut dst_id_list: Vec<String> = Vec::new();
for i in 0..src_ids.len() {
let src_id = src_ids.value(i);
let Some(src_dense) = src_type_idx.to_dense(src_id) else {
@ -749,7 +751,8 @@ async fn execute_expand(
if let Some(dst_id) = dst_type_idx.to_id(neighbor) {
let dst_id = dst_id.to_string();
if seen_dst_ids.insert(dst_id.clone()) {
result_dst_ids.push(dst_id);
src_indices.push(i as u32);
dst_id_list.push(dst_id);
}
}
}
@ -764,7 +767,36 @@ async fn execute_expand(
}
// Hydrate destination nodes from the snapshot
hydrate_nodes(snapshot, catalog, dst_type, &result_dst_ids).await
let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list).await?;
// Build a mapping from dst_id to row index in dst_batch
let dst_batch_id_col = dst_batch
.column_by_name("id")
.ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
let dst_id_to_row: HashMap<&str, usize> = (0..dst_batch_id_col.len())
.map(|i| (dst_batch_id_col.value(i), i))
.collect();
// Build aligned src/dst index arrays (only for IDs that exist in hydrated batch)
let mut final_src_indices: Vec<u32> = Vec::new();
let mut dst_indices: Vec<u32> = Vec::new();
for (src_idx, dst_id) in src_indices.iter().zip(dst_id_list.iter()) {
if let Some(&dst_row) = dst_id_to_row.get(dst_id.as_str()) {
final_src_indices.push(*src_idx);
dst_indices.push(dst_row as u32);
}
}
let src_take = UInt32Array::from(final_src_indices);
let dst_take = UInt32Array::from(dst_indices);
let expanded_wide = take_batch(wide, &src_take)?;
let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
*wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
Ok(())
}
/// Load full node rows for a set of IDs from a snapshot.
@ -829,15 +861,15 @@ async fn hydrate_nodes(
Ok(scan_result)
}
/// Try bulk anti-join via CSR existence check. Returns Some if the inner
/// Try bulk anti-join via CSR existence check. Returns Some(mask) if the inner
/// pipeline is a single Expand from outer_var (the common negation pattern).
fn try_bulk_anti_join(
outer_batch: &RecordBatch,
fn try_bulk_anti_join_mask(
wide: &RecordBatch,
inner_pipeline: &[IROp],
graph_index: Option<&GraphIndex>,
catalog: &Catalog,
outer_var: &str,
) -> Option<Result<RecordBatch>> {
) -> Option<BooleanArray> {
if inner_pipeline.len() != 1 {
return None;
}
@ -866,8 +898,9 @@ fn try_bulk_anti_join(
}?;
let type_idx = gi.type_index(src_type_name)?;
let outer_ids = outer_batch
.column_by_name("id")?
let id_col_name = format!("{}.id", outer_var);
let outer_ids = wide
.column_by_name(&id_col_name)?
.as_any()
.downcast_ref::<StringArray>()?;
@ -881,16 +914,12 @@ fn try_bulk_anti_join(
})
.collect();
let mask = BooleanArray::from(keep_mask);
Some(
arrow_select::filter::filter_record_batch(outer_batch, &mask)
.map_err(|e| OmniError::Lance(e.to_string())),
)
Some(BooleanArray::from(keep_mask))
}
/// Execute an AntiJoin: remove rows from outer_var where the inner pipeline finds matches.
/// Execute an AntiJoin: remove rows from wide batch where the inner pipeline finds matches.
async fn execute_anti_join(
bindings: &mut HashMap<String, RecordBatch>,
wide: &mut RecordBatch,
inner_pipeline: &[IROp],
params: &ParamMap,
snapshot: &Snapshot,
@ -898,35 +927,22 @@ async fn execute_anti_join(
catalog: &Catalog,
outer_var: &str,
) -> Result<()> {
let outer_batch = bindings.get(outer_var).ok_or_else(|| {
OmniError::manifest(format!(
"anti-join references unbound variable '{}'",
outer_var
))
})?;
// Fast path: bulk CSR existence check (O(N), zero Lance I/O)
if let Some(result) =
try_bulk_anti_join(outer_batch, inner_pipeline, graph_index, catalog, outer_var)
if let Some(mask) =
try_bulk_anti_join_mask(wide, inner_pipeline, graph_index, catalog, outer_var)
{
bindings.insert(outer_var.to_string(), result?);
*wide = arrow_select::filter::filter_record_batch(wide, &mask)
.map_err(|e| OmniError::Lance(e.to_string()))?;
return Ok(());
}
// Slow path: per-row inner pipeline execution
let outer_ids = outer_batch
.column_by_name("id")
.ok_or_else(|| OmniError::manifest("outer batch missing 'id' column".to_string()))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest("outer 'id' column is not Utf8".to_string()))?;
let num_rows = wide.num_rows();
let mut keep_mask = vec![true; num_rows];
let mut keep_mask = vec![true; outer_batch.num_rows()];
for i in 0..outer_ids.len() {
let single_row = outer_batch.slice(i, 1);
let mut inner_bindings: HashMap<String, RecordBatch> = HashMap::new();
inner_bindings.insert(outer_var.to_string(), single_row);
for i in 0..num_rows {
let single_row = wide.slice(i, 1);
let mut inner_wide: Option<RecordBatch> = Some(single_row);
let no_search = SearchMode::default();
execute_pipeline(
@ -935,15 +951,15 @@ async fn execute_anti_join(
snapshot,
graph_index,
catalog,
&mut inner_bindings,
&mut inner_wide,
&no_search,
)
.await?;
let has_match = inner_bindings
.iter()
.filter(|(k, _)| *k != outer_var)
.any(|(_, batch)| batch.num_rows() > 0);
let has_match = inner_wide
.as_ref()
.map(|batch| batch.num_rows() > 0)
.unwrap_or(false);
if has_match {
keep_mask[i] = false;
@ -951,10 +967,8 @@ async fn execute_anti_join(
}
let mask = BooleanArray::from(keep_mask);
let filtered = arrow_select::filter::filter_record_batch(outer_batch, &mask)
*wide = arrow_select::filter::filter_record_batch(wide, &mask)
.map_err(|e| OmniError::Lance(e.to_string()))?;
bindings.insert(outer_var.to_string(), filtered);
Ok(())
}
@ -1220,3 +1234,50 @@ pub(super) fn literal_to_sql(lit: &Literal) -> String {
Literal::List(_) => "NULL".to_string(), // Not supported in SQL pushdown
}
}
fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
let fields: Vec<Field> = batch.schema().fields().iter().map(|f| {
Field::new(format!("{}.{}", variable, f.name()), f.data_type().clone(), f.is_nullable())
}).collect();
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, batch.columns().to_vec()).map_err(|e| OmniError::Lance(e.to_string()))
}
fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
let n = left.num_rows();
let m = right.num_rows();
if n == 0 || m == 0 {
let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
}
let left_indices: Vec<u32> = (0..n as u32).flat_map(|i| std::iter::repeat(i).take(m)).collect();
let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
hconcat_batches(&left_expanded, &right_expanded)
}
fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
if cfg!(debug_assertions) {
let left_schema = left.schema();
let left_names: HashSet<&str> = left_schema.fields().iter().map(|f| f.name().as_str()).collect();
let right_schema = right.schema();
for f in right_schema.fields() {
debug_assert!(!left_names.contains(f.name().as_str()), "hconcat_batches: duplicate column '{}'", f.name());
}
}
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
let mut columns: Vec<ArrayRef> = left.columns().to_vec();
columns.extend(right.columns().to_vec());
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).map_err(|e| OmniError::Lance(e.to_string()))
}
fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
let columns: Vec<ArrayRef> = batch.columns().iter()
.map(|col| arrow_select::take::take(col.as_ref(), indices, None))
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| OmniError::Lance(e.to_string()))?;
RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
}

View file

@ -0,0 +1,302 @@
mod helpers;
use arrow_array::{Array, Float64Array, Int32Array, Int64Array, StringArray};
use omnigraph::db::Omnigraph;
use omnigraph::loader::{LoadMode, load_jsonl};
use omnigraph_compiler::ir::ParamMap;
use helpers::*;
// ─── Count aggregate with GROUP BY ─────────────────────────────────────────
#[tokio::test]
async fn friend_counts_grouped_by_person() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Test data: Alice knows Bob, Alice knows Charlie, Bob knows Diana
// So: Alice=2 friends, Bob=1 friend (Charlie & Diana have no outgoing knows → dropped)
let result = query_main(&mut db, TEST_QUERIES, "friend_counts", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 2);
let names = batch
.column_by_name("p.name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let counts = batch
.column_by_name("friends")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
// Ordered by friends desc: Alice(2), Bob(1)
assert_eq!(names.value(0), "Alice");
assert_eq!(counts.value(0), 2);
assert_eq!(names.value(1), "Bob");
assert_eq!(counts.value(1), 1);
}
// ─── Global count (no group key) ───────────────────────────────────────────
#[tokio::test]
async fn total_people_global_count() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let result = query_main(&mut db, TEST_QUERIES, "total_people", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let total = batch
.column_by_name("total")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(total.value(0), 4);
}
// ─── Sum/Avg/Min/Max aggregates ────────────────────────────────────────────
#[tokio::test]
async fn age_stats_per_company() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Test data: Alice(30) worksAt Acme, Bob(25) worksAt Globex
let result = query_main(&mut db, TEST_QUERIES, "age_stats", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 2);
let names = batch
.column_by_name("c.name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
// Find the Acme row (order not guaranteed without ORDER BY)
let acme_row = (0..batch.num_rows())
.find(|&i| names.value(i) == "Acme")
.expect("Acme should be in results");
let globex_row = (0..batch.num_rows())
.find(|&i| names.value(i) == "Globex")
.expect("Globex should be in results");
let headcount = batch
.column_by_name("headcount")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(headcount.value(acme_row), 1);
assert_eq!(headcount.value(globex_row), 1);
let avg_age = batch
.column_by_name("avg_age")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((avg_age.value(acme_row) - 30.0).abs() < 0.01);
assert!((avg_age.value(globex_row) - 25.0).abs() < 0.01);
let youngest = batch
.column_by_name("youngest")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(youngest.value(acme_row), 30);
assert_eq!(youngest.value(globex_row), 25);
let total_age = batch
.column_by_name("total_age")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((total_age.value(acme_row) - 30.0).abs() < 0.01);
assert!((total_age.value(globex_row) - 25.0).abs() < 0.01);
}
// ─── Aggregate + order + limit ─────────────────────────────────────────────
#[tokio::test]
async fn top_connected_person() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Alice has 2 friends (most connected)
let result = query_main(&mut db, TEST_QUERIES, "top_connected", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let names = batch
.column_by_name("p.name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let counts = batch
.column_by_name("friends")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(names.value(0), "Alice");
assert_eq!(counts.value(0), 2);
}
// ─── Inline aggregate query (not from fixture) ────────────────────────────
#[tokio::test]
async fn inline_count_with_no_friends() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Diana has no outgoing Knows edges, so she won't appear in friend_counts.
// But Alice(2) and Bob(1) do. Verify the count matches.
let queries = r#"
query fc() {
match {
$p: Person
$p knows $f
}
return { $p.name, count($f) as friends }
order { friends desc }
}
"#;
let result = query_main(&mut db, queries, "fc", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
// Only Alice and Bob have outgoing Knows edges
assert_eq!(batch.num_rows(), 2);
}
#[tokio::test]
async fn inline_global_count_empty() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
// Load only nodes, no edges
let data = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}}
{"type": "Company", "data": {"name": "Acme"}}"#;
load_jsonl(&mut db, data, LoadMode::Overwrite)
.await
.unwrap();
// Global count — should return 1 row with count=1
let queries = r#"
query tc() {
match { $p: Person }
return { count($p) as total }
}
"#;
let result = query_main(&mut db, queries, "tc", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let total = batch
.column_by_name("total")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(total.value(0), 1);
}
#[tokio::test]
async fn inline_min_max_string() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let queries = r#"
query name_range() {
match { $p: Person }
return {
min($p.name) as first_name
max($p.name) as last_name
}
}
"#;
let result = query_main(&mut db, queries, "name_range", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let first = batch
.column_by_name("first_name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let last = batch
.column_by_name("last_name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(first.value(0), "Alice");
assert_eq!(last.value(0), "Diana");
}
#[tokio::test]
async fn inline_multi_hop_aggregate() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Friends of friends count per person
let queries = r#"
query fof_counts($name: String) {
match {
$p: Person { name: $name }
$p knows $mid
$mid knows $fof
}
return { $p.name, count($fof) as fof_count }
}
"#;
// Alice → Bob, Charlie. Bob → Diana. Charlie → nobody.
// So Alice's fof = [Diana] → count = 1
let result = query_main(
&mut db,
queries,
"fof_counts",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let count = batch
.column_by_name("fof_count")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(count.value(0), 1);
}

View file

@ -76,3 +76,38 @@ query top_by_age() {
order { $p.age desc }
limit 2
}
// Aggregation: global count (no group key)
query total_people() {
match {
$p: Person
}
return { count($p) as total }
}
// Aggregation: sum/avg/min/max of age per company
query age_stats() {
match {
$p: Person
$p worksAt $c
}
return {
$c.name
sum($p.age) as total_age
avg($p.age) as avg_age
min($p.age) as youngest
max($p.age) as oldest
count($p) as headcount
}
}
// Aggregation: most connected person
query top_connected() {
match {
$p: Person
$p knows $f
}
return { $p.name, count($f) as friends }
order { friends desc }
limit 1
}