mirror of
https://github.com/ModernRelay/omnigraph.git
synced 2026-06-30 02:49:39 +02:00
Merge pull request #6 from ModernRelay/claude/omnigraph-aggregates-a53rG
Implement aggregate functions with GROUP BY support
This commit is contained in:
commit
c5a88cacb5
6 changed files with 868 additions and 152 deletions
|
|
@ -189,6 +189,29 @@ fn typecheck_read_query(catalog: &Catalog, query: &QueryDecl) -> Result<TypeCont
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// T9: If any return expression is an aggregate, non-aggregate expressions
|
||||||
|
// must be valid group-by keys (PropAccess or Variable).
|
||||||
|
let has_agg = query
|
||||||
|
.return_clause
|
||||||
|
.iter()
|
||||||
|
.any(|p| matches!(p.expr, Expr::Aggregate { .. }));
|
||||||
|
if has_agg {
|
||||||
|
for proj in &query.return_clause {
|
||||||
|
if !matches!(proj.expr, Expr::Aggregate { .. }) {
|
||||||
|
match &proj.expr {
|
||||||
|
Expr::PropAccess { .. } | Expr::Variable(_) => {}
|
||||||
|
_ => {
|
||||||
|
return Err(NanoError::Type(
|
||||||
|
"T9: non-aggregate expressions in an aggregate query must be \
|
||||||
|
property accesses or variables"
|
||||||
|
.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(ctx)
|
Ok(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1298,9 +1321,9 @@ fn resolve_expr_type(
|
||||||
Expr::Aggregate { func, arg } => {
|
Expr::Aggregate { func, arg } => {
|
||||||
let arg_type = resolve_expr_type(catalog, arg, ctx, params)?;
|
let arg_type = resolve_expr_type(catalog, arg, ctx, params)?;
|
||||||
|
|
||||||
// T8: sum/avg/min/max require numeric
|
// T8: sum/avg require numeric; min/max require numeric or string
|
||||||
match func {
|
match func {
|
||||||
AggFunc::Sum | AggFunc::Avg | AggFunc::Min | AggFunc::Max => {
|
AggFunc::Sum | AggFunc::Avg => {
|
||||||
if let ResolvedType::Scalar(s) = &arg_type
|
if let ResolvedType::Scalar(s) = &arg_type
|
||||||
&& (s.list || !s.scalar.is_numeric())
|
&& (s.list || !s.scalar.is_numeric())
|
||||||
{
|
{
|
||||||
|
|
@ -1311,6 +1334,17 @@ fn resolve_expr_type(
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
AggFunc::Min | AggFunc::Max => {
|
||||||
|
if let ResolvedType::Scalar(s) = &arg_type
|
||||||
|
&& (s.list || (!s.scalar.is_numeric() && s.scalar != ScalarType::String))
|
||||||
|
{
|
||||||
|
return Err(NanoError::Type(format!(
|
||||||
|
"T8: {} requires numeric or string type, got {}",
|
||||||
|
func,
|
||||||
|
s.display_name()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => {} // count works on any type
|
_ => {} // count works on any type
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1340,8 +1374,8 @@ fn infer_projection_field(
|
||||||
Expr::Aggregate { func, arg } => {
|
Expr::Aggregate { func, arg } => {
|
||||||
let (data_type, nullable) = match func {
|
let (data_type, nullable) = match func {
|
||||||
AggFunc::Count => (DataType::Int64, true),
|
AggFunc::Count => (DataType::Int64, true),
|
||||||
AggFunc::Avg => (DataType::Float64, true),
|
AggFunc::Avg | AggFunc::Sum => (DataType::Float64, true),
|
||||||
_ => {
|
AggFunc::Min | AggFunc::Max => {
|
||||||
let resolved = resolve_expr_type(catalog, arg, ctx, params)?;
|
let resolved = resolve_expr_type(catalog, arg, ctx, params)?;
|
||||||
let (data_type, _) = resolved_type_to_field_shape(catalog, &resolved)?;
|
let (data_type, _) = resolved_type_to_field_shape(catalog, &resolved)?;
|
||||||
(data_type, true)
|
(data_type, true)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ use omnigraph_compiler::ir::{
|
||||||
};
|
};
|
||||||
use omnigraph_compiler::lower_mutation_query;
|
use omnigraph_compiler::lower_mutation_query;
|
||||||
use omnigraph_compiler::lower_query;
|
use omnigraph_compiler::lower_query;
|
||||||
use omnigraph_compiler::query::ast::{CompOp, Literal, NOW_PARAM_NAME};
|
use omnigraph_compiler::query::ast::{AggFunc, CompOp, Literal, NOW_PARAM_NAME};
|
||||||
use omnigraph_compiler::query::typecheck::{CheckedQuery, typecheck_query, typecheck_query_decl};
|
use omnigraph_compiler::query::typecheck::{CheckedQuery, typecheck_query, typecheck_query_decl};
|
||||||
use omnigraph_compiler::result::{MutationResult, QueryResult};
|
use omnigraph_compiler::result::{MutationResult, QueryResult};
|
||||||
use omnigraph_compiler::types::Direction;
|
use omnigraph_compiler::types::Direction;
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,14 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub(super) fn apply_filter(
|
pub(super) fn apply_filter(
|
||||||
bindings: &mut HashMap<String, RecordBatch>,
|
batch: &mut RecordBatch,
|
||||||
filter: &IRFilter,
|
filter: &IRFilter,
|
||||||
params: &ParamMap,
|
params: &ParamMap,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// Find which binding this filter applies to
|
|
||||||
let var_name = match &filter.left {
|
|
||||||
IRExpr::PropAccess { variable, .. } => variable.clone(),
|
|
||||||
_ => return Ok(()), // Can't determine variable
|
|
||||||
};
|
|
||||||
|
|
||||||
let batch = bindings.get(&var_name).ok_or_else(|| {
|
|
||||||
OmniError::manifest(format!("filter references unbound variable '{}'", var_name))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let mask = evaluate_filter(batch, filter, params)?;
|
let mask = evaluate_filter(batch, filter, params)?;
|
||||||
let filtered = arrow_select::filter::filter_record_batch(batch, &mask)
|
let filtered = arrow_select::filter::filter_record_batch(batch, &mask)
|
||||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||||
|
*batch = filtered;
|
||||||
bindings.insert(var_name, filtered);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -59,12 +48,13 @@ fn evaluate_filter(
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate an IR expression against a batch, producing an array.
|
/// Evaluate an IR expression against a wide batch, producing an array.
|
||||||
fn evaluate_expr(batch: &RecordBatch, expr: &IRExpr, params: &ParamMap) -> Result<ArrayRef> {
|
fn evaluate_expr(batch: &RecordBatch, expr: &IRExpr, params: &ParamMap) -> Result<ArrayRef> {
|
||||||
match expr {
|
match expr {
|
||||||
IRExpr::PropAccess { property, .. } => {
|
IRExpr::PropAccess { variable, property } => {
|
||||||
batch.column_by_name(property).cloned().ok_or_else(|| {
|
let col_name = format!("{}.{}", variable, property);
|
||||||
OmniError::manifest(format!("column '{}' not found in batch", property))
|
batch.column_by_name(&col_name).cloned().ok_or_else(|| {
|
||||||
|
OmniError::manifest(format!("column '{}' not found in wide batch", col_name))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
IRExpr::Literal(lit) => literal_to_array(lit, batch.num_rows()),
|
IRExpr::Literal(lit) => literal_to_array(lit, batch.num_rows()),
|
||||||
|
|
@ -307,7 +297,7 @@ fn literal_scalar_type(lit: &Literal) -> Result<ScalarType> {
|
||||||
|
|
||||||
/// Project return expressions into a result batch.
|
/// Project return expressions into a result batch.
|
||||||
pub(super) fn project_return(
|
pub(super) fn project_return(
|
||||||
bindings: &HashMap<String, RecordBatch>,
|
wide_batch: &RecordBatch,
|
||||||
projections: &[IRProjection],
|
projections: &[IRProjection],
|
||||||
params: &ParamMap,
|
params: &ParamMap,
|
||||||
) -> Result<RecordBatch> {
|
) -> Result<RecordBatch> {
|
||||||
|
|
@ -317,11 +307,19 @@ pub(super) fn project_return(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Route to aggregate path if any projection contains an aggregate
|
||||||
|
let has_aggregates = projections
|
||||||
|
.iter()
|
||||||
|
.any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
|
||||||
|
if has_aggregates {
|
||||||
|
return aggregate_return(wide_batch, projections, params);
|
||||||
|
}
|
||||||
|
|
||||||
let mut fields = Vec::with_capacity(projections.len());
|
let mut fields = Vec::with_capacity(projections.len());
|
||||||
let mut columns: Vec<ArrayRef> = Vec::with_capacity(projections.len());
|
let mut columns: Vec<ArrayRef> = Vec::with_capacity(projections.len());
|
||||||
|
|
||||||
for proj in projections {
|
for proj in projections {
|
||||||
let (name, col) = evaluate_projection(bindings, &proj.expr, params)?;
|
let (name, col) = evaluate_projection(wide_batch, &proj.expr, params)?;
|
||||||
let field_name = proj.alias.as_deref().unwrap_or(&name);
|
let field_name = proj.alias.as_deref().unwrap_or(&name);
|
||||||
fields.push(Field::new(
|
fields.push(Field::new(
|
||||||
field_name,
|
field_name,
|
||||||
|
|
@ -335,42 +333,41 @@ pub(super) fn project_return(
|
||||||
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
|
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate a single projection expression.
|
/// Evaluate a single projection expression against a wide batch.
|
||||||
fn evaluate_projection(
|
fn evaluate_projection(
|
||||||
bindings: &HashMap<String, RecordBatch>,
|
wide_batch: &RecordBatch,
|
||||||
expr: &IRExpr,
|
expr: &IRExpr,
|
||||||
params: &ParamMap,
|
params: &ParamMap,
|
||||||
) -> Result<(String, ArrayRef)> {
|
) -> Result<(String, ArrayRef)> {
|
||||||
match expr {
|
match expr {
|
||||||
IRExpr::PropAccess { variable, property } => {
|
IRExpr::PropAccess { variable, property } => {
|
||||||
let batch = bindings.get(variable).ok_or_else(|| {
|
let col_name = format!("{}.{}", variable, property);
|
||||||
|
let col = wide_batch.column_by_name(&col_name).ok_or_else(|| {
|
||||||
OmniError::manifest(format!(
|
OmniError::manifest(format!(
|
||||||
"projection references unbound variable '{}'",
|
"column '{}' not found in wide batch",
|
||||||
variable
|
col_name
|
||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
let col = batch.column_by_name(property).ok_or_else(|| {
|
Ok((col_name, col.clone()))
|
||||||
OmniError::manifest(format!(
|
|
||||||
"column '{}' not found in binding '{}'",
|
|
||||||
property, variable
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
Ok((format!("{}.{}", variable, property), col.clone()))
|
|
||||||
}
|
}
|
||||||
IRExpr::Literal(lit) => {
|
IRExpr::Literal(lit) => {
|
||||||
// Get row count from first binding
|
let arr = literal_to_array(lit, wide_batch.num_rows())?;
|
||||||
let num_rows = bindings.values().next().map(|b| b.num_rows()).unwrap_or(0);
|
|
||||||
let arr = literal_to_array(lit, num_rows)?;
|
|
||||||
Ok(("literal".to_string(), arr))
|
Ok(("literal".to_string(), arr))
|
||||||
}
|
}
|
||||||
IRExpr::Param(name) => {
|
IRExpr::Param(name) => {
|
||||||
let lit = params
|
let lit = params
|
||||||
.get(name)
|
.get(name)
|
||||||
.ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?;
|
.ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?;
|
||||||
let num_rows = bindings.values().next().map(|b| b.num_rows()).unwrap_or(0);
|
let arr = literal_to_array(lit, wide_batch.num_rows())?;
|
||||||
let arr = literal_to_array(lit, num_rows)?;
|
|
||||||
Ok((name.clone(), arr))
|
Ok((name.clone(), arr))
|
||||||
}
|
}
|
||||||
|
IRExpr::Variable(name) => {
|
||||||
|
let col_name = format!("{}.id", name);
|
||||||
|
let col = wide_batch.column_by_name(&col_name).ok_or_else(|| {
|
||||||
|
OmniError::manifest(format!("column '{}' not found in wide batch", col_name))
|
||||||
|
})?;
|
||||||
|
Ok((name.clone(), col.clone()))
|
||||||
|
}
|
||||||
_ => Err(OmniError::manifest(format!(
|
_ => Err(OmniError::manifest(format!(
|
||||||
"unsupported projection expression: {:?}",
|
"unsupported projection expression: {:?}",
|
||||||
expr
|
expr
|
||||||
|
|
@ -378,11 +375,12 @@ fn evaluate_projection(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply ordering to a batch.
|
/// Apply ordering to a batch. `source` is used to resolve PropAccess columns
|
||||||
|
/// (typically the wide batch, or the result batch itself for aggregates).
|
||||||
pub(super) fn apply_ordering(
|
pub(super) fn apply_ordering(
|
||||||
batch: RecordBatch,
|
batch: RecordBatch,
|
||||||
orderings: &[IROrdering],
|
orderings: &[IROrdering],
|
||||||
bindings: &HashMap<String, RecordBatch>,
|
source: &RecordBatch,
|
||||||
_params: &ParamMap,
|
_params: &ParamMap,
|
||||||
) -> Result<RecordBatch> {
|
) -> Result<RecordBatch> {
|
||||||
use arrow_ord::sort::{SortColumn, lexsort_to_indices};
|
use arrow_ord::sort::{SortColumn, lexsort_to_indices};
|
||||||
|
|
@ -392,16 +390,11 @@ pub(super) fn apply_ordering(
|
||||||
for ordering in orderings {
|
for ordering in orderings {
|
||||||
let col = match &ordering.expr {
|
let col = match &ordering.expr {
|
||||||
IRExpr::PropAccess { variable, property } => {
|
IRExpr::PropAccess { variable, property } => {
|
||||||
let binding = bindings.get(variable).ok_or_else(|| {
|
let col_name = format!("{}.{}", variable, property);
|
||||||
OmniError::manifest(format!(
|
source
|
||||||
"ordering references unbound variable '{}'",
|
.column_by_name(&col_name)
|
||||||
variable
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
binding
|
|
||||||
.column_by_name(property)
|
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
OmniError::manifest(format!("column '{}' not found for ordering", property))
|
OmniError::manifest(format!("column '{}' not found for ordering", col_name))
|
||||||
})?
|
})?
|
||||||
.clone()
|
.clone()
|
||||||
}
|
}
|
||||||
|
|
@ -442,3 +435,294 @@ pub(super) fn apply_ordering(
|
||||||
|
|
||||||
RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
|
RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── Aggregate execution ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Project return expressions that contain aggregates.
|
||||||
|
fn aggregate_return(
|
||||||
|
wide: &RecordBatch,
|
||||||
|
projections: &[IRProjection],
|
||||||
|
params: &ParamMap,
|
||||||
|
) -> Result<RecordBatch> {
|
||||||
|
let num_rows = wide.num_rows();
|
||||||
|
|
||||||
|
struct GroupKey {
|
||||||
|
proj_idx: usize,
|
||||||
|
name: String,
|
||||||
|
column: ArrayRef,
|
||||||
|
}
|
||||||
|
struct AggProj {
|
||||||
|
proj_idx: usize,
|
||||||
|
name: String,
|
||||||
|
func: AggFunc,
|
||||||
|
arg_column: ArrayRef,
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut group_keys: Vec<GroupKey> = Vec::new();
|
||||||
|
let mut agg_projs: Vec<AggProj> = Vec::new();
|
||||||
|
|
||||||
|
for (i, proj) in projections.iter().enumerate() {
|
||||||
|
match &proj.expr {
|
||||||
|
IRExpr::Aggregate { func, arg } => {
|
||||||
|
let (name, col) = evaluate_projection(wide, arg, params)?;
|
||||||
|
let alias = proj.alias.as_deref().unwrap_or(&name);
|
||||||
|
agg_projs.push(AggProj {
|
||||||
|
proj_idx: i,
|
||||||
|
name: alias.to_string(),
|
||||||
|
func: *func,
|
||||||
|
arg_column: col,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let (name, col) = evaluate_projection(wide, &proj.expr, params)?;
|
||||||
|
let alias = proj.alias.as_deref().unwrap_or(&name);
|
||||||
|
group_keys.push(GroupKey {
|
||||||
|
proj_idx: i,
|
||||||
|
name: alias.to_string(),
|
||||||
|
column: col,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle empty input: return a single row with count=0, others=null
|
||||||
|
if num_rows == 0 && group_keys.is_empty() {
|
||||||
|
return build_empty_aggregate_result(projections);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build group assignments
|
||||||
|
let mut group_map: HashMap<String, usize> = HashMap::new();
|
||||||
|
let mut group_indices: Vec<Vec<usize>> = Vec::new();
|
||||||
|
let group_cols: Vec<&ArrayRef> = group_keys.iter().map(|gk| &gk.column).collect();
|
||||||
|
|
||||||
|
if group_keys.is_empty() {
|
||||||
|
group_indices.push((0..num_rows).collect());
|
||||||
|
} else {
|
||||||
|
for row in 0..num_rows {
|
||||||
|
let key = build_group_key(&group_cols, row);
|
||||||
|
let group_idx = match group_map.get(&key) {
|
||||||
|
Some(&idx) => idx,
|
||||||
|
None => {
|
||||||
|
let idx = group_indices.len();
|
||||||
|
group_map.insert(key, idx);
|
||||||
|
group_indices.push(Vec::new());
|
||||||
|
idx
|
||||||
|
}
|
||||||
|
};
|
||||||
|
group_indices[group_idx].push(row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let num_groups = group_indices.len();
|
||||||
|
let mut result_columns: Vec<(usize, String, ArrayRef)> =
|
||||||
|
Vec::with_capacity(projections.len());
|
||||||
|
|
||||||
|
for gk in &group_keys {
|
||||||
|
let first_row_indices: Vec<u32> =
|
||||||
|
group_indices.iter().map(|rows| rows[0] as u32).collect();
|
||||||
|
let take_idx = UInt32Array::from(first_row_indices);
|
||||||
|
let col = arrow_select::take::take(gk.column.as_ref(), &take_idx, None)
|
||||||
|
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||||
|
result_columns.push((gk.proj_idx, gk.name.clone(), col));
|
||||||
|
}
|
||||||
|
|
||||||
|
for ap in &agg_projs {
|
||||||
|
let col = compute_aggregate(ap.func, &ap.arg_column, &group_indices, num_groups)?;
|
||||||
|
result_columns.push((ap.proj_idx, ap.name.clone(), col));
|
||||||
|
}
|
||||||
|
|
||||||
|
result_columns.sort_by_key(|(idx, _, _)| *idx);
|
||||||
|
|
||||||
|
let fields: Vec<Field> = result_columns
|
||||||
|
.iter()
|
||||||
|
.map(|(_, name, col)| Field::new(name, col.data_type().clone(), true))
|
||||||
|
.collect();
|
||||||
|
let columns: Vec<ArrayRef> = result_columns.into_iter().map(|(_, _, col)| col).collect();
|
||||||
|
|
||||||
|
let schema = Arc::new(Schema::new(fields));
|
||||||
|
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a string key for grouping using length-prefixed encoding.
|
||||||
|
fn build_group_key(group_columns: &[&ArrayRef], row: usize) -> String {
|
||||||
|
let mut key = String::new();
|
||||||
|
for col in group_columns {
|
||||||
|
if col.is_null(row) {
|
||||||
|
key.push('N');
|
||||||
|
} else {
|
||||||
|
let val = arrow_cast::display::array_value_to_string(col, row).unwrap_or_default();
|
||||||
|
key.push('L');
|
||||||
|
key.push_str(&val.len().to_string());
|
||||||
|
key.push(':');
|
||||||
|
key.push_str(&val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
key
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_aggregate(
|
||||||
|
func: AggFunc,
|
||||||
|
arg: &ArrayRef,
|
||||||
|
group_indices: &[Vec<usize>],
|
||||||
|
num_groups: usize,
|
||||||
|
) -> Result<ArrayRef> {
|
||||||
|
match func {
|
||||||
|
AggFunc::Count => {
|
||||||
|
let mut builder = Int64Builder::with_capacity(num_groups);
|
||||||
|
for group in group_indices {
|
||||||
|
let count = group.iter().filter(|&&i| !arg.is_null(i)).count();
|
||||||
|
builder.append_value(count as i64);
|
||||||
|
}
|
||||||
|
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||||
|
}
|
||||||
|
AggFunc::Sum => compute_sum(arg, group_indices, num_groups),
|
||||||
|
AggFunc::Avg => compute_avg(arg, group_indices, num_groups),
|
||||||
|
AggFunc::Min => compute_min_max(arg, group_indices, num_groups, true),
|
||||||
|
AggFunc::Max => compute_min_max(arg, group_indices, num_groups, false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_sum(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize) -> Result<ArrayRef> {
|
||||||
|
macro_rules! sum_numeric {
|
||||||
|
($arr_type:ty, $arg:expr, $dt:expr) => {{
|
||||||
|
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
|
||||||
|
OmniError::manifest(format!("sum: expected {:?}, got {:?}", $dt, $arg.data_type()))
|
||||||
|
})?;
|
||||||
|
let mut builder = Float64Builder::with_capacity(num_groups);
|
||||||
|
for group in group_indices {
|
||||||
|
let mut sum = None;
|
||||||
|
for &i in group {
|
||||||
|
if !arr.is_null(i) {
|
||||||
|
*sum.get_or_insert(0.0f64) += arr.value(i) as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
match sum {
|
||||||
|
Some(v) => builder.append_value(v),
|
||||||
|
None => builder.append_null(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
match arg.data_type() {
|
||||||
|
dt @ DataType::Int32 => sum_numeric!(Int32Array, arg, dt),
|
||||||
|
dt @ DataType::Int64 => sum_numeric!(Int64Array, arg, dt),
|
||||||
|
dt @ DataType::UInt32 => sum_numeric!(UInt32Array, arg, dt),
|
||||||
|
dt @ DataType::UInt64 => sum_numeric!(UInt64Array, arg, dt),
|
||||||
|
dt @ DataType::Float32 => sum_numeric!(Float32Array, arg, dt),
|
||||||
|
dt @ DataType::Float64 => sum_numeric!(Float64Array, arg, dt),
|
||||||
|
dt => Err(OmniError::manifest(format!("sum: unsupported type {:?}", dt))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_avg(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize) -> Result<ArrayRef> {
|
||||||
|
macro_rules! avg_typed {
|
||||||
|
($arr_type:ty, $arg:expr) => {{
|
||||||
|
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
|
||||||
|
OmniError::manifest(format!("avg: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type()))
|
||||||
|
})?;
|
||||||
|
let mut builder = Float64Builder::with_capacity(num_groups);
|
||||||
|
for group in group_indices {
|
||||||
|
let mut sum = 0.0f64;
|
||||||
|
let mut count = 0usize;
|
||||||
|
for &i in group {
|
||||||
|
if !arr.is_null(i) { sum += arr.value(i) as f64; count += 1; }
|
||||||
|
}
|
||||||
|
if count > 0 { builder.append_value(sum / count as f64); } else { builder.append_null(); }
|
||||||
|
}
|
||||||
|
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
match arg.data_type() {
|
||||||
|
DataType::Int32 => avg_typed!(Int32Array, arg),
|
||||||
|
DataType::Int64 => avg_typed!(Int64Array, arg),
|
||||||
|
DataType::UInt32 => avg_typed!(UInt32Array, arg),
|
||||||
|
DataType::UInt64 => avg_typed!(UInt64Array, arg),
|
||||||
|
DataType::Float32 => avg_typed!(Float32Array, arg),
|
||||||
|
DataType::Float64 => avg_typed!(Float64Array, arg),
|
||||||
|
dt => Err(OmniError::manifest(format!("avg: unsupported type {:?}", dt))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_min_max(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize, is_min: bool) -> Result<ArrayRef> {
|
||||||
|
macro_rules! minmax_typed {
|
||||||
|
($arr_type:ty, $builder_type:ty, $arg:expr, $is_min:expr) => {{
|
||||||
|
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
|
||||||
|
OmniError::manifest(format!("min/max: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type()))
|
||||||
|
})?;
|
||||||
|
let mut builder = <$builder_type>::with_capacity(num_groups);
|
||||||
|
for group in group_indices {
|
||||||
|
let mut result = None;
|
||||||
|
for &i in group {
|
||||||
|
if !arr.is_null(i) {
|
||||||
|
let v = arr.value(i);
|
||||||
|
result = Some(match result {
|
||||||
|
None => v,
|
||||||
|
Some(cur) => if $is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
match result { Some(v) => builder.append_value(v), None => builder.append_null() }
|
||||||
|
}
|
||||||
|
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
match arg.data_type() {
|
||||||
|
DataType::Int32 => minmax_typed!(Int32Array, Int32Builder, arg, is_min),
|
||||||
|
DataType::Int64 => minmax_typed!(Int64Array, Int64Builder, arg, is_min),
|
||||||
|
DataType::UInt32 => minmax_typed!(UInt32Array, UInt32Builder, arg, is_min),
|
||||||
|
DataType::UInt64 => minmax_typed!(UInt64Array, UInt64Builder, arg, is_min),
|
||||||
|
DataType::Float32 => minmax_typed!(Float32Array, Float32Builder, arg, is_min),
|
||||||
|
DataType::Float64 => minmax_typed!(Float64Array, Float64Builder, arg, is_min),
|
||||||
|
DataType::Utf8 => {
|
||||||
|
let arr = arg.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
|
||||||
|
OmniError::manifest(format!("min/max: expected Utf8, got {:?}", arg.data_type()))
|
||||||
|
})?;
|
||||||
|
let mut builder = StringBuilder::with_capacity(num_groups, num_groups * 16);
|
||||||
|
for group in group_indices {
|
||||||
|
let mut result: Option<&str> = None;
|
||||||
|
for &i in group {
|
||||||
|
if !arr.is_null(i) {
|
||||||
|
let v = arr.value(i);
|
||||||
|
result = Some(match result {
|
||||||
|
None => v,
|
||||||
|
Some(cur) => if is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
match result { Some(v) => builder.append_value(v), None => builder.append_null() }
|
||||||
|
}
|
||||||
|
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||||
|
}
|
||||||
|
dt => Err(OmniError::manifest(format!("min/max: unsupported type {:?}", dt))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a single-row result for an aggregate query with zero input rows and no group keys.
|
||||||
|
fn build_empty_aggregate_result(projections: &[IRProjection]) -> Result<RecordBatch> {
|
||||||
|
let mut fields = Vec::with_capacity(projections.len());
|
||||||
|
let mut columns: Vec<ArrayRef> = Vec::with_capacity(projections.len());
|
||||||
|
|
||||||
|
for proj in projections {
|
||||||
|
let name = proj.alias.as_deref().unwrap_or("?");
|
||||||
|
match &proj.expr {
|
||||||
|
IRExpr::Aggregate { func, .. } => match func {
|
||||||
|
AggFunc::Count => {
|
||||||
|
fields.push(Field::new(name, DataType::Int64, true));
|
||||||
|
columns.push(Arc::new(Int64Array::from(vec![0i64])) as ArrayRef);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
fields.push(Field::new(name, DataType::Float64, true));
|
||||||
|
columns.push(Arc::new(Float64Array::from(vec![None as Option<f64>])) as ArrayRef);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
fields.push(Field::new(name, DataType::Utf8, true));
|
||||||
|
columns.push(Arc::new(StringArray::from(vec![None as Option<&str>])) as ArrayRef);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let schema = Arc::new(Schema::new(fields));
|
||||||
|
RecordBatch::try_new(schema, columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -358,25 +358,21 @@ pub async fn execute_query(
|
||||||
return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
|
return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut bindings: HashMap<String, RecordBatch> = HashMap::new();
|
let mut wide: Option<RecordBatch> = None;
|
||||||
|
execute_pipeline(&ir.pipeline, params, snapshot, graph_index, catalog, &mut wide, &search_mode).await?;
|
||||||
execute_pipeline(
|
let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
|
||||||
&ir.pipeline,
|
|
||||||
params,
|
|
||||||
snapshot,
|
|
||||||
graph_index,
|
|
||||||
catalog,
|
|
||||||
&mut bindings,
|
|
||||||
&search_mode,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Project return expressions
|
// Project return expressions
|
||||||
let mut result_batch = project_return(&bindings, &ir.return_exprs, params)?;
|
let has_aggregates = ir.return_exprs.iter().any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
|
||||||
|
let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
|
||||||
|
|
||||||
// Apply ordering (skip if search mode already ordered the results)
|
// Apply ordering (skip if search mode already ordered the results)
|
||||||
if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
|
if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
|
||||||
result_batch = apply_ordering(result_batch, &ir.order_by, &bindings, params)?;
|
result_batch = if has_aggregates {
|
||||||
|
apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
|
||||||
|
} else {
|
||||||
|
apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply limit
|
// Apply limit
|
||||||
|
|
@ -403,27 +399,27 @@ async fn execute_rrf_query(
|
||||||
rrf: &RrfMode,
|
rrf: &RrfMode,
|
||||||
) -> Result<QueryResult> {
|
) -> Result<QueryResult> {
|
||||||
// Execute primary search
|
// Execute primary search
|
||||||
let mut primary_bindings: HashMap<String, RecordBatch> = HashMap::new();
|
let mut primary_wide: Option<RecordBatch> = None;
|
||||||
execute_pipeline(
|
execute_pipeline(
|
||||||
&ir.pipeline,
|
&ir.pipeline,
|
||||||
params,
|
params,
|
||||||
snapshot,
|
snapshot,
|
||||||
graph_index,
|
graph_index,
|
||||||
catalog,
|
catalog,
|
||||||
&mut primary_bindings,
|
&mut primary_wide,
|
||||||
&rrf.primary,
|
&rrf.primary,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Execute secondary search
|
// Execute secondary search
|
||||||
let mut secondary_bindings: HashMap<String, RecordBatch> = HashMap::new();
|
let mut secondary_wide: Option<RecordBatch> = None;
|
||||||
execute_pipeline(
|
execute_pipeline(
|
||||||
&ir.pipeline,
|
&ir.pipeline,
|
||||||
params,
|
params,
|
||||||
snapshot,
|
snapshot,
|
||||||
graph_index,
|
graph_index,
|
||||||
catalog,
|
catalog,
|
||||||
&mut secondary_bindings,
|
&mut secondary_wide,
|
||||||
&rrf.secondary,
|
&rrf.secondary,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
@ -438,13 +434,13 @@ async fn execute_rrf_query(
|
||||||
.or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
|
.or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
|
||||||
.ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
|
.ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
|
||||||
|
|
||||||
let primary_batch = primary_bindings.get(primary_var).ok_or_else(|| {
|
let primary_batch = primary_wide.as_ref().ok_or_else(|| {
|
||||||
OmniError::manifest(format!(
|
OmniError::manifest(format!(
|
||||||
"rrf primary variable '{}' not in bindings",
|
"rrf primary variable '{}' not in bindings",
|
||||||
primary_var
|
primary_var
|
||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
let secondary_batch = secondary_bindings.get(primary_var).ok_or_else(|| {
|
let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
|
||||||
OmniError::manifest(format!(
|
OmniError::manifest(format!(
|
||||||
"rrf secondary variable '{}' not in bindings",
|
"rrf secondary variable '{}' not in bindings",
|
||||||
primary_var
|
primary_var
|
||||||
|
|
@ -452,8 +448,9 @@ async fn execute_rrf_query(
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Build ID → rank maps
|
// Build ID → rank maps
|
||||||
let primary_ids = extract_id_column(primary_batch)?;
|
let id_col_name = format!("{}.id", primary_var);
|
||||||
let secondary_ids = extract_id_column(secondary_batch)?;
|
let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
|
||||||
|
let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
|
||||||
|
|
||||||
let mut primary_rank: HashMap<String, usize> = HashMap::new();
|
let mut primary_rank: HashMap<String, usize> = HashMap::new();
|
||||||
for (i, id) in primary_ids.iter().enumerate() {
|
for (i, id) in primary_ids.iter().enumerate() {
|
||||||
|
|
@ -510,24 +507,21 @@ async fn execute_rrf_query(
|
||||||
// Reconstruct a combined batch for the binding in winning order
|
// Reconstruct a combined batch for the binding in winning order
|
||||||
let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
|
let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
|
||||||
|
|
||||||
// Replace the binding and project
|
// Project directly from fused batch
|
||||||
let mut fused_bindings = primary_bindings;
|
let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
|
||||||
fused_bindings.insert(primary_var.to_string(), fused_batch);
|
|
||||||
|
|
||||||
let result_batch = project_return(&fused_bindings, &ir.return_exprs, params)?;
|
|
||||||
|
|
||||||
// Already ordered by RRF score + already limited
|
// Already ordered by RRF score + already limited
|
||||||
Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
|
Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_id_column(batch: &RecordBatch) -> Result<Vec<String>> {
|
fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
|
||||||
let col = batch
|
let col = batch
|
||||||
.column_by_name("id")
|
.column_by_name(col_name)
|
||||||
.ok_or_else(|| OmniError::manifest("batch missing 'id' column for RRF".to_string()))?;
|
.ok_or_else(|| OmniError::manifest(format!("batch missing '{}' column for RRF", col_name)))?;
|
||||||
let ids = col
|
let ids = col
|
||||||
.as_any()
|
.as_any()
|
||||||
.downcast_ref::<StringArray>()
|
.downcast_ref::<StringArray>()
|
||||||
.ok_or_else(|| OmniError::manifest("'id' column is not Utf8".to_string()))?;
|
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
|
||||||
Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
|
Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -585,7 +579,7 @@ fn execute_pipeline<'a>(
|
||||||
snapshot: &'a Snapshot,
|
snapshot: &'a Snapshot,
|
||||||
graph_index: Option<&'a GraphIndex>,
|
graph_index: Option<&'a GraphIndex>,
|
||||||
catalog: &'a Catalog,
|
catalog: &'a Catalog,
|
||||||
bindings: &'a mut HashMap<String, RecordBatch>,
|
wide: &'a mut Option<RecordBatch>,
|
||||||
search_mode: &'a SearchMode,
|
search_mode: &'a SearchMode,
|
||||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
|
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
|
|
@ -632,10 +626,16 @@ fn execute_pipeline<'a>(
|
||||||
search_mode,
|
search_mode,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
bindings.insert(variable.clone(), batch);
|
let prefixed = prefix_batch(&batch, variable)?;
|
||||||
|
*wide = Some(match wide.take() {
|
||||||
|
None => prefixed,
|
||||||
|
Some(existing) => cross_join_batches(&existing, &prefixed)?,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
IROp::Filter(filter) => {
|
IROp::Filter(filter) => {
|
||||||
apply_filter(bindings, filter, params)?;
|
if let Some(batch) = wide.as_mut() {
|
||||||
|
apply_filter(batch, filter, params)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
IROp::Expand {
|
IROp::Expand {
|
||||||
src_var,
|
src_var,
|
||||||
|
|
@ -649,17 +649,20 @@ fn execute_pipeline<'a>(
|
||||||
let gi = graph_index.ok_or_else(|| {
|
let gi = graph_index.ok_or_else(|| {
|
||||||
OmniError::manifest("graph index required for traversal".to_string())
|
OmniError::manifest("graph index required for traversal".to_string())
|
||||||
})?;
|
})?;
|
||||||
let batch = execute_expand(
|
if let Some(batch) = wide.as_mut() {
|
||||||
bindings, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
|
execute_expand(
|
||||||
dst_type, *min_hops, *max_hops,
|
batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
|
||||||
)
|
dst_type, *min_hops, *max_hops,
|
||||||
.await?;
|
)
|
||||||
bindings.insert(dst_var.clone(), batch);
|
.await?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
IROp::AntiJoin { outer_var, inner } => {
|
IROp::AntiJoin { outer_var, inner } => {
|
||||||
let gi = graph_index;
|
let gi = graph_index;
|
||||||
execute_anti_join(bindings, inner, params, snapshot, gi, catalog, outer_var)
|
if let Some(batch) = wide.as_mut() {
|
||||||
.await?;
|
execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -669,28 +672,26 @@ fn execute_pipeline<'a>(
|
||||||
|
|
||||||
/// Execute a graph traversal (Expand).
|
/// Execute a graph traversal (Expand).
|
||||||
async fn execute_expand(
|
async fn execute_expand(
|
||||||
bindings: &HashMap<String, RecordBatch>,
|
wide: &mut RecordBatch,
|
||||||
graph_index: &GraphIndex,
|
graph_index: &GraphIndex,
|
||||||
snapshot: &Snapshot,
|
snapshot: &Snapshot,
|
||||||
catalog: &Catalog,
|
catalog: &Catalog,
|
||||||
src_var: &str,
|
src_var: &str,
|
||||||
_dst_var: &str,
|
dst_var: &str,
|
||||||
edge_type: &str,
|
edge_type: &str,
|
||||||
direction: Direction,
|
direction: Direction,
|
||||||
dst_type: &str,
|
dst_type: &str,
|
||||||
min_hops: u32,
|
min_hops: u32,
|
||||||
max_hops: Option<u32>,
|
max_hops: Option<u32>,
|
||||||
) -> Result<RecordBatch> {
|
) -> Result<()> {
|
||||||
let src_batch = bindings.get(src_var).ok_or_else(|| {
|
let src_id_col_name = format!("{}.id", src_var);
|
||||||
OmniError::manifest(format!("expand references unbound variable '{}'", src_var))
|
let src_ids = wide
|
||||||
})?;
|
.column_by_name(&src_id_col_name)
|
||||||
|
.ok_or_else(|| OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name)))?
|
||||||
let src_ids = src_batch
|
|
||||||
.column_by_name("id")
|
|
||||||
.ok_or_else(|| OmniError::manifest("source batch missing 'id' column".to_string()))?
|
|
||||||
.as_any()
|
.as_any()
|
||||||
.downcast_ref::<StringArray>()
|
.downcast_ref::<StringArray>()
|
||||||
.ok_or_else(|| OmniError::manifest("source 'id' column is not Utf8".to_string()))?;
|
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
|
||||||
|
.clone();
|
||||||
|
|
||||||
// Determine which type index to use for source and destination
|
// Determine which type index to use for source and destination
|
||||||
let edge_def = catalog
|
let edge_def = catalog
|
||||||
|
|
@ -720,8 +721,9 @@ async fn execute_expand(
|
||||||
|
|
||||||
let same_type = src_type_name == dst_type_name;
|
let same_type = src_type_name == dst_type_name;
|
||||||
|
|
||||||
// BFS to collect reachable destination dense IDs
|
// BFS to collect (src_row_idx, dst_id) pairs with per-source dedup
|
||||||
let mut result_dst_ids: Vec<String> = Vec::new();
|
let mut src_indices: Vec<u32> = Vec::new();
|
||||||
|
let mut dst_id_list: Vec<String> = Vec::new();
|
||||||
for i in 0..src_ids.len() {
|
for i in 0..src_ids.len() {
|
||||||
let src_id = src_ids.value(i);
|
let src_id = src_ids.value(i);
|
||||||
let Some(src_dense) = src_type_idx.to_dense(src_id) else {
|
let Some(src_dense) = src_type_idx.to_dense(src_id) else {
|
||||||
|
|
@ -749,7 +751,8 @@ async fn execute_expand(
|
||||||
if let Some(dst_id) = dst_type_idx.to_id(neighbor) {
|
if let Some(dst_id) = dst_type_idx.to_id(neighbor) {
|
||||||
let dst_id = dst_id.to_string();
|
let dst_id = dst_id.to_string();
|
||||||
if seen_dst_ids.insert(dst_id.clone()) {
|
if seen_dst_ids.insert(dst_id.clone()) {
|
||||||
result_dst_ids.push(dst_id);
|
src_indices.push(i as u32);
|
||||||
|
dst_id_list.push(dst_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -764,7 +767,36 @@ async fn execute_expand(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hydrate destination nodes from the snapshot
|
// Hydrate destination nodes from the snapshot
|
||||||
hydrate_nodes(snapshot, catalog, dst_type, &result_dst_ids).await
|
let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list).await?;
|
||||||
|
|
||||||
|
// Build a mapping from dst_id to row index in dst_batch
|
||||||
|
let dst_batch_id_col = dst_batch
|
||||||
|
.column_by_name("id")
|
||||||
|
.ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<StringArray>()
|
||||||
|
.ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
|
||||||
|
let dst_id_to_row: HashMap<&str, usize> = (0..dst_batch_id_col.len())
|
||||||
|
.map(|i| (dst_batch_id_col.value(i), i))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Build aligned src/dst index arrays (only for IDs that exist in hydrated batch)
|
||||||
|
let mut final_src_indices: Vec<u32> = Vec::new();
|
||||||
|
let mut dst_indices: Vec<u32> = Vec::new();
|
||||||
|
for (src_idx, dst_id) in src_indices.iter().zip(dst_id_list.iter()) {
|
||||||
|
if let Some(&dst_row) = dst_id_to_row.get(dst_id.as_str()) {
|
||||||
|
final_src_indices.push(*src_idx);
|
||||||
|
dst_indices.push(dst_row as u32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let src_take = UInt32Array::from(final_src_indices);
|
||||||
|
let dst_take = UInt32Array::from(dst_indices);
|
||||||
|
let expanded_wide = take_batch(wide, &src_take)?;
|
||||||
|
let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
|
||||||
|
let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
|
||||||
|
*wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load full node rows for a set of IDs from a snapshot.
|
/// Load full node rows for a set of IDs from a snapshot.
|
||||||
|
|
@ -829,15 +861,15 @@ async fn hydrate_nodes(
|
||||||
Ok(scan_result)
|
Ok(scan_result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Try bulk anti-join via CSR existence check. Returns Some if the inner
|
/// Try bulk anti-join via CSR existence check. Returns Some(mask) if the inner
|
||||||
/// pipeline is a single Expand from outer_var (the common negation pattern).
|
/// pipeline is a single Expand from outer_var (the common negation pattern).
|
||||||
fn try_bulk_anti_join(
|
fn try_bulk_anti_join_mask(
|
||||||
outer_batch: &RecordBatch,
|
wide: &RecordBatch,
|
||||||
inner_pipeline: &[IROp],
|
inner_pipeline: &[IROp],
|
||||||
graph_index: Option<&GraphIndex>,
|
graph_index: Option<&GraphIndex>,
|
||||||
catalog: &Catalog,
|
catalog: &Catalog,
|
||||||
outer_var: &str,
|
outer_var: &str,
|
||||||
) -> Option<Result<RecordBatch>> {
|
) -> Option<BooleanArray> {
|
||||||
if inner_pipeline.len() != 1 {
|
if inner_pipeline.len() != 1 {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
@ -866,8 +898,9 @@ fn try_bulk_anti_join(
|
||||||
}?;
|
}?;
|
||||||
let type_idx = gi.type_index(src_type_name)?;
|
let type_idx = gi.type_index(src_type_name)?;
|
||||||
|
|
||||||
let outer_ids = outer_batch
|
let id_col_name = format!("{}.id", outer_var);
|
||||||
.column_by_name("id")?
|
let outer_ids = wide
|
||||||
|
.column_by_name(&id_col_name)?
|
||||||
.as_any()
|
.as_any()
|
||||||
.downcast_ref::<StringArray>()?;
|
.downcast_ref::<StringArray>()?;
|
||||||
|
|
||||||
|
|
@ -881,16 +914,12 @@ fn try_bulk_anti_join(
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mask = BooleanArray::from(keep_mask);
|
Some(BooleanArray::from(keep_mask))
|
||||||
Some(
|
|
||||||
arrow_select::filter::filter_record_batch(outer_batch, &mask)
|
|
||||||
.map_err(|e| OmniError::Lance(e.to_string())),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute an AntiJoin: remove rows from outer_var where the inner pipeline finds matches.
|
/// Execute an AntiJoin: remove rows from wide batch where the inner pipeline finds matches.
|
||||||
async fn execute_anti_join(
|
async fn execute_anti_join(
|
||||||
bindings: &mut HashMap<String, RecordBatch>,
|
wide: &mut RecordBatch,
|
||||||
inner_pipeline: &[IROp],
|
inner_pipeline: &[IROp],
|
||||||
params: &ParamMap,
|
params: &ParamMap,
|
||||||
snapshot: &Snapshot,
|
snapshot: &Snapshot,
|
||||||
|
|
@ -898,35 +927,22 @@ async fn execute_anti_join(
|
||||||
catalog: &Catalog,
|
catalog: &Catalog,
|
||||||
outer_var: &str,
|
outer_var: &str,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let outer_batch = bindings.get(outer_var).ok_or_else(|| {
|
|
||||||
OmniError::manifest(format!(
|
|
||||||
"anti-join references unbound variable '{}'",
|
|
||||||
outer_var
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Fast path: bulk CSR existence check (O(N), zero Lance I/O)
|
// Fast path: bulk CSR existence check (O(N), zero Lance I/O)
|
||||||
if let Some(result) =
|
if let Some(mask) =
|
||||||
try_bulk_anti_join(outer_batch, inner_pipeline, graph_index, catalog, outer_var)
|
try_bulk_anti_join_mask(wide, inner_pipeline, graph_index, catalog, outer_var)
|
||||||
{
|
{
|
||||||
bindings.insert(outer_var.to_string(), result?);
|
*wide = arrow_select::filter::filter_record_batch(wide, &mask)
|
||||||
|
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Slow path: per-row inner pipeline execution
|
// Slow path: per-row inner pipeline execution
|
||||||
let outer_ids = outer_batch
|
let num_rows = wide.num_rows();
|
||||||
.column_by_name("id")
|
let mut keep_mask = vec![true; num_rows];
|
||||||
.ok_or_else(|| OmniError::manifest("outer batch missing 'id' column".to_string()))?
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<StringArray>()
|
|
||||||
.ok_or_else(|| OmniError::manifest("outer 'id' column is not Utf8".to_string()))?;
|
|
||||||
|
|
||||||
let mut keep_mask = vec![true; outer_batch.num_rows()];
|
for i in 0..num_rows {
|
||||||
|
let single_row = wide.slice(i, 1);
|
||||||
for i in 0..outer_ids.len() {
|
let mut inner_wide: Option<RecordBatch> = Some(single_row);
|
||||||
let single_row = outer_batch.slice(i, 1);
|
|
||||||
let mut inner_bindings: HashMap<String, RecordBatch> = HashMap::new();
|
|
||||||
inner_bindings.insert(outer_var.to_string(), single_row);
|
|
||||||
|
|
||||||
let no_search = SearchMode::default();
|
let no_search = SearchMode::default();
|
||||||
execute_pipeline(
|
execute_pipeline(
|
||||||
|
|
@ -935,15 +951,15 @@ async fn execute_anti_join(
|
||||||
snapshot,
|
snapshot,
|
||||||
graph_index,
|
graph_index,
|
||||||
catalog,
|
catalog,
|
||||||
&mut inner_bindings,
|
&mut inner_wide,
|
||||||
&no_search,
|
&no_search,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let has_match = inner_bindings
|
let has_match = inner_wide
|
||||||
.iter()
|
.as_ref()
|
||||||
.filter(|(k, _)| *k != outer_var)
|
.map(|batch| batch.num_rows() > 0)
|
||||||
.any(|(_, batch)| batch.num_rows() > 0);
|
.unwrap_or(false);
|
||||||
|
|
||||||
if has_match {
|
if has_match {
|
||||||
keep_mask[i] = false;
|
keep_mask[i] = false;
|
||||||
|
|
@ -951,10 +967,8 @@ async fn execute_anti_join(
|
||||||
}
|
}
|
||||||
|
|
||||||
let mask = BooleanArray::from(keep_mask);
|
let mask = BooleanArray::from(keep_mask);
|
||||||
let filtered = arrow_select::filter::filter_record_batch(outer_batch, &mask)
|
*wide = arrow_select::filter::filter_record_batch(wide, &mask)
|
||||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||||
|
|
||||||
bindings.insert(outer_var.to_string(), filtered);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1220,3 +1234,50 @@ pub(super) fn literal_to_sql(lit: &Literal) -> String {
|
||||||
Literal::List(_) => "NULL".to_string(), // Not supported in SQL pushdown
|
Literal::List(_) => "NULL".to_string(), // Not supported in SQL pushdown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
|
||||||
|
let fields: Vec<Field> = batch.schema().fields().iter().map(|f| {
|
||||||
|
Field::new(format!("{}.{}", variable, f.name()), f.data_type().clone(), f.is_nullable())
|
||||||
|
}).collect();
|
||||||
|
let schema = Arc::new(Schema::new(fields));
|
||||||
|
RecordBatch::try_new(schema, batch.columns().to_vec()).map_err(|e| OmniError::Lance(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
|
||||||
|
let n = left.num_rows();
|
||||||
|
let m = right.num_rows();
|
||||||
|
if n == 0 || m == 0 {
|
||||||
|
let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
|
||||||
|
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
|
||||||
|
return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
|
||||||
|
}
|
||||||
|
let left_indices: Vec<u32> = (0..n as u32).flat_map(|i| std::iter::repeat(i).take(m)).collect();
|
||||||
|
let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
|
||||||
|
let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
|
||||||
|
let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
|
||||||
|
hconcat_batches(&left_expanded, &right_expanded)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
|
||||||
|
let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
let left_schema = left.schema();
|
||||||
|
let left_names: HashSet<&str> = left_schema.fields().iter().map(|f| f.name().as_str()).collect();
|
||||||
|
let right_schema = right.schema();
|
||||||
|
for f in right_schema.fields() {
|
||||||
|
debug_assert!(!left_names.contains(f.name().as_str()), "hconcat_batches: duplicate column '{}'", f.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
|
||||||
|
let mut columns: Vec<ArrayRef> = left.columns().to_vec();
|
||||||
|
columns.extend(right.columns().to_vec());
|
||||||
|
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
|
||||||
|
let columns: Vec<ArrayRef> = batch.columns().iter()
|
||||||
|
.map(|col| arrow_select::take::take(col.as_ref(), indices, None))
|
||||||
|
.collect::<std::result::Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||||
|
RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
|
||||||
|
}
|
||||||
|
|
|
||||||
302
crates/omnigraph/tests/aggregation.rs
Normal file
302
crates/omnigraph/tests/aggregation.rs
Normal file
|
|
@ -0,0 +1,302 @@
|
||||||
|
mod helpers;
|
||||||
|
|
||||||
|
use arrow_array::{Array, Float64Array, Int32Array, Int64Array, StringArray};
|
||||||
|
|
||||||
|
use omnigraph::db::Omnigraph;
|
||||||
|
use omnigraph::loader::{LoadMode, load_jsonl};
|
||||||
|
use omnigraph_compiler::ir::ParamMap;
|
||||||
|
|
||||||
|
use helpers::*;
|
||||||
|
|
||||||
|
// ─── Count aggregate with GROUP BY ─────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn friend_counts_grouped_by_person() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let mut db = init_and_load(&dir).await;
|
||||||
|
|
||||||
|
// Test data: Alice knows Bob, Alice knows Charlie, Bob knows Diana
|
||||||
|
// So: Alice=2 friends, Bob=1 friend (Charlie & Diana have no outgoing knows → dropped)
|
||||||
|
let result = query_main(&mut db, TEST_QUERIES, "friend_counts", &ParamMap::new())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let batch = result.concat_batches().unwrap();
|
||||||
|
assert_eq!(batch.num_rows(), 2);
|
||||||
|
|
||||||
|
let names = batch
|
||||||
|
.column_by_name("p.name")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<StringArray>()
|
||||||
|
.unwrap();
|
||||||
|
let counts = batch
|
||||||
|
.column_by_name("friends")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Int64Array>()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Ordered by friends desc: Alice(2), Bob(1)
|
||||||
|
assert_eq!(names.value(0), "Alice");
|
||||||
|
assert_eq!(counts.value(0), 2);
|
||||||
|
assert_eq!(names.value(1), "Bob");
|
||||||
|
assert_eq!(counts.value(1), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Global count (no group key) ───────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn total_people_global_count() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let mut db = init_and_load(&dir).await;
|
||||||
|
|
||||||
|
let result = query_main(&mut db, TEST_QUERIES, "total_people", &ParamMap::new())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let batch = result.concat_batches().unwrap();
|
||||||
|
assert_eq!(batch.num_rows(), 1);
|
||||||
|
|
||||||
|
let total = batch
|
||||||
|
.column_by_name("total")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Int64Array>()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(total.value(0), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Sum/Avg/Min/Max aggregates ────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn age_stats_per_company() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let mut db = init_and_load(&dir).await;
|
||||||
|
|
||||||
|
// Test data: Alice(30) worksAt Acme, Bob(25) worksAt Globex
|
||||||
|
let result = query_main(&mut db, TEST_QUERIES, "age_stats", &ParamMap::new())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let batch = result.concat_batches().unwrap();
|
||||||
|
assert_eq!(batch.num_rows(), 2);
|
||||||
|
|
||||||
|
let names = batch
|
||||||
|
.column_by_name("c.name")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<StringArray>()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Find the Acme row (order not guaranteed without ORDER BY)
|
||||||
|
let acme_row = (0..batch.num_rows())
|
||||||
|
.find(|&i| names.value(i) == "Acme")
|
||||||
|
.expect("Acme should be in results");
|
||||||
|
let globex_row = (0..batch.num_rows())
|
||||||
|
.find(|&i| names.value(i) == "Globex")
|
||||||
|
.expect("Globex should be in results");
|
||||||
|
|
||||||
|
let headcount = batch
|
||||||
|
.column_by_name("headcount")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Int64Array>()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(headcount.value(acme_row), 1);
|
||||||
|
assert_eq!(headcount.value(globex_row), 1);
|
||||||
|
|
||||||
|
let avg_age = batch
|
||||||
|
.column_by_name("avg_age")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Float64Array>()
|
||||||
|
.unwrap();
|
||||||
|
assert!((avg_age.value(acme_row) - 30.0).abs() < 0.01);
|
||||||
|
assert!((avg_age.value(globex_row) - 25.0).abs() < 0.01);
|
||||||
|
|
||||||
|
let youngest = batch
|
||||||
|
.column_by_name("youngest")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Int32Array>()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(youngest.value(acme_row), 30);
|
||||||
|
assert_eq!(youngest.value(globex_row), 25);
|
||||||
|
|
||||||
|
let total_age = batch
|
||||||
|
.column_by_name("total_age")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Float64Array>()
|
||||||
|
.unwrap();
|
||||||
|
assert!((total_age.value(acme_row) - 30.0).abs() < 0.01);
|
||||||
|
assert!((total_age.value(globex_row) - 25.0).abs() < 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Aggregate + order + limit ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn top_connected_person() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let mut db = init_and_load(&dir).await;
|
||||||
|
|
||||||
|
// Alice has 2 friends (most connected)
|
||||||
|
let result = query_main(&mut db, TEST_QUERIES, "top_connected", &ParamMap::new())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let batch = result.concat_batches().unwrap();
|
||||||
|
assert_eq!(batch.num_rows(), 1);
|
||||||
|
|
||||||
|
let names = batch
|
||||||
|
.column_by_name("p.name")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<StringArray>()
|
||||||
|
.unwrap();
|
||||||
|
let counts = batch
|
||||||
|
.column_by_name("friends")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Int64Array>()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(names.value(0), "Alice");
|
||||||
|
assert_eq!(counts.value(0), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Inline aggregate query (not from fixture) ────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn inline_count_with_no_friends() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let mut db = init_and_load(&dir).await;
|
||||||
|
|
||||||
|
// Diana has no outgoing Knows edges, so she won't appear in friend_counts.
|
||||||
|
// But Alice(2) and Bob(1) do. Verify the count matches.
|
||||||
|
let queries = r#"
|
||||||
|
query fc() {
|
||||||
|
match {
|
||||||
|
$p: Person
|
||||||
|
$p knows $f
|
||||||
|
}
|
||||||
|
return { $p.name, count($f) as friends }
|
||||||
|
order { friends desc }
|
||||||
|
}
|
||||||
|
"#;
|
||||||
|
let result = query_main(&mut db, queries, "fc", &ParamMap::new())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let batch = result.concat_batches().unwrap();
|
||||||
|
// Only Alice and Bob have outgoing Knows edges
|
||||||
|
assert_eq!(batch.num_rows(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn inline_global_count_empty() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let uri = dir.path().to_str().unwrap();
|
||||||
|
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
|
||||||
|
|
||||||
|
// Load only nodes, no edges
|
||||||
|
let data = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}}
|
||||||
|
{"type": "Company", "data": {"name": "Acme"}}"#;
|
||||||
|
load_jsonl(&mut db, data, LoadMode::Overwrite)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Global count — should return 1 row with count=1
|
||||||
|
let queries = r#"
|
||||||
|
query tc() {
|
||||||
|
match { $p: Person }
|
||||||
|
return { count($p) as total }
|
||||||
|
}
|
||||||
|
"#;
|
||||||
|
let result = query_main(&mut db, queries, "tc", &ParamMap::new())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let batch = result.concat_batches().unwrap();
|
||||||
|
assert_eq!(batch.num_rows(), 1);
|
||||||
|
let total = batch
|
||||||
|
.column_by_name("total")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Int64Array>()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(total.value(0), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn inline_min_max_string() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let mut db = init_and_load(&dir).await;
|
||||||
|
|
||||||
|
let queries = r#"
|
||||||
|
query name_range() {
|
||||||
|
match { $p: Person }
|
||||||
|
return {
|
||||||
|
min($p.name) as first_name
|
||||||
|
max($p.name) as last_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"#;
|
||||||
|
let result = query_main(&mut db, queries, "name_range", &ParamMap::new())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let batch = result.concat_batches().unwrap();
|
||||||
|
assert_eq!(batch.num_rows(), 1);
|
||||||
|
|
||||||
|
let first = batch
|
||||||
|
.column_by_name("first_name")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<StringArray>()
|
||||||
|
.unwrap();
|
||||||
|
let last = batch
|
||||||
|
.column_by_name("last_name")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<StringArray>()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(first.value(0), "Alice");
|
||||||
|
assert_eq!(last.value(0), "Diana");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn inline_multi_hop_aggregate() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let mut db = init_and_load(&dir).await;
|
||||||
|
|
||||||
|
// Friends of friends count per person
|
||||||
|
let queries = r#"
|
||||||
|
query fof_counts($name: String) {
|
||||||
|
match {
|
||||||
|
$p: Person { name: $name }
|
||||||
|
$p knows $mid
|
||||||
|
$mid knows $fof
|
||||||
|
}
|
||||||
|
return { $p.name, count($fof) as fof_count }
|
||||||
|
}
|
||||||
|
"#;
|
||||||
|
// Alice → Bob, Charlie. Bob → Diana. Charlie → nobody.
|
||||||
|
// So Alice's fof = [Diana] → count = 1
|
||||||
|
let result = query_main(
|
||||||
|
&mut db,
|
||||||
|
queries,
|
||||||
|
"fof_counts",
|
||||||
|
¶ms(&[("$name", "Alice")]),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let batch = result.concat_batches().unwrap();
|
||||||
|
assert_eq!(batch.num_rows(), 1);
|
||||||
|
|
||||||
|
let count = batch
|
||||||
|
.column_by_name("fof_count")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Int64Array>()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(count.value(0), 1);
|
||||||
|
}
|
||||||
35
crates/omnigraph/tests/fixtures/test.gq
vendored
35
crates/omnigraph/tests/fixtures/test.gq
vendored
|
|
@ -76,3 +76,38 @@ query top_by_age() {
|
||||||
order { $p.age desc }
|
order { $p.age desc }
|
||||||
limit 2
|
limit 2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Aggregation: global count (no group key)
|
||||||
|
query total_people() {
|
||||||
|
match {
|
||||||
|
$p: Person
|
||||||
|
}
|
||||||
|
return { count($p) as total }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregation: sum/avg/min/max of age per company
|
||||||
|
query age_stats() {
|
||||||
|
match {
|
||||||
|
$p: Person
|
||||||
|
$p worksAt $c
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
$c.name
|
||||||
|
sum($p.age) as total_age
|
||||||
|
avg($p.age) as avg_age
|
||||||
|
min($p.age) as youngest
|
||||||
|
max($p.age) as oldest
|
||||||
|
count($p) as headcount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregation: most connected person
|
||||||
|
query top_connected() {
|
||||||
|
match {
|
||||||
|
$p: Person
|
||||||
|
$p knows $f
|
||||||
|
}
|
||||||
|
return { $p.name, count($f) as friends }
|
||||||
|
order { friends desc }
|
||||||
|
limit 1
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue