diff --git a/crates/omnigraph-compiler/src/ir/lower.rs b/crates/omnigraph-compiler/src/ir/lower.rs index 61e7eb9..a8ae6f1 100644 --- a/crates/omnigraph-compiler/src/ir/lower.rs +++ b/crates/omnigraph-compiler/src/ir/lower.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet, VecDeque}; use crate::catalog::Catalog; use crate::error::Result; @@ -147,67 +147,89 @@ fn lower_clauses( } } - // Lower bindings into NodeScan ops + // ── Determine which bindings are "deferred" ───────────────────────── + // + // When multiple bindings in the same match clause are connected by + // traversals, only the first-declared binding needs a NodeScan; the + // rest will be introduced by Expand operations. Making them all + // NodeScans triggers expensive cross-joins followed by cycle-closing + // filters. + // + // Algorithm: build an undirected graph of variables connected by + // traversals, then walk connected components in binding declaration + // order. The first binding in each component becomes the root (gets + // a NodeScan); all other bindings in the same component are deferred + // — their inline filters become post-Expand Filter ops. + + let binding_set: HashSet<&str> = bindings.iter().map(|b| b.variable.as_str()).collect(); + + // Build undirected traversal adjacency (variable → neighbours) + let mut adj: HashMap<&str, Vec<&str>> = HashMap::new(); + for t in &traversals { + adj.entry(t.src.as_str()).or_default().push(t.dst.as_str()); + adj.entry(t.dst.as_str()).or_default().push(t.src.as_str()); + } + + // Walk components to find deferred binding variables + let mut deferred_set: HashSet = HashSet::new(); + let mut component_visited: HashSet<&str> = HashSet::new(); + + for binding in &bindings { + if component_visited.contains(binding.variable.as_str()) { + continue; + } + // BFS from this binding through the traversal graph + let mut queue = VecDeque::new(); + queue.push_back(binding.variable.as_str()); + let mut component_bindings: Vec<&str> = Vec::new(); + + while let Some(var) = queue.pop_front() { + if !component_visited.insert(var) { + continue; + } + if binding_set.contains(var) { + component_bindings.push(var); + } + if let Some(neighbours) = adj.get(var) { + for &n in neighbours { + if !component_visited.contains(n) { + queue.push_back(n); + } + } + } + } + + // First binding in the component is the root; defer the rest. + for var in component_bindings.into_iter().skip(1) { + deferred_set.insert(var.to_string()); + } + } + + // Build deferred filters map for variables introduced by traversals + let mut deferred_filters: HashMap> = HashMap::new(); + + // Lower bindings into NodeScan ops (skip deferred ones) for binding in &bindings { let node_type = catalog .node_types .get(&binding.type_name) .expect("binding type was validated during typecheck"); - // Collect inline filters from prop matches - let mut scan_filters = Vec::new(); - for pm in &binding.prop_matches { - let prop = node_type - .properties - .get(&pm.prop_name) - .expect("binding property was validated during typecheck"); - let op = if prop.list { - CompOp::Contains - } else { - CompOp::Eq - }; - match &pm.value { - MatchValue::Literal(lit) => { - scan_filters.push(IRFilter { - left: IRExpr::PropAccess { - variable: binding.variable.clone(), - property: pm.prop_name.clone(), - }, - op, - right: IRExpr::Literal(lit.clone()), - }); - } - MatchValue::Now => { - scan_filters.push(IRFilter { - left: IRExpr::PropAccess { - variable: binding.variable.clone(), - property: pm.prop_name.clone(), - }, - op, - right: IRExpr::Param(NOW_PARAM_NAME.to_string()), - }); - } - MatchValue::Variable(v) => { - let right = if param_names.contains(v) { - IRExpr::Param(v.clone()) - } else { - IRExpr::Variable(v.clone()) - }; - scan_filters.push(IRFilter { - left: IRExpr::PropAccess { - variable: binding.variable.clone(), - property: pm.prop_name.clone(), - }, - op, - right, - }); - } + + let binding_filters = build_binding_filters(binding, node_type, param_names); + + if deferred_set.contains(&binding.variable) { + // Save filters for emission after the Expand that introduces + // this variable. + if !binding_filters.is_empty() { + deferred_filters.insert(binding.variable.clone(), binding_filters); } + continue; } pipeline.push(IROp::NodeScan { variable: binding.variable.clone(), type_name: binding.type_name.clone(), - filters: scan_filters, + filters: binding_filters, }); bound_vars.insert(binding.variable.clone()); } @@ -250,6 +272,7 @@ fn lower_clauses( dst_type, min_hops: traversal.min_hops, max_hops: traversal.max_hops, + dst_filters: vec![], }); pipeline.push(IROp::Filter(IRFilter { left: IRExpr::PropAccess { @@ -273,6 +296,8 @@ fn lower_clauses( Direction::Out => edge.from_type.clone(), Direction::In => edge.to_type.clone(), }; + let introduced_filters = + deferred_filters.remove(&traversal.src).unwrap_or_default(); pipeline.push(IROp::Expand { src_var: traversal.dst.clone(), dst_var: traversal.src.clone(), @@ -281,11 +306,14 @@ fn lower_clauses( dst_type: src_type, min_hops: traversal.min_hops, max_hops: traversal.max_hops, + dst_filters: introduced_filters, }); if traversal.src != "_" { bound_vars.insert(traversal.src.clone()); } } else { + let introduced_filters = + deferred_filters.remove(&traversal.dst).unwrap_or_default(); pipeline.push(IROp::Expand { src_var: traversal.src.clone(), dst_var: traversal.dst.clone(), @@ -294,6 +322,7 @@ fn lower_clauses( dst_type, min_hops: traversal.min_hops, max_hops: traversal.max_hops, + dst_filters: introduced_filters, }); if traversal.dst != "_" { bound_vars.insert(traversal.dst.clone()); @@ -335,6 +364,46 @@ fn lower_clauses( Ok(()) } +/// Build IR filters from a binding's inline property matches. +fn build_binding_filters( + binding: &Binding, + node_type: &crate::catalog::NodeType, + param_names: &HashSet, +) -> Vec { + let mut filters = Vec::new(); + for pm in &binding.prop_matches { + let prop = node_type + .properties + .get(&pm.prop_name) + .expect("binding property was validated during typecheck"); + let op = if prop.list { + CompOp::Contains + } else { + CompOp::Eq + }; + let right = match &pm.value { + MatchValue::Literal(lit) => IRExpr::Literal(lit.clone()), + MatchValue::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()), + MatchValue::Variable(v) => { + if param_names.contains(v) { + IRExpr::Param(v.clone()) + } else { + IRExpr::Variable(v.clone()) + } + } + }; + filters.push(IRFilter { + left: IRExpr::PropAccess { + variable: binding.variable.clone(), + property: pm.prop_name.clone(), + }, + op, + right, + }); + } + filters +} + fn find_outer_var(clauses: &[Clause], outer_bound: &HashSet) -> Option { for clause in clauses { match clause { @@ -692,4 +761,160 @@ query q($name: String, $age: I32, $friend: String) { matches!(&ir.ops[1], MutationOpIR::Insert { type_name, .. } if type_name == "Knows") ); } + + /// Destination binding is deferred: NodeScan + Expand + Filter (no cross-join). + #[test] + fn test_lower_traversal_with_destination_binding() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + $p worksAt $c + $c: Company { name: "Acme" } + } + return { $p.name, $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // Should be: NodeScan($p) → Expand($p→$c, dst_filters=[name=="Acme"]) + // NOT: NodeScan($p) → NodeScan($c) → cross-join → cycle-close + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "p" && dst_var == "c" && dst_filters.len() == 1 + )); + } + + /// Multi-hop chain: all intermediate and final bindings are deferred. + #[test] + fn test_lower_chain_defers_all_intermediate_bindings() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person { name: "Alice" } + $p knows $f + $f: Person { name: "Bob" } + $f worksAt $c + $c: Company { name: "Acme" } + } + return { $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // Should be: NodeScan($p,[name=Alice]) → Expand($p→$f, [name==Bob]) + // → Expand($f→$c, [name==Acme]) + assert_eq!(ir.pipeline.len(), 3); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "p" && dst_var == "f" && dst_filters.len() == 1 + )); + assert!(matches!( + &ir.pipeline[2], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "f" && dst_var == "c" && dst_filters.len() == 1 + )); + } + + /// Reverse traversal: source binding is deferred when destination is the root. + #[test] + fn test_lower_reverse_traversal_defers_source_binding() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $c: Company { name: "Acme" } + $p worksAt $c + $p: Person { name: "Alice" } + } + return { $p.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // $c is root (first declared). $p is deferred (connected via traversal). + // Traversal $p worksAt $c: $c is bound, $p is not → reverse expand. + // Pipeline: NodeScan($c,[name=Acme]) → Expand($c→$p, In, [name==Alice]) + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "c")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "c" && dst_var == "p" && dst_filters.len() == 1 + )); + } + + /// Independent bindings (no traversal) still cross-join. + #[test] + fn test_lower_independent_bindings_still_cross_join() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + $c: Company + } + return { $p.name, $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // No traversal connecting them → both get NodeScans (cross-join at runtime) + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!(&ir.pipeline[1], IROp::NodeScan { variable, .. } if variable == "c")); + } + + /// Destination binding without filters: no NodeScan, no post-expand filter. + #[test] + fn test_lower_destination_binding_without_filters() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + $p worksAt $c + $c: Company + } + return { $p.name, $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // $c binding is deferred (no filters) → just NodeScan + Expand + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, .. } + if src_var == "p" && dst_var == "c" + )); + } } diff --git a/crates/omnigraph-compiler/src/ir/mod.rs b/crates/omnigraph-compiler/src/ir/mod.rs index f277ef7..49dfb3d 100644 --- a/crates/omnigraph-compiler/src/ir/mod.rs +++ b/crates/omnigraph-compiler/src/ir/mod.rs @@ -70,6 +70,10 @@ pub enum IROp { dst_type: String, min_hops: u32, max_hops: Option, + /// Filters from a deferred destination binding, pushed into the + /// Expand so the executor can apply them during hydration (Lance + /// SQL pushdown) rather than as a separate post-expand pass. + dst_filters: Vec, }, Filter(IRFilter), AntiJoin { diff --git a/crates/omnigraph/src/exec/query.rs b/crates/omnigraph/src/exec/query.rs index dc95a53..a09184a 100644 --- a/crates/omnigraph/src/exec/query.rs +++ b/crates/omnigraph/src/exec/query.rs @@ -645,6 +645,7 @@ fn execute_pipeline<'a>( dst_type, min_hops, max_hops, + dst_filters, } => { let gi = graph_index.ok_or_else(|| { OmniError::manifest("graph index required for traversal".to_string()) @@ -652,7 +653,7 @@ fn execute_pipeline<'a>( if let Some(batch) = wide.as_mut() { execute_expand( batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction, - dst_type, *min_hops, *max_hops, + dst_type, *min_hops, *max_hops, dst_filters, params, ) .await?; } @@ -683,6 +684,8 @@ async fn execute_expand( dst_type: &str, min_hops: u32, max_hops: Option, + dst_filters: &[IRFilter], + params: &ParamMap, ) -> Result<()> { let src_id_col_name = format!("{}.id", src_var); let src_ids = wide @@ -766,8 +769,15 @@ async fn execute_expand( } } - // Hydrate destination nodes from the snapshot - let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list).await?; + // Split dst_filters: SQL-pushable go to Lance, the rest applied post-hconcat + let pushdown_sql = build_lance_filter(dst_filters, params); + let non_pushable: Vec<&IRFilter> = dst_filters + .iter() + .filter(|f| ir_filter_to_sql(f, params).is_none()) + .collect(); + + // Hydrate destination nodes from the snapshot (with pushed-down filters) + let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list, pushdown_sql.as_deref()).await?; // Build a mapping from dst_id to row index in dst_batch let dst_batch_id_col = dst_batch @@ -796,15 +806,26 @@ async fn execute_expand( let dst_prefixed = prefix_batch(&dst_batch, dst_var)?; let aligned_dst = take_batch(&dst_prefixed, &dst_take)?; *wide = hconcat_batches(&expanded_wide, &aligned_dst)?; + + // Apply any non-pushable destination filters (e.g. list-contains) in memory + for f in &non_pushable { + apply_filter(wide, f, params)?; + } + Ok(()) } /// Load full node rows for a set of IDs from a snapshot. +/// +/// When `extra_filter_sql` is provided (from deferred destination-binding +/// filters), it is ANDed with the `id IN (...)` clause so that Lance can +/// skip non-matching rows at the storage level. async fn hydrate_nodes( snapshot: &Snapshot, catalog: &Catalog, type_name: &str, ids: &[String], + extra_filter_sql: Option<&str>, ) -> Result { let node_type = catalog .node_types @@ -823,7 +844,10 @@ async fn hydrate_nodes( .iter() .map(|id| format!("'{}'", id.replace('\'', "''"))) .collect(); - let filter_sql = format!("id IN ({})", escaped.join(", ")); + let mut filter_sql = format!("id IN ({})", escaped.join(", ")); + if let Some(extra) = extra_filter_sql { + filter_sql = format!("({}) AND ({})", filter_sql, extra); + } let has_blobs = !node_type.blob_properties.is_empty(); let non_blob_cols: Vec<&str> = node_type .arrow_schema diff --git a/crates/omnigraph/tests/traversal.rs b/crates/omnigraph/tests/traversal.rs index cc3228f..4d9de77 100644 --- a/crates/omnigraph/tests/traversal.rs +++ b/crates/omnigraph/tests/traversal.rs @@ -396,3 +396,220 @@ query insert_no_name($age: I32) { assert!(result.is_err(), "insert without @key property should fail"); } + +// ─── Join alignment: traversal + destination binding ─────────────────────── + +/// Traversal with destination binding filter constrains the source. +/// Regression: previously over-returned because the lowering created a +/// cross-join followed by cycle-closing instead of Expand + post-filter. +#[tokio::test] +async fn traversal_destination_binding_constrains_source() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Only Alice works at Acme. The binding on $c must constrain $p. + let queries = r#" +query at_acme() { + match { + $p: Person + $p worksAt $c + $c: Company { name: "Acme" } + } + return { $p.name } +} +"#; + let result = query_main(&mut db, queries, "at_acme", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + let names = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.len(), 1); + assert_eq!(names.value(0), "Alice"); +} + +/// Multi-variable projection: columns from source and destination must be +/// row-aligned. Previously this could fail with "all columns must have +/// the same length" when variables had different cardinalities. +#[tokio::test] +async fn traversal_multi_variable_projection_aligned() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query employee_companies() { + match { + $p: Person + $p worksAt $c + $c: Company + } + return { $p.name, $c.name } +} +"#; + let result = query_main(&mut db, queries, "employee_companies", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + // Alice→Acme, Bob→Globex + assert_eq!(batch.num_rows(), 2); + let person_names = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let company_names = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let mut pairs: Vec<(&str, &str)> = (0..batch.num_rows()) + .map(|i| (person_names.value(i), company_names.value(i))) + .collect(); + pairs.sort(); + assert_eq!(pairs, vec![("Alice", "Acme"), ("Bob", "Globex")]); +} + +/// Multi-hop projection: all three variables must be row-aligned. +#[tokio::test] +async fn multi_hop_projection_aligned() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Alice knows Bob, Bob knows Diana. + // Alice→Bob→Diana is the only 2-hop path. + let queries = r#" +query fof_chain($name: String) { + match { + $p: Person { name: $name } + $p knows $mid + $mid knows $fof + } + return { $p.name, $mid.name, $fof.name } +} +"#; + let result = query_main( + &mut db, + queries, + "fof_chain", + ¶ms(&[("$name", "Alice")]), + ) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + let col0 = batch.column(0).as_any().downcast_ref::().unwrap(); + let col1 = batch.column(1).as_any().downcast_ref::().unwrap(); + let col2 = batch.column(2).as_any().downcast_ref::().unwrap(); + assert_eq!(col0.value(0), "Alice"); + assert_eq!(col1.value(0), "Bob"); + assert_eq!(col2.value(0), "Diana"); +} + +/// Multi-hop with destination binding filters at each hop. +#[tokio::test] +async fn multi_hop_with_intermediate_binding_filters() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Alice knows Bob and Charlie. + // Bob knows Diana. Charlie knows nobody. + // Filter $mid to only "Bob" → only Alice→Bob→Diana survives. + let queries = r#" +query fof_via($name: String, $mid_name: String) { + match { + $p: Person { name: $name } + $p knows $mid + $mid: Person { name: $mid_name } + $mid knows $fof + } + return { $fof.name } +} +"#; + let result = query_main( + &mut db, + queries, + "fof_via", + ¶ms(&[("$name", "Alice"), ("$mid_name", "Bob")]), + ) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + let names = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.len(), 1); + assert_eq!(names.value(0), "Diana"); +} + +/// Destination binding with filter + multi-variable return: the classic +/// "join across a traversal" scenario that triggers the bug. +#[tokio::test] +async fn traversal_destination_filter_with_multi_return() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query at_acme_named() { + match { + $p: Person + $p worksAt $c + $c: Company { name: "Acme" } + } + return { $p.name, $c.name } +} +"#; + let result = query_main(&mut db, queries, "at_acme_named", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + let person = batch.column(0).as_any().downcast_ref::().unwrap(); + let company = batch.column(1).as_any().downcast_ref::().unwrap(); + assert_eq!(person.value(0), "Alice"); + assert_eq!(company.value(0), "Acme"); +} + +/// Parameterized destination filter exercises param resolution through the +/// Lance SQL pushdown path (params are resolved to literals in ir_expr_to_sql). +#[tokio::test] +async fn traversal_destination_filter_pushdown_with_param() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query at_company($company: String) { + match { + $p: Person + $p worksAt $c + $c: Company { name: $company } + } + return { $p.name, $c.name } +} +"#; + let result = query_main( + &mut db, + queries, + "at_company", + ¶ms(&[("$company", "Globex")]), + ) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + let person = batch.column(0).as_any().downcast_ref::().unwrap(); + let company = batch.column(1).as_any().downcast_ref::().unwrap(); + assert_eq!(person.value(0), "Bob"); + assert_eq!(company.value(0), "Globex"); +}