mirror of
https://github.com/ModernRelay/omnigraph.git
synced 2026-06-21 02:28:07 +02:00
Add runtime support for aggregate functions (count, sum, avg, min, max) with GROUP BY semantics, built on a single wide RecordBatch that eliminates correlation tracking by construction. Execution engine (exec/query.rs): - Replace HashMap<String, RecordBatch> with Option<RecordBatch> where columns are prefixed as <variable>.<property> - NodeScan prefixes columns and cross-joins with existing batch - Expand collects (src_row, dst_id) pairs, takes wide batch rows, appends prefixed destination columns via hconcat - Filter applies single mask to entire wide batch - AntiJoin: fast-path returns BooleanArray mask; slow-path slices one row for inner pipeline execution Projection engine (exec/projection.rs): - aggregate_return groups rows by non-aggregate key columns using length-prefixed string encoding, computes per-group aggregates - SUM accumulates into f64 to avoid integer overflow - MIN/MAX support both numeric and string types - Empty input returns count=0, others=null Compiler (typecheck.rs): - T8: split MIN/MAX from SUM/AVG — allow string arguments - T9: non-aggregate expressions in aggregate queries must be property accesses or variables - SUM type inference returns Float64 (matching runtime) Tests: 8 new integration tests covering grouped count, global count, sum/avg/min/max per company, aggregate+order+limit, string min/max, multi-hop aggregates, and edge cases. https://claude.ai/code/session_019o5NRyYomgETFyd7hpiLey
302 lines
8.7 KiB
Rust
302 lines
8.7 KiB
Rust
mod helpers;
|
|
|
|
use arrow_array::{Array, Float64Array, Int32Array, Int64Array, StringArray};
|
|
|
|
use omnigraph::db::Omnigraph;
|
|
use omnigraph::loader::{LoadMode, load_jsonl};
|
|
use omnigraph_compiler::ir::ParamMap;
|
|
|
|
use helpers::*;
|
|
|
|
// ─── Count aggregate with GROUP BY ─────────────────────────────────────────
|
|
|
|
#[tokio::test]
|
|
async fn friend_counts_grouped_by_person() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let mut db = init_and_load(&dir).await;
|
|
|
|
// Test data: Alice knows Bob, Alice knows Charlie, Bob knows Diana
|
|
// So: Alice=2 friends, Bob=1 friend (Charlie & Diana have no outgoing knows → dropped)
|
|
let result = query_main(&mut db, TEST_QUERIES, "friend_counts", &ParamMap::new())
|
|
.await
|
|
.unwrap();
|
|
|
|
let batch = result.concat_batches().unwrap();
|
|
assert_eq!(batch.num_rows(), 2);
|
|
|
|
let names = batch
|
|
.column_by_name("p.name")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<StringArray>()
|
|
.unwrap();
|
|
let counts = batch
|
|
.column_by_name("friends")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<Int64Array>()
|
|
.unwrap();
|
|
|
|
// Ordered by friends desc: Alice(2), Bob(1)
|
|
assert_eq!(names.value(0), "Alice");
|
|
assert_eq!(counts.value(0), 2);
|
|
assert_eq!(names.value(1), "Bob");
|
|
assert_eq!(counts.value(1), 1);
|
|
}
|
|
|
|
// ─── Global count (no group key) ───────────────────────────────────────────
|
|
|
|
#[tokio::test]
|
|
async fn total_people_global_count() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let mut db = init_and_load(&dir).await;
|
|
|
|
let result = query_main(&mut db, TEST_QUERIES, "total_people", &ParamMap::new())
|
|
.await
|
|
.unwrap();
|
|
|
|
let batch = result.concat_batches().unwrap();
|
|
assert_eq!(batch.num_rows(), 1);
|
|
|
|
let total = batch
|
|
.column_by_name("total")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<Int64Array>()
|
|
.unwrap();
|
|
assert_eq!(total.value(0), 4);
|
|
}
|
|
|
|
// ─── Sum/Avg/Min/Max aggregates ────────────────────────────────────────────
|
|
|
|
#[tokio::test]
|
|
async fn age_stats_per_company() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let mut db = init_and_load(&dir).await;
|
|
|
|
// Test data: Alice(30) worksAt Acme, Bob(25) worksAt Globex
|
|
let result = query_main(&mut db, TEST_QUERIES, "age_stats", &ParamMap::new())
|
|
.await
|
|
.unwrap();
|
|
|
|
let batch = result.concat_batches().unwrap();
|
|
assert_eq!(batch.num_rows(), 2);
|
|
|
|
let names = batch
|
|
.column_by_name("c.name")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<StringArray>()
|
|
.unwrap();
|
|
|
|
// Find the Acme row (order not guaranteed without ORDER BY)
|
|
let acme_row = (0..batch.num_rows())
|
|
.find(|&i| names.value(i) == "Acme")
|
|
.expect("Acme should be in results");
|
|
let globex_row = (0..batch.num_rows())
|
|
.find(|&i| names.value(i) == "Globex")
|
|
.expect("Globex should be in results");
|
|
|
|
let headcount = batch
|
|
.column_by_name("headcount")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<Int64Array>()
|
|
.unwrap();
|
|
assert_eq!(headcount.value(acme_row), 1);
|
|
assert_eq!(headcount.value(globex_row), 1);
|
|
|
|
let avg_age = batch
|
|
.column_by_name("avg_age")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<Float64Array>()
|
|
.unwrap();
|
|
assert!((avg_age.value(acme_row) - 30.0).abs() < 0.01);
|
|
assert!((avg_age.value(globex_row) - 25.0).abs() < 0.01);
|
|
|
|
let youngest = batch
|
|
.column_by_name("youngest")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<Int32Array>()
|
|
.unwrap();
|
|
assert_eq!(youngest.value(acme_row), 30);
|
|
assert_eq!(youngest.value(globex_row), 25);
|
|
|
|
let total_age = batch
|
|
.column_by_name("total_age")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<Float64Array>()
|
|
.unwrap();
|
|
assert!((total_age.value(acme_row) - 30.0).abs() < 0.01);
|
|
assert!((total_age.value(globex_row) - 25.0).abs() < 0.01);
|
|
}
|
|
|
|
// ─── Aggregate + order + limit ─────────────────────────────────────────────
|
|
|
|
#[tokio::test]
|
|
async fn top_connected_person() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let mut db = init_and_load(&dir).await;
|
|
|
|
// Alice has 2 friends (most connected)
|
|
let result = query_main(&mut db, TEST_QUERIES, "top_connected", &ParamMap::new())
|
|
.await
|
|
.unwrap();
|
|
|
|
let batch = result.concat_batches().unwrap();
|
|
assert_eq!(batch.num_rows(), 1);
|
|
|
|
let names = batch
|
|
.column_by_name("p.name")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<StringArray>()
|
|
.unwrap();
|
|
let counts = batch
|
|
.column_by_name("friends")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<Int64Array>()
|
|
.unwrap();
|
|
|
|
assert_eq!(names.value(0), "Alice");
|
|
assert_eq!(counts.value(0), 2);
|
|
}
|
|
|
|
// ─── Inline aggregate query (not from fixture) ────────────────────────────
|
|
|
|
#[tokio::test]
|
|
async fn inline_count_with_no_friends() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let mut db = init_and_load(&dir).await;
|
|
|
|
// Diana has no outgoing Knows edges, so she won't appear in friend_counts.
|
|
// But Alice(2) and Bob(1) do. Verify the count matches.
|
|
let queries = r#"
|
|
query fc() {
|
|
match {
|
|
$p: Person
|
|
$p knows $f
|
|
}
|
|
return { $p.name, count($f) as friends }
|
|
order { friends desc }
|
|
}
|
|
"#;
|
|
let result = query_main(&mut db, queries, "fc", &ParamMap::new())
|
|
.await
|
|
.unwrap();
|
|
let batch = result.concat_batches().unwrap();
|
|
// Only Alice and Bob have outgoing Knows edges
|
|
assert_eq!(batch.num_rows(), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn inline_global_count_empty() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let uri = dir.path().to_str().unwrap();
|
|
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
|
|
|
|
// Load only nodes, no edges
|
|
let data = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}}
|
|
{"type": "Company", "data": {"name": "Acme"}}"#;
|
|
load_jsonl(&mut db, data, LoadMode::Overwrite)
|
|
.await
|
|
.unwrap();
|
|
|
|
// Global count — should return 1 row with count=1
|
|
let queries = r#"
|
|
query tc() {
|
|
match { $p: Person }
|
|
return { count($p) as total }
|
|
}
|
|
"#;
|
|
let result = query_main(&mut db, queries, "tc", &ParamMap::new())
|
|
.await
|
|
.unwrap();
|
|
let batch = result.concat_batches().unwrap();
|
|
assert_eq!(batch.num_rows(), 1);
|
|
let total = batch
|
|
.column_by_name("total")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<Int64Array>()
|
|
.unwrap();
|
|
assert_eq!(total.value(0), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn inline_min_max_string() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let mut db = init_and_load(&dir).await;
|
|
|
|
let queries = r#"
|
|
query name_range() {
|
|
match { $p: Person }
|
|
return {
|
|
min($p.name) as first_name
|
|
max($p.name) as last_name
|
|
}
|
|
}
|
|
"#;
|
|
let result = query_main(&mut db, queries, "name_range", &ParamMap::new())
|
|
.await
|
|
.unwrap();
|
|
let batch = result.concat_batches().unwrap();
|
|
assert_eq!(batch.num_rows(), 1);
|
|
|
|
let first = batch
|
|
.column_by_name("first_name")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<StringArray>()
|
|
.unwrap();
|
|
let last = batch
|
|
.column_by_name("last_name")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<StringArray>()
|
|
.unwrap();
|
|
assert_eq!(first.value(0), "Alice");
|
|
assert_eq!(last.value(0), "Diana");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn inline_multi_hop_aggregate() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let mut db = init_and_load(&dir).await;
|
|
|
|
// Friends of friends count per person
|
|
let queries = r#"
|
|
query fof_counts($name: String) {
|
|
match {
|
|
$p: Person { name: $name }
|
|
$p knows $mid
|
|
$mid knows $fof
|
|
}
|
|
return { $p.name, count($fof) as fof_count }
|
|
}
|
|
"#;
|
|
// Alice → Bob, Charlie. Bob → Diana. Charlie → nobody.
|
|
// So Alice's fof = [Diana] → count = 1
|
|
let result = query_main(
|
|
&mut db,
|
|
queries,
|
|
"fof_counts",
|
|
¶ms(&[("$name", "Alice")]),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
let batch = result.concat_batches().unwrap();
|
|
assert_eq!(batch.num_rows(), 1);
|
|
|
|
let count = batch
|
|
.column_by_name("fof_count")
|
|
.unwrap()
|
|
.as_any()
|
|
.downcast_ref::<Int64Array>()
|
|
.unwrap();
|
|
assert_eq!(count.value(0), 1);
|
|
}
|