From b999ae37531e9547fefe67bbada428f97b8c90f8 Mon Sep 17 00:00:00 2001 From: Ragnor Comerford Date: Mon, 15 Jun 2026 17:15:11 +0200 Subject: [PATCH] feat(engine)!: provider-independent embedding client (RFC-012 Phase 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the Gemini-only EmbeddingClient with one resolved EmbeddingConfig { provider, model, base_url, api_key } behind a sealed Provider enum (OpenAiCompatible | Gemini | Mock). OpenAiCompatible (POST {base}/embeddings, bearer, {model, input, dimensions}) covers OpenRouter — the new default gateway — OpenAI direct, and self-hosted endpoints; Gemini keeps its RETRIEVAL_QUERY/RETRIEVAL_DOCUMENT task types; Mock is offline/deterministic. EmbedRole replaces the task-type string. from_env() resolves provider via OMNIGRAPH_EMBED_PROVIDER (default openai-compatible), base/model via OMNIGRAPH_EMBED_BASE_URL/_MODEL, and the api key from OPENROUTER_API_KEY/OPENAI_API_KEY or GEMINI_API_KEY. BREAKING (pre-release, no back-compat): the default provider is now OpenRouter, OMNIGRAPH_GEMINI_BASE_URL is dropped, and Gemini-only users set OMNIGRAPH_EMBED_PROVIDER=gemini. Folds in RFC-012 Phase 1 NFR floor: a total-operation OMNIGRAPH_EMBED_QUERY_DEADLINE_MS deadline (default 60s; 0=unbounded) bounds the ~121s worst case, and tracing spans (target omnigraph::embedding) record provider/model/dim/attempt/elapsed/outcome. The offline 'omnigraph embed' CLI follows the resolved provider (its hardcoded gemini-only bail removed). 17 engine embedding unit tests, 4 CLI embed tests, and the search integration suite (22) pass. Cross-query client reuse and the docs refresh land in follow-up commits. --- crates/omnigraph-cli/src/embed.rs | 15 +- crates/omnigraph-cli/tests/system_local.rs | 5 + crates/omnigraph/src/embedding.rs | 591 +++++++++++++++++---- crates/omnigraph/tests/search.rs | 13 +- 4 files changed, 518 insertions(+), 106 deletions(-) diff --git a/crates/omnigraph-cli/src/embed.rs b/crates/omnigraph-cli/src/embed.rs index 2e1c6d9..b1773f6 100644 --- a/crates/omnigraph-cli/src/embed.rs +++ b/crates/omnigraph-cli/src/embed.rs @@ -9,8 +9,6 @@ use omnigraph::embedding::EmbeddingClient; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value, json}; -const DEFAULT_EMBED_MODEL: &str = "gemini-embedding-2-preview"; - #[derive(Debug, Args, Clone)] pub(crate) struct EmbedArgs { /// Seed manifest path @@ -85,7 +83,7 @@ impl EmbedMode { #[derive(Debug, Clone, Deserialize)] struct EmbedSpec { - #[serde(default = "default_embed_model")] + #[serde(default)] model: String, dimension: usize, types: BTreeMap, @@ -180,13 +178,6 @@ pub(crate) fn resolve_embed_job(args: &EmbedArgs) -> Result { (input, output, spec) }; - if spec.model != DEFAULT_EMBED_MODEL { - bail!( - "only {} is supported for explicit seed embeddings right now", - DEFAULT_EMBED_MODEL - ); - } - Ok(EmbedJob { input, output, @@ -315,10 +306,6 @@ fn temp_output_path(output: &Path) -> PathBuf { PathBuf::from(temp) } -fn default_embed_model() -> String { - DEFAULT_EMBED_MODEL.to_string() -} - fn load_embed_spec(path: &Path) -> Result { Ok(serde_json::from_str(&fs::read_to_string(path)?)?) } diff --git a/crates/omnigraph-cli/tests/system_local.rs b/crates/omnigraph-cli/tests/system_local.rs index b6a87f1..ddedaf7 100644 --- a/crates/omnigraph-cli/tests/system_local.rs +++ b/crates/omnigraph-cli/tests/system_local.rs @@ -1111,6 +1111,11 @@ query vector_search($q: String) { let result = parse_stdout_json(&output_success( cli() + // Stored vectors above were produced with gemini-embedding-2-preview; + // pin the query-time embedder to the same provider/model so the + // auto-embedded `$q` lands in the same vector space. + .env("OMNIGRAPH_EMBED_PROVIDER", "gemini") + .env("OMNIGRAPH_EMBED_MODEL", "gemini-embedding-2-preview") .arg("read") .arg(&graph) .arg("--query") diff --git a/crates/omnigraph/src/embedding.rs b/crates/omnigraph/src/embedding.rs index cfd4071..70ac9df 100644 --- a/crates/omnigraph/src/embedding.rs +++ b/crates/omnigraph/src/embedding.rs @@ -8,29 +8,149 @@ use tokio::time::sleep; use crate::error::{OmniError, Result}; -const GEMINI_EMBED_MODEL: &str = "gemini-embedding-2-preview"; +const DEFAULT_OPENAI_BASE_URL: &str = "https://openrouter.ai/api/v1"; +const DEFAULT_OPENAI_MODEL: &str = "openai/text-embedding-3-large"; const DEFAULT_GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta"; +const DEFAULT_GEMINI_MODEL: &str = "gemini-embedding-2"; const DEFAULT_TIMEOUT_MS: u64 = 30_000; const DEFAULT_RETRY_ATTEMPTS: usize = 4; const DEFAULT_RETRY_BACKOFF_MS: u64 = 200; -const QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY"; -const DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT"; +const DEFAULT_QUERY_DEADLINE_MS: u64 = 60_000; +const GEMINI_QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY"; +const GEMINI_DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT"; -#[derive(Clone, Debug)] -enum EmbeddingTransport { +/// Which embedding API a client speaks. Each variant owns its request shape, +/// auth, and response parsing; everything else (retry, deadline, normalization, +/// tracing) is provider-independent. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Provider { + /// OpenAI-compatible (`POST {base}/embeddings`, bearer auth, + /// `{model, input, dimensions}`). Covers OpenRouter (the default gateway), + /// OpenAI direct, and self-hosted endpoints (vLLM/Ollama/LM Studio). + OpenAiCompatible, + /// Google Gemini `generativelanguage` (`POST {base}/models/{model}:embedContent`, + /// `x-goog-api-key`), with `RETRIEVAL_QUERY` / `RETRIEVAL_DOCUMENT` task types. + Gemini, + /// Deterministic, offline. No network, no key. Mock, - Gemini { - api_key: String, - base_url: String, - http: Client, - }, +} + +impl Provider { + fn default_base_url(self) -> &'static str { + match self { + Provider::OpenAiCompatible => DEFAULT_OPENAI_BASE_URL, + Provider::Gemini => DEFAULT_GEMINI_BASE_URL, + Provider::Mock => "", + } + } + + fn default_model(self) -> &'static str { + match self { + Provider::OpenAiCompatible => DEFAULT_OPENAI_MODEL, + Provider::Gemini => DEFAULT_GEMINI_MODEL, + Provider::Mock => "", + } + } +} + +/// Whether the text being embedded is a search query or a stored document. +/// Only Gemini distinguishes these (`RETRIEVAL_QUERY` vs `RETRIEVAL_DOCUMENT`); +/// OpenAI-compatible providers and Mock produce the identical request for both, +/// which is also the same-space property a query relies on. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum EmbedRole { + Query, + Document, +} + +/// The single source of truth for how embedding text becomes a vector: +/// provider + model + endpoint + key. Resolved once (from env today; from the +/// cluster `providers.embedding` profile in a later RFC-012 phase) and shared by +/// the query path and the offline CLI so stored and query vectors stay +/// same-space by construction. +#[derive(Clone, Debug)] +pub struct EmbeddingConfig { + pub provider: Provider, + pub model: String, + pub base_url: String, + pub api_key: String, +} + +impl EmbeddingConfig { + /// Resolve from the environment. Precedence: + /// 1. `OMNIGRAPH_EMBEDDINGS_MOCK` → Mock. + /// 2. `OMNIGRAPH_EMBED_PROVIDER` (`openai-compatible`|`openai`|`gemini`|`mock`); + /// unset defaults to `openai-compatible` (OpenRouter). + /// 3. `OMNIGRAPH_EMBED_BASE_URL` else the provider default. + /// 4. `OMNIGRAPH_EMBED_MODEL` else the provider default. + /// 5. provider api-key env (`OPENROUTER_API_KEY`/`OPENAI_API_KEY`, or `GEMINI_API_KEY`). + pub fn from_env() -> Result { + if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") { + return Ok(Self::mock()); + } + + let provider = match env_string("OMNIGRAPH_EMBED_PROVIDER").as_deref() { + None | Some("openai-compatible") | Some("openai") => Provider::OpenAiCompatible, + Some("gemini") => Provider::Gemini, + Some("mock") => return Ok(Self::mock()), + Some(other) => { + return Err(OmniError::manifest_internal(format!( + "unknown OMNIGRAPH_EMBED_PROVIDER '{}' (expected openai-compatible|gemini|mock)", + other + ))); + } + }; + + let base_url = env_string("OMNIGRAPH_EMBED_BASE_URL") + .unwrap_or_else(|| provider.default_base_url().to_string()) + .trim_end_matches('/') + .to_string(); + let model = + env_string("OMNIGRAPH_EMBED_MODEL").unwrap_or_else(|| provider.default_model().to_string()); + + let api_key = match provider { + Provider::OpenAiCompatible => env_string("OPENROUTER_API_KEY") + .or_else(|| env_string("OPENAI_API_KEY")) + .ok_or_else(|| { + OmniError::manifest_internal( + "OPENROUTER_API_KEY or OPENAI_API_KEY is required for the openai-compatible embedding provider", + ) + })?, + Provider::Gemini => env_string("GEMINI_API_KEY").ok_or_else(|| { + OmniError::manifest_internal( + "GEMINI_API_KEY is required for the gemini embedding provider", + ) + })?, + Provider::Mock => unreachable!("mock returns early"), + }; + + Ok(Self { + provider, + model, + base_url, + api_key, + }) + } + + fn mock() -> Self { + Self { + provider: Provider::Mock, + model: String::new(), + base_url: String::new(), + api_key: String::new(), + } + } } #[derive(Clone, Debug)] pub struct EmbeddingClient { + config: EmbeddingConfig, + http: Client, retry_attempts: usize, retry_backoff_ms: u64, - transport: EmbeddingTransport, + /// Total wall-clock budget for one embed call, across all retries + /// (`OMNIGRAPH_EMBED_QUERY_DEADLINE_MS`). `0` = unbounded. + query_deadline_ms: u64, } struct EmbedCallError { @@ -58,35 +178,39 @@ struct GoogleErrorBody { message: String, } +#[derive(Debug, Deserialize)] +struct OpenAiEmbeddingResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAiEmbeddingDatum { + index: usize, + embedding: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAiErrorEnvelope { + error: OpenAiErrorBody, +} + +#[derive(Debug, Deserialize)] +struct OpenAiErrorBody { + message: String, +} + impl EmbeddingClient { pub fn from_env() -> Result { + Self::new(EmbeddingConfig::from_env()?) + } + + pub fn new(config: EmbeddingConfig) -> Result { let retry_attempts = parse_env_usize("OMNIGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS); let retry_backoff_ms = parse_env_u64("OMNIGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS); - - if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") { - return Ok(Self { - retry_attempts, - retry_backoff_ms, - transport: EmbeddingTransport::Mock, - }); - } - - let api_key = std::env::var("GEMINI_API_KEY") - .ok() - .map(|v| v.trim().to_string()) - .filter(|v| !v.is_empty()) - .ok_or_else(|| { - OmniError::manifest_internal( - "GEMINI_API_KEY is required when nearest() needs a string embedding", - ) - })?; - let base_url = std::env::var("OMNIGRAPH_GEMINI_BASE_URL") - .ok() - .map(|v| v.trim_end_matches('/').to_string()) - .filter(|v| !v.is_empty()) - .unwrap_or_else(|| DEFAULT_GEMINI_BASE_URL.to_string()); + let query_deadline_ms = + parse_env_u64_allow_zero("OMNIGRAPH_EMBED_QUERY_DEADLINE_MS", DEFAULT_QUERY_DEADLINE_MS); let timeout_ms = parse_env_u64("OMNIGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS); let http = Client::builder() .timeout(Duration::from_millis(timeout_ms)) @@ -96,39 +220,36 @@ impl EmbeddingClient { })?; Ok(Self { + config, + http, retry_attempts, retry_backoff_ms, - transport: EmbeddingTransport::Gemini { - api_key, - base_url, - http, - }, + query_deadline_ms, }) } + pub fn config(&self) -> &EmbeddingConfig { + &self.config + } + #[cfg(test)] fn mock_for_tests() -> Self { - Self { - retry_attempts: DEFAULT_RETRY_ATTEMPTS, - retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS, - transport: EmbeddingTransport::Mock, - } + Self::new(EmbeddingConfig::mock()).expect("mock client builds") } pub async fn embed_query_text(&self, input: &str, expected_dim: usize) -> Result> { - self.embed_text(input, expected_dim, QUERY_TASK_TYPE).await + self.embed_text(input, expected_dim, EmbedRole::Query).await } pub async fn embed_document_text(&self, input: &str, expected_dim: usize) -> Result> { - self.embed_text(input, expected_dim, DOCUMENT_TASK_TYPE) - .await + self.embed_text(input, expected_dim, EmbedRole::Document).await } async fn embed_text( &self, input: &str, expected_dim: usize, - task_type: &'static str, + role: EmbedRole, ) -> Result> { if expected_dim == 0 { return Err(OmniError::manifest_internal( @@ -136,10 +257,70 @@ impl EmbeddingClient { )); } - match &self.transport { - EmbeddingTransport::Mock => Ok(mock_embedding(input, expected_dim)), - EmbeddingTransport::Gemini { .. } => { - self.with_retry(|| self.embed_text_gemini_once(input, expected_dim, task_type)) + let started = std::time::Instant::now(); + let result = self + .run_with_deadline(self.embed_text_inner(input, expected_dim, role)) + .await; + let elapsed_ms = started.elapsed().as_millis() as u64; + + match &result { + Ok(_) => tracing::info!( + target: "omnigraph::embedding", + provider = ?self.config.provider, + model = %self.config.model, + dim = expected_dim, + elapsed_ms, + outcome = "ok", + "embedding succeeded" + ), + Err(err) => tracing::warn!( + target: "omnigraph::embedding", + provider = ?self.config.provider, + model = %self.config.model, + dim = expected_dim, + elapsed_ms, + outcome = "error", + error = %err, + "embedding failed" + ), + } + result + } + + /// Bound the whole embed operation (all retries + backoff) by + /// `query_deadline_ms`, so a degraded provider can never hang a read for the + /// full retry envelope. `0` = unbounded. Read-path only, so cancelling the + /// in-flight request future on elapse is safe. + async fn run_with_deadline(&self, fut: F) -> Result> + where + F: Future>>, + { + if self.query_deadline_ms == 0 { + return fut.await; + } + match tokio::time::timeout(Duration::from_millis(self.query_deadline_ms), fut).await { + Ok(res) => res, + Err(_elapsed) => Err(OmniError::manifest_internal(format!( + "embedding deadline exceeded after {} ms (provider={:?}, model={})", + self.query_deadline_ms, self.config.provider, self.config.model + ))), + } + } + + async fn embed_text_inner( + &self, + input: &str, + expected_dim: usize, + role: EmbedRole, + ) -> Result> { + match self.config.provider { + Provider::Mock => Ok(mock_embedding(input, expected_dim)), + Provider::Gemini => { + self.with_retry(|| self.embed_gemini_once(input, expected_dim, role)) + .await + } + Provider::OpenAiCompatible => { + self.with_retry(|| self.embed_openai_once(input, expected_dim)) .await } } @@ -160,6 +341,14 @@ impl EmbeddingClient { if !err.retryable || attempt >= max_attempt { return Err(OmniError::manifest_internal(err.message)); } + tracing::warn!( + target: "omnigraph::embedding", + provider = ?self.config.provider, + model = %self.config.model, + attempt, + error = %err.message, + "embedding attempt failed, retrying" + ); let shift = (attempt - 1).min(10) as u32; let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift); sleep(Duration::from_millis(delay)).await; @@ -168,25 +357,27 @@ impl EmbeddingClient { } } - async fn embed_text_gemini_once( + async fn embed_gemini_once( &self, input: &str, expected_dim: usize, - task_type: &'static str, + role: EmbedRole, ) -> std::result::Result, EmbedCallError> { - let (api_key, base_url, http) = match &self.transport { - EmbeddingTransport::Gemini { - api_key, - base_url, - http, - } => (api_key, base_url, http), - EmbeddingTransport::Mock => unreachable!("mock transport should not call Gemini"), + let task_type = match role { + EmbedRole::Query => GEMINI_QUERY_TASK_TYPE, + EmbedRole::Document => GEMINI_DOCUMENT_TASK_TYPE, }; - let response = http - .post(gemini_endpoint(base_url)) - .header("x-goog-api-key", api_key) - .json(&build_gemini_request(input, expected_dim, task_type)) + let response = self + .http + .post(gemini_endpoint(&self.config.base_url, &self.config.model)) + .header("x-goog-api-key", &self.config.api_key) + .json(&build_gemini_request( + &self.config.model, + input, + expected_dim, + task_type, + )) .send() .await; let response = match response { @@ -205,10 +396,7 @@ impl EmbeddingClient { Ok(body) => body, Err(err) => { return Err(EmbedCallError { - message: format!( - "embedding response read failed (status {}): {}", - status, err - ), + message: format!("embedding response read failed (status {}): {}", status, err), retryable: status.is_server_error() || status.as_u16() == 429, }); } @@ -217,10 +405,7 @@ impl EmbeddingClient { if !status.is_success() { let message = parse_google_error_message(&body).unwrap_or(body); return Err(EmbedCallError { - message: format!( - "embedding request failed with status {}: {}", - status, message - ), + message: format!("embedding request failed with status {}: {}", status, message), retryable: status.is_server_error() || status.as_u16() == 429, }); } @@ -238,19 +423,85 @@ impl EmbeddingClient { } }) } + + async fn embed_openai_once( + &self, + input: &str, + expected_dim: usize, + ) -> std::result::Result, EmbedCallError> { + let response = self + .http + .post(format!("{}/embeddings", self.config.base_url)) + .bearer_auth(&self.config.api_key) + .json(&build_openai_request(&self.config.model, input, expected_dim)) + .send() + .await; + let response = match response { + Ok(response) => response, + Err(err) => { + let retryable = err.is_timeout() || err.is_connect() || err.is_request(); + return Err(EmbedCallError { + message: format!("embedding request failed: {}", err), + retryable, + }); + } + }; + + let status = response.status(); + let body = match response.text().await { + Ok(body) => body, + Err(err) => { + return Err(EmbedCallError { + message: format!("embedding response read failed (status {}): {}", status, err), + retryable: status.is_server_error() || status.as_u16() == 429, + }); + } + }; + + if !status.is_success() { + let message = parse_openai_error_message(&body).unwrap_or(body); + return Err(EmbedCallError { + message: format!("embedding request failed with status {}: {}", status, message), + retryable: status.is_server_error() || status.as_u16() == 429, + }); + } + + let parsed: OpenAiEmbeddingResponse = + serde_json::from_str(&body).map_err(|err| EmbedCallError { + message: format!("embedding response decode failed: {}", err), + retryable: false, + })?; + + // The query path embeds exactly one string, so expect one datum at index 0. + let datum = parsed + .data + .into_iter() + .find(|d| d.index == 0) + .ok_or_else(|| EmbedCallError { + message: "embedding response missing data[0]".to_string(), + retryable: false, + })?; + + validate_and_normalize_embedding(datum.embedding, expected_dim).map_err(|message| { + EmbedCallError { + message, + retryable: false, + } + }) + } } -fn gemini_endpoint(base_url: &str) -> String { +fn gemini_endpoint(base_url: &str, model: &str) -> String { format!( "{}/models/{}:embedContent", base_url.trim_end_matches('/'), - GEMINI_EMBED_MODEL + model ) } -fn build_gemini_request(input: &str, expected_dim: usize, task_type: &'static str) -> Value { +fn build_gemini_request(model: &str, input: &str, expected_dim: usize, task_type: &str) -> Value { json!({ - "model": format!("models/{}", GEMINI_EMBED_MODEL), + "model": format!("models/{}", model), "content": { "parts": [ { @@ -263,6 +514,14 @@ fn build_gemini_request(input: &str, expected_dim: usize, task_type: &'static st }) } +fn build_openai_request(model: &str, input: &str, expected_dim: usize) -> Value { + json!({ + "model": model, + "input": [input], + "dimensions": expected_dim, + }) +} + fn validate_and_normalize_embedding( values: Vec, expected_dim: usize, @@ -298,6 +557,20 @@ fn parse_google_error_message(body: &str) -> Option { .filter(|msg| !msg.trim().is_empty()) } +fn parse_openai_error_message(body: &str) -> Option { + serde_json::from_str::(body) + .ok() + .map(|e| e.error.message) + .filter(|msg| !msg.trim().is_empty()) +} + +fn env_string(name: &str) -> Option { + std::env::var(name) + .ok() + .map(|v| v.trim().to_string()) + .filter(|v| !v.is_empty()) +} + fn parse_env_usize(name: &str, default: usize) -> usize { std::env::var(name) .ok() @@ -314,6 +587,15 @@ fn parse_env_u64(name: &str, default: u64) -> u64 { .unwrap_or(default) } +/// Like [`parse_env_u64`] but accepts `0` as a meaningful value (the deadline +/// uses `0` for "unbounded"). +fn parse_env_u64_allow_zero(name: &str, default: u64) -> u64 { + std::env::var(name) + .ok() + .and_then(|v| v.trim().parse::().ok()) + .unwrap_or(default) +} + fn env_flag(name: &str) -> bool { std::env::var(name) .ok() @@ -395,6 +677,25 @@ mod tests { } } + // Every test that calls `EmbeddingConfig::from_env` clears the full set of + // embedding env vars first so the host environment can't leak in. + const EMBED_ENV: &[&str] = &[ + "OMNIGRAPH_EMBEDDINGS_MOCK", + "OMNIGRAPH_EMBED_PROVIDER", + "OMNIGRAPH_EMBED_BASE_URL", + "OMNIGRAPH_EMBED_MODEL", + "OPENROUTER_API_KEY", + "OPENAI_API_KEY", + "GEMINI_API_KEY", + ]; + + fn cleared_env(extra: &[(&'static str, Option<&str>)]) -> EnvGuard { + let mut vars: Vec<(&'static str, Option<&str>)> = + EMBED_ENV.iter().map(|n| (*n, None)).collect(); + vars.extend_from_slice(extra); + EnvGuard::set(&vars) + } + #[tokio::test] async fn mock_embeddings_are_deterministic() { let client = EmbeddingClient::mock_for_tests(); @@ -407,18 +708,30 @@ mod tests { } #[test] - fn gemini_request_uses_preview_model_retrieval_query_and_dimension() { - let request = build_gemini_request("alpha", 4, QUERY_TASK_TYPE); - assert_eq!(request["model"], "models/gemini-embedding-2-preview"); - assert_eq!(request["taskType"], QUERY_TASK_TYPE); + fn gemini_request_uses_model_retrieval_query_and_dimension() { + let request = + build_gemini_request("gemini-embedding-2", "alpha", 4, GEMINI_QUERY_TASK_TYPE); + assert_eq!(request["model"], "models/gemini-embedding-2"); + assert_eq!(request["taskType"], GEMINI_QUERY_TASK_TYPE); assert_eq!(request["outputDimensionality"], 4); assert_eq!(request["content"]["parts"][0]["text"], "alpha"); } #[test] fn gemini_document_request_uses_retrieval_document_task_type() { - let request = build_gemini_request("alpha", 4, DOCUMENT_TASK_TYPE); - assert_eq!(request["taskType"], DOCUMENT_TASK_TYPE); + let request = + build_gemini_request("gemini-embedding-2", "alpha", 4, GEMINI_DOCUMENT_TASK_TYPE); + assert_eq!(request["taskType"], GEMINI_DOCUMENT_TASK_TYPE); + } + + #[test] + fn openai_request_uses_model_input_array_and_dimensions() { + let request = build_openai_request("openai/text-embedding-3-large", "alpha", 4); + assert_eq!(request["model"], "openai/text-embedding-3-large"); + assert_eq!(request["input"][0], "alpha"); + assert!(request["input"].is_array()); + assert_eq!(request["dimensions"], 4); + assert!(request.get("taskType").is_none()); } #[test] @@ -475,15 +788,113 @@ mod tests { assert!(err.to_string().contains("do not retry")); } + #[tokio::test] + async fn run_with_deadline_aborts_slow_future() { + let mut client = EmbeddingClient::mock_for_tests(); + client.query_deadline_ms = 20; + let slow = async { + tokio::time::sleep(Duration::from_secs(5)).await; + Ok(vec![0.0_f32]) + }; + let err = client.run_with_deadline(slow).await.unwrap_err(); + assert!(err.to_string().contains("deadline exceeded")); + } + + #[tokio::test] + async fn run_with_deadline_passes_through_fast_future() { + let client = EmbeddingClient::mock_for_tests(); + let ok = client + .run_with_deadline(async { Ok(vec![1.0_f32, 2.0]) }) + .await + .unwrap(); + assert_eq!(ok, vec![1.0, 2.0]); + } + + #[tokio::test] + async fn run_with_deadline_zero_is_unbounded() { + let mut client = EmbeddingClient::mock_for_tests(); + client.query_deadline_ms = 0; + let ok = client + .run_with_deadline(async { Ok(vec![3.0_f32]) }) + .await + .unwrap(); + assert_eq!(ok, vec![3.0]); + } + #[test] #[serial] - fn from_env_requires_gemini_api_key_when_not_mocking() { - let _guard = EnvGuard::set(&[ - ("OMNIGRAPH_EMBEDDINGS_MOCK", None), - ("GEMINI_API_KEY", None), - ]); + fn from_env_defaults_to_openai_compatible_openrouter() { + let _guard = cleared_env(&[("OPENROUTER_API_KEY", Some("sk-test"))]); + let config = EmbeddingConfig::from_env().unwrap(); + assert_eq!(config.provider, Provider::OpenAiCompatible); + assert_eq!(config.base_url, DEFAULT_OPENAI_BASE_URL); + assert_eq!(config.model, DEFAULT_OPENAI_MODEL); + assert_eq!(config.api_key, "sk-test"); + } - let err = EmbeddingClient::from_env().unwrap_err(); - assert!(err.to_string().contains("GEMINI_API_KEY")); + #[test] + #[serial] + fn from_env_openai_compatible_prefers_openrouter_key() { + let _guard = cleared_env(&[ + ("OPENROUTER_API_KEY", Some("router")), + ("OPENAI_API_KEY", Some("openai")), + ]); + let config = EmbeddingConfig::from_env().unwrap(); + assert_eq!(config.api_key, "router"); + } + + #[test] + #[serial] + fn from_env_explicit_gemini_provider() { + let _guard = cleared_env(&[ + ("OMNIGRAPH_EMBED_PROVIDER", Some("gemini")), + ("GEMINI_API_KEY", Some("g-key")), + ]); + let config = EmbeddingConfig::from_env().unwrap(); + assert_eq!(config.provider, Provider::Gemini); + assert_eq!(config.base_url, DEFAULT_GEMINI_BASE_URL); + assert_eq!(config.model, DEFAULT_GEMINI_MODEL); + assert_eq!(config.api_key, "g-key"); + } + + #[test] + #[serial] + fn from_env_base_url_and_model_overrides_apply() { + let _guard = cleared_env(&[ + ("OMNIGRAPH_EMBED_PROVIDER", Some("openai-compatible")), + ("OMNIGRAPH_EMBED_BASE_URL", Some("https://example.test/v1/")), + ("OMNIGRAPH_EMBED_MODEL", Some("custom/model")), + ("OPENAI_API_KEY", Some("k")), + ]); + let config = EmbeddingConfig::from_env().unwrap(); + assert_eq!(config.base_url, "https://example.test/v1"); // trailing slash trimmed + assert_eq!(config.model, "custom/model"); + } + + #[test] + #[serial] + fn from_env_unknown_provider_errors() { + let _guard = cleared_env(&[("OMNIGRAPH_EMBED_PROVIDER", Some("cohere"))]); + let err = EmbeddingConfig::from_env().unwrap_err(); + assert!(err.to_string().contains("unknown OMNIGRAPH_EMBED_PROVIDER")); + } + + #[test] + #[serial] + fn from_env_errors_when_no_key_present() { + let _guard = cleared_env(&[]); + let err = EmbeddingConfig::from_env().unwrap_err(); + assert!(err.to_string().contains("OPENROUTER_API_KEY or OPENAI_API_KEY")); + } + + #[test] + #[serial] + fn from_env_mock_flag_wins() { + let _guard = cleared_env(&[ + ("OMNIGRAPH_EMBEDDINGS_MOCK", Some("1")), + ("OMNIGRAPH_EMBED_PROVIDER", Some("gemini")), + ]); + let config = EmbeddingConfig::from_env().unwrap(); + assert_eq!(config.provider, Provider::Mock); } } diff --git a/crates/omnigraph/tests/search.rs b/crates/omnigraph/tests/search.rs index 480ec3c..7537e5f 100644 --- a/crates/omnigraph/tests/search.rs +++ b/crates/omnigraph/tests/search.rs @@ -510,9 +510,14 @@ async fn explicit_vector_nearest_does_not_require_gemini_credentials() { #[tokio::test] #[serial] -async fn string_nearest_requires_gemini_credentials_when_mock_is_disabled() { +async fn string_nearest_requires_provider_credentials_when_mock_is_disabled() { + // With mock off and no provider key, the default (openai-compatible) + // provider fails loudly rather than silently producing garbage vectors. let _guard = EnvGuard::set(&[ ("OMNIGRAPH_EMBEDDINGS_MOCK", None), + ("OMNIGRAPH_EMBED_PROVIDER", None), + ("OPENROUTER_API_KEY", None), + ("OPENAI_API_KEY", None), ("GEMINI_API_KEY", None), ]); @@ -528,7 +533,11 @@ async fn string_nearest_requires_gemini_credentials_when_mock_is_disabled() { .await .unwrap_err(); - assert!(err.to_string().contains("GEMINI_API_KEY")); + assert!( + err.to_string() + .contains("OPENROUTER_API_KEY or OPENAI_API_KEY"), + "unexpected error: {err}" + ); } // ─── BM25 search ────────────────────────────────────────────────────────────