diff --git a/crates/omnigraph/tests/proptest_equivalence.rs b/crates/omnigraph/tests/proptest_equivalence.rs index 0566b71..3423a2f 100644 --- a/crates/omnigraph/tests/proptest_equivalence.rs +++ b/crates/omnigraph/tests/proptest_equivalence.rs @@ -138,15 +138,28 @@ fn config() -> Config { } } -fn set_mode(mode: &str) { - // SAFE: every test in this binary is #[serial], so no thread reads the env - // during this write. - unsafe { std::env::set_var("OMNIGRAPH_TRAVERSAL_MODE", mode) }; -} fn clear_mode() { unsafe { std::env::remove_var("OMNIGRAPH_TRAVERSAL_MODE") }; } +/// RAII guard that sets `OMNIGRAPH_TRAVERSAL_MODE` and clears it on drop — so a +/// panic mid-case (e.g. a query `unwrap`) cannot leak the forced mode into +/// proptest's subsequent shrink/cases and mask the divergence under test. SAFE: +/// every test in this binary is `#[serial]`, so no thread reads the env during +/// the write. +struct ModeGuard; +impl ModeGuard { + fn set(mode: &str) -> Self { + unsafe { std::env::set_var("OMNIGRAPH_TRAVERSAL_MODE", mode) }; + ModeGuard + } +} +impl Drop for ModeGuard { + fn drop(&mut self) { + unsafe { std::env::remove_var("OMNIGRAPH_TRAVERSAL_MODE") }; + } +} + async fn load_graph(graph: &GenGraph) -> (tempfile::TempDir, Omnigraph) { let dir = tempfile::tempdir().unwrap(); let uri = dir.path().to_str().unwrap(); @@ -201,11 +214,17 @@ fn prop_expand_indexed_eq_csr() { for start in graph.persons.clone() { let p = one_param(&start); for q in ["friends", "employers"] { - set_mode("csr"); - let csr = col0_sorted(&mut db, q, &p).await; - set_mode("indexed"); - let indexed = col0_sorted(&mut db, q, &p).await; - clear_mode(); + // Each guard clears the mode on drop (end of the block, + // or on panic), so a forced mode never leaks across runs. + let csr = { + let _g = ModeGuard::set("csr"); + col0_sorted(&mut db, q, &p).await + }; + let indexed = { + let _g = ModeGuard::set("indexed"); + col0_sorted(&mut db, q, &p).await + }; + // No guard → env unset → auto (cost-based) path. let auto = col0_sorted(&mut db, q, &p).await; if csr != indexed || csr != auto { return Some((start, q, csr, indexed, auto));