mirror of
https://github.com/ModernRelay/omnigraph.git
synced 2026-06-18 02:24:27 +02:00
feat(engine)!: provider-independent embedding client (RFC-012 Phase 2)
Replace the Gemini-only EmbeddingClient with one resolved EmbeddingConfig { provider, model, base_url, api_key } behind a sealed Provider enum (OpenAiCompatible | Gemini | Mock). OpenAiCompatible (POST {base}/embeddings, bearer, {model, input, dimensions}) covers OpenRouter — the new default gateway — OpenAI direct, and self-hosted endpoints; Gemini keeps its RETRIEVAL_QUERY/RETRIEVAL_DOCUMENT task types; Mock is offline/deterministic. EmbedRole replaces the task-type string.
from_env() resolves provider via OMNIGRAPH_EMBED_PROVIDER (default openai-compatible), base/model via OMNIGRAPH_EMBED_BASE_URL/_MODEL, and the api key from OPENROUTER_API_KEY/OPENAI_API_KEY or GEMINI_API_KEY. BREAKING (pre-release, no back-compat): the default provider is now OpenRouter, OMNIGRAPH_GEMINI_BASE_URL is dropped, and Gemini-only users set OMNIGRAPH_EMBED_PROVIDER=gemini.
Folds in RFC-012 Phase 1 NFR floor: a total-operation OMNIGRAPH_EMBED_QUERY_DEADLINE_MS deadline (default 60s; 0=unbounded) bounds the ~121s worst case, and tracing spans (target omnigraph::embedding) record provider/model/dim/attempt/elapsed/outcome. The offline 'omnigraph embed' CLI follows the resolved provider (its hardcoded gemini-only bail removed). 17 engine embedding unit tests, 4 CLI embed tests, and the search integration suite (22) pass.
Cross-query client reuse and the docs refresh land in follow-up commits.
This commit is contained in:
parent
7c916f5b98
commit
b999ae3753
4 changed files with 518 additions and 106 deletions
|
|
@ -9,8 +9,6 @@ use omnigraph::embedding::EmbeddingClient;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value, json};
|
||||
|
||||
const DEFAULT_EMBED_MODEL: &str = "gemini-embedding-2-preview";
|
||||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
pub(crate) struct EmbedArgs {
|
||||
/// Seed manifest path
|
||||
|
|
@ -85,7 +83,7 @@ impl EmbedMode {
|
|||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct EmbedSpec {
|
||||
#[serde(default = "default_embed_model")]
|
||||
#[serde(default)]
|
||||
model: String,
|
||||
dimension: usize,
|
||||
types: BTreeMap<String, EmbedTypeSpec>,
|
||||
|
|
@ -180,13 +178,6 @@ pub(crate) fn resolve_embed_job(args: &EmbedArgs) -> Result<EmbedJob> {
|
|||
(input, output, spec)
|
||||
};
|
||||
|
||||
if spec.model != DEFAULT_EMBED_MODEL {
|
||||
bail!(
|
||||
"only {} is supported for explicit seed embeddings right now",
|
||||
DEFAULT_EMBED_MODEL
|
||||
);
|
||||
}
|
||||
|
||||
Ok(EmbedJob {
|
||||
input,
|
||||
output,
|
||||
|
|
@ -315,10 +306,6 @@ fn temp_output_path(output: &Path) -> PathBuf {
|
|||
PathBuf::from(temp)
|
||||
}
|
||||
|
||||
fn default_embed_model() -> String {
|
||||
DEFAULT_EMBED_MODEL.to_string()
|
||||
}
|
||||
|
||||
fn load_embed_spec(path: &Path) -> Result<EmbedSpec> {
|
||||
Ok(serde_json::from_str(&fs::read_to_string(path)?)?)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1111,6 +1111,11 @@ query vector_search($q: String) {
|
|||
|
||||
let result = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
// Stored vectors above were produced with gemini-embedding-2-preview;
|
||||
// pin the query-time embedder to the same provider/model so the
|
||||
// auto-embedded `$q` lands in the same vector space.
|
||||
.env("OMNIGRAPH_EMBED_PROVIDER", "gemini")
|
||||
.env("OMNIGRAPH_EMBED_MODEL", "gemini-embedding-2-preview")
|
||||
.arg("read")
|
||||
.arg(&graph)
|
||||
.arg("--query")
|
||||
|
|
|
|||
|
|
@ -8,29 +8,149 @@ use tokio::time::sleep;
|
|||
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
const GEMINI_EMBED_MODEL: &str = "gemini-embedding-2-preview";
|
||||
const DEFAULT_OPENAI_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
||||
const DEFAULT_OPENAI_MODEL: &str = "openai/text-embedding-3-large";
|
||||
const DEFAULT_GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
|
||||
const DEFAULT_GEMINI_MODEL: &str = "gemini-embedding-2";
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 30_000;
|
||||
const DEFAULT_RETRY_ATTEMPTS: usize = 4;
|
||||
const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
|
||||
const QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY";
|
||||
const DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT";
|
||||
const DEFAULT_QUERY_DEADLINE_MS: u64 = 60_000;
|
||||
const GEMINI_QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY";
|
||||
const GEMINI_DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum EmbeddingTransport {
|
||||
/// Which embedding API a client speaks. Each variant owns its request shape,
|
||||
/// auth, and response parsing; everything else (retry, deadline, normalization,
|
||||
/// tracing) is provider-independent.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Provider {
|
||||
/// OpenAI-compatible (`POST {base}/embeddings`, bearer auth,
|
||||
/// `{model, input, dimensions}`). Covers OpenRouter (the default gateway),
|
||||
/// OpenAI direct, and self-hosted endpoints (vLLM/Ollama/LM Studio).
|
||||
OpenAiCompatible,
|
||||
/// Google Gemini `generativelanguage` (`POST {base}/models/{model}:embedContent`,
|
||||
/// `x-goog-api-key`), with `RETRIEVAL_QUERY` / `RETRIEVAL_DOCUMENT` task types.
|
||||
Gemini,
|
||||
/// Deterministic, offline. No network, no key.
|
||||
Mock,
|
||||
Gemini {
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
http: Client,
|
||||
},
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
fn default_base_url(self) -> &'static str {
|
||||
match self {
|
||||
Provider::OpenAiCompatible => DEFAULT_OPENAI_BASE_URL,
|
||||
Provider::Gemini => DEFAULT_GEMINI_BASE_URL,
|
||||
Provider::Mock => "",
|
||||
}
|
||||
}
|
||||
|
||||
fn default_model(self) -> &'static str {
|
||||
match self {
|
||||
Provider::OpenAiCompatible => DEFAULT_OPENAI_MODEL,
|
||||
Provider::Gemini => DEFAULT_GEMINI_MODEL,
|
||||
Provider::Mock => "",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the text being embedded is a search query or a stored document.
|
||||
/// Only Gemini distinguishes these (`RETRIEVAL_QUERY` vs `RETRIEVAL_DOCUMENT`);
|
||||
/// OpenAI-compatible providers and Mock produce the identical request for both,
|
||||
/// which is also the same-space property a query relies on.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum EmbedRole {
|
||||
Query,
|
||||
Document,
|
||||
}
|
||||
|
||||
/// The single source of truth for how embedding text becomes a vector:
|
||||
/// provider + model + endpoint + key. Resolved once (from env today; from the
|
||||
/// cluster `providers.embedding` profile in a later RFC-012 phase) and shared by
|
||||
/// the query path and the offline CLI so stored and query vectors stay
|
||||
/// same-space by construction.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EmbeddingConfig {
|
||||
pub provider: Provider,
|
||||
pub model: String,
|
||||
pub base_url: String,
|
||||
pub api_key: String,
|
||||
}
|
||||
|
||||
impl EmbeddingConfig {
|
||||
/// Resolve from the environment. Precedence:
|
||||
/// 1. `OMNIGRAPH_EMBEDDINGS_MOCK` → Mock.
|
||||
/// 2. `OMNIGRAPH_EMBED_PROVIDER` (`openai-compatible`|`openai`|`gemini`|`mock`);
|
||||
/// unset defaults to `openai-compatible` (OpenRouter).
|
||||
/// 3. `OMNIGRAPH_EMBED_BASE_URL` else the provider default.
|
||||
/// 4. `OMNIGRAPH_EMBED_MODEL` else the provider default.
|
||||
/// 5. provider api-key env (`OPENROUTER_API_KEY`/`OPENAI_API_KEY`, or `GEMINI_API_KEY`).
|
||||
pub fn from_env() -> Result<Self> {
|
||||
if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") {
|
||||
return Ok(Self::mock());
|
||||
}
|
||||
|
||||
let provider = match env_string("OMNIGRAPH_EMBED_PROVIDER").as_deref() {
|
||||
None | Some("openai-compatible") | Some("openai") => Provider::OpenAiCompatible,
|
||||
Some("gemini") => Provider::Gemini,
|
||||
Some("mock") => return Ok(Self::mock()),
|
||||
Some(other) => {
|
||||
return Err(OmniError::manifest_internal(format!(
|
||||
"unknown OMNIGRAPH_EMBED_PROVIDER '{}' (expected openai-compatible|gemini|mock)",
|
||||
other
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let base_url = env_string("OMNIGRAPH_EMBED_BASE_URL")
|
||||
.unwrap_or_else(|| provider.default_base_url().to_string())
|
||||
.trim_end_matches('/')
|
||||
.to_string();
|
||||
let model =
|
||||
env_string("OMNIGRAPH_EMBED_MODEL").unwrap_or_else(|| provider.default_model().to_string());
|
||||
|
||||
let api_key = match provider {
|
||||
Provider::OpenAiCompatible => env_string("OPENROUTER_API_KEY")
|
||||
.or_else(|| env_string("OPENAI_API_KEY"))
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(
|
||||
"OPENROUTER_API_KEY or OPENAI_API_KEY is required for the openai-compatible embedding provider",
|
||||
)
|
||||
})?,
|
||||
Provider::Gemini => env_string("GEMINI_API_KEY").ok_or_else(|| {
|
||||
OmniError::manifest_internal(
|
||||
"GEMINI_API_KEY is required for the gemini embedding provider",
|
||||
)
|
||||
})?,
|
||||
Provider::Mock => unreachable!("mock returns early"),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
model,
|
||||
base_url,
|
||||
api_key,
|
||||
})
|
||||
}
|
||||
|
||||
fn mock() -> Self {
|
||||
Self {
|
||||
provider: Provider::Mock,
|
||||
model: String::new(),
|
||||
base_url: String::new(),
|
||||
api_key: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EmbeddingClient {
|
||||
config: EmbeddingConfig,
|
||||
http: Client,
|
||||
retry_attempts: usize,
|
||||
retry_backoff_ms: u64,
|
||||
transport: EmbeddingTransport,
|
||||
/// Total wall-clock budget for one embed call, across all retries
|
||||
/// (`OMNIGRAPH_EMBED_QUERY_DEADLINE_MS`). `0` = unbounded.
|
||||
query_deadline_ms: u64,
|
||||
}
|
||||
|
||||
struct EmbedCallError {
|
||||
|
|
@ -58,35 +178,39 @@ struct GoogleErrorBody {
|
|||
message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiEmbeddingResponse {
|
||||
data: Vec<OpenAiEmbeddingDatum>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiEmbeddingDatum {
|
||||
index: usize,
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiErrorEnvelope {
|
||||
error: OpenAiErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiErrorBody {
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl EmbeddingClient {
|
||||
pub fn from_env() -> Result<Self> {
|
||||
Self::new(EmbeddingConfig::from_env()?)
|
||||
}
|
||||
|
||||
pub fn new(config: EmbeddingConfig) -> Result<Self> {
|
||||
let retry_attempts =
|
||||
parse_env_usize("OMNIGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
|
||||
let retry_backoff_ms =
|
||||
parse_env_u64("OMNIGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
|
||||
|
||||
if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") {
|
||||
return Ok(Self {
|
||||
retry_attempts,
|
||||
retry_backoff_ms,
|
||||
transport: EmbeddingTransport::Mock,
|
||||
});
|
||||
}
|
||||
|
||||
let api_key = std::env::var("GEMINI_API_KEY")
|
||||
.ok()
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(
|
||||
"GEMINI_API_KEY is required when nearest() needs a string embedding",
|
||||
)
|
||||
})?;
|
||||
let base_url = std::env::var("OMNIGRAPH_GEMINI_BASE_URL")
|
||||
.ok()
|
||||
.map(|v| v.trim_end_matches('/').to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.unwrap_or_else(|| DEFAULT_GEMINI_BASE_URL.to_string());
|
||||
let query_deadline_ms =
|
||||
parse_env_u64_allow_zero("OMNIGRAPH_EMBED_QUERY_DEADLINE_MS", DEFAULT_QUERY_DEADLINE_MS);
|
||||
let timeout_ms = parse_env_u64("OMNIGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
|
||||
let http = Client::builder()
|
||||
.timeout(Duration::from_millis(timeout_ms))
|
||||
|
|
@ -96,39 +220,36 @@ impl EmbeddingClient {
|
|||
})?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
http,
|
||||
retry_attempts,
|
||||
retry_backoff_ms,
|
||||
transport: EmbeddingTransport::Gemini {
|
||||
api_key,
|
||||
base_url,
|
||||
http,
|
||||
},
|
||||
query_deadline_ms,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &EmbeddingConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn mock_for_tests() -> Self {
|
||||
Self {
|
||||
retry_attempts: DEFAULT_RETRY_ATTEMPTS,
|
||||
retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
|
||||
transport: EmbeddingTransport::Mock,
|
||||
}
|
||||
Self::new(EmbeddingConfig::mock()).expect("mock client builds")
|
||||
}
|
||||
|
||||
pub async fn embed_query_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
|
||||
self.embed_text(input, expected_dim, QUERY_TASK_TYPE).await
|
||||
self.embed_text(input, expected_dim, EmbedRole::Query).await
|
||||
}
|
||||
|
||||
pub async fn embed_document_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
|
||||
self.embed_text(input, expected_dim, DOCUMENT_TASK_TYPE)
|
||||
.await
|
||||
self.embed_text(input, expected_dim, EmbedRole::Document).await
|
||||
}
|
||||
|
||||
async fn embed_text(
|
||||
&self,
|
||||
input: &str,
|
||||
expected_dim: usize,
|
||||
task_type: &'static str,
|
||||
role: EmbedRole,
|
||||
) -> Result<Vec<f32>> {
|
||||
if expected_dim == 0 {
|
||||
return Err(OmniError::manifest_internal(
|
||||
|
|
@ -136,10 +257,70 @@ impl EmbeddingClient {
|
|||
));
|
||||
}
|
||||
|
||||
match &self.transport {
|
||||
EmbeddingTransport::Mock => Ok(mock_embedding(input, expected_dim)),
|
||||
EmbeddingTransport::Gemini { .. } => {
|
||||
self.with_retry(|| self.embed_text_gemini_once(input, expected_dim, task_type))
|
||||
let started = std::time::Instant::now();
|
||||
let result = self
|
||||
.run_with_deadline(self.embed_text_inner(input, expected_dim, role))
|
||||
.await;
|
||||
let elapsed_ms = started.elapsed().as_millis() as u64;
|
||||
|
||||
match &result {
|
||||
Ok(_) => tracing::info!(
|
||||
target: "omnigraph::embedding",
|
||||
provider = ?self.config.provider,
|
||||
model = %self.config.model,
|
||||
dim = expected_dim,
|
||||
elapsed_ms,
|
||||
outcome = "ok",
|
||||
"embedding succeeded"
|
||||
),
|
||||
Err(err) => tracing::warn!(
|
||||
target: "omnigraph::embedding",
|
||||
provider = ?self.config.provider,
|
||||
model = %self.config.model,
|
||||
dim = expected_dim,
|
||||
elapsed_ms,
|
||||
outcome = "error",
|
||||
error = %err,
|
||||
"embedding failed"
|
||||
),
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Bound the whole embed operation (all retries + backoff) by
|
||||
/// `query_deadline_ms`, so a degraded provider can never hang a read for the
|
||||
/// full retry envelope. `0` = unbounded. Read-path only, so cancelling the
|
||||
/// in-flight request future on elapse is safe.
|
||||
async fn run_with_deadline<F>(&self, fut: F) -> Result<Vec<f32>>
|
||||
where
|
||||
F: Future<Output = Result<Vec<f32>>>,
|
||||
{
|
||||
if self.query_deadline_ms == 0 {
|
||||
return fut.await;
|
||||
}
|
||||
match tokio::time::timeout(Duration::from_millis(self.query_deadline_ms), fut).await {
|
||||
Ok(res) => res,
|
||||
Err(_elapsed) => Err(OmniError::manifest_internal(format!(
|
||||
"embedding deadline exceeded after {} ms (provider={:?}, model={})",
|
||||
self.query_deadline_ms, self.config.provider, self.config.model
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn embed_text_inner(
|
||||
&self,
|
||||
input: &str,
|
||||
expected_dim: usize,
|
||||
role: EmbedRole,
|
||||
) -> Result<Vec<f32>> {
|
||||
match self.config.provider {
|
||||
Provider::Mock => Ok(mock_embedding(input, expected_dim)),
|
||||
Provider::Gemini => {
|
||||
self.with_retry(|| self.embed_gemini_once(input, expected_dim, role))
|
||||
.await
|
||||
}
|
||||
Provider::OpenAiCompatible => {
|
||||
self.with_retry(|| self.embed_openai_once(input, expected_dim))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
|
@ -160,6 +341,14 @@ impl EmbeddingClient {
|
|||
if !err.retryable || attempt >= max_attempt {
|
||||
return Err(OmniError::manifest_internal(err.message));
|
||||
}
|
||||
tracing::warn!(
|
||||
target: "omnigraph::embedding",
|
||||
provider = ?self.config.provider,
|
||||
model = %self.config.model,
|
||||
attempt,
|
||||
error = %err.message,
|
||||
"embedding attempt failed, retrying"
|
||||
);
|
||||
let shift = (attempt - 1).min(10) as u32;
|
||||
let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
|
|
@ -168,25 +357,27 @@ impl EmbeddingClient {
|
|||
}
|
||||
}
|
||||
|
||||
async fn embed_text_gemini_once(
|
||||
async fn embed_gemini_once(
|
||||
&self,
|
||||
input: &str,
|
||||
expected_dim: usize,
|
||||
task_type: &'static str,
|
||||
role: EmbedRole,
|
||||
) -> std::result::Result<Vec<f32>, EmbedCallError> {
|
||||
let (api_key, base_url, http) = match &self.transport {
|
||||
EmbeddingTransport::Gemini {
|
||||
api_key,
|
||||
base_url,
|
||||
http,
|
||||
} => (api_key, base_url, http),
|
||||
EmbeddingTransport::Mock => unreachable!("mock transport should not call Gemini"),
|
||||
let task_type = match role {
|
||||
EmbedRole::Query => GEMINI_QUERY_TASK_TYPE,
|
||||
EmbedRole::Document => GEMINI_DOCUMENT_TASK_TYPE,
|
||||
};
|
||||
|
||||
let response = http
|
||||
.post(gemini_endpoint(base_url))
|
||||
.header("x-goog-api-key", api_key)
|
||||
.json(&build_gemini_request(input, expected_dim, task_type))
|
||||
let response = self
|
||||
.http
|
||||
.post(gemini_endpoint(&self.config.base_url, &self.config.model))
|
||||
.header("x-goog-api-key", &self.config.api_key)
|
||||
.json(&build_gemini_request(
|
||||
&self.config.model,
|
||||
input,
|
||||
expected_dim,
|
||||
task_type,
|
||||
))
|
||||
.send()
|
||||
.await;
|
||||
let response = match response {
|
||||
|
|
@ -205,10 +396,7 @@ impl EmbeddingClient {
|
|||
Ok(body) => body,
|
||||
Err(err) => {
|
||||
return Err(EmbedCallError {
|
||||
message: format!(
|
||||
"embedding response read failed (status {}): {}",
|
||||
status, err
|
||||
),
|
||||
message: format!("embedding response read failed (status {}): {}", status, err),
|
||||
retryable: status.is_server_error() || status.as_u16() == 429,
|
||||
});
|
||||
}
|
||||
|
|
@ -217,10 +405,7 @@ impl EmbeddingClient {
|
|||
if !status.is_success() {
|
||||
let message = parse_google_error_message(&body).unwrap_or(body);
|
||||
return Err(EmbedCallError {
|
||||
message: format!(
|
||||
"embedding request failed with status {}: {}",
|
||||
status, message
|
||||
),
|
||||
message: format!("embedding request failed with status {}: {}", status, message),
|
||||
retryable: status.is_server_error() || status.as_u16() == 429,
|
||||
});
|
||||
}
|
||||
|
|
@ -238,19 +423,85 @@ impl EmbeddingClient {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn embed_openai_once(
|
||||
&self,
|
||||
input: &str,
|
||||
expected_dim: usize,
|
||||
) -> std::result::Result<Vec<f32>, EmbedCallError> {
|
||||
let response = self
|
||||
.http
|
||||
.post(format!("{}/embeddings", self.config.base_url))
|
||||
.bearer_auth(&self.config.api_key)
|
||||
.json(&build_openai_request(&self.config.model, input, expected_dim))
|
||||
.send()
|
||||
.await;
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
let retryable = err.is_timeout() || err.is_connect() || err.is_request();
|
||||
return Err(EmbedCallError {
|
||||
message: format!("embedding request failed: {}", err),
|
||||
retryable,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
let body = match response.text().await {
|
||||
Ok(body) => body,
|
||||
Err(err) => {
|
||||
return Err(EmbedCallError {
|
||||
message: format!("embedding response read failed (status {}): {}", status, err),
|
||||
retryable: status.is_server_error() || status.as_u16() == 429,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if !status.is_success() {
|
||||
let message = parse_openai_error_message(&body).unwrap_or(body);
|
||||
return Err(EmbedCallError {
|
||||
message: format!("embedding request failed with status {}: {}", status, message),
|
||||
retryable: status.is_server_error() || status.as_u16() == 429,
|
||||
});
|
||||
}
|
||||
|
||||
let parsed: OpenAiEmbeddingResponse =
|
||||
serde_json::from_str(&body).map_err(|err| EmbedCallError {
|
||||
message: format!("embedding response decode failed: {}", err),
|
||||
retryable: false,
|
||||
})?;
|
||||
|
||||
// The query path embeds exactly one string, so expect one datum at index 0.
|
||||
let datum = parsed
|
||||
.data
|
||||
.into_iter()
|
||||
.find(|d| d.index == 0)
|
||||
.ok_or_else(|| EmbedCallError {
|
||||
message: "embedding response missing data[0]".to_string(),
|
||||
retryable: false,
|
||||
})?;
|
||||
|
||||
validate_and_normalize_embedding(datum.embedding, expected_dim).map_err(|message| {
|
||||
EmbedCallError {
|
||||
message,
|
||||
retryable: false,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn gemini_endpoint(base_url: &str) -> String {
|
||||
fn gemini_endpoint(base_url: &str, model: &str) -> String {
|
||||
format!(
|
||||
"{}/models/{}:embedContent",
|
||||
base_url.trim_end_matches('/'),
|
||||
GEMINI_EMBED_MODEL
|
||||
model
|
||||
)
|
||||
}
|
||||
|
||||
fn build_gemini_request(input: &str, expected_dim: usize, task_type: &'static str) -> Value {
|
||||
fn build_gemini_request(model: &str, input: &str, expected_dim: usize, task_type: &str) -> Value {
|
||||
json!({
|
||||
"model": format!("models/{}", GEMINI_EMBED_MODEL),
|
||||
"model": format!("models/{}", model),
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
|
|
@ -263,6 +514,14 @@ fn build_gemini_request(input: &str, expected_dim: usize, task_type: &'static st
|
|||
})
|
||||
}
|
||||
|
||||
fn build_openai_request(model: &str, input: &str, expected_dim: usize) -> Value {
|
||||
json!({
|
||||
"model": model,
|
||||
"input": [input],
|
||||
"dimensions": expected_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_and_normalize_embedding(
|
||||
values: Vec<f32>,
|
||||
expected_dim: usize,
|
||||
|
|
@ -298,6 +557,20 @@ fn parse_google_error_message(body: &str) -> Option<String> {
|
|||
.filter(|msg| !msg.trim().is_empty())
|
||||
}
|
||||
|
||||
fn parse_openai_error_message(body: &str) -> Option<String> {
|
||||
serde_json::from_str::<OpenAiErrorEnvelope>(body)
|
||||
.ok()
|
||||
.map(|e| e.error.message)
|
||||
.filter(|msg| !msg.trim().is_empty())
|
||||
}
|
||||
|
||||
fn env_string(name: &str) -> Option<String> {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
}
|
||||
|
||||
fn parse_env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
|
|
@ -314,6 +587,15 @@ fn parse_env_u64(name: &str, default: u64) -> u64 {
|
|||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
/// Like [`parse_env_u64`] but accepts `0` as a meaningful value (the deadline
|
||||
/// uses `0` for "unbounded").
|
||||
fn parse_env_u64_allow_zero(name: &str, default: u64) -> u64 {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|v| v.trim().parse::<u64>().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_flag(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
|
|
@ -395,6 +677,25 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
// Every test that calls `EmbeddingConfig::from_env` clears the full set of
|
||||
// embedding env vars first so the host environment can't leak in.
|
||||
const EMBED_ENV: &[&str] = &[
|
||||
"OMNIGRAPH_EMBEDDINGS_MOCK",
|
||||
"OMNIGRAPH_EMBED_PROVIDER",
|
||||
"OMNIGRAPH_EMBED_BASE_URL",
|
||||
"OMNIGRAPH_EMBED_MODEL",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"GEMINI_API_KEY",
|
||||
];
|
||||
|
||||
fn cleared_env(extra: &[(&'static str, Option<&str>)]) -> EnvGuard {
|
||||
let mut vars: Vec<(&'static str, Option<&str>)> =
|
||||
EMBED_ENV.iter().map(|n| (*n, None)).collect();
|
||||
vars.extend_from_slice(extra);
|
||||
EnvGuard::set(&vars)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_embeddings_are_deterministic() {
|
||||
let client = EmbeddingClient::mock_for_tests();
|
||||
|
|
@ -407,18 +708,30 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn gemini_request_uses_preview_model_retrieval_query_and_dimension() {
|
||||
let request = build_gemini_request("alpha", 4, QUERY_TASK_TYPE);
|
||||
assert_eq!(request["model"], "models/gemini-embedding-2-preview");
|
||||
assert_eq!(request["taskType"], QUERY_TASK_TYPE);
|
||||
fn gemini_request_uses_model_retrieval_query_and_dimension() {
|
||||
let request =
|
||||
build_gemini_request("gemini-embedding-2", "alpha", 4, GEMINI_QUERY_TASK_TYPE);
|
||||
assert_eq!(request["model"], "models/gemini-embedding-2");
|
||||
assert_eq!(request["taskType"], GEMINI_QUERY_TASK_TYPE);
|
||||
assert_eq!(request["outputDimensionality"], 4);
|
||||
assert_eq!(request["content"]["parts"][0]["text"], "alpha");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemini_document_request_uses_retrieval_document_task_type() {
|
||||
let request = build_gemini_request("alpha", 4, DOCUMENT_TASK_TYPE);
|
||||
assert_eq!(request["taskType"], DOCUMENT_TASK_TYPE);
|
||||
let request =
|
||||
build_gemini_request("gemini-embedding-2", "alpha", 4, GEMINI_DOCUMENT_TASK_TYPE);
|
||||
assert_eq!(request["taskType"], GEMINI_DOCUMENT_TASK_TYPE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openai_request_uses_model_input_array_and_dimensions() {
|
||||
let request = build_openai_request("openai/text-embedding-3-large", "alpha", 4);
|
||||
assert_eq!(request["model"], "openai/text-embedding-3-large");
|
||||
assert_eq!(request["input"][0], "alpha");
|
||||
assert!(request["input"].is_array());
|
||||
assert_eq!(request["dimensions"], 4);
|
||||
assert!(request.get("taskType").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -475,15 +788,113 @@ mod tests {
|
|||
assert!(err.to_string().contains("do not retry"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_with_deadline_aborts_slow_future() {
|
||||
let mut client = EmbeddingClient::mock_for_tests();
|
||||
client.query_deadline_ms = 20;
|
||||
let slow = async {
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
Ok(vec![0.0_f32])
|
||||
};
|
||||
let err = client.run_with_deadline(slow).await.unwrap_err();
|
||||
assert!(err.to_string().contains("deadline exceeded"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_with_deadline_passes_through_fast_future() {
|
||||
let client = EmbeddingClient::mock_for_tests();
|
||||
let ok = client
|
||||
.run_with_deadline(async { Ok(vec![1.0_f32, 2.0]) })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(ok, vec![1.0, 2.0]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_with_deadline_zero_is_unbounded() {
|
||||
let mut client = EmbeddingClient::mock_for_tests();
|
||||
client.query_deadline_ms = 0;
|
||||
let ok = client
|
||||
.run_with_deadline(async { Ok(vec![3.0_f32]) })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(ok, vec![3.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn from_env_requires_gemini_api_key_when_not_mocking() {
|
||||
let _guard = EnvGuard::set(&[
|
||||
("OMNIGRAPH_EMBEDDINGS_MOCK", None),
|
||||
("GEMINI_API_KEY", None),
|
||||
]);
|
||||
fn from_env_defaults_to_openai_compatible_openrouter() {
|
||||
let _guard = cleared_env(&[("OPENROUTER_API_KEY", Some("sk-test"))]);
|
||||
let config = EmbeddingConfig::from_env().unwrap();
|
||||
assert_eq!(config.provider, Provider::OpenAiCompatible);
|
||||
assert_eq!(config.base_url, DEFAULT_OPENAI_BASE_URL);
|
||||
assert_eq!(config.model, DEFAULT_OPENAI_MODEL);
|
||||
assert_eq!(config.api_key, "sk-test");
|
||||
}
|
||||
|
||||
let err = EmbeddingClient::from_env().unwrap_err();
|
||||
assert!(err.to_string().contains("GEMINI_API_KEY"));
|
||||
#[test]
|
||||
#[serial]
|
||||
fn from_env_openai_compatible_prefers_openrouter_key() {
|
||||
let _guard = cleared_env(&[
|
||||
("OPENROUTER_API_KEY", Some("router")),
|
||||
("OPENAI_API_KEY", Some("openai")),
|
||||
]);
|
||||
let config = EmbeddingConfig::from_env().unwrap();
|
||||
assert_eq!(config.api_key, "router");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn from_env_explicit_gemini_provider() {
|
||||
let _guard = cleared_env(&[
|
||||
("OMNIGRAPH_EMBED_PROVIDER", Some("gemini")),
|
||||
("GEMINI_API_KEY", Some("g-key")),
|
||||
]);
|
||||
let config = EmbeddingConfig::from_env().unwrap();
|
||||
assert_eq!(config.provider, Provider::Gemini);
|
||||
assert_eq!(config.base_url, DEFAULT_GEMINI_BASE_URL);
|
||||
assert_eq!(config.model, DEFAULT_GEMINI_MODEL);
|
||||
assert_eq!(config.api_key, "g-key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn from_env_base_url_and_model_overrides_apply() {
|
||||
let _guard = cleared_env(&[
|
||||
("OMNIGRAPH_EMBED_PROVIDER", Some("openai-compatible")),
|
||||
("OMNIGRAPH_EMBED_BASE_URL", Some("https://example.test/v1/")),
|
||||
("OMNIGRAPH_EMBED_MODEL", Some("custom/model")),
|
||||
("OPENAI_API_KEY", Some("k")),
|
||||
]);
|
||||
let config = EmbeddingConfig::from_env().unwrap();
|
||||
assert_eq!(config.base_url, "https://example.test/v1"); // trailing slash trimmed
|
||||
assert_eq!(config.model, "custom/model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn from_env_unknown_provider_errors() {
|
||||
let _guard = cleared_env(&[("OMNIGRAPH_EMBED_PROVIDER", Some("cohere"))]);
|
||||
let err = EmbeddingConfig::from_env().unwrap_err();
|
||||
assert!(err.to_string().contains("unknown OMNIGRAPH_EMBED_PROVIDER"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn from_env_errors_when_no_key_present() {
|
||||
let _guard = cleared_env(&[]);
|
||||
let err = EmbeddingConfig::from_env().unwrap_err();
|
||||
assert!(err.to_string().contains("OPENROUTER_API_KEY or OPENAI_API_KEY"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn from_env_mock_flag_wins() {
|
||||
let _guard = cleared_env(&[
|
||||
("OMNIGRAPH_EMBEDDINGS_MOCK", Some("1")),
|
||||
("OMNIGRAPH_EMBED_PROVIDER", Some("gemini")),
|
||||
]);
|
||||
let config = EmbeddingConfig::from_env().unwrap();
|
||||
assert_eq!(config.provider, Provider::Mock);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -510,9 +510,14 @@ async fn explicit_vector_nearest_does_not_require_gemini_credentials() {
|
|||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn string_nearest_requires_gemini_credentials_when_mock_is_disabled() {
|
||||
async fn string_nearest_requires_provider_credentials_when_mock_is_disabled() {
|
||||
// With mock off and no provider key, the default (openai-compatible)
|
||||
// provider fails loudly rather than silently producing garbage vectors.
|
||||
let _guard = EnvGuard::set(&[
|
||||
("OMNIGRAPH_EMBEDDINGS_MOCK", None),
|
||||
("OMNIGRAPH_EMBED_PROVIDER", None),
|
||||
("OPENROUTER_API_KEY", None),
|
||||
("OPENAI_API_KEY", None),
|
||||
("GEMINI_API_KEY", None),
|
||||
]);
|
||||
|
||||
|
|
@ -528,7 +533,11 @@ async fn string_nearest_requires_gemini_credentials_when_mock_is_disabled() {
|
|||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("GEMINI_API_KEY"));
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("OPENROUTER_API_KEY or OPENAI_API_KEY"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── BM25 search ────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue