diff --git a/crates/omnigraph/src/embedding.rs b/crates/omnigraph/src/embedding.rs index 9fbf8c0..c141a2b 100644 --- a/crates/omnigraph/src/embedding.rs +++ b/crates/omnigraph/src/embedding.rs @@ -134,7 +134,9 @@ impl EmbeddingConfig { fn mock() -> Self { Self { provider: Provider::Mock, - model: String::new(), + // Honor OMNIGRAPH_EMBED_MODEL so the same-space check is exercisable + // under mock; the mock vectors themselves don't depend on the model. + model: env_string("OMNIGRAPH_EMBED_MODEL").unwrap_or_default(), base_url: String::new(), api_key: String::new(), } diff --git a/crates/omnigraph/src/exec/query.rs b/crates/omnigraph/src/exec/query.rs index 8411dd3..8efadad 100644 --- a/crates/omnigraph/src/exec/query.rs +++ b/crates/omnigraph/src/exec/query.rs @@ -256,13 +256,29 @@ async fn resolve_nearest_query_vec( match lit { Literal::List(_) => literal_to_f32_vec(&lit), Literal::String(text) => { - let expected_dim = nearest_property_dimension(ir, catalog, variable, property)?; + let (expected_dim, recorded_model) = + nearest_property_dim_and_model(ir, catalog, variable, property)?; // Lazily resolve the per-handle client once, then reuse it across // queries (keeps the provider connection pool warm); a graph that // never embeds never builds a client and needs no provider key. let client = embedding .get_or_try_init(|| async { EmbeddingClient::from_env() }) .await?; + // Same-space guarantee: if the property recorded the model that + // produced its stored vectors (`@embed("…", model="…")`), the query + // embedder must resolve to that same model — otherwise the comparison + // is across vector spaces. Reject loudly instead of ranking garbage. + if let Some(recorded) = &recorded_model { + let resolved = &client.config().model; + if resolved != recorded { + return Err(OmniError::manifest(format!( + "nearest() on '{property}': its stored vectors were embedded with model \ + '{recorded}', but the query embedder resolves to '{resolved}'. Set \ + OMNIGRAPH_EMBED_MODEL='{recorded}' (and the matching provider) or re-embed \ + the stored vectors." + ))); + } + } client.embed_query_text(&text, expected_dim).await } _ => Err(OmniError::manifest( @@ -305,12 +321,14 @@ fn literal_to_f32_vec(lit: &Literal) -> Result> { } } -fn nearest_property_dimension( +/// Resolve the nearest() target property's vector dimension and the embedding +/// model recorded for it via `@embed("…", model="…")` (`None` if unrecorded). +fn nearest_property_dim_and_model( ir: &QueryIR, catalog: &Catalog, variable: &str, property: &str, -) -> Result { +) -> Result<(usize, Option)> { let type_name = resolve_binding_type_name(&ir.pipeline, variable).ok_or_else(|| { OmniError::manifest_internal(format!( "nearest() variable '${}' is not bound to a node type in the lowered pipeline", @@ -329,13 +347,20 @@ fn nearest_property_dimension( type_name, property )) })?; - match prop.scalar { - ScalarType::Vector(dim) if !prop.list => Ok(dim as usize), - _ => Err(OmniError::manifest_internal(format!( - "nearest() property '{}.{}' is not a scalar vector", - type_name, property - ))), - } + let dim = match prop.scalar { + ScalarType::Vector(dim) if !prop.list => dim as usize, + _ => { + return Err(OmniError::manifest_internal(format!( + "nearest() property '{}.{}' is not a scalar vector", + type_name, property + ))); + } + }; + let recorded_model = node_type + .embed_sources + .get(property) + .and_then(|embed| embed.model.clone()); + Ok((dim, recorded_model)) } fn resolve_binding_type_name<'a>(pipeline: &'a [IROp], variable: &str) -> Option<&'a str> { diff --git a/crates/omnigraph/tests/search.rs b/crates/omnigraph/tests/search.rs index 7537e5f..fb6e853 100644 --- a/crates/omnigraph/tests/search.rs +++ b/crates/omnigraph/tests/search.rs @@ -60,6 +60,15 @@ query hybrid_search_string($vq: String, $tq: String) { limit 3 } "#; +// Same shape as MOCK_SEARCH_SCHEMA but the vector records the model that +// produced its stored vectors, opting into the query-time same-space check. +const MODEL_RECORDED_SCHEMA: &str = r#" +node Doc { + slug: String @key + title: String @index + embedding: Vector(4) @embed("title", model="test-model-a") @index +} +"#; const SEARCH_MUTATIONS: &str = r#" query insert_doc($slug: String, $title: String, $body: String, $embedding: Vector(4)) { insert Doc { @@ -89,6 +98,15 @@ async fn init_mock_embedding_search_db(dir: &tempfile::TempDir) -> Omnigraph { db } +async fn init_model_recorded_search_db(dir: &tempfile::TempDir) -> Omnigraph { + let uri = dir.path().to_str().unwrap(); + let mut db = Omnigraph::init(uri, MODEL_RECORDED_SCHEMA).await.unwrap(); + load_jsonl(&mut db, &mock_embedding_seed_data(), LoadMode::Overwrite) + .await + .unwrap(); + db +} + fn mock_embedding_seed_data() -> String { [ ("alpha-doc", "alpha guide", mock_embedding("alpha", 4)), @@ -540,6 +558,63 @@ async fn string_nearest_requires_provider_credentials_when_mock_is_disabled() { ); } +#[tokio::test] +#[serial] +async fn nearest_string_passes_when_query_model_matches_recorded_model() { + let _guard = EnvGuard::set(&[ + ("OMNIGRAPH_EMBEDDINGS_MOCK", Some("1")), + ("OMNIGRAPH_EMBED_MODEL", Some("test-model-a")), + ("OMNIGRAPH_EMBED_PROVIDER", None), + ("OPENROUTER_API_KEY", None), + ("OPENAI_API_KEY", None), + ("GEMINI_API_KEY", None), + ]); + + let dir = tempfile::tempdir().unwrap(); + let mut db = init_model_recorded_search_db(&dir).await; + + let result = query_main( + &mut db, + MOCK_SEARCH_QUERIES, + "vector_search_string", + ¶ms(&[("$q", "alpha")]), + ) + .await + .unwrap(); + + assert_eq!(result_slugs(&result)[0], "alpha-doc"); +} + +#[tokio::test] +#[serial] +async fn nearest_string_errors_when_query_model_differs_from_recorded_model() { + let _guard = EnvGuard::set(&[ + ("OMNIGRAPH_EMBEDDINGS_MOCK", Some("1")), + ("OMNIGRAPH_EMBED_MODEL", Some("test-model-b")), + ("OMNIGRAPH_EMBED_PROVIDER", None), + ("OPENROUTER_API_KEY", None), + ("OPENAI_API_KEY", None), + ("GEMINI_API_KEY", None), + ]); + + let dir = tempfile::tempdir().unwrap(); + let mut db = init_model_recorded_search_db(&dir).await; + + let err = query_main( + &mut db, + MOCK_SEARCH_QUERIES, + "vector_search_string", + ¶ms(&[("$q", "alpha")]), + ) + .await + .unwrap_err(); + + // The error must name both the recorded model and the resolved one. + let msg = err.to_string(); + assert!(msg.contains("test-model-a"), "got: {msg}"); + assert!(msg.contains("test-model-b"), "got: {msg}"); +} + // ─── BM25 search ──────────────────────────────────────────────────────────── #[tokio::test]