feat(engine): same-space validation for @embed model (RFC-012 Phase 4)

resolve_nearest_query_vec rejects a nearest($v, string) query with a typed error when the property recorded a model (via @embed) that differs from the resolved query embedder's model, closing the silent cross-space ranking. An @embed without a recorded model keeps working with no check. EmbeddingConfig::mock() honors OMNIGRAPH_EMBED_MODEL so the check is exercisable under mock. Two new search tests.
This commit is contained in:
Ragnor Comerford 2026-06-15 21:09:35 +02:00
parent 1a06150c33
commit 0a34f9011b
No known key found for this signature in database
3 changed files with 113 additions and 11 deletions

View file

@ -134,7 +134,9 @@ impl EmbeddingConfig {
fn mock() -> Self {
Self {
provider: Provider::Mock,
model: String::new(),
// Honor OMNIGRAPH_EMBED_MODEL so the same-space check is exercisable
// under mock; the mock vectors themselves don't depend on the model.
model: env_string("OMNIGRAPH_EMBED_MODEL").unwrap_or_default(),
base_url: String::new(),
api_key: String::new(),
}

View file

@ -256,13 +256,29 @@ async fn resolve_nearest_query_vec(
match lit {
Literal::List(_) => literal_to_f32_vec(&lit),
Literal::String(text) => {
let expected_dim = nearest_property_dimension(ir, catalog, variable, property)?;
let (expected_dim, recorded_model) =
nearest_property_dim_and_model(ir, catalog, variable, property)?;
// Lazily resolve the per-handle client once, then reuse it across
// queries (keeps the provider connection pool warm); a graph that
// never embeds never builds a client and needs no provider key.
let client = embedding
.get_or_try_init(|| async { EmbeddingClient::from_env() })
.await?;
// Same-space guarantee: if the property recorded the model that
// produced its stored vectors (`@embed("…", model="…")`), the query
// embedder must resolve to that same model — otherwise the comparison
// is across vector spaces. Reject loudly instead of ranking garbage.
if let Some(recorded) = &recorded_model {
let resolved = &client.config().model;
if resolved != recorded {
return Err(OmniError::manifest(format!(
"nearest() on '{property}': its stored vectors were embedded with model \
'{recorded}', but the query embedder resolves to '{resolved}'. Set \
OMNIGRAPH_EMBED_MODEL='{recorded}' (and the matching provider) or re-embed \
the stored vectors."
)));
}
}
client.embed_query_text(&text, expected_dim).await
}
_ => Err(OmniError::manifest(
@ -305,12 +321,14 @@ fn literal_to_f32_vec(lit: &Literal) -> Result<Vec<f32>> {
}
}
fn nearest_property_dimension(
/// Resolve the nearest() target property's vector dimension and the embedding
/// model recorded for it via `@embed("…", model="…")` (`None` if unrecorded).
fn nearest_property_dim_and_model(
ir: &QueryIR,
catalog: &Catalog,
variable: &str,
property: &str,
) -> Result<usize> {
) -> Result<(usize, Option<String>)> {
let type_name = resolve_binding_type_name(&ir.pipeline, variable).ok_or_else(|| {
OmniError::manifest_internal(format!(
"nearest() variable '${}' is not bound to a node type in the lowered pipeline",
@ -329,13 +347,20 @@ fn nearest_property_dimension(
type_name, property
))
})?;
match prop.scalar {
ScalarType::Vector(dim) if !prop.list => Ok(dim as usize),
_ => Err(OmniError::manifest_internal(format!(
"nearest() property '{}.{}' is not a scalar vector",
type_name, property
))),
}
let dim = match prop.scalar {
ScalarType::Vector(dim) if !prop.list => dim as usize,
_ => {
return Err(OmniError::manifest_internal(format!(
"nearest() property '{}.{}' is not a scalar vector",
type_name, property
)));
}
};
let recorded_model = node_type
.embed_sources
.get(property)
.and_then(|embed| embed.model.clone());
Ok((dim, recorded_model))
}
fn resolve_binding_type_name<'a>(pipeline: &'a [IROp], variable: &str) -> Option<&'a str> {

View file

@ -60,6 +60,15 @@ query hybrid_search_string($vq: String, $tq: String) {
limit 3
}
"#;
// Same shape as MOCK_SEARCH_SCHEMA but the vector records the model that
// produced its stored vectors, opting into the query-time same-space check.
const MODEL_RECORDED_SCHEMA: &str = r#"
node Doc {
slug: String @key
title: String @index
embedding: Vector(4) @embed("title", model="test-model-a") @index
}
"#;
const SEARCH_MUTATIONS: &str = r#"
query insert_doc($slug: String, $title: String, $body: String, $embedding: Vector(4)) {
insert Doc {
@ -89,6 +98,15 @@ async fn init_mock_embedding_search_db(dir: &tempfile::TempDir) -> Omnigraph {
db
}
async fn init_model_recorded_search_db(dir: &tempfile::TempDir) -> Omnigraph {
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, MODEL_RECORDED_SCHEMA).await.unwrap();
load_jsonl(&mut db, &mock_embedding_seed_data(), LoadMode::Overwrite)
.await
.unwrap();
db
}
fn mock_embedding_seed_data() -> String {
[
("alpha-doc", "alpha guide", mock_embedding("alpha", 4)),
@ -540,6 +558,63 @@ async fn string_nearest_requires_provider_credentials_when_mock_is_disabled() {
);
}
#[tokio::test]
#[serial]
async fn nearest_string_passes_when_query_model_matches_recorded_model() {
let _guard = EnvGuard::set(&[
("OMNIGRAPH_EMBEDDINGS_MOCK", Some("1")),
("OMNIGRAPH_EMBED_MODEL", Some("test-model-a")),
("OMNIGRAPH_EMBED_PROVIDER", None),
("OPENROUTER_API_KEY", None),
("OPENAI_API_KEY", None),
("GEMINI_API_KEY", None),
]);
let dir = tempfile::tempdir().unwrap();
let mut db = init_model_recorded_search_db(&dir).await;
let result = query_main(
&mut db,
MOCK_SEARCH_QUERIES,
"vector_search_string",
&params(&[("$q", "alpha")]),
)
.await
.unwrap();
assert_eq!(result_slugs(&result)[0], "alpha-doc");
}
#[tokio::test]
#[serial]
async fn nearest_string_errors_when_query_model_differs_from_recorded_model() {
let _guard = EnvGuard::set(&[
("OMNIGRAPH_EMBEDDINGS_MOCK", Some("1")),
("OMNIGRAPH_EMBED_MODEL", Some("test-model-b")),
("OMNIGRAPH_EMBED_PROVIDER", None),
("OPENROUTER_API_KEY", None),
("OPENAI_API_KEY", None),
("GEMINI_API_KEY", None),
]);
let dir = tempfile::tempdir().unwrap();
let mut db = init_model_recorded_search_db(&dir).await;
let err = query_main(
&mut db,
MOCK_SEARCH_QUERIES,
"vector_search_string",
&params(&[("$q", "alpha")]),
)
.await
.unwrap_err();
// The error must name both the recorded model and the resolved one.
let msg = err.to_string();
assert!(msg.contains("test-model-a"), "got: {msg}");
assert!(msg.contains("test-model-b"), "got: {msg}");
}
// ─── BM25 search ────────────────────────────────────────────────────────────
#[tokio::test]