omnigraph/crates/omnigraph/tests/aggregation.rs
Claude 351610d18c
Implement aggregate execution with wide-batch model
Add runtime support for aggregate functions (count, sum, avg, min, max)
with GROUP BY semantics, built on a single wide RecordBatch that
eliminates correlation tracking by construction.

Execution engine (exec/query.rs):
- Replace HashMap<String, RecordBatch> with Option<RecordBatch> where
  columns are prefixed as <variable>.<property>
- NodeScan prefixes columns and cross-joins with existing batch
- Expand collects (src_row, dst_id) pairs, takes wide batch rows,
  appends prefixed destination columns via hconcat
- Filter applies single mask to entire wide batch
- AntiJoin: fast-path returns BooleanArray mask; slow-path slices
  one row for inner pipeline execution

Projection engine (exec/projection.rs):
- aggregate_return groups rows by non-aggregate key columns using
  length-prefixed string encoding, computes per-group aggregates
- SUM accumulates into f64 to avoid integer overflow
- MIN/MAX support both numeric and string types
- Empty input returns count=0, others=null

Compiler (typecheck.rs):
- T8: split MIN/MAX from SUM/AVG — allow string arguments
- T9: non-aggregate expressions in aggregate queries must be
  property accesses or variables
- SUM type inference returns Float64 (matching runtime)

Tests: 8 new integration tests covering grouped count, global count,
sum/avg/min/max per company, aggregate+order+limit, string min/max,
multi-hop aggregates, and edge cases.

https://claude.ai/code/session_019o5NRyYomgETFyd7hpiLey
2026-04-12 20:59:13 +00:00

302 lines
8.7 KiB
Rust

mod helpers;
use arrow_array::{Array, Float64Array, Int32Array, Int64Array, StringArray};
use omnigraph::db::Omnigraph;
use omnigraph::loader::{LoadMode, load_jsonl};
use omnigraph_compiler::ir::ParamMap;
use helpers::*;
// ─── Count aggregate with GROUP BY ─────────────────────────────────────────
#[tokio::test]
async fn friend_counts_grouped_by_person() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Test data: Alice knows Bob, Alice knows Charlie, Bob knows Diana
// So: Alice=2 friends, Bob=1 friend (Charlie & Diana have no outgoing knows → dropped)
let result = query_main(&mut db, TEST_QUERIES, "friend_counts", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 2);
let names = batch
.column_by_name("p.name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let counts = batch
.column_by_name("friends")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
// Ordered by friends desc: Alice(2), Bob(1)
assert_eq!(names.value(0), "Alice");
assert_eq!(counts.value(0), 2);
assert_eq!(names.value(1), "Bob");
assert_eq!(counts.value(1), 1);
}
// ─── Global count (no group key) ───────────────────────────────────────────
#[tokio::test]
async fn total_people_global_count() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let result = query_main(&mut db, TEST_QUERIES, "total_people", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let total = batch
.column_by_name("total")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(total.value(0), 4);
}
// ─── Sum/Avg/Min/Max aggregates ────────────────────────────────────────────
#[tokio::test]
async fn age_stats_per_company() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Test data: Alice(30) worksAt Acme, Bob(25) worksAt Globex
let result = query_main(&mut db, TEST_QUERIES, "age_stats", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 2);
let names = batch
.column_by_name("c.name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
// Find the Acme row (order not guaranteed without ORDER BY)
let acme_row = (0..batch.num_rows())
.find(|&i| names.value(i) == "Acme")
.expect("Acme should be in results");
let globex_row = (0..batch.num_rows())
.find(|&i| names.value(i) == "Globex")
.expect("Globex should be in results");
let headcount = batch
.column_by_name("headcount")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(headcount.value(acme_row), 1);
assert_eq!(headcount.value(globex_row), 1);
let avg_age = batch
.column_by_name("avg_age")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((avg_age.value(acme_row) - 30.0).abs() < 0.01);
assert!((avg_age.value(globex_row) - 25.0).abs() < 0.01);
let youngest = batch
.column_by_name("youngest")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(youngest.value(acme_row), 30);
assert_eq!(youngest.value(globex_row), 25);
let total_age = batch
.column_by_name("total_age")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((total_age.value(acme_row) - 30.0).abs() < 0.01);
assert!((total_age.value(globex_row) - 25.0).abs() < 0.01);
}
// ─── Aggregate + order + limit ─────────────────────────────────────────────
#[tokio::test]
async fn top_connected_person() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Alice has 2 friends (most connected)
let result = query_main(&mut db, TEST_QUERIES, "top_connected", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let names = batch
.column_by_name("p.name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let counts = batch
.column_by_name("friends")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(names.value(0), "Alice");
assert_eq!(counts.value(0), 2);
}
// ─── Inline aggregate query (not from fixture) ────────────────────────────
#[tokio::test]
async fn inline_count_with_no_friends() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Diana has no outgoing Knows edges, so she won't appear in friend_counts.
// But Alice(2) and Bob(1) do. Verify the count matches.
let queries = r#"
query fc() {
match {
$p: Person
$p knows $f
}
return { $p.name, count($f) as friends }
order { friends desc }
}
"#;
let result = query_main(&mut db, queries, "fc", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
// Only Alice and Bob have outgoing Knows edges
assert_eq!(batch.num_rows(), 2);
}
#[tokio::test]
async fn inline_global_count_empty() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
// Load only nodes, no edges
let data = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}}
{"type": "Company", "data": {"name": "Acme"}}"#;
load_jsonl(&mut db, data, LoadMode::Overwrite)
.await
.unwrap();
// Global count — should return 1 row with count=1
let queries = r#"
query tc() {
match { $p: Person }
return { count($p) as total }
}
"#;
let result = query_main(&mut db, queries, "tc", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let total = batch
.column_by_name("total")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(total.value(0), 1);
}
#[tokio::test]
async fn inline_min_max_string() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let queries = r#"
query name_range() {
match { $p: Person }
return {
min($p.name) as first_name
max($p.name) as last_name
}
}
"#;
let result = query_main(&mut db, queries, "name_range", &ParamMap::new())
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let first = batch
.column_by_name("first_name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let last = batch
.column_by_name("last_name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(first.value(0), "Alice");
assert_eq!(last.value(0), "Diana");
}
#[tokio::test]
async fn inline_multi_hop_aggregate() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Friends of friends count per person
let queries = r#"
query fof_counts($name: String) {
match {
$p: Person { name: $name }
$p knows $mid
$mid knows $fof
}
return { $p.name, count($fof) as fof_count }
}
"#;
// Alice → Bob, Charlie. Bob → Diana. Charlie → nobody.
// So Alice's fof = [Diana] → count = 1
let result = query_main(
&mut db,
queries,
"fof_counts",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let count = batch
.column_by_name("fof_count")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(count.value(0), 1);
}