mirror of
https://github.com/ModernRelay/omnigraph.git
synced 2026-06-09 01:35:18 +02:00
Merge pull request #6 from ModernRelay/claude/omnigraph-aggregates-a53rG
Implement aggregate functions with GROUP BY support
This commit is contained in:
commit
c5a88cacb5
6 changed files with 868 additions and 152 deletions
|
|
@ -189,6 +189,29 @@ fn typecheck_read_query(catalog: &Catalog, query: &QueryDecl) -> Result<TypeCont
|
|||
));
|
||||
}
|
||||
|
||||
// T9: If any return expression is an aggregate, non-aggregate expressions
|
||||
// must be valid group-by keys (PropAccess or Variable).
|
||||
let has_agg = query
|
||||
.return_clause
|
||||
.iter()
|
||||
.any(|p| matches!(p.expr, Expr::Aggregate { .. }));
|
||||
if has_agg {
|
||||
for proj in &query.return_clause {
|
||||
if !matches!(proj.expr, Expr::Aggregate { .. }) {
|
||||
match &proj.expr {
|
||||
Expr::PropAccess { .. } | Expr::Variable(_) => {}
|
||||
_ => {
|
||||
return Err(NanoError::Type(
|
||||
"T9: non-aggregate expressions in an aggregate query must be \
|
||||
property accesses or variables"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
|
|
@ -1298,9 +1321,9 @@ fn resolve_expr_type(
|
|||
Expr::Aggregate { func, arg } => {
|
||||
let arg_type = resolve_expr_type(catalog, arg, ctx, params)?;
|
||||
|
||||
// T8: sum/avg/min/max require numeric
|
||||
// T8: sum/avg require numeric; min/max require numeric or string
|
||||
match func {
|
||||
AggFunc::Sum | AggFunc::Avg | AggFunc::Min | AggFunc::Max => {
|
||||
AggFunc::Sum | AggFunc::Avg => {
|
||||
if let ResolvedType::Scalar(s) = &arg_type
|
||||
&& (s.list || !s.scalar.is_numeric())
|
||||
{
|
||||
|
|
@ -1311,6 +1334,17 @@ fn resolve_expr_type(
|
|||
)));
|
||||
}
|
||||
}
|
||||
AggFunc::Min | AggFunc::Max => {
|
||||
if let ResolvedType::Scalar(s) = &arg_type
|
||||
&& (s.list || (!s.scalar.is_numeric() && s.scalar != ScalarType::String))
|
||||
{
|
||||
return Err(NanoError::Type(format!(
|
||||
"T8: {} requires numeric or string type, got {}",
|
||||
func,
|
||||
s.display_name()
|
||||
)));
|
||||
}
|
||||
}
|
||||
_ => {} // count works on any type
|
||||
}
|
||||
|
||||
|
|
@ -1340,8 +1374,8 @@ fn infer_projection_field(
|
|||
Expr::Aggregate { func, arg } => {
|
||||
let (data_type, nullable) = match func {
|
||||
AggFunc::Count => (DataType::Int64, true),
|
||||
AggFunc::Avg => (DataType::Float64, true),
|
||||
_ => {
|
||||
AggFunc::Avg | AggFunc::Sum => (DataType::Float64, true),
|
||||
AggFunc::Min | AggFunc::Max => {
|
||||
let resolved = resolve_expr_type(catalog, arg, ctx, params)?;
|
||||
let (data_type, _) = resolved_type_to_field_shape(catalog, &resolved)?;
|
||||
(data_type, true)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ use omnigraph_compiler::ir::{
|
|||
};
|
||||
use omnigraph_compiler::lower_mutation_query;
|
||||
use omnigraph_compiler::lower_query;
|
||||
use omnigraph_compiler::query::ast::{CompOp, Literal, NOW_PARAM_NAME};
|
||||
use omnigraph_compiler::query::ast::{AggFunc, CompOp, Literal, NOW_PARAM_NAME};
|
||||
use omnigraph_compiler::query::typecheck::{CheckedQuery, typecheck_query, typecheck_query_decl};
|
||||
use omnigraph_compiler::result::{MutationResult, QueryResult};
|
||||
use omnigraph_compiler::types::Direction;
|
||||
|
|
|
|||
|
|
@ -1,25 +1,14 @@
|
|||
use super::*;
|
||||
|
||||
pub(super) fn apply_filter(
|
||||
bindings: &mut HashMap<String, RecordBatch>,
|
||||
batch: &mut RecordBatch,
|
||||
filter: &IRFilter,
|
||||
params: &ParamMap,
|
||||
) -> Result<()> {
|
||||
// Find which binding this filter applies to
|
||||
let var_name = match &filter.left {
|
||||
IRExpr::PropAccess { variable, .. } => variable.clone(),
|
||||
_ => return Ok(()), // Can't determine variable
|
||||
};
|
||||
|
||||
let batch = bindings.get(&var_name).ok_or_else(|| {
|
||||
OmniError::manifest(format!("filter references unbound variable '{}'", var_name))
|
||||
})?;
|
||||
|
||||
let mask = evaluate_filter(batch, filter, params)?;
|
||||
let filtered = arrow_select::filter::filter_record_batch(batch, &mask)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
bindings.insert(var_name, filtered);
|
||||
*batch = filtered;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -59,12 +48,13 @@ fn evaluate_filter(
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
/// Evaluate an IR expression against a batch, producing an array.
|
||||
/// Evaluate an IR expression against a wide batch, producing an array.
|
||||
fn evaluate_expr(batch: &RecordBatch, expr: &IRExpr, params: &ParamMap) -> Result<ArrayRef> {
|
||||
match expr {
|
||||
IRExpr::PropAccess { property, .. } => {
|
||||
batch.column_by_name(property).cloned().ok_or_else(|| {
|
||||
OmniError::manifest(format!("column '{}' not found in batch", property))
|
||||
IRExpr::PropAccess { variable, property } => {
|
||||
let col_name = format!("{}.{}", variable, property);
|
||||
batch.column_by_name(&col_name).cloned().ok_or_else(|| {
|
||||
OmniError::manifest(format!("column '{}' not found in wide batch", col_name))
|
||||
})
|
||||
}
|
||||
IRExpr::Literal(lit) => literal_to_array(lit, batch.num_rows()),
|
||||
|
|
@ -307,7 +297,7 @@ fn literal_scalar_type(lit: &Literal) -> Result<ScalarType> {
|
|||
|
||||
/// Project return expressions into a result batch.
|
||||
pub(super) fn project_return(
|
||||
bindings: &HashMap<String, RecordBatch>,
|
||||
wide_batch: &RecordBatch,
|
||||
projections: &[IRProjection],
|
||||
params: &ParamMap,
|
||||
) -> Result<RecordBatch> {
|
||||
|
|
@ -317,11 +307,19 @@ pub(super) fn project_return(
|
|||
));
|
||||
}
|
||||
|
||||
// Route to aggregate path if any projection contains an aggregate
|
||||
let has_aggregates = projections
|
||||
.iter()
|
||||
.any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
|
||||
if has_aggregates {
|
||||
return aggregate_return(wide_batch, projections, params);
|
||||
}
|
||||
|
||||
let mut fields = Vec::with_capacity(projections.len());
|
||||
let mut columns: Vec<ArrayRef> = Vec::with_capacity(projections.len());
|
||||
|
||||
for proj in projections {
|
||||
let (name, col) = evaluate_projection(bindings, &proj.expr, params)?;
|
||||
let (name, col) = evaluate_projection(wide_batch, &proj.expr, params)?;
|
||||
let field_name = proj.alias.as_deref().unwrap_or(&name);
|
||||
fields.push(Field::new(
|
||||
field_name,
|
||||
|
|
@ -335,42 +333,41 @@ pub(super) fn project_return(
|
|||
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
/// Evaluate a single projection expression.
|
||||
/// Evaluate a single projection expression against a wide batch.
|
||||
fn evaluate_projection(
|
||||
bindings: &HashMap<String, RecordBatch>,
|
||||
wide_batch: &RecordBatch,
|
||||
expr: &IRExpr,
|
||||
params: &ParamMap,
|
||||
) -> Result<(String, ArrayRef)> {
|
||||
match expr {
|
||||
IRExpr::PropAccess { variable, property } => {
|
||||
let batch = bindings.get(variable).ok_or_else(|| {
|
||||
let col_name = format!("{}.{}", variable, property);
|
||||
let col = wide_batch.column_by_name(&col_name).ok_or_else(|| {
|
||||
OmniError::manifest(format!(
|
||||
"projection references unbound variable '{}'",
|
||||
variable
|
||||
"column '{}' not found in wide batch",
|
||||
col_name
|
||||
))
|
||||
})?;
|
||||
let col = batch.column_by_name(property).ok_or_else(|| {
|
||||
OmniError::manifest(format!(
|
||||
"column '{}' not found in binding '{}'",
|
||||
property, variable
|
||||
))
|
||||
})?;
|
||||
Ok((format!("{}.{}", variable, property), col.clone()))
|
||||
Ok((col_name, col.clone()))
|
||||
}
|
||||
IRExpr::Literal(lit) => {
|
||||
// Get row count from first binding
|
||||
let num_rows = bindings.values().next().map(|b| b.num_rows()).unwrap_or(0);
|
||||
let arr = literal_to_array(lit, num_rows)?;
|
||||
let arr = literal_to_array(lit, wide_batch.num_rows())?;
|
||||
Ok(("literal".to_string(), arr))
|
||||
}
|
||||
IRExpr::Param(name) => {
|
||||
let lit = params
|
||||
.get(name)
|
||||
.ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?;
|
||||
let num_rows = bindings.values().next().map(|b| b.num_rows()).unwrap_or(0);
|
||||
let arr = literal_to_array(lit, num_rows)?;
|
||||
let arr = literal_to_array(lit, wide_batch.num_rows())?;
|
||||
Ok((name.clone(), arr))
|
||||
}
|
||||
IRExpr::Variable(name) => {
|
||||
let col_name = format!("{}.id", name);
|
||||
let col = wide_batch.column_by_name(&col_name).ok_or_else(|| {
|
||||
OmniError::manifest(format!("column '{}' not found in wide batch", col_name))
|
||||
})?;
|
||||
Ok((name.clone(), col.clone()))
|
||||
}
|
||||
_ => Err(OmniError::manifest(format!(
|
||||
"unsupported projection expression: {:?}",
|
||||
expr
|
||||
|
|
@ -378,11 +375,12 @@ fn evaluate_projection(
|
|||
}
|
||||
}
|
||||
|
||||
/// Apply ordering to a batch.
|
||||
/// Apply ordering to a batch. `source` is used to resolve PropAccess columns
|
||||
/// (typically the wide batch, or the result batch itself for aggregates).
|
||||
pub(super) fn apply_ordering(
|
||||
batch: RecordBatch,
|
||||
orderings: &[IROrdering],
|
||||
bindings: &HashMap<String, RecordBatch>,
|
||||
source: &RecordBatch,
|
||||
_params: &ParamMap,
|
||||
) -> Result<RecordBatch> {
|
||||
use arrow_ord::sort::{SortColumn, lexsort_to_indices};
|
||||
|
|
@ -392,16 +390,11 @@ pub(super) fn apply_ordering(
|
|||
for ordering in orderings {
|
||||
let col = match &ordering.expr {
|
||||
IRExpr::PropAccess { variable, property } => {
|
||||
let binding = bindings.get(variable).ok_or_else(|| {
|
||||
OmniError::manifest(format!(
|
||||
"ordering references unbound variable '{}'",
|
||||
variable
|
||||
))
|
||||
})?;
|
||||
binding
|
||||
.column_by_name(property)
|
||||
let col_name = format!("{}.{}", variable, property);
|
||||
source
|
||||
.column_by_name(&col_name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest(format!("column '{}' not found for ordering", property))
|
||||
OmniError::manifest(format!("column '{}' not found for ordering", col_name))
|
||||
})?
|
||||
.clone()
|
||||
}
|
||||
|
|
@ -442,3 +435,294 @@ pub(super) fn apply_ordering(
|
|||
|
||||
RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
// ─── Aggregate execution ───────────────────────────────────────────────────
|
||||
|
||||
/// Project return expressions that contain aggregates.
|
||||
fn aggregate_return(
|
||||
wide: &RecordBatch,
|
||||
projections: &[IRProjection],
|
||||
params: &ParamMap,
|
||||
) -> Result<RecordBatch> {
|
||||
let num_rows = wide.num_rows();
|
||||
|
||||
struct GroupKey {
|
||||
proj_idx: usize,
|
||||
name: String,
|
||||
column: ArrayRef,
|
||||
}
|
||||
struct AggProj {
|
||||
proj_idx: usize,
|
||||
name: String,
|
||||
func: AggFunc,
|
||||
arg_column: ArrayRef,
|
||||
}
|
||||
|
||||
let mut group_keys: Vec<GroupKey> = Vec::new();
|
||||
let mut agg_projs: Vec<AggProj> = Vec::new();
|
||||
|
||||
for (i, proj) in projections.iter().enumerate() {
|
||||
match &proj.expr {
|
||||
IRExpr::Aggregate { func, arg } => {
|
||||
let (name, col) = evaluate_projection(wide, arg, params)?;
|
||||
let alias = proj.alias.as_deref().unwrap_or(&name);
|
||||
agg_projs.push(AggProj {
|
||||
proj_idx: i,
|
||||
name: alias.to_string(),
|
||||
func: *func,
|
||||
arg_column: col,
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
let (name, col) = evaluate_projection(wide, &proj.expr, params)?;
|
||||
let alias = proj.alias.as_deref().unwrap_or(&name);
|
||||
group_keys.push(GroupKey {
|
||||
proj_idx: i,
|
||||
name: alias.to_string(),
|
||||
column: col,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle empty input: return a single row with count=0, others=null
|
||||
if num_rows == 0 && group_keys.is_empty() {
|
||||
return build_empty_aggregate_result(projections);
|
||||
}
|
||||
|
||||
// Build group assignments
|
||||
let mut group_map: HashMap<String, usize> = HashMap::new();
|
||||
let mut group_indices: Vec<Vec<usize>> = Vec::new();
|
||||
let group_cols: Vec<&ArrayRef> = group_keys.iter().map(|gk| &gk.column).collect();
|
||||
|
||||
if group_keys.is_empty() {
|
||||
group_indices.push((0..num_rows).collect());
|
||||
} else {
|
||||
for row in 0..num_rows {
|
||||
let key = build_group_key(&group_cols, row);
|
||||
let group_idx = match group_map.get(&key) {
|
||||
Some(&idx) => idx,
|
||||
None => {
|
||||
let idx = group_indices.len();
|
||||
group_map.insert(key, idx);
|
||||
group_indices.push(Vec::new());
|
||||
idx
|
||||
}
|
||||
};
|
||||
group_indices[group_idx].push(row);
|
||||
}
|
||||
}
|
||||
|
||||
let num_groups = group_indices.len();
|
||||
let mut result_columns: Vec<(usize, String, ArrayRef)> =
|
||||
Vec::with_capacity(projections.len());
|
||||
|
||||
for gk in &group_keys {
|
||||
let first_row_indices: Vec<u32> =
|
||||
group_indices.iter().map(|rows| rows[0] as u32).collect();
|
||||
let take_idx = UInt32Array::from(first_row_indices);
|
||||
let col = arrow_select::take::take(gk.column.as_ref(), &take_idx, None)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
result_columns.push((gk.proj_idx, gk.name.clone(), col));
|
||||
}
|
||||
|
||||
for ap in &agg_projs {
|
||||
let col = compute_aggregate(ap.func, &ap.arg_column, &group_indices, num_groups)?;
|
||||
result_columns.push((ap.proj_idx, ap.name.clone(), col));
|
||||
}
|
||||
|
||||
result_columns.sort_by_key(|(idx, _, _)| *idx);
|
||||
|
||||
let fields: Vec<Field> = result_columns
|
||||
.iter()
|
||||
.map(|(_, name, col)| Field::new(name, col.data_type().clone(), true))
|
||||
.collect();
|
||||
let columns: Vec<ArrayRef> = result_columns.into_iter().map(|(_, _, col)| col).collect();
|
||||
|
||||
let schema = Arc::new(Schema::new(fields));
|
||||
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
/// Build a string key for grouping using length-prefixed encoding.
|
||||
fn build_group_key(group_columns: &[&ArrayRef], row: usize) -> String {
|
||||
let mut key = String::new();
|
||||
for col in group_columns {
|
||||
if col.is_null(row) {
|
||||
key.push('N');
|
||||
} else {
|
||||
let val = arrow_cast::display::array_value_to_string(col, row).unwrap_or_default();
|
||||
key.push('L');
|
||||
key.push_str(&val.len().to_string());
|
||||
key.push(':');
|
||||
key.push_str(&val);
|
||||
}
|
||||
}
|
||||
key
|
||||
}
|
||||
|
||||
fn compute_aggregate(
|
||||
func: AggFunc,
|
||||
arg: &ArrayRef,
|
||||
group_indices: &[Vec<usize>],
|
||||
num_groups: usize,
|
||||
) -> Result<ArrayRef> {
|
||||
match func {
|
||||
AggFunc::Count => {
|
||||
let mut builder = Int64Builder::with_capacity(num_groups);
|
||||
for group in group_indices {
|
||||
let count = group.iter().filter(|&&i| !arg.is_null(i)).count();
|
||||
builder.append_value(count as i64);
|
||||
}
|
||||
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||
}
|
||||
AggFunc::Sum => compute_sum(arg, group_indices, num_groups),
|
||||
AggFunc::Avg => compute_avg(arg, group_indices, num_groups),
|
||||
AggFunc::Min => compute_min_max(arg, group_indices, num_groups, true),
|
||||
AggFunc::Max => compute_min_max(arg, group_indices, num_groups, false),
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_sum(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize) -> Result<ArrayRef> {
|
||||
macro_rules! sum_numeric {
|
||||
($arr_type:ty, $arg:expr, $dt:expr) => {{
|
||||
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
|
||||
OmniError::manifest(format!("sum: expected {:?}, got {:?}", $dt, $arg.data_type()))
|
||||
})?;
|
||||
let mut builder = Float64Builder::with_capacity(num_groups);
|
||||
for group in group_indices {
|
||||
let mut sum = None;
|
||||
for &i in group {
|
||||
if !arr.is_null(i) {
|
||||
*sum.get_or_insert(0.0f64) += arr.value(i) as f64;
|
||||
}
|
||||
}
|
||||
match sum {
|
||||
Some(v) => builder.append_value(v),
|
||||
None => builder.append_null(),
|
||||
}
|
||||
}
|
||||
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||
}};
|
||||
}
|
||||
match arg.data_type() {
|
||||
dt @ DataType::Int32 => sum_numeric!(Int32Array, arg, dt),
|
||||
dt @ DataType::Int64 => sum_numeric!(Int64Array, arg, dt),
|
||||
dt @ DataType::UInt32 => sum_numeric!(UInt32Array, arg, dt),
|
||||
dt @ DataType::UInt64 => sum_numeric!(UInt64Array, arg, dt),
|
||||
dt @ DataType::Float32 => sum_numeric!(Float32Array, arg, dt),
|
||||
dt @ DataType::Float64 => sum_numeric!(Float64Array, arg, dt),
|
||||
dt => Err(OmniError::manifest(format!("sum: unsupported type {:?}", dt))),
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_avg(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize) -> Result<ArrayRef> {
|
||||
macro_rules! avg_typed {
|
||||
($arr_type:ty, $arg:expr) => {{
|
||||
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
|
||||
OmniError::manifest(format!("avg: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type()))
|
||||
})?;
|
||||
let mut builder = Float64Builder::with_capacity(num_groups);
|
||||
for group in group_indices {
|
||||
let mut sum = 0.0f64;
|
||||
let mut count = 0usize;
|
||||
for &i in group {
|
||||
if !arr.is_null(i) { sum += arr.value(i) as f64; count += 1; }
|
||||
}
|
||||
if count > 0 { builder.append_value(sum / count as f64); } else { builder.append_null(); }
|
||||
}
|
||||
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||
}};
|
||||
}
|
||||
match arg.data_type() {
|
||||
DataType::Int32 => avg_typed!(Int32Array, arg),
|
||||
DataType::Int64 => avg_typed!(Int64Array, arg),
|
||||
DataType::UInt32 => avg_typed!(UInt32Array, arg),
|
||||
DataType::UInt64 => avg_typed!(UInt64Array, arg),
|
||||
DataType::Float32 => avg_typed!(Float32Array, arg),
|
||||
DataType::Float64 => avg_typed!(Float64Array, arg),
|
||||
dt => Err(OmniError::manifest(format!("avg: unsupported type {:?}", dt))),
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_min_max(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize, is_min: bool) -> Result<ArrayRef> {
|
||||
macro_rules! minmax_typed {
|
||||
($arr_type:ty, $builder_type:ty, $arg:expr, $is_min:expr) => {{
|
||||
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
|
||||
OmniError::manifest(format!("min/max: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type()))
|
||||
})?;
|
||||
let mut builder = <$builder_type>::with_capacity(num_groups);
|
||||
for group in group_indices {
|
||||
let mut result = None;
|
||||
for &i in group {
|
||||
if !arr.is_null(i) {
|
||||
let v = arr.value(i);
|
||||
result = Some(match result {
|
||||
None => v,
|
||||
Some(cur) => if $is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } },
|
||||
});
|
||||
}
|
||||
}
|
||||
match result { Some(v) => builder.append_value(v), None => builder.append_null() }
|
||||
}
|
||||
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||
}};
|
||||
}
|
||||
match arg.data_type() {
|
||||
DataType::Int32 => minmax_typed!(Int32Array, Int32Builder, arg, is_min),
|
||||
DataType::Int64 => minmax_typed!(Int64Array, Int64Builder, arg, is_min),
|
||||
DataType::UInt32 => minmax_typed!(UInt32Array, UInt32Builder, arg, is_min),
|
||||
DataType::UInt64 => minmax_typed!(UInt64Array, UInt64Builder, arg, is_min),
|
||||
DataType::Float32 => minmax_typed!(Float32Array, Float32Builder, arg, is_min),
|
||||
DataType::Float64 => minmax_typed!(Float64Array, Float64Builder, arg, is_min),
|
||||
DataType::Utf8 => {
|
||||
let arr = arg.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
|
||||
OmniError::manifest(format!("min/max: expected Utf8, got {:?}", arg.data_type()))
|
||||
})?;
|
||||
let mut builder = StringBuilder::with_capacity(num_groups, num_groups * 16);
|
||||
for group in group_indices {
|
||||
let mut result: Option<&str> = None;
|
||||
for &i in group {
|
||||
if !arr.is_null(i) {
|
||||
let v = arr.value(i);
|
||||
result = Some(match result {
|
||||
None => v,
|
||||
Some(cur) => if is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } },
|
||||
});
|
||||
}
|
||||
}
|
||||
match result { Some(v) => builder.append_value(v), None => builder.append_null() }
|
||||
}
|
||||
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||
}
|
||||
dt => Err(OmniError::manifest(format!("min/max: unsupported type {:?}", dt))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a single-row result for an aggregate query with zero input rows and no group keys.
|
||||
fn build_empty_aggregate_result(projections: &[IRProjection]) -> Result<RecordBatch> {
|
||||
let mut fields = Vec::with_capacity(projections.len());
|
||||
let mut columns: Vec<ArrayRef> = Vec::with_capacity(projections.len());
|
||||
|
||||
for proj in projections {
|
||||
let name = proj.alias.as_deref().unwrap_or("?");
|
||||
match &proj.expr {
|
||||
IRExpr::Aggregate { func, .. } => match func {
|
||||
AggFunc::Count => {
|
||||
fields.push(Field::new(name, DataType::Int64, true));
|
||||
columns.push(Arc::new(Int64Array::from(vec![0i64])) as ArrayRef);
|
||||
}
|
||||
_ => {
|
||||
fields.push(Field::new(name, DataType::Float64, true));
|
||||
columns.push(Arc::new(Float64Array::from(vec![None as Option<f64>])) as ArrayRef);
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
fields.push(Field::new(name, DataType::Utf8, true));
|
||||
columns.push(Arc::new(StringArray::from(vec![None as Option<&str>])) as ArrayRef);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let schema = Arc::new(Schema::new(fields));
|
||||
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -358,25 +358,21 @@ pub async fn execute_query(
|
|||
return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
|
||||
}
|
||||
|
||||
let mut bindings: HashMap<String, RecordBatch> = HashMap::new();
|
||||
|
||||
execute_pipeline(
|
||||
&ir.pipeline,
|
||||
params,
|
||||
snapshot,
|
||||
graph_index,
|
||||
catalog,
|
||||
&mut bindings,
|
||||
&search_mode,
|
||||
)
|
||||
.await?;
|
||||
let mut wide: Option<RecordBatch> = None;
|
||||
execute_pipeline(&ir.pipeline, params, snapshot, graph_index, catalog, &mut wide, &search_mode).await?;
|
||||
let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
|
||||
|
||||
// Project return expressions
|
||||
let mut result_batch = project_return(&bindings, &ir.return_exprs, params)?;
|
||||
let has_aggregates = ir.return_exprs.iter().any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
|
||||
let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
|
||||
|
||||
// Apply ordering (skip if search mode already ordered the results)
|
||||
if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
|
||||
result_batch = apply_ordering(result_batch, &ir.order_by, &bindings, params)?;
|
||||
result_batch = if has_aggregates {
|
||||
apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
|
||||
} else {
|
||||
apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
|
||||
};
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
|
|
@ -403,27 +399,27 @@ async fn execute_rrf_query(
|
|||
rrf: &RrfMode,
|
||||
) -> Result<QueryResult> {
|
||||
// Execute primary search
|
||||
let mut primary_bindings: HashMap<String, RecordBatch> = HashMap::new();
|
||||
let mut primary_wide: Option<RecordBatch> = None;
|
||||
execute_pipeline(
|
||||
&ir.pipeline,
|
||||
params,
|
||||
snapshot,
|
||||
graph_index,
|
||||
catalog,
|
||||
&mut primary_bindings,
|
||||
&mut primary_wide,
|
||||
&rrf.primary,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Execute secondary search
|
||||
let mut secondary_bindings: HashMap<String, RecordBatch> = HashMap::new();
|
||||
let mut secondary_wide: Option<RecordBatch> = None;
|
||||
execute_pipeline(
|
||||
&ir.pipeline,
|
||||
params,
|
||||
snapshot,
|
||||
graph_index,
|
||||
catalog,
|
||||
&mut secondary_bindings,
|
||||
&mut secondary_wide,
|
||||
&rrf.secondary,
|
||||
)
|
||||
.await?;
|
||||
|
|
@ -438,13 +434,13 @@ async fn execute_rrf_query(
|
|||
.or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
|
||||
.ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
|
||||
|
||||
let primary_batch = primary_bindings.get(primary_var).ok_or_else(|| {
|
||||
let primary_batch = primary_wide.as_ref().ok_or_else(|| {
|
||||
OmniError::manifest(format!(
|
||||
"rrf primary variable '{}' not in bindings",
|
||||
primary_var
|
||||
))
|
||||
})?;
|
||||
let secondary_batch = secondary_bindings.get(primary_var).ok_or_else(|| {
|
||||
let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
|
||||
OmniError::manifest(format!(
|
||||
"rrf secondary variable '{}' not in bindings",
|
||||
primary_var
|
||||
|
|
@ -452,8 +448,9 @@ async fn execute_rrf_query(
|
|||
})?;
|
||||
|
||||
// Build ID → rank maps
|
||||
let primary_ids = extract_id_column(primary_batch)?;
|
||||
let secondary_ids = extract_id_column(secondary_batch)?;
|
||||
let id_col_name = format!("{}.id", primary_var);
|
||||
let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
|
||||
let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
|
||||
|
||||
let mut primary_rank: HashMap<String, usize> = HashMap::new();
|
||||
for (i, id) in primary_ids.iter().enumerate() {
|
||||
|
|
@ -510,24 +507,21 @@ async fn execute_rrf_query(
|
|||
// Reconstruct a combined batch for the binding in winning order
|
||||
let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
|
||||
|
||||
// Replace the binding and project
|
||||
let mut fused_bindings = primary_bindings;
|
||||
fused_bindings.insert(primary_var.to_string(), fused_batch);
|
||||
|
||||
let result_batch = project_return(&fused_bindings, &ir.return_exprs, params)?;
|
||||
// Project directly from fused batch
|
||||
let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
|
||||
|
||||
// Already ordered by RRF score + already limited
|
||||
Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
|
||||
}
|
||||
|
||||
fn extract_id_column(batch: &RecordBatch) -> Result<Vec<String>> {
|
||||
fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
|
||||
let col = batch
|
||||
.column_by_name("id")
|
||||
.ok_or_else(|| OmniError::manifest("batch missing 'id' column for RRF".to_string()))?;
|
||||
.column_by_name(col_name)
|
||||
.ok_or_else(|| OmniError::manifest(format!("batch missing '{}' column for RRF", col_name)))?;
|
||||
let ids = col
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.ok_or_else(|| OmniError::manifest("'id' column is not Utf8".to_string()))?;
|
||||
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
|
||||
Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
|
||||
}
|
||||
|
||||
|
|
@ -585,7 +579,7 @@ fn execute_pipeline<'a>(
|
|||
snapshot: &'a Snapshot,
|
||||
graph_index: Option<&'a GraphIndex>,
|
||||
catalog: &'a Catalog,
|
||||
bindings: &'a mut HashMap<String, RecordBatch>,
|
||||
wide: &'a mut Option<RecordBatch>,
|
||||
search_mode: &'a SearchMode,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
|
|
@ -632,10 +626,16 @@ fn execute_pipeline<'a>(
|
|||
search_mode,
|
||||
)
|
||||
.await?;
|
||||
bindings.insert(variable.clone(), batch);
|
||||
let prefixed = prefix_batch(&batch, variable)?;
|
||||
*wide = Some(match wide.take() {
|
||||
None => prefixed,
|
||||
Some(existing) => cross_join_batches(&existing, &prefixed)?,
|
||||
});
|
||||
}
|
||||
IROp::Filter(filter) => {
|
||||
apply_filter(bindings, filter, params)?;
|
||||
if let Some(batch) = wide.as_mut() {
|
||||
apply_filter(batch, filter, params)?;
|
||||
}
|
||||
}
|
||||
IROp::Expand {
|
||||
src_var,
|
||||
|
|
@ -649,17 +649,20 @@ fn execute_pipeline<'a>(
|
|||
let gi = graph_index.ok_or_else(|| {
|
||||
OmniError::manifest("graph index required for traversal".to_string())
|
||||
})?;
|
||||
let batch = execute_expand(
|
||||
bindings, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
|
||||
dst_type, *min_hops, *max_hops,
|
||||
)
|
||||
.await?;
|
||||
bindings.insert(dst_var.clone(), batch);
|
||||
if let Some(batch) = wide.as_mut() {
|
||||
execute_expand(
|
||||
batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
|
||||
dst_type, *min_hops, *max_hops,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
IROp::AntiJoin { outer_var, inner } => {
|
||||
let gi = graph_index;
|
||||
execute_anti_join(bindings, inner, params, snapshot, gi, catalog, outer_var)
|
||||
.await?;
|
||||
if let Some(batch) = wide.as_mut() {
|
||||
execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -669,28 +672,26 @@ fn execute_pipeline<'a>(
|
|||
|
||||
/// Execute a graph traversal (Expand).
|
||||
async fn execute_expand(
|
||||
bindings: &HashMap<String, RecordBatch>,
|
||||
wide: &mut RecordBatch,
|
||||
graph_index: &GraphIndex,
|
||||
snapshot: &Snapshot,
|
||||
catalog: &Catalog,
|
||||
src_var: &str,
|
||||
_dst_var: &str,
|
||||
dst_var: &str,
|
||||
edge_type: &str,
|
||||
direction: Direction,
|
||||
dst_type: &str,
|
||||
min_hops: u32,
|
||||
max_hops: Option<u32>,
|
||||
) -> Result<RecordBatch> {
|
||||
let src_batch = bindings.get(src_var).ok_or_else(|| {
|
||||
OmniError::manifest(format!("expand references unbound variable '{}'", src_var))
|
||||
})?;
|
||||
|
||||
let src_ids = src_batch
|
||||
.column_by_name("id")
|
||||
.ok_or_else(|| OmniError::manifest("source batch missing 'id' column".to_string()))?
|
||||
) -> Result<()> {
|
||||
let src_id_col_name = format!("{}.id", src_var);
|
||||
let src_ids = wide
|
||||
.column_by_name(&src_id_col_name)
|
||||
.ok_or_else(|| OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name)))?
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.ok_or_else(|| OmniError::manifest("source 'id' column is not Utf8".to_string()))?;
|
||||
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
|
||||
.clone();
|
||||
|
||||
// Determine which type index to use for source and destination
|
||||
let edge_def = catalog
|
||||
|
|
@ -720,8 +721,9 @@ async fn execute_expand(
|
|||
|
||||
let same_type = src_type_name == dst_type_name;
|
||||
|
||||
// BFS to collect reachable destination dense IDs
|
||||
let mut result_dst_ids: Vec<String> = Vec::new();
|
||||
// BFS to collect (src_row_idx, dst_id) pairs with per-source dedup
|
||||
let mut src_indices: Vec<u32> = Vec::new();
|
||||
let mut dst_id_list: Vec<String> = Vec::new();
|
||||
for i in 0..src_ids.len() {
|
||||
let src_id = src_ids.value(i);
|
||||
let Some(src_dense) = src_type_idx.to_dense(src_id) else {
|
||||
|
|
@ -749,7 +751,8 @@ async fn execute_expand(
|
|||
if let Some(dst_id) = dst_type_idx.to_id(neighbor) {
|
||||
let dst_id = dst_id.to_string();
|
||||
if seen_dst_ids.insert(dst_id.clone()) {
|
||||
result_dst_ids.push(dst_id);
|
||||
src_indices.push(i as u32);
|
||||
dst_id_list.push(dst_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -764,7 +767,36 @@ async fn execute_expand(
|
|||
}
|
||||
|
||||
// Hydrate destination nodes from the snapshot
|
||||
hydrate_nodes(snapshot, catalog, dst_type, &result_dst_ids).await
|
||||
let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list).await?;
|
||||
|
||||
// Build a mapping from dst_id to row index in dst_batch
|
||||
let dst_batch_id_col = dst_batch
|
||||
.column_by_name("id")
|
||||
.ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
|
||||
let dst_id_to_row: HashMap<&str, usize> = (0..dst_batch_id_col.len())
|
||||
.map(|i| (dst_batch_id_col.value(i), i))
|
||||
.collect();
|
||||
|
||||
// Build aligned src/dst index arrays (only for IDs that exist in hydrated batch)
|
||||
let mut final_src_indices: Vec<u32> = Vec::new();
|
||||
let mut dst_indices: Vec<u32> = Vec::new();
|
||||
for (src_idx, dst_id) in src_indices.iter().zip(dst_id_list.iter()) {
|
||||
if let Some(&dst_row) = dst_id_to_row.get(dst_id.as_str()) {
|
||||
final_src_indices.push(*src_idx);
|
||||
dst_indices.push(dst_row as u32);
|
||||
}
|
||||
}
|
||||
|
||||
let src_take = UInt32Array::from(final_src_indices);
|
||||
let dst_take = UInt32Array::from(dst_indices);
|
||||
let expanded_wide = take_batch(wide, &src_take)?;
|
||||
let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
|
||||
let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
|
||||
*wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load full node rows for a set of IDs from a snapshot.
|
||||
|
|
@ -829,15 +861,15 @@ async fn hydrate_nodes(
|
|||
Ok(scan_result)
|
||||
}
|
||||
|
||||
/// Try bulk anti-join via CSR existence check. Returns Some if the inner
|
||||
/// Try bulk anti-join via CSR existence check. Returns Some(mask) if the inner
|
||||
/// pipeline is a single Expand from outer_var (the common negation pattern).
|
||||
fn try_bulk_anti_join(
|
||||
outer_batch: &RecordBatch,
|
||||
fn try_bulk_anti_join_mask(
|
||||
wide: &RecordBatch,
|
||||
inner_pipeline: &[IROp],
|
||||
graph_index: Option<&GraphIndex>,
|
||||
catalog: &Catalog,
|
||||
outer_var: &str,
|
||||
) -> Option<Result<RecordBatch>> {
|
||||
) -> Option<BooleanArray> {
|
||||
if inner_pipeline.len() != 1 {
|
||||
return None;
|
||||
}
|
||||
|
|
@ -866,8 +898,9 @@ fn try_bulk_anti_join(
|
|||
}?;
|
||||
let type_idx = gi.type_index(src_type_name)?;
|
||||
|
||||
let outer_ids = outer_batch
|
||||
.column_by_name("id")?
|
||||
let id_col_name = format!("{}.id", outer_var);
|
||||
let outer_ids = wide
|
||||
.column_by_name(&id_col_name)?
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()?;
|
||||
|
||||
|
|
@ -881,16 +914,12 @@ fn try_bulk_anti_join(
|
|||
})
|
||||
.collect();
|
||||
|
||||
let mask = BooleanArray::from(keep_mask);
|
||||
Some(
|
||||
arrow_select::filter::filter_record_batch(outer_batch, &mask)
|
||||
.map_err(|e| OmniError::Lance(e.to_string())),
|
||||
)
|
||||
Some(BooleanArray::from(keep_mask))
|
||||
}
|
||||
|
||||
/// Execute an AntiJoin: remove rows from outer_var where the inner pipeline finds matches.
|
||||
/// Execute an AntiJoin: remove rows from wide batch where the inner pipeline finds matches.
|
||||
async fn execute_anti_join(
|
||||
bindings: &mut HashMap<String, RecordBatch>,
|
||||
wide: &mut RecordBatch,
|
||||
inner_pipeline: &[IROp],
|
||||
params: &ParamMap,
|
||||
snapshot: &Snapshot,
|
||||
|
|
@ -898,35 +927,22 @@ async fn execute_anti_join(
|
|||
catalog: &Catalog,
|
||||
outer_var: &str,
|
||||
) -> Result<()> {
|
||||
let outer_batch = bindings.get(outer_var).ok_or_else(|| {
|
||||
OmniError::manifest(format!(
|
||||
"anti-join references unbound variable '{}'",
|
||||
outer_var
|
||||
))
|
||||
})?;
|
||||
|
||||
// Fast path: bulk CSR existence check (O(N), zero Lance I/O)
|
||||
if let Some(result) =
|
||||
try_bulk_anti_join(outer_batch, inner_pipeline, graph_index, catalog, outer_var)
|
||||
if let Some(mask) =
|
||||
try_bulk_anti_join_mask(wide, inner_pipeline, graph_index, catalog, outer_var)
|
||||
{
|
||||
bindings.insert(outer_var.to_string(), result?);
|
||||
*wide = arrow_select::filter::filter_record_batch(wide, &mask)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Slow path: per-row inner pipeline execution
|
||||
let outer_ids = outer_batch
|
||||
.column_by_name("id")
|
||||
.ok_or_else(|| OmniError::manifest("outer batch missing 'id' column".to_string()))?
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.ok_or_else(|| OmniError::manifest("outer 'id' column is not Utf8".to_string()))?;
|
||||
let num_rows = wide.num_rows();
|
||||
let mut keep_mask = vec![true; num_rows];
|
||||
|
||||
let mut keep_mask = vec![true; outer_batch.num_rows()];
|
||||
|
||||
for i in 0..outer_ids.len() {
|
||||
let single_row = outer_batch.slice(i, 1);
|
||||
let mut inner_bindings: HashMap<String, RecordBatch> = HashMap::new();
|
||||
inner_bindings.insert(outer_var.to_string(), single_row);
|
||||
for i in 0..num_rows {
|
||||
let single_row = wide.slice(i, 1);
|
||||
let mut inner_wide: Option<RecordBatch> = Some(single_row);
|
||||
|
||||
let no_search = SearchMode::default();
|
||||
execute_pipeline(
|
||||
|
|
@ -935,15 +951,15 @@ async fn execute_anti_join(
|
|||
snapshot,
|
||||
graph_index,
|
||||
catalog,
|
||||
&mut inner_bindings,
|
||||
&mut inner_wide,
|
||||
&no_search,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let has_match = inner_bindings
|
||||
.iter()
|
||||
.filter(|(k, _)| *k != outer_var)
|
||||
.any(|(_, batch)| batch.num_rows() > 0);
|
||||
let has_match = inner_wide
|
||||
.as_ref()
|
||||
.map(|batch| batch.num_rows() > 0)
|
||||
.unwrap_or(false);
|
||||
|
||||
if has_match {
|
||||
keep_mask[i] = false;
|
||||
|
|
@ -951,10 +967,8 @@ async fn execute_anti_join(
|
|||
}
|
||||
|
||||
let mask = BooleanArray::from(keep_mask);
|
||||
let filtered = arrow_select::filter::filter_record_batch(outer_batch, &mask)
|
||||
*wide = arrow_select::filter::filter_record_batch(wide, &mask)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
bindings.insert(outer_var.to_string(), filtered);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -1220,3 +1234,50 @@ pub(super) fn literal_to_sql(lit: &Literal) -> String {
|
|||
Literal::List(_) => "NULL".to_string(), // Not supported in SQL pushdown
|
||||
}
|
||||
}
|
||||
|
||||
fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
|
||||
let fields: Vec<Field> = batch.schema().fields().iter().map(|f| {
|
||||
Field::new(format!("{}.{}", variable, f.name()), f.data_type().clone(), f.is_nullable())
|
||||
}).collect();
|
||||
let schema = Arc::new(Schema::new(fields));
|
||||
RecordBatch::try_new(schema, batch.columns().to_vec()).map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
|
||||
let n = left.num_rows();
|
||||
let m = right.num_rows();
|
||||
if n == 0 || m == 0 {
|
||||
let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
|
||||
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
|
||||
return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
|
||||
}
|
||||
let left_indices: Vec<u32> = (0..n as u32).flat_map(|i| std::iter::repeat(i).take(m)).collect();
|
||||
let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
|
||||
let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
|
||||
let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
|
||||
hconcat_batches(&left_expanded, &right_expanded)
|
||||
}
|
||||
|
||||
fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
|
||||
let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
|
||||
if cfg!(debug_assertions) {
|
||||
let left_schema = left.schema();
|
||||
let left_names: HashSet<&str> = left_schema.fields().iter().map(|f| f.name().as_str()).collect();
|
||||
let right_schema = right.schema();
|
||||
for f in right_schema.fields() {
|
||||
debug_assert!(!left_names.contains(f.name().as_str()), "hconcat_batches: duplicate column '{}'", f.name());
|
||||
}
|
||||
}
|
||||
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
|
||||
let mut columns: Vec<ArrayRef> = left.columns().to_vec();
|
||||
columns.extend(right.columns().to_vec());
|
||||
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
|
||||
let columns: Vec<ArrayRef> = batch.columns().iter()
|
||||
.map(|col| arrow_select::take::take(col.as_ref(), indices, None))
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
|
|
|||
302
crates/omnigraph/tests/aggregation.rs
Normal file
302
crates/omnigraph/tests/aggregation.rs
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
mod helpers;
|
||||
|
||||
use arrow_array::{Array, Float64Array, Int32Array, Int64Array, StringArray};
|
||||
|
||||
use omnigraph::db::Omnigraph;
|
||||
use omnigraph::loader::{LoadMode, load_jsonl};
|
||||
use omnigraph_compiler::ir::ParamMap;
|
||||
|
||||
use helpers::*;
|
||||
|
||||
// ─── Count aggregate with GROUP BY ─────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn friend_counts_grouped_by_person() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
// Test data: Alice knows Bob, Alice knows Charlie, Bob knows Diana
|
||||
// So: Alice=2 friends, Bob=1 friend (Charlie & Diana have no outgoing knows → dropped)
|
||||
let result = query_main(&mut db, TEST_QUERIES, "friend_counts", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batch = result.concat_batches().unwrap();
|
||||
assert_eq!(batch.num_rows(), 2);
|
||||
|
||||
let names = batch
|
||||
.column_by_name("p.name")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let counts = batch
|
||||
.column_by_name("friends")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.unwrap();
|
||||
|
||||
// Ordered by friends desc: Alice(2), Bob(1)
|
||||
assert_eq!(names.value(0), "Alice");
|
||||
assert_eq!(counts.value(0), 2);
|
||||
assert_eq!(names.value(1), "Bob");
|
||||
assert_eq!(counts.value(1), 1);
|
||||
}
|
||||
|
||||
// ─── Global count (no group key) ───────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn total_people_global_count() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
let result = query_main(&mut db, TEST_QUERIES, "total_people", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batch = result.concat_batches().unwrap();
|
||||
assert_eq!(batch.num_rows(), 1);
|
||||
|
||||
let total = batch
|
||||
.column_by_name("total")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.unwrap();
|
||||
assert_eq!(total.value(0), 4);
|
||||
}
|
||||
|
||||
// ─── Sum/Avg/Min/Max aggregates ────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn age_stats_per_company() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
// Test data: Alice(30) worksAt Acme, Bob(25) worksAt Globex
|
||||
let result = query_main(&mut db, TEST_QUERIES, "age_stats", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batch = result.concat_batches().unwrap();
|
||||
assert_eq!(batch.num_rows(), 2);
|
||||
|
||||
let names = batch
|
||||
.column_by_name("c.name")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
|
||||
// Find the Acme row (order not guaranteed without ORDER BY)
|
||||
let acme_row = (0..batch.num_rows())
|
||||
.find(|&i| names.value(i) == "Acme")
|
||||
.expect("Acme should be in results");
|
||||
let globex_row = (0..batch.num_rows())
|
||||
.find(|&i| names.value(i) == "Globex")
|
||||
.expect("Globex should be in results");
|
||||
|
||||
let headcount = batch
|
||||
.column_by_name("headcount")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.unwrap();
|
||||
assert_eq!(headcount.value(acme_row), 1);
|
||||
assert_eq!(headcount.value(globex_row), 1);
|
||||
|
||||
let avg_age = batch
|
||||
.column_by_name("avg_age")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Float64Array>()
|
||||
.unwrap();
|
||||
assert!((avg_age.value(acme_row) - 30.0).abs() < 0.01);
|
||||
assert!((avg_age.value(globex_row) - 25.0).abs() < 0.01);
|
||||
|
||||
let youngest = batch
|
||||
.column_by_name("youngest")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap();
|
||||
assert_eq!(youngest.value(acme_row), 30);
|
||||
assert_eq!(youngest.value(globex_row), 25);
|
||||
|
||||
let total_age = batch
|
||||
.column_by_name("total_age")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Float64Array>()
|
||||
.unwrap();
|
||||
assert!((total_age.value(acme_row) - 30.0).abs() < 0.01);
|
||||
assert!((total_age.value(globex_row) - 25.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
// ─── Aggregate + order + limit ─────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn top_connected_person() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
// Alice has 2 friends (most connected)
|
||||
let result = query_main(&mut db, TEST_QUERIES, "top_connected", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batch = result.concat_batches().unwrap();
|
||||
assert_eq!(batch.num_rows(), 1);
|
||||
|
||||
let names = batch
|
||||
.column_by_name("p.name")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let counts = batch
|
||||
.column_by_name("friends")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(names.value(0), "Alice");
|
||||
assert_eq!(counts.value(0), 2);
|
||||
}
|
||||
|
||||
// ─── Inline aggregate query (not from fixture) ────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn inline_count_with_no_friends() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
// Diana has no outgoing Knows edges, so she won't appear in friend_counts.
|
||||
// But Alice(2) and Bob(1) do. Verify the count matches.
|
||||
let queries = r#"
|
||||
query fc() {
|
||||
match {
|
||||
$p: Person
|
||||
$p knows $f
|
||||
}
|
||||
return { $p.name, count($f) as friends }
|
||||
order { friends desc }
|
||||
}
|
||||
"#;
|
||||
let result = query_main(&mut db, queries, "fc", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = result.concat_batches().unwrap();
|
||||
// Only Alice and Bob have outgoing Knows edges
|
||||
assert_eq!(batch.num_rows(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn inline_global_count_empty() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
|
||||
|
||||
// Load only nodes, no edges
|
||||
let data = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}}
|
||||
{"type": "Company", "data": {"name": "Acme"}}"#;
|
||||
load_jsonl(&mut db, data, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Global count — should return 1 row with count=1
|
||||
let queries = r#"
|
||||
query tc() {
|
||||
match { $p: Person }
|
||||
return { count($p) as total }
|
||||
}
|
||||
"#;
|
||||
let result = query_main(&mut db, queries, "tc", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = result.concat_batches().unwrap();
|
||||
assert_eq!(batch.num_rows(), 1);
|
||||
let total = batch
|
||||
.column_by_name("total")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.unwrap();
|
||||
assert_eq!(total.value(0), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn inline_min_max_string() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
let queries = r#"
|
||||
query name_range() {
|
||||
match { $p: Person }
|
||||
return {
|
||||
min($p.name) as first_name
|
||||
max($p.name) as last_name
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let result = query_main(&mut db, queries, "name_range", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = result.concat_batches().unwrap();
|
||||
assert_eq!(batch.num_rows(), 1);
|
||||
|
||||
let first = batch
|
||||
.column_by_name("first_name")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let last = batch
|
||||
.column_by_name("last_name")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
assert_eq!(first.value(0), "Alice");
|
||||
assert_eq!(last.value(0), "Diana");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn inline_multi_hop_aggregate() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
// Friends of friends count per person
|
||||
let queries = r#"
|
||||
query fof_counts($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
$p knows $mid
|
||||
$mid knows $fof
|
||||
}
|
||||
return { $p.name, count($fof) as fof_count }
|
||||
}
|
||||
"#;
|
||||
// Alice → Bob, Charlie. Bob → Diana. Charlie → nobody.
|
||||
// So Alice's fof = [Diana] → count = 1
|
||||
let result = query_main(
|
||||
&mut db,
|
||||
queries,
|
||||
"fof_counts",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = result.concat_batches().unwrap();
|
||||
assert_eq!(batch.num_rows(), 1);
|
||||
|
||||
let count = batch
|
||||
.column_by_name("fof_count")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.unwrap();
|
||||
assert_eq!(count.value(0), 1);
|
||||
}
|
||||
35
crates/omnigraph/tests/fixtures/test.gq
vendored
35
crates/omnigraph/tests/fixtures/test.gq
vendored
|
|
@ -76,3 +76,38 @@ query top_by_age() {
|
|||
order { $p.age desc }
|
||||
limit 2
|
||||
}
|
||||
|
||||
// Aggregation: global count (no group key)
|
||||
query total_people() {
|
||||
match {
|
||||
$p: Person
|
||||
}
|
||||
return { count($p) as total }
|
||||
}
|
||||
|
||||
// Aggregation: sum/avg/min/max of age per company
|
||||
query age_stats() {
|
||||
match {
|
||||
$p: Person
|
||||
$p worksAt $c
|
||||
}
|
||||
return {
|
||||
$c.name
|
||||
sum($p.age) as total_age
|
||||
avg($p.age) as avg_age
|
||||
min($p.age) as youngest
|
||||
max($p.age) as oldest
|
||||
count($p) as headcount
|
||||
}
|
||||
}
|
||||
|
||||
// Aggregation: most connected person
|
||||
query top_connected() {
|
||||
match {
|
||||
$p: Person
|
||||
$p knows $f
|
||||
}
|
||||
return { $p.name, count($f) as friends }
|
||||
order { friends desc }
|
||||
limit 1
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue