diff --git a/crates/omnigraph-compiler/src/ir/lower.rs b/crates/omnigraph-compiler/src/ir/lower.rs index 61e7eb9..a077b4b 100644 --- a/crates/omnigraph-compiler/src/ir/lower.rs +++ b/crates/omnigraph-compiler/src/ir/lower.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet, VecDeque}; use crate::catalog::Catalog; use crate::error::Result; @@ -147,158 +147,214 @@ fn lower_clauses( } } - // Lower bindings into NodeScan ops + // ── Determine which bindings are "deferred" ───────────────────────── + // + // When multiple bindings in the same match clause are connected by + // traversals, only the first-declared binding needs a NodeScan; the + // rest will be introduced by Expand operations. Making them all + // NodeScans triggers expensive cross-joins followed by cycle-closing + // filters. + // + // Algorithm: build an undirected graph of variables connected by + // traversals, then walk connected components in binding declaration + // order. The first binding in each component becomes the root (gets + // a NodeScan); all other bindings in the same component are deferred + // — their inline filters become post-Expand Filter ops. + + let binding_set: HashSet<&str> = bindings.iter().map(|b| b.variable.as_str()).collect(); + + // Build undirected traversal adjacency (variable → neighbours). + // Exclude the anonymous wildcard "_" so it cannot falsely bridge + // otherwise-independent components. + let mut adj: HashMap<&str, Vec<&str>> = HashMap::new(); + for t in &traversals { + let src = t.src.as_str(); + let dst = t.dst.as_str(); + if src != "_" && dst != "_" { + adj.entry(src).or_default().push(dst); + adj.entry(dst).or_default().push(src); + } + } + + // Walk components to find deferred binding variables + let mut deferred_set: HashSet = HashSet::new(); + let mut component_visited: HashSet<&str> = HashSet::new(); + + for binding in &bindings { + if component_visited.contains(binding.variable.as_str()) { + continue; + } + // BFS from this binding through the traversal graph + let mut queue = VecDeque::new(); + queue.push_back(binding.variable.as_str()); + let mut component_bindings: Vec<&str> = Vec::new(); + + while let Some(var) = queue.pop_front() { + if !component_visited.insert(var) { + continue; + } + if binding_set.contains(var) { + component_bindings.push(var); + } + if let Some(neighbours) = adj.get(var) { + for &n in neighbours { + if !component_visited.contains(n) { + queue.push_back(n); + } + } + } + } + + // First binding in the component is the root; defer the rest. + for var in component_bindings.into_iter().skip(1) { + deferred_set.insert(var.to_string()); + } + } + + // Build deferred filters map for variables introduced by traversals + let mut deferred_filters: HashMap> = HashMap::new(); + + // Lower bindings into NodeScan ops (skip deferred ones) for binding in &bindings { let node_type = catalog .node_types .get(&binding.type_name) .expect("binding type was validated during typecheck"); - // Collect inline filters from prop matches - let mut scan_filters = Vec::new(); - for pm in &binding.prop_matches { - let prop = node_type - .properties - .get(&pm.prop_name) - .expect("binding property was validated during typecheck"); - let op = if prop.list { - CompOp::Contains - } else { - CompOp::Eq - }; - match &pm.value { - MatchValue::Literal(lit) => { - scan_filters.push(IRFilter { - left: IRExpr::PropAccess { - variable: binding.variable.clone(), - property: pm.prop_name.clone(), - }, - op, - right: IRExpr::Literal(lit.clone()), - }); - } - MatchValue::Now => { - scan_filters.push(IRFilter { - left: IRExpr::PropAccess { - variable: binding.variable.clone(), - property: pm.prop_name.clone(), - }, - op, - right: IRExpr::Param(NOW_PARAM_NAME.to_string()), - }); - } - MatchValue::Variable(v) => { - let right = if param_names.contains(v) { - IRExpr::Param(v.clone()) - } else { - IRExpr::Variable(v.clone()) - }; - scan_filters.push(IRFilter { - left: IRExpr::PropAccess { - variable: binding.variable.clone(), - property: pm.prop_name.clone(), - }, - op, - right, - }); - } + + let binding_filters = build_binding_filters(binding, node_type, param_names); + + if deferred_set.contains(&binding.variable) { + // Save filters for emission after the Expand that introduces + // this variable. + if !binding_filters.is_empty() { + deferred_filters.insert(binding.variable.clone(), binding_filters); } + continue; } pipeline.push(IROp::NodeScan { variable: binding.variable.clone(), type_name: binding.type_name.clone(), - filters: scan_filters, + filters: binding_filters, }); bound_vars.insert(binding.variable.clone()); } - // Lower traversals into Expand ops - // Handle "cycle closing" — if both src and dst are already bound, use a filter - for traversal in &traversals { - let edge = catalog - .lookup_edge_by_name(&traversal.edge_name) - .ok_or_else(|| { - crate::error::NanoError::Plan(format!( - "lowering traversal referenced missing edge '{}' after typecheck", - traversal.edge_name - )) - })?; - - // Determine direction from type context - let direction = type_ctx - .traversals - .iter() - .find(|rt| { - rt.src == traversal.src && rt.dst == traversal.dst && rt.edge_type == edge.name - }) - .map(|rt| rt.direction) - .unwrap_or(Direction::Out); - - let dst_type = match direction { - Direction::Out => edge.to_type.clone(), - Direction::In => edge.from_type.clone(), - }; - - if bound_vars.contains(&traversal.src) && bound_vars.contains(&traversal.dst) { - // Cycle closing: emit expand to a temp var, then filter temp.id = dst.id - let temp_var = format!("__temp_{}", traversal.dst); - pipeline.push(IROp::Expand { - src_var: traversal.src.clone(), - dst_var: temp_var.clone(), - edge_type: edge.name.clone(), - direction, - dst_type, - min_hops: traversal.min_hops, - max_hops: traversal.max_hops, - }); - pipeline.push(IROp::Filter(IRFilter { - left: IRExpr::PropAccess { - variable: temp_var, - property: "id".to_string(), - }, - op: CompOp::Eq, - right: IRExpr::PropAccess { - variable: traversal.dst.clone(), - property: "id".to_string(), - }, - })); - } else if !bound_vars.contains(&traversal.src) && bound_vars.contains(&traversal.dst) { - // Reverse expand: dst is bound, src is not. - // Swap direction and expand from dst to discover src. - let reverse_dir = match direction { - Direction::Out => Direction::In, - Direction::In => Direction::Out, - }; - let src_type = match direction { - Direction::Out => edge.from_type.clone(), - Direction::In => edge.to_type.clone(), - }; - pipeline.push(IROp::Expand { - src_var: traversal.dst.clone(), - dst_var: traversal.src.clone(), - edge_type: edge.name.clone(), - direction: reverse_dir, - dst_type: src_type, - min_hops: traversal.min_hops, - max_hops: traversal.max_hops, - }); - if traversal.src != "_" { - bound_vars.insert(traversal.src.clone()); + // Lower traversals into Expand ops. + // + // Traversals are processed iteratively rather than in a single pass + // because deferred bindings mean a traversal's source might not be + // bound until a prior traversal introduces it. Each pass processes + // every traversal that has at least one bound endpoint; this repeats + // until all traversals are consumed. + let mut remaining: Vec<&Traversal> = traversals.to_vec(); + while !remaining.is_empty() { + let mut next_remaining = Vec::new(); + for traversal in &remaining { + let src_bound = bound_vars.contains(&traversal.src); + let dst_bound = bound_vars.contains(&traversal.dst); + if !src_bound && !dst_bound { + next_remaining.push(*traversal); + continue; } - } else { - pipeline.push(IROp::Expand { - src_var: traversal.src.clone(), - dst_var: traversal.dst.clone(), - edge_type: edge.name.clone(), - direction, - dst_type, - min_hops: traversal.min_hops, - max_hops: traversal.max_hops, - }); - if traversal.dst != "_" { - bound_vars.insert(traversal.dst.clone()); + + let edge = catalog + .lookup_edge_by_name(&traversal.edge_name) + .ok_or_else(|| { + crate::error::NanoError::Plan(format!( + "lowering traversal referenced missing edge '{}' after typecheck", + traversal.edge_name + )) + })?; + + let direction = type_ctx + .traversals + .iter() + .find(|rt| { + rt.src == traversal.src + && rt.dst == traversal.dst + && rt.edge_type == edge.name + }) + .map(|rt| rt.direction) + .unwrap_or(Direction::Out); + + let dst_type = match direction { + Direction::Out => edge.to_type.clone(), + Direction::In => edge.from_type.clone(), + }; + + if src_bound && dst_bound { + // Cycle closing: expand to a temp var, then filter temp.id = dst.id + let temp_var = format!("__temp_{}", traversal.dst); + pipeline.push(IROp::Expand { + src_var: traversal.src.clone(), + dst_var: temp_var.clone(), + edge_type: edge.name.clone(), + direction, + dst_type, + min_hops: traversal.min_hops, + max_hops: traversal.max_hops, + dst_filters: vec![], + }); + pipeline.push(IROp::Filter(IRFilter { + left: IRExpr::PropAccess { + variable: temp_var, + property: "id".to_string(), + }, + op: CompOp::Eq, + right: IRExpr::PropAccess { + variable: traversal.dst.clone(), + property: "id".to_string(), + }, + })); + } else if !src_bound && dst_bound { + // Reverse expand: dst is bound, src is not. + let reverse_dir = match direction { + Direction::Out => Direction::In, + Direction::In => Direction::Out, + }; + let src_type = match direction { + Direction::Out => edge.from_type.clone(), + Direction::In => edge.to_type.clone(), + }; + let introduced_filters = + deferred_filters.remove(&traversal.src).unwrap_or_default(); + pipeline.push(IROp::Expand { + src_var: traversal.dst.clone(), + dst_var: traversal.src.clone(), + edge_type: edge.name.clone(), + direction: reverse_dir, + dst_type: src_type, + min_hops: traversal.min_hops, + max_hops: traversal.max_hops, + dst_filters: introduced_filters, + }); + if traversal.src != "_" { + bound_vars.insert(traversal.src.clone()); + } + } else { + // Normal expand: src is bound, dst is not. + let introduced_filters = + deferred_filters.remove(&traversal.dst).unwrap_or_default(); + pipeline.push(IROp::Expand { + src_var: traversal.src.clone(), + dst_var: traversal.dst.clone(), + edge_type: edge.name.clone(), + direction, + dst_type, + min_hops: traversal.min_hops, + max_hops: traversal.max_hops, + dst_filters: introduced_filters, + }); + if traversal.dst != "_" { + bound_vars.insert(traversal.dst.clone()); + } } } + if next_remaining.len() == remaining.len() { + break; + } + remaining = next_remaining; } // Lower explicit filters @@ -335,6 +391,46 @@ fn lower_clauses( Ok(()) } +/// Build IR filters from a binding's inline property matches. +fn build_binding_filters( + binding: &Binding, + node_type: &crate::catalog::NodeType, + param_names: &HashSet, +) -> Vec { + let mut filters = Vec::new(); + for pm in &binding.prop_matches { + let prop = node_type + .properties + .get(&pm.prop_name) + .expect("binding property was validated during typecheck"); + let op = if prop.list { + CompOp::Contains + } else { + CompOp::Eq + }; + let right = match &pm.value { + MatchValue::Literal(lit) => IRExpr::Literal(lit.clone()), + MatchValue::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()), + MatchValue::Variable(v) => { + if param_names.contains(v) { + IRExpr::Param(v.clone()) + } else { + IRExpr::Variable(v.clone()) + } + } + }; + filters.push(IRFilter { + left: IRExpr::PropAccess { + variable: binding.variable.clone(), + property: pm.prop_name.clone(), + }, + op, + right, + }); + } + filters +} + fn find_outer_var(clauses: &[Clause], outer_bound: &HashSet) -> Option { for clause in clauses { match clause { @@ -692,4 +788,452 @@ query q($name: String, $age: I32, $friend: String) { matches!(&ir.ops[1], MutationOpIR::Insert { type_name, .. } if type_name == "Knows") ); } + + /// Destination binding is deferred: NodeScan + Expand + Filter (no cross-join). + #[test] + fn test_lower_traversal_with_destination_binding() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + $p worksAt $c + $c: Company { name: "Acme" } + } + return { $p.name, $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // Should be: NodeScan($p) → Expand($p→$c, dst_filters=[name=="Acme"]) + // NOT: NodeScan($p) → NodeScan($c) → cross-join → cycle-close + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "p" && dst_var == "c" && dst_filters.len() == 1 + )); + } + + /// Multi-hop chain: all intermediate and final bindings are deferred. + #[test] + fn test_lower_chain_defers_all_intermediate_bindings() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person { name: "Alice" } + $p knows $f + $f: Person { name: "Bob" } + $f worksAt $c + $c: Company { name: "Acme" } + } + return { $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // Should be: NodeScan($p,[name=Alice]) → Expand($p→$f, [name==Bob]) + // → Expand($f→$c, [name==Acme]) + assert_eq!(ir.pipeline.len(), 3); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "p" && dst_var == "f" && dst_filters.len() == 1 + )); + assert!(matches!( + &ir.pipeline[2], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "f" && dst_var == "c" && dst_filters.len() == 1 + )); + } + + /// Reverse traversal: source binding is deferred when destination is the root. + #[test] + fn test_lower_reverse_traversal_defers_source_binding() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $c: Company { name: "Acme" } + $p worksAt $c + $p: Person { name: "Alice" } + } + return { $p.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // $c is root (first declared). $p is deferred (connected via traversal). + // Traversal $p worksAt $c: $c is bound, $p is not → reverse expand. + // Pipeline: NodeScan($c,[name=Acme]) → Expand($c→$p, In, [name==Alice]) + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "c")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "c" && dst_var == "p" && dst_filters.len() == 1 + )); + } + + /// Independent bindings (no traversal) still cross-join. + #[test] + fn test_lower_independent_bindings_still_cross_join() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + $c: Company + } + return { $p.name, $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // No traversal connecting them → both get NodeScans (cross-join at runtime) + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!(&ir.pipeline[1], IROp::NodeScan { variable, .. } if variable == "c")); + } + + /// Destination binding without filters: no NodeScan, no post-expand filter. + #[test] + fn test_lower_destination_binding_without_filters() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + $p worksAt $c + $c: Company + } + return { $p.name, $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // $c binding is deferred (no filters) → just NodeScan + Expand + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, .. } + if src_var == "p" && dst_var == "c" + )); + } + + /// Traversals declared in non-topological order are reordered automatically. + #[test] + fn test_lower_out_of_order_traversals() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + $f worksAt $c + $p knows $f + $f: Person + $c: Company { name: "Acme" } + } + return { $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // Even though "$f worksAt $c" is declared before "$p knows $f", + // the iterative lowering processes "$p knows $f" first (because $p + // is bound) and then "$f worksAt $c" (once $f is bound). + assert_eq!(ir.pipeline.len(), 3); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + // First expand: $p → $f (knows) + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, .. } + if src_var == "p" && dst_var == "f" + )); + // Second expand: $f → $c (worksAt), with filter from $c binding + assert!(matches!( + &ir.pipeline[2], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "f" && dst_var == "c" && dst_filters.len() == 1 + )); + } + + /// Wildcard $_ must not bridge unrelated components in the adjacency graph. + #[test] + fn test_lower_wildcard_does_not_bridge_components() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + $p knows $_ + $c: Company + } + return { $p.name, $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // $p and $c are in separate components (connected only through $_). + // Both must get their own NodeScan — $c must NOT be deferred. + // Bindings are emitted first, then traversals. + assert_eq!(ir.pipeline.len(), 3); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!(&ir.pipeline[1], IROp::NodeScan { variable, .. } if variable == "c")); + // The expand for $p knows $_ (wildcard destination) + assert!(matches!( + &ir.pipeline[2], + IROp::Expand { src_var, dst_var, .. } + if src_var == "p" && dst_var == "_" + )); + } + + /// Fan-out: one root fans to two deferred destinations via different edges. + #[test] + fn test_lower_fan_out_topology() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person { name: "Alice" } + $p knows $f + $f: Person { name: "Bob" } + $p worksAt $c + $c: Company { name: "Acme" } + } + return { $f.name, $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // Root: $p. Deferred: $f, $c (both reachable from $p). + assert_eq!(ir.pipeline.len(), 3); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "p" && dst_var == "f" && dst_filters.len() == 1 + )); + assert!(matches!( + &ir.pipeline[2], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "p" && dst_var == "c" && dst_filters.len() == 1 + )); + } + + /// Fan-in: two sources converge on one destination; second source is + /// introduced via reverse expand from the shared destination. + #[test] + fn test_lower_fan_in_topology() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $a: Person { name: "Alice" } + $a knows $c + $b: Person { name: "Bob" } + $b knows $c + $c: Person + } + return { $a.name, $b.name, $c.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // Root: $a (first in component {a,b,c}). Deferred: $b, $c. + // $a knows $c: expand(a→c). $b knows $c: reverse expand(c→b). + assert_eq!(ir.pipeline.len(), 3); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "a")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "a" && dst_var == "c" && dst_filters.is_empty() + )); + assert!(matches!( + &ir.pipeline[2], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "c" && dst_var == "b" && dst_filters.len() == 1 + )); + } + + /// Genuine graph cycle: deferred binding is introduced by first traversal, + /// second traversal triggers cycle-closing. + #[test] + fn test_lower_cycle_with_deferred_binding() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $a: Person + $a knows $b + $b: Person { name: "Bob" } + $b knows $a + } + return { $a.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // $b is deferred, introduced by first expand. + // Second traversal ($b knows $a) is genuine cycle-closing. + assert_eq!(ir.pipeline.len(), 4); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "a")); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "a" && dst_var == "b" && dst_filters.len() == 1 + )); + // Cycle-closing expand to __temp_a + assert!(matches!( + &ir.pipeline[2], + IROp::Expand { src_var, dst_var, dst_filters, .. } + if src_var == "b" && dst_var.starts_with("__temp_") && dst_filters.is_empty() + )); + // Cycle-closing filter: __temp_a.id == a.id + assert!(matches!(&ir.pipeline[3], IROp::Filter(_))); + } + + /// Multiple filters on a single deferred binding. + #[test] + fn test_lower_multiple_filters_on_deferred_binding() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + $p knows $f + $f: Person { name: "Bob", age: 25 } + } + return { $f.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // Two prop_matches → two dst_filters on the Expand. + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { dst_filters, .. } + if dst_filters.len() == 2 + )); + } + + /// Parameter in a deferred binding filter (unit test level). + #[test] + fn test_lower_param_filter_on_deferred_binding() { + let catalog = setup(); + let qf = parse_query( + r#" +query q($company: String) { + match { + $p: Person + $p worksAt $c + $c: Company { name: $company } + } + return { $p.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!( + &ir.pipeline[1], + IROp::Expand { dst_filters, .. } + if dst_filters.len() == 1 + )); + // The filter's right-hand side should be a Param, not a Literal + if let IROp::Expand { dst_filters, .. } = &ir.pipeline[1] { + assert!(matches!(&dst_filters[0].right, IRExpr::Param(name) if name == "company")); + } + } + + /// Negation with inner binding: inner binding is NOT deferred because + /// bound_vars (from outer scope) is not in binding_set for the inner call. + /// This documents current behavior — the inner pipeline uses a NodeScan + + /// cycle-closing, which is correct but less efficient than deferral. + #[test] + fn test_lower_negation_with_inner_binding() { + let catalog = setup(); + let qf = parse_query( + r#" +query q() { + match { + $p: Person + not { + $p worksAt $c + $c: Company { name: "Acme" } + } + } + return { $p.name } +} +"#, + ) + .unwrap(); + let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap(); + let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap(); + + // Outer: NodeScan($p) + AntiJoin + assert_eq!(ir.pipeline.len(), 2); + assert!(matches!(&ir.pipeline[0], IROp::NodeScan { variable, .. } if variable == "p")); + let IROp::AntiJoin { inner, .. } = &ir.pipeline[1] else { + panic!("expected AntiJoin"); + }; + // Inner pipeline: $c is NOT deferred (it's the only binding in the + // inner scope), so it gets a NodeScan + cycle-closing (3 ops). + assert_eq!(inner.len(), 3); + assert!(matches!(&inner[0], IROp::NodeScan { variable, .. } if variable == "c")); + assert!(matches!(&inner[1], IROp::Expand { .. })); + assert!(matches!(&inner[2], IROp::Filter(_))); + } } diff --git a/crates/omnigraph-compiler/src/ir/mod.rs b/crates/omnigraph-compiler/src/ir/mod.rs index f277ef7..49dfb3d 100644 --- a/crates/omnigraph-compiler/src/ir/mod.rs +++ b/crates/omnigraph-compiler/src/ir/mod.rs @@ -70,6 +70,10 @@ pub enum IROp { dst_type: String, min_hops: u32, max_hops: Option, + /// Filters from a deferred destination binding, pushed into the + /// Expand so the executor can apply them during hydration (Lance + /// SQL pushdown) rather than as a separate post-expand pass. + dst_filters: Vec, }, Filter(IRFilter), AntiJoin { diff --git a/crates/omnigraph/src/exec/query.rs b/crates/omnigraph/src/exec/query.rs index f4547a7..9a69381 100644 --- a/crates/omnigraph/src/exec/query.rs +++ b/crates/omnigraph/src/exec/query.rs @@ -645,6 +645,7 @@ fn execute_pipeline<'a>( dst_type, min_hops, max_hops, + dst_filters, } => { let gi = graph_index.ok_or_else(|| { OmniError::manifest("graph index required for traversal".to_string()) @@ -652,7 +653,7 @@ fn execute_pipeline<'a>( if let Some(batch) = wide.as_mut() { execute_expand( batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction, - dst_type, *min_hops, *max_hops, + dst_type, *min_hops, *max_hops, dst_filters, params, ) .await?; } @@ -683,6 +684,8 @@ async fn execute_expand( dst_type: &str, min_hops: u32, max_hops: Option, + dst_filters: &[IRFilter], + params: &ParamMap, ) -> Result<()> { let src_id_col_name = format!("{}.id", src_var); let src_ids = wide @@ -766,8 +769,15 @@ async fn execute_expand( } } - // Hydrate destination nodes from the snapshot - let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list).await?; + // Split dst_filters: SQL-pushable go to Lance, the rest applied post-hconcat + let pushdown_sql = build_lance_filter(dst_filters, params); + let non_pushable: Vec<&IRFilter> = dst_filters + .iter() + .filter(|f| ir_filter_to_sql(f, params).is_none()) + .collect(); + + // Hydrate destination nodes from the snapshot (with pushed-down filters) + let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list, pushdown_sql.as_deref()).await?; // Build a mapping from dst_id to row index in dst_batch let dst_batch_id_col = dst_batch @@ -796,15 +806,26 @@ async fn execute_expand( let dst_prefixed = prefix_batch(&dst_batch, dst_var)?; let aligned_dst = take_batch(&dst_prefixed, &dst_take)?; *wide = hconcat_batches(&expanded_wide, &aligned_dst)?; + + // Apply any non-pushable destination filters (e.g. list-contains) in memory + for f in &non_pushable { + apply_filter(wide, f, params)?; + } + Ok(()) } /// Load full node rows for a set of IDs from a snapshot. +/// +/// When `extra_filter_sql` is provided (from deferred destination-binding +/// filters), it is ANDed with the `id IN (...)` clause so that Lance can +/// skip non-matching rows at the storage level. async fn hydrate_nodes( snapshot: &Snapshot, catalog: &Catalog, type_name: &str, ids: &[String], + extra_filter_sql: Option<&str>, ) -> Result { let node_type = catalog .node_types @@ -823,7 +844,10 @@ async fn hydrate_nodes( .iter() .map(|id| format!("'{}'", id.replace('\'', "''"))) .collect(); - let filter_sql = format!("id IN ({})", escaped.join(", ")); + let mut filter_sql = format!("id IN ({})", escaped.join(", ")); + if let Some(extra) = extra_filter_sql { + filter_sql = format!("({}) AND ({})", filter_sql, extra); + } let has_blobs = !node_type.blob_properties.is_empty(); let non_blob_cols: Vec<&str> = node_type .arrow_schema @@ -877,6 +901,7 @@ fn try_bulk_anti_join_mask( src_var, edge_type, direction, + dst_filters, .. } = &inner_pipeline[0] else { @@ -885,6 +910,11 @@ fn try_bulk_anti_join_mask( if src_var != outer_var { return None; } + // Bulk CSR check only tests neighbor existence, not destination + // properties. Fall back to the slow path when dst_filters are present. + if !dst_filters.is_empty() { + return None; + } let gi = graph_index?; let edge_def = catalog.edge_types.get(edge_type.as_str())?; diff --git a/crates/omnigraph/tests/traversal.rs b/crates/omnigraph/tests/traversal.rs index cc3228f..6b6fbe3 100644 --- a/crates/omnigraph/tests/traversal.rs +++ b/crates/omnigraph/tests/traversal.rs @@ -396,3 +396,318 @@ query insert_no_name($age: I32) { assert!(result.is_err(), "insert without @key property should fail"); } + +// ─── Join alignment: traversal + destination binding ─────────────────────── + +/// Traversal with destination binding filter constrains the source. +/// Regression: previously over-returned because the lowering created a +/// cross-join followed by cycle-closing instead of Expand + post-filter. +#[tokio::test] +async fn traversal_destination_binding_constrains_source() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Only Alice works at Acme. The binding on $c must constrain $p. + let queries = r#" +query at_acme() { + match { + $p: Person + $p worksAt $c + $c: Company { name: "Acme" } + } + return { $p.name } +} +"#; + let result = query_main(&mut db, queries, "at_acme", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + let names = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.len(), 1); + assert_eq!(names.value(0), "Alice"); +} + +/// Multi-variable projection: columns from source and destination must be +/// row-aligned. Previously this could fail with "all columns must have +/// the same length" when variables had different cardinalities. +#[tokio::test] +async fn traversal_multi_variable_projection_aligned() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query employee_companies() { + match { + $p: Person + $p worksAt $c + $c: Company + } + return { $p.name, $c.name } +} +"#; + let result = query_main(&mut db, queries, "employee_companies", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + // Alice→Acme, Bob→Globex + assert_eq!(batch.num_rows(), 2); + let person_names = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let company_names = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let mut pairs: Vec<(&str, &str)> = (0..batch.num_rows()) + .map(|i| (person_names.value(i), company_names.value(i))) + .collect(); + pairs.sort(); + assert_eq!(pairs, vec![("Alice", "Acme"), ("Bob", "Globex")]); +} + +/// Multi-hop projection: all three variables must be row-aligned. +#[tokio::test] +async fn multi_hop_projection_aligned() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Alice knows Bob, Bob knows Diana. + // Alice→Bob→Diana is the only 2-hop path. + let queries = r#" +query fof_chain($name: String) { + match { + $p: Person { name: $name } + $p knows $mid + $mid knows $fof + } + return { $p.name, $mid.name, $fof.name } +} +"#; + let result = query_main( + &mut db, + queries, + "fof_chain", + ¶ms(&[("$name", "Alice")]), + ) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + let col0 = batch.column(0).as_any().downcast_ref::().unwrap(); + let col1 = batch.column(1).as_any().downcast_ref::().unwrap(); + let col2 = batch.column(2).as_any().downcast_ref::().unwrap(); + assert_eq!(col0.value(0), "Alice"); + assert_eq!(col1.value(0), "Bob"); + assert_eq!(col2.value(0), "Diana"); +} + +/// Multi-hop with destination binding filters at each hop. +#[tokio::test] +async fn multi_hop_with_intermediate_binding_filters() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + // Alice knows Bob and Charlie. + // Bob knows Diana. Charlie knows nobody. + // Filter $mid to only "Bob" → only Alice→Bob→Diana survives. + let queries = r#" +query fof_via($name: String, $mid_name: String) { + match { + $p: Person { name: $name } + $p knows $mid + $mid: Person { name: $mid_name } + $mid knows $fof + } + return { $fof.name } +} +"#; + let result = query_main( + &mut db, + queries, + "fof_via", + ¶ms(&[("$name", "Alice"), ("$mid_name", "Bob")]), + ) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + let names = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.len(), 1); + assert_eq!(names.value(0), "Diana"); +} + +/// Destination binding with filter + multi-variable return: the classic +/// "join across a traversal" scenario that triggers the bug. +#[tokio::test] +async fn traversal_destination_filter_with_multi_return() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query at_acme_named() { + match { + $p: Person + $p worksAt $c + $c: Company { name: "Acme" } + } + return { $p.name, $c.name } +} +"#; + let result = query_main(&mut db, queries, "at_acme_named", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + let person = batch.column(0).as_any().downcast_ref::().unwrap(); + let company = batch.column(1).as_any().downcast_ref::().unwrap(); + assert_eq!(person.value(0), "Alice"); + assert_eq!(company.value(0), "Acme"); +} + +/// Parameterized destination filter exercises param resolution through the +/// Lance SQL pushdown path (params are resolved to literals in ir_expr_to_sql). +#[tokio::test] +async fn traversal_destination_filter_pushdown_with_param() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query at_company($company: String) { + match { + $p: Person + $p worksAt $c + $c: Company { name: $company } + } + return { $p.name, $c.name } +} +"#; + let result = query_main( + &mut db, + queries, + "at_company", + ¶ms(&[("$company", "Globex")]), + ) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 1); + let person = batch.column(0).as_any().downcast_ref::().unwrap(); + let company = batch.column(1).as_any().downcast_ref::().unwrap(); + assert_eq!(person.value(0), "Bob"); + assert_eq!(company.value(0), "Globex"); +} + +/// Fan-out: one source expanded to two different destination types. +/// Each (friend, company) pair should be a cross-product per source row. +#[tokio::test] +async fn fan_out_two_destinations() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query fan_out($name: String) { + match { + $p: Person { name: $name } + $p knows $f + $p worksAt $c + } + return { $f.name, $c.name } +} +"#; + // Alice knows Bob and Charlie, works at Acme. + // Each friend paired with her company → 2 rows. + let result = query_main( + &mut db, + queries, + "fan_out", + ¶ms(&[("$name", "Alice")]), + ) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + assert_eq!(batch.num_rows(), 2); + let friends = batch.column(0).as_any().downcast_ref::().unwrap(); + let companies = batch.column(1).as_any().downcast_ref::().unwrap(); + + let mut pairs: Vec<(&str, &str)> = (0..batch.num_rows()) + .map(|i| (friends.value(i), companies.value(i))) + .collect(); + pairs.sort(); + assert_eq!(pairs, vec![("Bob", "Acme"), ("Charlie", "Acme")]); +} + +/// Deferred destination filter that matches nothing → empty result. +#[tokio::test] +async fn traversal_destination_filter_no_match() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query at_phantom() { + match { + $p: Person + $p worksAt $c + $c: Company { name: "NonExistent" } + } + return { $p.name } +} +"#; + let result = query_main(&mut db, queries, "at_phantom", &ParamMap::new()) + .await + .unwrap(); + + assert_eq!(result.num_rows(), 0); +} + +/// Negation with inner destination binding filter. +/// "People who do NOT work at Acme" — uses binding syntax inside negation. +#[tokio::test] +async fn negation_with_inner_destination_binding() { + let dir = tempfile::tempdir().unwrap(); + let mut db = init_and_load(&dir).await; + + let queries = r#" +query not_at_acme_binding() { + match { + $p: Person + not { + $p worksAt $c + $c: Company { name: "Acme" } + } + } + return { $p.name } +} +"#; + // Alice→Acme. Everyone else should be returned. + let result = query_main(&mut db, queries, "not_at_acme_binding", &ParamMap::new()) + .await + .unwrap(); + + let batch = result.concat_batches().unwrap(); + let names = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let mut names_vec: Vec<&str> = (0..names.len()).map(|i| names.value(i)).collect(); + names_vec.sort(); + assert_eq!(names_vec, vec!["Bob", "Charlie", "Diana"]); +}