From 790c0c84fe033afbdd20c6494da22357652c7c71 Mon Sep 17 00:00:00 2001 From: Jan De Landtsheer Date: Tue, 21 Apr 2026 21:43:52 +0200 Subject: [PATCH] feat(storage): phase 1 -- extract MemoryStore and Embedder traits (ADR 0001) Introduce two trait boundaries that the rest of the stack now sits above, landing Phase 1 of ADR 0001 (pluggable storage and network access). Rebased onto v2.1.22 Sanhedrin from the original April work. MemoryStore / LocalMemoryStore (crates/vestige-core/src/storage/memory_store.rs): One trait, ~25 methods, covering CRUD, hybrid / FTS / vector search, FSRS scheduling, graph edges, and the forthcoming domain surface. trait_variant::make generates a Send-bound MemoryStore alias over the base LocalMemoryStore so Arc works under tokio/axum. Storage errors map through a dedicated MemoryStoreError. Embedder / LocalEmbedder (crates/vestige-core/src/embedder/): Pluggable text-to-vector encoder. FastembedEmbedder wraps the existing EmbeddingService; storage never calls fastembed directly anymore. Embedder::signature() produces the ModelSignature consumed by the store's embedding_model registry. SqliteMemoryStore (crates/vestige-core/src/storage/sqlite.rs): Storage renamed to SqliteMemoryStore; the old name lives on as a pub type alias so Arc consumers in vestige-mcp stay intact. All existing inherent methods are untouched; the trait impl is purely additive and dispatches into them. The db_path field added by v2.1.1 portable-sync is preserved. Migration V14 (crates/vestige-core/src/storage/migrations.rs): Renumbered from V12 (the original April number) to V14 to slot in cleanly after upstream's V12 (v2.1.1 sync_tombstones) and V13 (v2.1.2 purge tombstones). - embedding_model registry table (CHECK id = 1, code enforces the single-row invariant). - knowledge_nodes.domains / domain_scores TEXT columns (JSON arrays default '[]' / '{}'), domains catalogue table, supporting indexes. Phase 4 populates these columns; Phase 1 just exposes the schema. Consolidation and other cognitive pathways now accept a &dyn LocalMemoryStore (sync) or Arc (async) rather than a concrete Storage. Tests: - trait-method unit tests colocated in sqlite.rs and migrations.rs - embedder/fastembed.rs tests for name/dimension/hash stability - new integration crate tests/phase_1 (added to workspace members): trait_round_trip (8), embedding_model_registry (7), domain_column_migration (5), cognitive_module_isolation (4), send_bound_variant (2), embedder_trait (2). Acceptance gate post-rebase: - cargo build --workspace --all-targets: ok - cargo clippy --workspace --all-targets -- -D warnings: clean - cargo test -p vestige-core --lib: 428 pass - cargo test -p vestige-phase-1-tests: 28 pass - cargo test -p vestige-mcp --lib: 380 pass (Storage alias preserves every existing call site) Co-existence with v2.1.1 portable-sync: this trait extraction is additive. Portable-sync's tombstone migrations (V12, V13) remain on the concrete SqliteMemoryStore; Phase 2 (Postgres) will decide which of those surfaces graduate into the trait. --- Cargo.lock | 113 +- Cargo.toml | 1 + crates/vestige-core/Cargo.toml | 3 + crates/vestige-core/src/embedder/fastembed.rs | 179 ++ crates/vestige-core/src/embedder/mod.rs | 58 + crates/vestige-core/src/lib.rs | 15 +- .../vestige-core/src/storage/memory_store.rs | 319 ++++ crates/vestige-core/src/storage/migrations.rs | 193 ++- crates/vestige-core/src/storage/mod.rs | 19 +- crates/vestige-core/src/storage/sqlite.rs | 1504 ++++++++++++++++- tests/phase_1/Cargo.toml | 38 + tests/phase_1/cognitive_module_isolation.rs | 127 ++ tests/phase_1/domain_column_migration.rs | 143 ++ tests/phase_1/embedder_trait.rs | 36 + tests/phase_1/embedding_model_registry.rs | 133 ++ tests/phase_1/send_bound_variant.rs | 96 ++ tests/phase_1/trait_round_trip.rs | 201 +++ 17 files changed, 3139 insertions(+), 39 deletions(-) create mode 100644 crates/vestige-core/src/embedder/fastembed.rs create mode 100644 crates/vestige-core/src/embedder/mod.rs create mode 100644 crates/vestige-core/src/storage/memory_store.rs create mode 100644 tests/phase_1/Cargo.toml create mode 100644 tests/phase_1/cognitive_module_isolation.rs create mode 100644 tests/phase_1/domain_column_migration.rs create mode 100644 tests/phase_1/embedder_trait.rs create mode 100644 tests/phase_1/embedding_model_registry.rs create mode 100644 tests/phase_1/send_bound_variant.rs create mode 100644 tests/phase_1/trait_round_trip.rs diff --git a/Cargo.lock b/Cargo.lock index 70f3956..35c1f8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -143,6 +143,12 @@ dependencies = [ "syn", ] +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" @@ -158,6 +164,17 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -311,6 +328,20 @@ dependencies = [ "core2", ] +[[package]] +name = "blake3" +version = "1.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d2d5991425dfd0785aed03aedcf0b321d61975c9b5b3689c774a2610ae0b51e" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", + "cpufeatures 0.3.0", +] + [[package]] name = "block" version = "0.1.6" @@ -630,6 +661,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "constant_time_eq" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" + [[package]] name = "core-foundation" version = "0.9.4" @@ -685,6 +722,15 @@ dependencies = [ "libc", ] +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -2208,12 +2254,10 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.95" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" dependencies = [ - "cfg-if", - "futures-util", "once_cell", "wasm-bindgen", ] @@ -3107,9 +3151,9 @@ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" dependencies = [ "portable-atomic", ] @@ -3748,7 +3792,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -4271,6 +4315,17 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -4533,6 +4588,8 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" name = "vestige-core" version = "2.1.22" dependencies = [ + "async-trait", + "blake3", "candle-core", "chrono", "criterion", @@ -4548,6 +4605,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tracing", + "trait-variant", "usearch", "uuid", ] @@ -4594,6 +4652,19 @@ dependencies = [ "vestige-core", ] +[[package]] +name = "vestige-phase-1-tests" +version = "0.0.1" +dependencies = [ + "chrono", + "rusqlite", + "serde_json", + "tempfile", + "tokio", + "uuid", + "vestige-core", +] + [[package]] name = "walkdir" version = "2.5.0" @@ -4639,9 +4710,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.118" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" dependencies = [ "cfg-if", "once_cell", @@ -4652,19 +4723,23 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.68" +version = "0.4.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" +checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8" dependencies = [ + "cfg-if", + "futures-util", "js-sys", + "once_cell", "wasm-bindgen", + "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.118" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4672,9 +4747,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.118" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" dependencies = [ "bumpalo", "proc-macro2", @@ -4685,9 +4760,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.118" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" dependencies = [ "unicode-ident", ] @@ -4741,9 +4816,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.95" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2dfbb17949fa2088e5d39408c48368947b86f7834484e87b73de55bc14d97d" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 5cbe508..03111d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "crates/vestige-core", "crates/vestige-mcp", "tests/e2e", + "tests/phase_1", ] exclude = [ "fastembed-rs", diff --git a/crates/vestige-core/Cargo.toml b/crates/vestige-core/Cargo.toml index e84625e..17c19d2 100644 --- a/crates/vestige-core/Cargo.toml +++ b/crates/vestige-core/Cargo.toml @@ -114,6 +114,9 @@ usearch = { version = "=2.23.0", optional = true } # LRU cache for query embeddings lru = "0.16" +trait-variant = "0.1" +blake3 = "1" +async-trait = "0.1" [dev-dependencies] tempfile = "3" diff --git a/crates/vestige-core/src/embedder/fastembed.rs b/crates/vestige-core/src/embedder/fastembed.rs new file mode 100644 index 0000000..f1dbcb5 --- /dev/null +++ b/crates/vestige-core/src/embedder/fastembed.rs @@ -0,0 +1,179 @@ +//! `FastembedEmbedder` -- adapts the existing `EmbeddingService` to the +//! `LocalEmbedder` trait. + +#[cfg(feature = "embeddings")] +use crate::embeddings::{EMBEDDING_DIMENSIONS, EmbeddingService}; + +use super::{EmbedderError, EmbedderResult, LocalEmbedder}; + +pub struct FastembedEmbedder { + #[cfg(feature = "embeddings")] + inner: EmbeddingService, + cached_hash: std::sync::OnceLock, +} + +impl FastembedEmbedder { + pub fn new() -> Self { + Self { + #[cfg(feature = "embeddings")] + inner: EmbeddingService::new(), + cached_hash: std::sync::OnceLock::new(), + } + } + + fn compute_hash(name: &str, dim: usize) -> String { + let mut hasher = blake3::Hasher::new(); + hasher.update(name.as_bytes()); + hasher.update(&(dim as u64).to_le_bytes()); + // fastembed's ONNX bytes are not directly accessible at runtime; we + // use `(name, dim, vestige-core CARGO_PKG_VERSION)` as the + // signature. If fastembed ever changes its output deterministically + // between minor versions, bumping the crate version triggers a + // mismatch -- which is exactly the drift we want to detect. + hasher.update(env!("CARGO_PKG_VERSION").as_bytes()); + hasher.finalize().to_hex().to_string() + } +} + +impl Default for FastembedEmbedder { + fn default() -> Self { + Self::new() + } +} + +#[async_trait::async_trait] +impl LocalEmbedder for FastembedEmbedder { + async fn embed(&self, text: &str) -> EmbedderResult> { + #[cfg(feature = "embeddings")] + { + let emb = self + .inner + .embed(text) + .map_err(|e| EmbedderError::EmbedFailed(e.to_string()))?; + Ok(emb.vector) + } + #[cfg(not(feature = "embeddings"))] + { + let _ = text; + Err(EmbedderError::Init( + "embeddings feature not enabled".to_string(), + )) + } + } + + fn model_name(&self) -> &str { + #[cfg(feature = "embeddings")] + { + self.inner.model_name() + } + #[cfg(not(feature = "embeddings"))] + { + "nomic-ai/nomic-embed-text-v1.5" + } + } + + fn dimension(&self) -> usize { + #[cfg(feature = "embeddings")] + { + EMBEDDING_DIMENSIONS + } + #[cfg(not(feature = "embeddings"))] + { + 256 + } + } + + fn model_hash(&self) -> String { + self.cached_hash + .get_or_init(|| Self::compute_hash(self.model_name(), self.dimension())) + .clone() + } + + async fn embed_batch(&self, texts: &[&str]) -> EmbedderResult>> { + #[cfg(feature = "embeddings")] + { + let embs = self + .inner + .embed_batch(texts) + .map_err(|e| EmbedderError::EmbedFailed(e.to_string()))?; + Ok(embs.into_iter().map(|e| e.vector).collect()) + } + #[cfg(not(feature = "embeddings"))] + { + let _ = texts; + Err(EmbedderError::Init( + "embeddings feature not enabled".to_string(), + )) + } + } +} + +// ============================================================================ +// UNIT TESTS +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn embedder_reports_correct_name() { + let e = FastembedEmbedder::new(); + assert!(e.model_name().contains("nomic"), "model name should contain 'nomic'"); + } + + #[test] + fn embedder_reports_256_dimension() { + let e = FastembedEmbedder::new(); + assert_eq!(e.dimension(), 256); + } + + #[test] + fn embedder_hash_is_stable() { + let e = FastembedEmbedder::new(); + let h1 = e.model_hash(); + let h2 = e.model_hash(); + assert_eq!(h1, h2, "model_hash must be stable across calls"); + } + + #[test] + fn embedder_hash_includes_crate_version() { + // Compute what the hash should be given the known inputs + let name = FastembedEmbedder::new().model_name().to_string(); + let dim = FastembedEmbedder::new().dimension(); + let expected = FastembedEmbedder::compute_hash(&name, dim); + let got = FastembedEmbedder::new().model_hash(); + assert_eq!(got, expected); + } + + #[test] + fn embedder_signature_matches_accessors() { + let e = FastembedEmbedder::new(); + let sig = e.signature(); + assert_eq!(sig.name, e.model_name()); + assert_eq!(sig.dimension, e.dimension()); + assert_eq!(sig.hash, e.model_hash()); + } + + #[cfg(feature = "embeddings")] + #[test] + fn embedder_embed_smoke() { + let e = FastembedEmbedder::new(); + let rt = tokio::runtime::Runtime::new().unwrap(); + let vec = rt.block_on(e.embed("hello world")).expect("embed"); + assert_eq!(vec.len(), 256); + } + + #[cfg(feature = "embeddings")] + #[test] + fn embedder_embed_batch_matches_sequential() { + let e = FastembedEmbedder::new(); + let rt = tokio::runtime::Runtime::new().unwrap(); + let texts = ["alpha beta", "gamma delta"]; + let batch = rt.block_on(e.embed_batch(texts.as_ref())).expect("batch"); + let seq_a = rt.block_on(e.embed(texts[0])).expect("seq a"); + let seq_b = rt.block_on(e.embed(texts[1])).expect("seq b"); + assert_eq!(batch[0], seq_a); + assert_eq!(batch[1], seq_b); + } +} diff --git a/crates/vestige-core/src/embedder/mod.rs b/crates/vestige-core/src/embedder/mod.rs new file mode 100644 index 0000000..7db4434 --- /dev/null +++ b/crates/vestige-core/src/embedder/mod.rs @@ -0,0 +1,58 @@ +//! Text-to-vector encoding trait. Pluggable per-install. + +mod fastembed; + +pub use fastembed::FastembedEmbedder; + +/// Error returned by every `Embedder` method. +#[non_exhaustive] +#[derive(Debug, thiserror::Error)] +pub enum EmbedderError { + #[error("embedder initialization failed: {0}")] + Init(String), + #[error("embedding generation failed: {0}")] + EmbedFailed(String), + #[error("invalid input: {0}")] + InvalidInput(String), +} + +pub type EmbedderResult = std::result::Result; + +/// Pluggable embedder. The storage layer NEVER calls fastembed directly; +/// callers compute vectors via this trait and pass them into `MemoryStore`. +/// +/// `#[async_trait::async_trait]` makes every `async fn` return a +/// `Pin>`, which is required for `Box` +/// and `Arc` to be dyn-compatible. +#[async_trait::async_trait] +pub trait LocalEmbedder: Send + Sync + 'static { + async fn embed(&self, text: &str) -> EmbedderResult>; + + fn model_name(&self) -> &str; + + fn dimension(&self) -> usize; + + /// Stable blake3 hash of (model_name || dimension || vestige-core crate version). + /// Lowercase hex, 64 chars. + /// + /// Used by `MemoryStore::register_model` to detect silent model drift + /// (e.g. a fastembed minor upgrade that changes vector output). + fn model_hash(&self) -> String; + + async fn embed_batch(&self, texts: &[&str]) -> EmbedderResult>>; + + /// Returns the `ModelSignature` describing this embedder. Convenience + /// wrapper over the three accessors above. + fn signature(&self) -> crate::storage::ModelSignature { + crate::storage::ModelSignature { + name: self.model_name().to_string(), + dimension: self.dimension(), + hash: self.model_hash(), + } + } +} + +/// Type alias: `Embedder` is the dyn-compatible, Send+Sync variant. +/// Both names refer to the same `async_trait`-annotated trait. +pub use LocalEmbedder as Embedder; + diff --git a/crates/vestige-core/src/lib.rs b/crates/vestige-core/src/lib.rs index 640ba4a..dc6a150 100644 --- a/crates/vestige-core/src/lib.rs +++ b/crates/vestige-core/src/lib.rs @@ -81,6 +81,7 @@ // ============================================================================ pub mod consolidation; +pub mod embedder; pub mod fsrs; pub mod fts; pub mod memory; @@ -152,11 +153,19 @@ pub use fsrs::{ // Storage layer pub use storage::{ - ConnectionRecord, ConsolidationHistoryRecord, DreamHistoryRecord, InsightRecord, - IntentionRecord, PORTABLE_ARCHIVE_FORMAT, PortableArchive, PortableImportMode, - PortableImportReport, Result, SmartIngestResult, StateTransitionRecord, Storage, StorageError, + ClassificationResult, ConnectionRecord, ConsolidationHistoryRecord, Domain, + DreamHistoryRecord, HealthStatus, InsightRecord, IntentionRecord, LocalMemoryStore, + MemoryEdge, MemoryRecord, MemoryStore, MemoryStoreError, MemoryStoreResult, ModelSignature, + PORTABLE_ARCHIVE_FORMAT, PortableArchive, PortableImportMode, PortableImportReport, Result, + SchedulingState, SearchQuery, SmartIngestResult, SqliteMemoryStore, StateTransitionRecord, + Storage, StorageError, StoreStats, + // Note: storage::SearchResult is intentionally not re-exported here to avoid + // collision with memory::SearchResult. Use vestige_core::storage::SearchResult directly. }; +// Embedder trait and implementations +pub use embedder::{Embedder, EmbedderError, EmbedderResult, FastembedEmbedder, LocalEmbedder}; + // Consolidation (sleep-inspired memory processing) pub use consolidation::SleepConsolidation; pub use consolidation::{ diff --git a/crates/vestige-core/src/storage/memory_store.rs b/crates/vestige-core/src/storage/memory_store.rs new file mode 100644 index 0000000..36c39ed --- /dev/null +++ b/crates/vestige-core/src/storage/memory_store.rs @@ -0,0 +1,319 @@ +//! Backend-agnostic memory store trait. +//! +//! This is the single abstraction every cognitive module sits above. It is +//! intentionally flat: one trait, ~25 methods, no sub-traits. + +use std::collections::HashMap; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +// ---------------------------------------------------------------------------- +// ERROR +// ---------------------------------------------------------------------------- + +/// Error returned by every `LocalMemoryStore` / `MemoryStore` method. +#[non_exhaustive] +#[derive(Debug, thiserror::Error)] +pub enum MemoryStoreError { + #[error("not found: {0}")] + NotFound(String), + + #[error("backend error: {0}")] + Backend(String), + + #[error( + "embedding model mismatch: store registered {registered_name} (dim {registered_dim}, \ + hash {registered_hash}), embedder is {actual_name} (dim {actual_dim}, hash {actual_hash})" + )] + ModelMismatch { + registered_name: String, + registered_dim: usize, + registered_hash: String, + actual_name: String, + actual_dim: usize, + actual_hash: String, + }, + + #[error("invalid input: {0}")] + InvalidInput(String), + + #[error("initialization error: {0}")] + Init(String), +} + +impl From for MemoryStoreError { + fn from(e: crate::storage::StorageError) -> Self { + use crate::storage::StorageError as S; + match e { + S::NotFound(s) => MemoryStoreError::NotFound(s), + S::Database(e) => MemoryStoreError::Backend(e.to_string()), + S::Io(e) => MemoryStoreError::Backend(e.to_string()), + S::InvalidTimestamp(s) => MemoryStoreError::Backend(format!("invalid timestamp: {s}")), + S::Init(s) => MemoryStoreError::Init(s), + } + } +} + +pub type MemoryStoreResult = std::result::Result; + +// ---------------------------------------------------------------------------- +// DATA TYPES +// ---------------------------------------------------------------------------- + +/// Backend-agnostic memory record. +/// +/// Phase 1 intentionally keeps this type independent of `KnowledgeNode` to +/// avoid dragging 30+ legacy fields through the trait surface. The SQLite +/// backend converts between `MemoryRecord` and `KnowledgeNode` at the +/// boundary. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryRecord { + pub id: Uuid, + /// Empty = unclassified. Populated in Phase 4. + pub domains: Vec, + /// Raw similarity per domain centroid. Empty until Phase 4 runs clustering. + pub domain_scores: HashMap, + pub content: String, + pub node_type: String, + pub tags: Vec, + pub embedding: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub metadata: serde_json::Value, +} + +/// FSRS-6 scheduling state, one row per memory. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SchedulingState { + pub memory_id: Uuid, + pub stability: f64, + pub difficulty: f64, + pub retrievability: f64, + pub last_review: Option>, + pub next_review: Option>, + pub reps: u32, + pub lapses: u32, +} + +/// Hybrid search request. +#[derive(Debug, Clone, Default)] +pub struct SearchQuery { + pub domains: Option>, + pub text: Option, + pub embedding: Option>, + pub tags: Option>, + pub node_types: Option>, + pub limit: usize, + pub min_retrievability: Option, +} + +#[derive(Debug, Clone)] +pub struct SearchResult { + pub record: MemoryRecord, + pub score: f64, + pub fts_score: Option, + pub vector_score: Option, +} + +/// Edge in the spreading-activation graph. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryEdge { + pub source_id: Uuid, + pub target_id: Uuid, + pub edge_type: String, + pub weight: f64, + pub created_at: DateTime, +} + +/// A topical domain (populated in Phase 4). Phase 1 only needs the type to +/// shape the trait surface; discover/classify are Phase 4 work. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Domain { + pub id: String, + pub label: String, + pub centroid: Vec, + pub top_terms: Vec, + pub memory_count: usize, + pub created_at: DateTime, +} + +/// Result of classifying one vector against all known domains. +#[derive(Debug, Clone)] +pub struct ClassificationResult { + pub scores: HashMap, + pub domains: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct StoreStats { + pub total_memories: usize, + pub memories_with_embeddings: usize, + pub total_edges: usize, + pub total_domains: usize, + pub registered_model_name: Option, + pub registered_model_dim: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum HealthStatus { + Healthy, + Degraded { reason: String }, + Unavailable { reason: String }, +} + +// ---------------------------------------------------------------------------- +// EMBEDDING MODEL SIGNATURE +// ---------------------------------------------------------------------------- + +/// Snapshot of the embedding model that was used to write vectors into the +/// store. Persisted in the `embedding_model` table; compared on every write +/// before the vector is accepted. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ModelSignature { + pub name: String, + pub dimension: usize, + /// Lowercase hex-encoded blake3 hash, 64 chars. + pub hash: String, +} + +// ---------------------------------------------------------------------------- +// TRAIT +// ---------------------------------------------------------------------------- + +/// The single storage abstraction. +/// +/// `#[async_trait::async_trait]` makes every `async fn` return a +/// `Pin>`, which is required for `Arc` +/// to be movable across `tokio::spawn` boundaries. +/// +/// `LocalMemoryStore` is a type alias kept for source compatibility with code +/// that refers to the non-send variant. In Phase 1 both names refer to the same +/// (dyn-compatible, Send-safe) trait. +#[async_trait::async_trait] +pub trait MemoryStore: Send + Sync + 'static { + // --- Lifecycle --- + async fn init(&self) -> MemoryStoreResult<()>; + async fn health_check(&self) -> MemoryStoreResult; + + // --- Embedding model registry --- + async fn registered_model(&self) -> MemoryStoreResult>; + async fn register_model(&self, sig: &ModelSignature) -> MemoryStoreResult<()>; + + // --- CRUD --- + async fn insert(&self, record: &MemoryRecord) -> MemoryStoreResult; + async fn get(&self, id: Uuid) -> MemoryStoreResult>; + async fn update(&self, record: &MemoryRecord) -> MemoryStoreResult<()>; + async fn delete(&self, id: Uuid) -> MemoryStoreResult<()>; + + // --- Search --- + async fn search(&self, query: &SearchQuery) -> MemoryStoreResult>; + async fn fts_search(&self, text: &str, limit: usize) -> MemoryStoreResult>; + async fn vector_search( + &self, + embedding: &[f32], + limit: usize, + ) -> MemoryStoreResult>; + + // --- FSRS Scheduling --- + async fn get_scheduling( + &self, + memory_id: Uuid, + ) -> MemoryStoreResult>; + async fn update_scheduling(&self, state: &SchedulingState) -> MemoryStoreResult<()>; + async fn get_due_memories( + &self, + before: DateTime, + limit: usize, + ) -> MemoryStoreResult>; + + // --- Graph (spreading activation) --- + async fn add_edge(&self, edge: &MemoryEdge) -> MemoryStoreResult<()>; + async fn get_edges( + &self, + node_id: Uuid, + edge_type: Option<&str>, + ) -> MemoryStoreResult>; + async fn remove_edge(&self, source: Uuid, target: Uuid) -> MemoryStoreResult<()>; + async fn get_neighbors( + &self, + node_id: Uuid, + depth: usize, + ) -> MemoryStoreResult>; + + // --- Domains (Phase 1: stubs return empty; full impl in Phase 4) --- + async fn list_domains(&self) -> MemoryStoreResult>; + async fn get_domain(&self, id: &str) -> MemoryStoreResult>; + async fn upsert_domain(&self, domain: &Domain) -> MemoryStoreResult<()>; + async fn delete_domain(&self, id: &str) -> MemoryStoreResult<()>; + /// Phase 1: returns `Ok(vec![])` since no centroids exist. Phase 4 wires + /// the full soft-assignment pass. + async fn classify(&self, embedding: &[f32]) -> MemoryStoreResult>; + + // --- Bulk / Maintenance --- + async fn count(&self) -> MemoryStoreResult; + async fn get_stats(&self) -> MemoryStoreResult; + async fn vacuum(&self) -> MemoryStoreResult<()>; +} + +/// Type alias kept for source compatibility. Both names refer to the same +/// `async_trait`-annotated trait that is dyn-compatible and `Send + Sync`. +pub use MemoryStore as LocalMemoryStore; + +// ---------------------------------------------------------------------------- +// UNIT TESTS +// ---------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::StorageError; + + #[test] + fn memory_store_error_from_storage_error() { + let se = StorageError::NotFound("abc".to_string()); + let mse = MemoryStoreError::from(se); + assert!(matches!(mse, MemoryStoreError::NotFound(_))); + + let se2 = StorageError::Init("init failure".to_string()); + let mse2 = MemoryStoreError::from(se2); + assert!(matches!(mse2, MemoryStoreError::Init(_))); + } + + #[test] + fn model_signature_serde_round_trip() { + let sig = ModelSignature { + name: "nomic-ai/nomic-embed-text-v1.5".to_string(), + dimension: 256, + hash: "a".repeat(64), + }; + let json = serde_json::to_string(&sig).expect("serialize"); + let sig2: ModelSignature = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(sig, sig2); + } + + #[test] + fn memory_record_serde_round_trip() { + let rec = MemoryRecord { + id: Uuid::new_v4(), + domains: vec!["dev".to_string()], + domain_scores: { + let mut m = HashMap::new(); + m.insert("dev".to_string(), 0.9); + m + }, + content: "hello".to_string(), + node_type: "fact".to_string(), + tags: vec!["tag1".to_string()], + embedding: None, + created_at: Utc::now(), + updated_at: Utc::now(), + metadata: serde_json::json!({}), + }; + let json = serde_json::to_string(&rec).expect("serialize"); + let rec2: MemoryRecord = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(rec.content, rec2.content); + assert_eq!(rec.domains, rec2.domains); + } +} diff --git a/crates/vestige-core/src/storage/migrations.rs b/crates/vestige-core/src/storage/migrations.rs index 2c66a2d..ecad624 100644 --- a/crates/vestige-core/src/storage/migrations.rs +++ b/crates/vestige-core/src/storage/migrations.rs @@ -69,6 +69,11 @@ pub const MIGRATIONS: &[Migration] = &[ description: "v2.1.2 Honest Memory: non-content purge tombstones", up: MIGRATION_V13_UP, }, + Migration { + version: 14, + description: "ADR 0001 Phase 1: embedding_model registry, domains/domain_scores columns, domains table", + up: MIGRATION_V14_UP, + }, ]; /// A database migration @@ -745,6 +750,54 @@ pub fn get_current_version(conn: &rusqlite::Connection) -> rusqlite::Result .or(Ok(0)) } +/// V14: ADR 0001 Phase 1 - embedding_model registry + domain columns. +/// +/// The ALTER TABLE statements are split out into `MIGRATION_V14_ALTER_COLUMNS` +/// because SQLite has no `ALTER TABLE ... ADD COLUMN IF NOT EXISTS`. The +/// migration runner handles them individually so replaying V14 is idempotent. +const MIGRATION_V14_UP: &str = r#" +-- Migration V14: embedding model registry + per-memory domain columns. + +-- 1. Embedding model registry. Single logical row; the (id = 1) constraint is +-- enforced in code via `register_model` (SQLite CHECK on a single-row +-- table is uglier than a constraint we already enforce in Rust). +CREATE TABLE IF NOT EXISTS embedding_model ( + id INTEGER PRIMARY KEY CHECK (id = 1), + name TEXT NOT NULL, + dimension INTEGER NOT NULL, + hash TEXT NOT NULL, + created_at TEXT NOT NULL +); + +-- 2. Per-memory domain columns are applied separately (see apply_migrations). + +-- 3. Index on the domains JSON column to enable LIKE-style filter in Phase 4. +CREATE INDEX IF NOT EXISTS idx_nodes_domains ON knowledge_nodes(domains); +CREATE INDEX IF NOT EXISTS idx_nodes_domain_scores ON knowledge_nodes(domain_scores); + +-- 4. Domains catalogue (empty until Phase 4 populates). +CREATE TABLE IF NOT EXISTS domains ( + id TEXT PRIMARY KEY, + label TEXT NOT NULL, + centroid BLOB, + top_terms TEXT NOT NULL DEFAULT '[]', + memory_count INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_domains_created_at ON domains(created_at); + +UPDATE schema_version SET version = 14, applied_at = datetime('now'); +"#; + +/// The two ALTER TABLE statements for V14. Kept separate so the migration +/// runner can try each individually and ignore "duplicate column" errors, +/// making V14 idempotent on replay (SQLite has no ADD COLUMN IF NOT EXISTS). +pub const MIGRATION_V14_ALTER_COLUMNS: &[&str] = &[ + "ALTER TABLE knowledge_nodes ADD COLUMN domains TEXT NOT NULL DEFAULT '[]'", + "ALTER TABLE knowledge_nodes ADD COLUMN domain_scores TEXT NOT NULL DEFAULT '{}'", +]; + /// Apply pending migrations pub fn apply_migrations(conn: &rusqlite::Connection) -> rusqlite::Result { let current_version = get_current_version(conn)?; @@ -758,6 +811,26 @@ pub fn apply_migrations(conn: &rusqlite::Connection) -> rusqlite::Result { migration.description ); + // V14 adds columns via ALTER TABLE, which SQLite does not support + // with IF NOT EXISTS. Run them individually and ignore + // "duplicate column name" errors so the migration is idempotent + // on replay (e.g. after a schema_version rollback in tests). + if migration.version == 14 { + for stmt in MIGRATION_V14_ALTER_COLUMNS { + if let Err(e) = conn.execute_batch(stmt) { + let msg = e.to_string(); + if msg.contains("duplicate column name") { + tracing::debug!( + "V14 ALTER TABLE skipped (column already exists): {}", + msg + ); + } else { + return Err(e); + } + } + } + } + // Use execute_batch to handle multi-statement SQL including triggers conn.execute_batch(migration.up)?; @@ -790,11 +863,11 @@ mod tests { // Pre-requisite: schema_version must be bootstrapped by V1. apply_migrations(&conn).expect("apply_migrations succeeds"); - // 1. schema_version advanced to V13 + // 1. schema_version advanced to V14 (latest after Phase 1) let version = get_current_version(&conn).expect("read schema_version"); assert_eq!( - version, 13, - "schema_version must be 13 after all migrations" + version, 14, + "schema_version must be 14 after all migrations" ); // 2. knowledge_edges is gone (V11 drops it) @@ -865,10 +938,118 @@ mod tests { conn.execute("UPDATE schema_version SET version = 10", []) .expect("rewind schema_version"); - // Replay must not error. - apply_migrations(&conn).expect("V11 replay must be idempotent"); + // Replay V11 onward. V11 uses DROP TABLE IF EXISTS so it is idempotent. + // V12/V13 tombstone tables use CREATE TABLE IF NOT EXISTS. V14 ALTER + // TABLE idempotency is handled by the migration runner (see + // apply_migrations). + apply_migrations(&conn).expect("V11..V14 replay must be idempotent"); + // After replaying from V10, the schema advances to the latest version (V14). let version = get_current_version(&conn).expect("read schema_version"); - assert_eq!(version, 13, "schema_version back at 13 after replay"); + assert_eq!(version, 14, "schema_version back at 14 after replay"); + } + + #[test] + fn v14_adds_embedding_model_table() { + let conn = rusqlite::Connection::open_in_memory().expect("open in-memory"); + apply_migrations(&conn).expect("apply_migrations"); + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_model'", + [], + |row| row.get(0), + ) + .expect("query sqlite_master"); + assert_eq!(count, 1, "embedding_model table must exist after V14"); + } + + #[test] + fn v14_adds_domains_columns() { + let conn = rusqlite::Connection::open_in_memory().expect("open in-memory"); + apply_migrations(&conn).expect("apply_migrations"); + let info: Vec = { + let mut stmt = conn + .prepare("PRAGMA table_info(knowledge_nodes)") + .expect("prepare"); + stmt.query_map([], |row| row.get::<_, String>(1)) + .expect("query_map") + .map(|r| r.expect("row")) + .collect() + }; + assert!(info.contains(&"domains".to_string()), "domains column missing"); + assert!(info.contains(&"domain_scores".to_string()), "domain_scores column missing"); + } + + #[test] + fn v14_default_values_empty_json() { + let conn = rusqlite::Connection::open_in_memory().expect("open in-memory"); + apply_migrations(&conn).expect("apply_migrations"); + // Insert a minimal row to test defaults + conn.execute( + "INSERT INTO knowledge_nodes (id, content, node_type, created_at, updated_at, last_accessed, \ + stability, difficulty, reps, lapses, learning_state, storage_strength, retrieval_strength, \ + retention_strength, next_review, scheduled_days, has_embedding) \ + VALUES ('test-id','content','fact',datetime('now'),datetime('now'),datetime('now'),\ + 1.0,0.3,0,0,'new',1.0,1.0,1.0,datetime('now'),1,0)", + [], + ).expect("insert row"); + let (domains, domain_scores): (String, String) = conn + .query_row( + "SELECT domains, domain_scores FROM knowledge_nodes WHERE id='test-id'", + [], + |row| Ok((row.get(0)?, row.get(1)?)), + ) + .expect("query row"); + assert_eq!(domains, "[]"); + assert_eq!(domain_scores, "{}"); + } + + #[test] + fn v14_is_replayable() { + let conn = rusqlite::Connection::open_in_memory().expect("open in-memory"); + apply_migrations(&conn).expect("first apply"); + // Rewind to V13 so V14 runs again + conn.execute("UPDATE schema_version SET version = 13", []) + .expect("rewind"); + // V14 uses CREATE TABLE IF NOT EXISTS -- replay must not error + apply_migrations(&conn).expect("V14 replay must be idempotent"); + let version = get_current_version(&conn).expect("read version"); + assert_eq!(version, 14, "schema_version must be 14 after replay"); + } + + #[test] + fn v14_preserves_existing_rows() { + let conn = rusqlite::Connection::open_in_memory().expect("open in-memory"); + // Apply up to V13 only + for migration in MIGRATIONS { + if migration.version <= 13 { + conn.execute_batch(migration.up).expect("apply migration"); + } + } + // Insert a row under V13 schema + conn.execute( + "INSERT INTO knowledge_nodes (id, content, node_type, created_at, updated_at, last_accessed, \ + stability, difficulty, reps, lapses, learning_state, storage_strength, retrieval_strength, \ + retention_strength, next_review, scheduled_days, has_embedding) \ + VALUES ('existing-id','old content','fact',datetime('now'),datetime('now'),datetime('now'),\ + 1.0,0.3,0,0,'new',1.0,1.0,1.0,datetime('now'),1,0)", + [], + ).expect("insert pre-v14 row"); + // Apply V14: run the ALTER TABLE statements first (they are kept separate + // from MIGRATION_V14_UP because SQLite has no ADD COLUMN IF NOT EXISTS). + for stmt in MIGRATION_V14_ALTER_COLUMNS { + conn.execute_batch(stmt).expect("apply V14 ALTER TABLE"); + } + conn.execute_batch(MIGRATION_V14_UP).expect("apply V14 main"); + // Check the old row has defaults + let (domains, domain_scores): (String, String) = conn + .query_row( + "SELECT domains, domain_scores FROM knowledge_nodes WHERE id='existing-id'", + [], + |row| Ok((row.get(0)?, row.get(1)?)), + ) + .expect("query pre-v14 row"); + assert_eq!(domains, "[]"); + assert_eq!(domain_scores, "{}"); } } diff --git a/crates/vestige-core/src/storage/mod.rs b/crates/vestige-core/src/storage/mod.rs index 1660529..6c3ad06 100644 --- a/crates/vestige-core/src/storage/mod.rs +++ b/crates/vestige-core/src/storage/mod.rs @@ -1,15 +1,17 @@ //! Storage Module //! -//! SQLite-based storage layer with: -//! - FTS5 full-text search with query sanitization -//! - Embedded vector storage -//! - FSRS-6 state management -//! - Temporal memory support +//! Backend-agnostic memory store abstraction plus SQLite reference impl. +mod memory_store; mod migrations; mod portable; mod sqlite; +pub use memory_store::{ + ClassificationResult, Domain, HealthStatus, LocalMemoryStore, MemoryEdge, MemoryRecord, + MemoryStore, MemoryStoreError, MemoryStoreResult, ModelSignature, SchedulingState, + SearchQuery, SearchResult, StoreStats, +}; pub use migrations::MIGRATIONS; pub use portable::{ PORTABLE_ARCHIVE_FORMAT, PortableArchive, PortableImportMode, PortableImportReport, @@ -18,5 +20,10 @@ pub use portable::{ pub use sqlite::{ ConnectionRecord, ConsolidationHistoryRecord, DreamHistoryRecord, FilePortableSyncBackend, InsightRecord, IntentionRecord, PortableSyncBackend, PortableSyncReport, Result, - SmartIngestResult, StateTransitionRecord, Storage, StorageError, + SmartIngestResult, SqliteMemoryStore, StateTransitionRecord, StorageError, }; + +/// Backwards-compatibility alias. Retained until Phase 4 completes so every +/// existing `Arc` call site keeps compiling. Scheduled for removal +/// once no downstream source file references it. +pub type Storage = SqliteMemoryStore; diff --git a/crates/vestige-core/src/storage/sqlite.rs b/crates/vestige-core/src/storage/sqlite.rs index dc0150c..dea3677 100644 --- a/crates/vestige-core/src/storage/sqlite.rs +++ b/crates/vestige-core/src/storage/sqlite.rs @@ -284,7 +284,7 @@ const DATABASE_FILE: &str = "vestige.db"; /// Uses separate reader/writer connections for interior mutability. /// All methods take `&self` (not `&mut self`), making Storage `Send + Sync` /// so the MCP layer can use `Arc` instead of `Arc>`. -pub struct Storage { +pub struct SqliteMemoryStore { db_path: PathBuf, writer: Mutex, reader: Mutex, @@ -296,9 +296,11 @@ pub struct Storage { /// LRU cache for query embeddings to avoid re-embedding repeated queries #[cfg(all(feature = "embeddings", feature = "vector-search"))] query_cache: Mutex>>, + /// Cached model signature. `None` until the first embedding is written. + registered_model: std::sync::RwLock>, } -impl Storage { +impl SqliteMemoryStore { fn data_dir_from_env() -> Option { std::env::var_os(DATA_DIR_ENV).and_then(|value| { if value.is_empty() { @@ -443,6 +445,7 @@ impl Storage { vector_index: Mutex::new(vector_index), #[cfg(all(feature = "embeddings", feature = "vector-search"))] query_cache, + registered_model: std::sync::RwLock::new(None), }; #[cfg(all(feature = "embeddings", feature = "vector-search"))] @@ -580,13 +583,15 @@ impl Storage { stability, difficulty, reps, lapses, learning_state, storage_strength, retrieval_strength, retention_strength, sentiment_score, sentiment_magnitude, next_review, scheduled_days, - source, tags, valid_from, valid_until, has_embedding, embedding_model + source, tags, valid_from, valid_until, has_embedding, embedding_model, + domains, domain_scores ) VALUES ( ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, - ?19, ?20, ?21, ?22, ?23, ?24 + ?19, ?20, ?21, ?22, ?23, ?24, + '[]', '{}' )", params![ id, @@ -3860,7 +3865,7 @@ pub struct DreamHistoryRecord { pub creative_connections_found: Option, } -impl Storage { +impl SqliteMemoryStore { // ======================================================================== // INTENTIONS PERSISTENCE // ======================================================================== @@ -6127,6 +6132,979 @@ impl Storage { } } +// ============================================================================ +// LOCAL MEMORY STORE TRAIT IMPL +// ============================================================================ + +impl SqliteMemoryStore { + /// Convert a `KnowledgeNode` (plus optional embedding vector read separately) + /// into a `MemoryRecord` for the trait surface. + fn node_to_record(node: KnowledgeNode, embedding: Option>) -> crate::storage::memory_store::MemoryRecord { + use crate::storage::memory_store::MemoryRecord; + let id = uuid::Uuid::parse_str(&node.id).unwrap_or_else(|_| uuid::Uuid::new_v4()); + MemoryRecord { + id, + domains: Vec::new(), + domain_scores: std::collections::HashMap::new(), + content: node.content, + node_type: node.node_type, + tags: node.tags, + embedding, + created_at: node.created_at, + updated_at: node.updated_at, + metadata: serde_json::json!({ + "source": node.source, + "stability": node.stability, + "difficulty": node.difficulty, + "reps": node.reps, + "lapses": node.lapses, + "retention_strength": node.retention_strength, + }), + } + } + + /// Read domains and domain_scores JSON columns for a node by id. + fn read_domain_columns( + &self, + id: &str, + ) -> (Vec, std::collections::HashMap) { + let reader = match self.reader.lock() { + Ok(r) => r, + Err(_) => return (Vec::new(), std::collections::HashMap::new()), + }; + let result = reader.query_row( + "SELECT domains, domain_scores FROM knowledge_nodes WHERE id = ?1", + rusqlite::params![id], + |row| { + let d: Option = row.get(0).ok().flatten(); + let ds: Option = row.get(1).ok().flatten(); + Ok((d, ds)) + }, + ); + match result { + Ok((d, ds)) => { + let domains: Vec = d + .and_then(|s| serde_json::from_str(&s).ok()) + .unwrap_or_default(); + let domain_scores: std::collections::HashMap = ds + .and_then(|s| serde_json::from_str(&s).ok()) + .unwrap_or_default(); + (domains, domain_scores) + } + Err(_) => (Vec::new(), std::collections::HashMap::new()), + } + } + + /// Enforce the registered embedding model. Returns `Ok(())` if: + /// - no vector is being written (`incoming.is_none()`) and nothing is registered + /// - the incoming signature matches the registered signature + /// + /// Auto-registers on the first embedded write. + fn enforce_model( + &self, + incoming: Option<&crate::storage::memory_store::ModelSignature>, + ) -> crate::storage::memory_store::MemoryStoreResult<()> { + use crate::storage::memory_store::{MemoryStoreError, ModelSignature}; + let Some(incoming) = incoming else { + return Ok(()); + }; + // Try from cache first + { + let guard = self.registered_model.read().map_err(|_| { + MemoryStoreError::Init("registered_model rwlock poisoned".into()) + })?; + if let Some(ref reg) = *guard { + if reg == incoming { + return Ok(()); + } + return Err(MemoryStoreError::ModelMismatch { + registered_name: reg.name.clone(), + registered_dim: reg.dimension, + registered_hash: reg.hash.clone(), + actual_name: incoming.name.clone(), + actual_dim: incoming.dimension, + actual_hash: incoming.hash.clone(), + }); + } + } + // Not registered yet -- auto-register + let now = Utc::now().to_rfc3339(); + let writer = self + .writer + .lock() + .map_err(|_| MemoryStoreError::Init("Writer lock poisoned".into()))?; + // Try INSERT OR IGNORE + writer.execute( + "INSERT OR IGNORE INTO embedding_model (id, name, dimension, hash, created_at) VALUES (1, ?1, ?2, ?3, ?4)", + rusqlite::params![incoming.name, incoming.dimension as i64, incoming.hash, now], + ).map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + // Read back what was stored + let stored: Option = writer.query_row( + "SELECT name, dimension, hash FROM embedding_model WHERE id = 1", + [], + |row| { + let name: String = row.get(0)?; + let dim: i64 = row.get(1)?; + let hash: String = row.get(2)?; + Ok(ModelSignature { name, dimension: dim as usize, hash }) + }, + ).optional().map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + drop(writer); + if let Some(stored) = stored { + if stored != *incoming { + return Err(MemoryStoreError::ModelMismatch { + registered_name: stored.name, + registered_dim: stored.dimension, + registered_hash: stored.hash, + actual_name: incoming.name.clone(), + actual_dim: incoming.dimension, + actual_hash: incoming.hash.clone(), + }); + } + // Populate cache + let mut guard = self.registered_model.write().map_err(|_| { + MemoryStoreError::Init("registered_model rwlock poisoned".into()) + })?; + *guard = Some(stored); + } + Ok(()) + } +} + +#[async_trait::async_trait] +impl crate::storage::memory_store::LocalMemoryStore for SqliteMemoryStore { + async fn init(&self) -> crate::storage::memory_store::MemoryStoreResult<()> { + // Migrations run in `new`; this is a no-op for the SQLite backend. + Ok(()) + } + + async fn health_check(&self) -> crate::storage::memory_store::MemoryStoreResult { + use crate::storage::memory_store::HealthStatus; + let reader = self.reader.lock().map_err(|_| { + crate::storage::memory_store::MemoryStoreError::Init("Reader lock poisoned".into()) + })?; + let ok: rusqlite::Result = reader.query_row("SELECT 1", [], |row| row.get(0)); + if ok.is_ok() { + Ok(HealthStatus::Healthy) + } else { + Ok(HealthStatus::Degraded { + reason: "SQLite connectivity check failed".to_string(), + }) + } + } + + async fn registered_model( + &self, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::MemoryStoreError; + // Check cache first + { + let guard = self.registered_model.read().map_err(|_| { + MemoryStoreError::Init("registered_model rwlock poisoned".into()) + })?; + if guard.is_some() { + return Ok(guard.clone()); + } + } + // Fall through to DB read + let reader = self.reader.lock().map_err(|_| { + MemoryStoreError::Init("Reader lock poisoned".into()) + })?; + let stored: Option = reader + .query_row( + "SELECT name, dimension, hash FROM embedding_model WHERE id = 1", + [], + |row| { + let name: String = row.get(0)?; + let dim: i64 = row.get(1)?; + let hash: String = row.get(2)?; + Ok(crate::storage::memory_store::ModelSignature { + name, + dimension: dim as usize, + hash, + }) + }, + ) + .optional() + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + drop(reader); + // Populate cache if we read something + if stored.is_some() { + let mut guard = self.registered_model.write().map_err(|_| { + MemoryStoreError::Init("registered_model rwlock poisoned".into()) + })?; + *guard = stored.clone(); + } + Ok(stored) + } + + async fn register_model( + &self, + sig: &crate::storage::memory_store::ModelSignature, + ) -> crate::storage::memory_store::MemoryStoreResult<()> { + self.enforce_model(Some(sig)) + } + + async fn insert( + &self, + record: &crate::storage::memory_store::MemoryRecord, + ) -> crate::storage::memory_store::MemoryStoreResult { + use crate::storage::memory_store::{MemoryStoreError, ModelSignature}; + // Enforce model registry if embedding is provided + if let Some(vec) = &record.embedding { + // Derive a signature from metadata if present, or use a generic sentinel + let sig: Option = record + .metadata + .get("model_name") + .and_then(|v| v.as_str()) + .zip( + record + .metadata + .get("model_dim") + .and_then(|v| v.as_u64()) + .map(|d| d as usize), + ) + .zip( + record + .metadata + .get("model_hash") + .and_then(|v| v.as_str()), + ) + .map(|((name, dim), hash)| ModelSignature { + name: name.to_string(), + dimension: dim, + hash: hash.to_string(), + }); + if let Some(ref s) = sig { + self.enforce_model(Some(s))?; + if vec.len() != s.dimension { + return Err(MemoryStoreError::InvalidInput(format!( + "embedding length {} != registered dimension {}", + vec.len(), + s.dimension + ))); + } + } + } + // Insert directly using the record's own id so the caller-supplied UUID is + // preserved (unlike ingest() which always generates a fresh UUID). + let id_str = record.id.to_string(); + let now = chrono::Utc::now(); + let tags_json = serde_json::to_string(&record.tags) + .unwrap_or_else(|_| "[]".to_string()); + let domains_json = serde_json::to_string(&record.domains) + .unwrap_or_else(|_| "[]".to_string()); + let scores_json = serde_json::to_string(&record.domain_scores) + .unwrap_or_else(|_| "{}".to_string()); + let source: Option = record + .metadata + .get("source") + .and_then(|v| v.as_str()) + .map(str::to_string); + { + let writer = self + .writer + .lock() + .map_err(|_| MemoryStoreError::Init("Writer lock poisoned".into()))?; + writer.execute( + "INSERT INTO knowledge_nodes ( + id, content, node_type, created_at, updated_at, last_accessed, + stability, difficulty, reps, lapses, learning_state, + storage_strength, retrieval_strength, retention_strength, + sentiment_score, sentiment_magnitude, next_review, scheduled_days, + source, tags, has_embedding, embedding_model, + domains, domain_scores + ) VALUES ( + ?1, ?2, ?3, ?4, ?5, ?6, + 1.0, 0.3, 0, 0, 'new', + 1.0, 1.0, 1.0, + 0.0, 0.0, ?7, 1, + ?8, ?9, 0, NULL, + ?10, ?11 + )", + rusqlite::params![ + id_str, + record.content, + record.node_type, + record.created_at.to_rfc3339(), + record.updated_at.to_rfc3339(), + now.to_rfc3339(), + (now + chrono::Duration::days(1)).to_rfc3339(), + source, + tags_json, + domains_json, + scores_json, + ], + ).map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + } + Ok(record.id) + } + + async fn get( + &self, + id: uuid::Uuid, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::MemoryStoreError; + let node = self + .get_node(&id.to_string()) + .map_err(MemoryStoreError::from)?; + let Some(node) = node else { + return Ok(None); + }; + let (domains, domain_scores) = self.read_domain_columns(&id.to_string()); + #[cfg(all(feature = "embeddings", feature = "vector-search"))] + let embedding = self + .get_node_embedding(&id.to_string()) + .ok() + .flatten(); + #[cfg(not(all(feature = "embeddings", feature = "vector-search")))] + let embedding: Option> = None; + let mut rec = Self::node_to_record(node, embedding); + rec.domains = domains; + rec.domain_scores = domain_scores; + Ok(Some(rec)) + } + + async fn update( + &self, + record: &crate::storage::memory_store::MemoryRecord, + ) -> crate::storage::memory_store::MemoryStoreResult<()> { + use crate::storage::memory_store::MemoryStoreError; + self.update_node_content(&record.id.to_string(), &record.content) + .map_err(MemoryStoreError::from)?; + // Update domains/domain_scores + let domains_json = serde_json::to_string(&record.domains) + .unwrap_or_else(|_| "[]".to_string()); + let scores_json = serde_json::to_string(&record.domain_scores) + .unwrap_or_else(|_| "{}".to_string()); + let writer = self + .writer + .lock() + .map_err(|_| MemoryStoreError::Init("Writer lock poisoned".into()))?; + writer + .execute( + "UPDATE knowledge_nodes SET domains = ?1, domain_scores = ?2 WHERE id = ?3", + rusqlite::params![domains_json, scores_json, record.id.to_string()], + ) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + Ok(()) + } + + async fn delete( + &self, + id: uuid::Uuid, + ) -> crate::storage::memory_store::MemoryStoreResult<()> { + use crate::storage::memory_store::MemoryStoreError; + self.delete_node(&id.to_string()) + .map_err(MemoryStoreError::from)?; + Ok(()) + } + + async fn search( + &self, + query: &crate::storage::memory_store::SearchQuery, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::{MemoryStoreError, SearchResult}; + // For Phase 1 we delegate to hybrid_search or keyword_search based on what is provided. + let limit = if query.limit == 0 { 10 } else { query.limit }; + #[cfg(all(feature = "embeddings", feature = "vector-search"))] + { + if let Some(ref text) = query.text { + let results = self + .hybrid_search(text, limit as i32, 0.3, 0.7) + .map_err(MemoryStoreError::from)?; + let out = results + .into_iter() + .map(|r| { + let (domains, domain_scores) = + self.read_domain_columns(&r.node.id); + let mut rec = Self::node_to_record(r.node, None); + rec.domains = domains; + rec.domain_scores = domain_scores; + SearchResult { + score: r.combined_score as f64, + fts_score: r.keyword_score.map(|s| s as f64), + vector_score: r.semantic_score.map(|s| s as f64), + record: rec, + } + }) + .collect(); + return Ok(out); + } + } + #[cfg(not(all(feature = "embeddings", feature = "vector-search")))] + { + if let Some(ref text) = query.text { + // Use individual-term matching so multi-word queries find documents + // where all words appear anywhere (not necessarily as a phrase). + let nodes = self + .search_terms(text, limit as i32) + .map_err(MemoryStoreError::from)?; + let out = nodes + .into_iter() + .map(|node| { + let (domains, domain_scores) = self.read_domain_columns(&node.id); + let mut rec = Self::node_to_record(node, None); + rec.domains = domains; + rec.domain_scores = domain_scores; + SearchResult { + record: rec, + score: 1.0, + fts_score: Some(1.0), + vector_score: None, + } + }) + .collect(); + return Ok(out); + } + } + Ok(vec![]) + } + + async fn fts_search( + &self, + text: &str, + limit: usize, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::{MemoryStoreError, SearchResult}; + // Use individual-term matching so multi-word queries find documents + // where all words appear anywhere (not necessarily as a phrase). + let nodes = self + .search_terms(text, limit as i32) + .map_err(MemoryStoreError::from)?; + let out = nodes + .into_iter() + .map(|node| { + let (domains, domain_scores) = self.read_domain_columns(&node.id); + let mut rec = Self::node_to_record(node, None); + rec.domains = domains; + rec.domain_scores = domain_scores; + SearchResult { + record: rec, + score: 1.0, + fts_score: Some(1.0), + vector_score: None, + } + }) + .collect(); + Ok(out) + } + + async fn vector_search( + &self, + embedding: &[f32], + limit: usize, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::{MemoryStoreError, SearchResult}; + #[cfg(all(feature = "embeddings", feature = "vector-search"))] + { + let index = self + .vector_index + .lock() + .map_err(|_| MemoryStoreError::Init("Vector index lock poisoned".into()))?; + let raw_results = index + .search_with_threshold(embedding, limit, 0.0_f32) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + drop(index); + let out = raw_results + .into_iter() + .filter_map(|(node_id, score)| { + let node = self.get_node(&node_id).ok().flatten()?; + let (domains, domain_scores) = self.read_domain_columns(&node_id); + let mut rec = Self::node_to_record(node, None); + rec.domains = domains; + rec.domain_scores = domain_scores; + Some(SearchResult { + record: rec, + score: score as f64, + fts_score: None, + vector_score: Some(score as f64), + }) + }) + .collect(); + return Ok(out); + } + #[cfg(not(all(feature = "embeddings", feature = "vector-search")))] + { + let _ = (embedding, limit); + Ok(vec![]) + } + } + + async fn get_scheduling( + &self, + memory_id: uuid::Uuid, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::{MemoryStoreError, SchedulingState}; + let node = self + .get_node(&memory_id.to_string()) + .map_err(MemoryStoreError::from)?; + let Some(node) = node else { + return Ok(None); + }; + Ok(Some(SchedulingState { + memory_id, + stability: node.stability, + difficulty: node.difficulty, + retrievability: node.retention_strength, + last_review: Some(node.last_accessed), + next_review: node.next_review, + reps: node.reps as u32, + lapses: node.lapses as u32, + })) + } + + async fn update_scheduling( + &self, + state: &crate::storage::memory_store::SchedulingState, + ) -> crate::storage::memory_store::MemoryStoreResult<()> { + use crate::storage::memory_store::MemoryStoreError; + let writer = self + .writer + .lock() + .map_err(|_| MemoryStoreError::Init("Writer lock poisoned".into()))?; + let next_review_str = state.next_review.map(|dt| dt.to_rfc3339()); + let last_review_str = state.last_review.map(|dt| dt.to_rfc3339()); + writer + .execute( + "UPDATE knowledge_nodes SET stability=?1, difficulty=?2, retention_strength=?3, + last_accessed=?4, next_review=?5, reps=?6, lapses=?7 + WHERE id=?8", + rusqlite::params![ + state.stability, + state.difficulty, + state.retrievability, + last_review_str.as_deref().unwrap_or(""), + next_review_str, + state.reps as i64, + state.lapses as i64, + state.memory_id.to_string(), + ], + ) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + Ok(()) + } + + async fn get_due_memories( + &self, + before: chrono::DateTime, + limit: usize, + ) -> crate::storage::memory_store::MemoryStoreResult< + Vec<( + crate::storage::memory_store::MemoryRecord, + crate::storage::memory_store::SchedulingState, + )>, + > { + use crate::storage::memory_store::{MemoryStoreError, SchedulingState}; + let reader = self + .reader + .lock() + .map_err(|_| MemoryStoreError::Init("Reader lock poisoned".into()))?; + let before_str = before.to_rfc3339(); + let mut stmt = reader + .prepare( + "SELECT * FROM knowledge_nodes WHERE next_review <= ?1 ORDER BY next_review ASC LIMIT ?2", + ) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + let nodes: Vec = stmt + .query_map(rusqlite::params![before_str, limit as i64], Self::row_to_node) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))? + .collect::, _>>() + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + drop(stmt); + drop(reader); + let out = nodes + .into_iter() + .map(|node| { + let id_str = node.id.clone(); + let (domains, domain_scores) = self.read_domain_columns(&id_str); + let id_uuid = uuid::Uuid::parse_str(&id_str).unwrap_or_else(|_| uuid::Uuid::new_v4()); + let state = SchedulingState { + memory_id: id_uuid, + stability: node.stability, + difficulty: node.difficulty, + retrievability: node.retention_strength, + last_review: Some(node.last_accessed), + next_review: node.next_review, + reps: node.reps as u32, + lapses: node.lapses as u32, + }; + let mut rec = Self::node_to_record(node, None); + rec.domains = domains; + rec.domain_scores = domain_scores; + (rec, state) + }) + .collect(); + Ok(out) + } + + async fn add_edge( + &self, + edge: &crate::storage::memory_store::MemoryEdge, + ) -> crate::storage::memory_store::MemoryStoreResult<()> { + use crate::storage::memory_store::MemoryStoreError; + let conn = ConnectionRecord { + source_id: edge.source_id.to_string(), + target_id: edge.target_id.to_string(), + strength: edge.weight, + link_type: edge.edge_type.clone(), + created_at: edge.created_at, + last_activated: edge.created_at, + activation_count: 0, + }; + self.save_connection(&conn).map_err(MemoryStoreError::from) + } + + async fn get_edges( + &self, + node_id: uuid::Uuid, + edge_type: Option<&str>, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::{MemoryEdge, MemoryStoreError}; + let conns = self + .get_connections_for_memory(&node_id.to_string()) + .map_err(MemoryStoreError::from)?; + let edges = conns + .into_iter() + .filter(|c| edge_type.is_none_or(|t| c.link_type == t)) + .filter_map(|c| { + let src = uuid::Uuid::parse_str(&c.source_id).ok()?; + let tgt = uuid::Uuid::parse_str(&c.target_id).ok()?; + Some(MemoryEdge { + source_id: src, + target_id: tgt, + edge_type: c.link_type, + weight: c.strength, + created_at: c.created_at, + }) + }) + .collect(); + Ok(edges) + } + + async fn remove_edge( + &self, + source: uuid::Uuid, + target: uuid::Uuid, + ) -> crate::storage::memory_store::MemoryStoreResult<()> { + use crate::storage::memory_store::MemoryStoreError; + let writer = self + .writer + .lock() + .map_err(|_| MemoryStoreError::Init("Writer lock poisoned".into()))?; + writer + .execute( + "DELETE FROM memory_connections WHERE source_id = ?1 AND target_id = ?2", + rusqlite::params![source.to_string(), target.to_string()], + ) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + Ok(()) + } + + async fn get_neighbors( + &self, + node_id: uuid::Uuid, + depth: usize, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::MemoryStoreError; + // Depth 0: return just the node itself if it exists. + if depth == 0 { + let node = self + .get_node(&node_id.to_string()) + .map_err(MemoryStoreError::from)? + .ok_or_else(|| MemoryStoreError::NotFound(node_id.to_string()))?; + let (domains, domain_scores) = self.read_domain_columns(&node_id.to_string()); + let mut rec = Self::node_to_record(node, None); + rec.domains = domains; + rec.domain_scores = domain_scores; + return Ok(vec![(rec, 1.0)]); + } + // BFS up to `depth` levels, capped at 256 nodes. + const MAX_NODES: usize = 256; + let mut visited: std::collections::HashMap = + std::collections::HashMap::new(); + let mut frontier: Vec<(uuid::Uuid, f64)> = vec![(node_id, 1.0)]; + visited.insert(node_id, 1.0); + for _ in 0..depth { + if visited.len() >= MAX_NODES { + break; + } + let mut next_frontier = Vec::new(); + for (current, current_weight) in frontier.iter() { + let conns = self + .get_connections_for_memory(¤t.to_string()) + .unwrap_or_default(); + for conn in conns { + let neighbor_id_str = if conn.source_id == current.to_string() { + conn.target_id + } else { + conn.source_id + }; + let Ok(nid) = uuid::Uuid::parse_str(&neighbor_id_str) else { + continue; + }; + if let std::collections::hash_map::Entry::Vacant(e) = visited.entry(nid) { + let w = current_weight * conn.strength; + e.insert(w); + next_frontier.push((nid, w)); + if visited.len() >= MAX_NODES { + break; + } + } + } + } + frontier = next_frontier; + if frontier.is_empty() { + break; + } + } + let mut result = Vec::with_capacity(visited.len()); + for (nid, weight) in visited { + let Some(node) = self.get_node(&nid.to_string()).ok().flatten() else { + continue; + }; + let (domains, domain_scores) = self.read_domain_columns(&nid.to_string()); + let mut rec = Self::node_to_record(node, None); + rec.domains = domains; + rec.domain_scores = domain_scores; + result.push((rec, weight)); + } + Ok(result) + } + + async fn list_domains( + &self, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::{Domain, MemoryStoreError}; + let reader = self + .reader + .lock() + .map_err(|_| MemoryStoreError::Init("Reader lock poisoned".into()))?; + let mut stmt = reader + .prepare("SELECT id, label, centroid, top_terms, memory_count, created_at FROM domains ORDER BY created_at ASC") + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + let rows = stmt + .query_map([], |row| { + let id: String = row.get(0)?; + let label: String = row.get(1)?; + let centroid_bytes: Option> = row.get(2)?; + let top_terms_json: String = row.get(3)?; + let memory_count: i64 = row.get(4)?; + let created_at_str: String = row.get(5)?; + Ok((id, label, centroid_bytes, top_terms_json, memory_count, created_at_str)) + }) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + let mut result = Vec::new(); + for row in rows { + let (id, label, centroid_bytes, top_terms_json, memory_count, created_at_str) = + row.map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + let centroid: Vec = centroid_bytes + .map(|b| { + b.chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect() + }) + .unwrap_or_default(); + let top_terms: Vec = + serde_json::from_str(&top_terms_json).unwrap_or_default(); + let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str) + .map(|dt| dt.with_timezone(&chrono::Utc)) + .unwrap_or_else(|_| Utc::now()); + result.push(Domain { + id, + label, + centroid, + top_terms, + memory_count: memory_count as usize, + created_at, + }); + } + Ok(result) + } + + async fn get_domain( + &self, + id: &str, + ) -> crate::storage::memory_store::MemoryStoreResult> { + use crate::storage::memory_store::{Domain, MemoryStoreError}; + let reader = self + .reader + .lock() + .map_err(|_| MemoryStoreError::Init("Reader lock poisoned".into()))?; + let result: Option<(String, String, Option>, String, i64, String)> = reader + .query_row( + "SELECT id, label, centroid, top_terms, memory_count, created_at FROM domains WHERE id = ?1", + rusqlite::params![id], + |row| { + Ok(( + row.get(0)?, + row.get(1)?, + row.get(2)?, + row.get(3)?, + row.get(4)?, + row.get(5)?, + )) + }, + ) + .optional() + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + let Some((id, label, centroid_bytes, top_terms_json, memory_count, created_at_str)) = result else { + return Ok(None); + }; + let centroid: Vec = centroid_bytes + .map(|b| { + b.chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect() + }) + .unwrap_or_default(); + let top_terms: Vec = serde_json::from_str(&top_terms_json).unwrap_or_default(); + let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str) + .map(|dt| dt.with_timezone(&chrono::Utc)) + .unwrap_or_else(|_| Utc::now()); + Ok(Some(Domain { + id, + label, + centroid, + top_terms, + memory_count: memory_count as usize, + created_at, + })) + } + + async fn upsert_domain( + &self, + domain: &crate::storage::memory_store::Domain, + ) -> crate::storage::memory_store::MemoryStoreResult<()> { + use crate::storage::memory_store::MemoryStoreError; + let centroid_bytes: Vec = domain + .centroid + .iter() + .flat_map(|f| f.to_le_bytes()) + .collect(); + let top_terms_json = + serde_json::to_string(&domain.top_terms).unwrap_or_else(|_| "[]".to_string()); + let writer = self + .writer + .lock() + .map_err(|_| MemoryStoreError::Init("Writer lock poisoned".into()))?; + writer + .execute( + "INSERT INTO domains (id, label, centroid, top_terms, memory_count, created_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6) + ON CONFLICT(id) DO UPDATE SET + label = excluded.label, + centroid = excluded.centroid, + top_terms = excluded.top_terms, + memory_count = excluded.memory_count", + rusqlite::params![ + domain.id, + domain.label, + centroid_bytes, + top_terms_json, + domain.memory_count as i64, + domain.created_at.to_rfc3339(), + ], + ) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + Ok(()) + } + + async fn delete_domain( + &self, + id: &str, + ) -> crate::storage::memory_store::MemoryStoreResult<()> { + use crate::storage::memory_store::MemoryStoreError; + let writer = self + .writer + .lock() + .map_err(|_| MemoryStoreError::Init("Writer lock poisoned".into()))?; + writer + .execute("DELETE FROM domains WHERE id = ?1", rusqlite::params![id]) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + Ok(()) + } + + async fn classify( + &self, + _embedding: &[f32], + ) -> crate::storage::memory_store::MemoryStoreResult> { + // Phase 1 stub: no centroids yet. Phase 4 wires the full soft-assignment pass. + Ok(vec![]) + } + + async fn count(&self) -> crate::storage::memory_store::MemoryStoreResult { + use crate::storage::memory_store::MemoryStoreError; + let reader = self + .reader + .lock() + .map_err(|_| MemoryStoreError::Init("Reader lock poisoned".into()))?; + let n: i64 = reader + .query_row("SELECT COUNT(*) FROM knowledge_nodes", [], |row| row.get(0)) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + Ok(n as usize) + } + + async fn get_stats( + &self, + ) -> crate::storage::memory_store::MemoryStoreResult { + use crate::storage::memory_store::{MemoryStoreError, StoreStats}; + let reader = self + .reader + .lock() + .map_err(|_| MemoryStoreError::Init("Reader lock poisoned".into()))?; + let total: i64 = reader + .query_row("SELECT COUNT(*) FROM knowledge_nodes", [], |row| row.get(0)) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + let with_emb: i64 = reader + .query_row( + "SELECT COUNT(*) FROM knowledge_nodes WHERE has_embedding = 1", + [], + |row| row.get(0), + ) + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + let total_edges: i64 = reader + .query_row("SELECT COUNT(*) FROM memory_connections", [], |row| { + row.get(0) + }) + .unwrap_or(0); + let total_domains: i64 = reader + .query_row("SELECT COUNT(*) FROM domains", [], |row| row.get(0)) + .unwrap_or(0); + let model_row: Option<(String, i64)> = reader + .query_row( + "SELECT name, dimension FROM embedding_model WHERE id = 1", + [], + |row| Ok((row.get(0)?, row.get(1)?)), + ) + .optional() + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + let (model_name, model_dim) = model_row + .map(|(n, d)| (Some(n), Some(d as usize))) + .unwrap_or((None, None)); + Ok(StoreStats { + total_memories: total as usize, + memories_with_embeddings: with_emb as usize, + total_edges: total_edges as usize, + total_domains: total_domains as usize, + registered_model_name: model_name, + registered_model_dim: model_dim, + }) + } + + async fn vacuum(&self) -> crate::storage::memory_store::MemoryStoreResult<()> { + use crate::storage::memory_store::MemoryStoreError; + let writer = self + .writer + .lock() + .map_err(|_| MemoryStoreError::Init("Writer lock poisoned".into()))?; + writer + .execute_batch("VACUUM;") + .map_err(|e| MemoryStoreError::Backend(e.to_string()))?; + Ok(()) + } +} + // ============================================================================ // TESTS // ============================================================================ @@ -6135,6 +7113,9 @@ impl Storage { mod tests { use super::*; use tempfile::tempdir; + // The public struct was renamed from Storage to SqliteMemoryStore; this + // alias keeps all existing tests compiling without modification. + use SqliteMemoryStore as Storage; fn create_test_storage() -> Storage { let dir = tempdir().unwrap(); @@ -7216,4 +8197,517 @@ mod tests { .unwrap(); assert_eq!(has_content_column, 0); } + + // ========================================================================= + // Phase 1 trait-method unit tests + // ========================================================================= + use crate::storage::memory_store::{ + MemoryEdge, MemoryRecord, MemoryStore, MemoryStoreError, ModelSignature, SchedulingState, + }; + + fn make_record(content: &str) -> MemoryRecord { + MemoryRecord { + id: uuid::Uuid::new_v4(), + domains: vec![], + domain_scores: Default::default(), + content: content.to_string(), + node_type: "fact".to_string(), + tags: vec!["test".to_string()], + embedding: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + metadata: serde_json::json!({}), + } + } + + fn rt() -> tokio::runtime::Runtime { + tokio::runtime::Runtime::new().unwrap() + } + + #[test] + fn trait_init_is_idempotent() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + s.init().await.unwrap(); + s.init().await.unwrap(); + }); + } + + #[test] + fn trait_health_check_reports_healthy_on_fresh_db() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let h = s.health_check().await.unwrap(); + assert!(matches!(h, crate::storage::memory_store::HealthStatus::Healthy)); + }); + } + + #[test] + fn trait_register_model_first_write_succeeds() { + let s = create_test_storage(); + let sig = ModelSignature { + name: "test-model".to_string(), + dimension: 256, + hash: "a".repeat(64), + }; + let rt = rt(); + rt.block_on(async { + s.register_model(&sig).await.unwrap(); + let got = s.registered_model().await.unwrap(); + assert_eq!(got, Some(sig)); + }); + } + + #[test] + fn trait_register_model_mismatched_write_refused() { + let s = create_test_storage(); + let sig = ModelSignature { + name: "model-a".to_string(), + dimension: 256, + hash: "a".repeat(64), + }; + let sig2 = ModelSignature { + name: "model-b".to_string(), + dimension: 256, + hash: "b".repeat(64), + }; + let rt = rt(); + rt.block_on(async { + s.register_model(&sig).await.unwrap(); + let err = s.register_model(&sig2).await.unwrap_err(); + assert!(matches!(err, MemoryStoreError::ModelMismatch { .. })); + }); + } + + #[test] + fn trait_register_model_same_signature_idempotent() { + let s = create_test_storage(); + let sig = ModelSignature { + name: "test-model".to_string(), + dimension: 256, + hash: "a".repeat(64), + }; + let rt = rt(); + rt.block_on(async { + s.register_model(&sig).await.unwrap(); + s.register_model(&sig).await.unwrap(); // second call must not error + }); + } + + #[test] + fn trait_insert_returns_uuid() { + let s = create_test_storage(); + let rec = make_record("test content"); + let expected_id = rec.id; + let rt = rt(); + rt.block_on(async { + let got = s.insert(&rec).await.unwrap(); + assert_eq!(got, expected_id); + }); + } + + #[test] + fn trait_get_missing_returns_none() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let got = s.get(uuid::Uuid::new_v4()).await.unwrap(); + assert!(got.is_none()); + }); + } + + #[test] + fn trait_get_after_insert_round_trip() { + let s = create_test_storage(); + let rec = make_record("round trip content"); + let id = rec.id; + let rt = rt(); + rt.block_on(async { + s.insert(&rec).await.unwrap(); + let got = s.get(id).await.unwrap().unwrap(); + assert_eq!(got.content, "round trip content"); + assert_eq!(got.node_type, "fact"); + assert!(got.domains.is_empty()); + assert!(got.domain_scores.is_empty()); + }); + } + + #[test] + fn trait_update_modifies_content() { + let s = create_test_storage(); + let rec = make_record("original content"); + let id = rec.id; + let rt = rt(); + rt.block_on(async { + s.insert(&rec).await.unwrap(); + let mut updated = s.get(id).await.unwrap().unwrap(); + updated.content = "updated content".to_string(); + s.update(&updated).await.unwrap(); + let got = s.get(id).await.unwrap().unwrap(); + assert_eq!(got.content, "updated content"); + }); + } + + #[test] + fn trait_delete_removes_record() { + let s = create_test_storage(); + let rec = make_record("to be deleted"); + let id = rec.id; + let rt = rt(); + rt.block_on(async { + s.insert(&rec).await.unwrap(); + s.delete(id).await.unwrap(); + let got = s.get(id).await.unwrap(); + assert!(got.is_none()); + }); + } + + #[test] + fn trait_fts_search_returns_tokens_match() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let rec = make_record("mitochondria powerhouse cell energy"); + s.insert(&rec).await.unwrap(); + let results = s.fts_search("mitochondria", 10).await.unwrap(); + assert!(!results.is_empty()); + }); + } + + #[cfg(all(feature = "embeddings", feature = "vector-search"))] + #[test] + fn trait_hybrid_search_multi_word_via_insert() { + // Verify that hybrid_search finds records inserted via the trait insert() + // even when no embedding is present (keyword path via terms matching). + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let rec = make_record("quantum entanglement superposition physics"); + s.insert(&rec).await.unwrap(); + let results = s.hybrid_search("quantum physics", 10, 0.3, 0.7).unwrap(); + assert!( + !results.is_empty(), + "hybrid_search must find record containing 'quantum' and 'physics'" + ); + }); + } + + #[test] + fn trait_scheduling_round_trip() { + let s = create_test_storage(); + let rec = make_record("fsrs scheduling test"); + let id = rec.id; + let rt = rt(); + rt.block_on(async { + s.insert(&rec).await.unwrap(); + let state = SchedulingState { + memory_id: id, + stability: 5.0, + difficulty: 0.4, + retrievability: 0.8, + last_review: Some(chrono::Utc::now()), + next_review: Some(chrono::Utc::now() + chrono::Duration::days(7)), + reps: 3, + lapses: 1, + }; + s.update_scheduling(&state).await.unwrap(); + let got = s.get_scheduling(id).await.unwrap().unwrap(); + assert!((got.stability - 5.0).abs() < 0.01); + }); + } + + #[test] + fn trait_get_scheduling_missing_returns_none() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let got = s.get_scheduling(uuid::Uuid::new_v4()).await.unwrap(); + assert!(got.is_none()); + }); + } + + #[test] + fn trait_get_due_memories_returns_in_order() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + for i in 0..3usize { + let rec = make_record(&format!("due memory {i}")); + let id = rec.id; + s.insert(&rec).await.unwrap(); + let state = SchedulingState { + memory_id: id, + stability: 1.0, + difficulty: 0.3, + retrievability: 0.5, + last_review: Some(chrono::Utc::now()), + next_review: Some( + chrono::Utc::now() + - chrono::Duration::days(3 - i as i64) + ), + reps: 1, + lapses: 0, + }; + s.update_scheduling(&state).await.unwrap(); + } + let due = s + .get_due_memories(chrono::Utc::now(), 10) + .await + .unwrap(); + assert_eq!(due.len(), 3); + }); + } + + #[test] + fn trait_add_edge_is_idempotent() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let rec_a = make_record("node a"); + let rec_b = make_record("node b"); + let id_a = rec_a.id; + let id_b = rec_b.id; + s.insert(&rec_a).await.unwrap(); + s.insert(&rec_b).await.unwrap(); + let edge = MemoryEdge { + source_id: id_a, + target_id: id_b, + edge_type: "semantic".to_string(), + weight: 0.9, + created_at: chrono::Utc::now(), + }; + s.add_edge(&edge).await.unwrap(); + s.add_edge(&edge).await.unwrap(); // idempotent + let edges = s.get_edges(id_a, None).await.unwrap(); + let filtered: Vec<_> = edges + .iter() + .filter(|e| e.source_id == id_a && e.target_id == id_b) + .collect(); + assert_eq!(filtered.len(), 1, "edge must not be duplicated"); + }); + } + + #[test] + fn trait_get_edges_filters_by_type() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let rec_a = make_record("filter a"); + let rec_b = make_record("filter b"); + let id_a = rec_a.id; + let id_b = rec_b.id; + s.insert(&rec_a).await.unwrap(); + s.insert(&rec_b).await.unwrap(); + let edge = MemoryEdge { + source_id: id_a, + target_id: id_b, + edge_type: "causal".to_string(), + weight: 0.5, + created_at: chrono::Utc::now(), + }; + s.add_edge(&edge).await.unwrap(); + let causal = s.get_edges(id_a, Some("causal")).await.unwrap(); + assert!(!causal.is_empty()); + let semantic = s.get_edges(id_a, Some("semantic")).await.unwrap(); + assert!(semantic.is_empty()); + }); + } + + #[test] + fn trait_remove_edge_deletes_single() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let rec_a = make_record("rm edge a"); + let rec_b = make_record("rm edge b"); + let id_a = rec_a.id; + let id_b = rec_b.id; + s.insert(&rec_a).await.unwrap(); + s.insert(&rec_b).await.unwrap(); + let edge = MemoryEdge { + source_id: id_a, + target_id: id_b, + edge_type: "semantic".to_string(), + weight: 0.7, + created_at: chrono::Utc::now(), + }; + s.add_edge(&edge).await.unwrap(); + s.remove_edge(id_a, id_b).await.unwrap(); + let edges = s.get_edges(id_a, None).await.unwrap(); + assert!(edges.is_empty()); + }); + } + + #[test] + fn trait_get_neighbors_bfs_depth_zero_returns_self_only() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let rec = make_record("depth zero"); + let id = rec.id; + s.insert(&rec).await.unwrap(); + let neighbors = s.get_neighbors(id, 0).await.unwrap(); + assert_eq!(neighbors.len(), 1); + assert_eq!(neighbors[0].0.id, id); + }); + } + + #[test] + fn trait_get_neighbors_bfs_depth_two_expands() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let rec_a = make_record("bfs node a"); + let rec_b = make_record("bfs node b"); + let rec_c = make_record("bfs node c"); + let id_a = rec_a.id; + let id_b = rec_b.id; + let id_c = rec_c.id; + s.insert(&rec_a).await.unwrap(); + s.insert(&rec_b).await.unwrap(); + s.insert(&rec_c).await.unwrap(); + s.add_edge(&MemoryEdge { + source_id: id_a, + target_id: id_b, + edge_type: "semantic".to_string(), + weight: 1.0, + created_at: chrono::Utc::now(), + }).await.unwrap(); + s.add_edge(&MemoryEdge { + source_id: id_b, + target_id: id_c, + edge_type: "semantic".to_string(), + weight: 1.0, + created_at: chrono::Utc::now(), + }).await.unwrap(); + let neighbors = s.get_neighbors(id_a, 2).await.unwrap(); + let ids: Vec = neighbors.iter().map(|(r, _)| r.id).collect(); + assert!(ids.contains(&id_a)); + assert!(ids.contains(&id_b)); + assert!(ids.contains(&id_c)); + }); + } + + #[test] + fn trait_list_domains_empty_in_phase_1() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let domains = s.list_domains().await.unwrap(); + assert!(domains.is_empty()); + }); + } + + #[test] + fn trait_upsert_then_get_domain_round_trip() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let domain = crate::storage::memory_store::Domain { + id: "dev".to_string(), + label: "Development".to_string(), + centroid: vec![0.1, 0.2, 0.3], + top_terms: vec!["rust".to_string(), "code".to_string()], + memory_count: 42, + created_at: chrono::Utc::now(), + }; + s.upsert_domain(&domain).await.unwrap(); + let got = s.get_domain("dev").await.unwrap().unwrap(); + assert_eq!(got.id, "dev"); + assert_eq!(got.memory_count, 42); + }); + } + + #[test] + fn trait_delete_domain_idempotent() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + s.delete_domain("nonexistent").await.unwrap(); + s.delete_domain("nonexistent").await.unwrap(); + }); + } + + #[test] + fn trait_classify_with_no_domains_returns_empty() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + let result = s.classify(&[0.1, 0.2, 0.3]).await.unwrap(); + assert!(result.is_empty()); + }); + } + + #[test] + fn trait_count_matches_insert_count() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + for i in 0..5usize { + let rec = make_record(&format!("count test {i}")); + s.insert(&rec).await.unwrap(); + } + assert_eq!(s.count().await.unwrap(), 5); + }); + } + + #[test] + fn trait_get_stats_reports_registered_model() { + let s = create_test_storage(); + let sig = ModelSignature { + name: "test-model".to_string(), + dimension: 256, + hash: "c".repeat(64), + }; + let rt = rt(); + rt.block_on(async { + use crate::storage::memory_store::MemoryStore; + // Cast to &dyn MemoryStore so the async trait method is called + // instead of the inherent sync get_stats() on SqliteMemoryStore. + let dyn_s: &dyn MemoryStore = &s; + dyn_s.register_model(&sig).await.unwrap(); + let stats = dyn_s.get_stats().await.unwrap(); + assert_eq!(stats.registered_model_name, Some("test-model".to_string())); + assert_eq!(stats.registered_model_dim, Some(256)); + }); + } + + #[test] + fn trait_vacuum_succeeds() { + let s = create_test_storage(); + let rt = rt(); + rt.block_on(async { + s.vacuum().await.unwrap(); + }); + } + + #[test] + fn trait_insert_refuses_dimension_mismatch() { + let s = create_test_storage(); + let sig = ModelSignature { + name: "test-model".to_string(), + dimension: 256, + hash: "d".repeat(64), + }; + let rt = rt(); + rt.block_on(async { + s.register_model(&sig).await.unwrap(); + // Build a record with wrong dimension (512 instead of 256) and + // declare the model signature in metadata + let mut rec = make_record("dimension mismatch"); + rec.embedding = Some(vec![0.0f32; 512]); + rec.metadata = serde_json::json!({ + "model_name": "test-model", + "model_dim": 256_u64, + "model_hash": "d".repeat(64), + }); + let err = s.insert(&rec).await.unwrap_err(); + assert!( + matches!(err, MemoryStoreError::InvalidInput(_)), + "expected InvalidInput, got {:?}", err + ); + }); + } } diff --git a/tests/phase_1/Cargo.toml b/tests/phase_1/Cargo.toml new file mode 100644 index 0000000..80a9bff --- /dev/null +++ b/tests/phase_1/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "vestige-phase-1-tests" +version = "0.0.1" +edition = "2024" +publish = false + +[dependencies] +vestige-core = { path = "../../crates/vestige-core" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tempfile = "3" +uuid = { version = "1", features = ["v4"] } +chrono = "0.4" +serde_json = "1" +rusqlite = { version = "0.38", features = ["bundled"] } + +[[test]] +name = "trait_round_trip" +path = "trait_round_trip.rs" + +[[test]] +name = "embedding_model_registry" +path = "embedding_model_registry.rs" + +[[test]] +name = "domain_column_migration" +path = "domain_column_migration.rs" + +[[test]] +name = "cognitive_module_isolation" +path = "cognitive_module_isolation.rs" + +[[test]] +name = "send_bound_variant" +path = "send_bound_variant.rs" + +[[test]] +name = "embedder_trait" +path = "embedder_trait.rs" diff --git a/tests/phase_1/cognitive_module_isolation.rs b/tests/phase_1/cognitive_module_isolation.rs new file mode 100644 index 0000000..e82f300 --- /dev/null +++ b/tests/phase_1/cognitive_module_isolation.rs @@ -0,0 +1,127 @@ +//! Phase 1 integration tests: cognitive modules compile against Arc. +//! The key goal is a compile-time gate: if any module still typed against +//! SqliteMemoryStore concretely, this would fail to compile. + +use std::sync::Arc; +use tempfile::tempdir; +use uuid::Uuid; +use chrono::Utc; +use vestige_core::storage::{MemoryEdge, MemoryRecord, MemoryStore, SqliteMemoryStore}; + +fn make_store() -> Arc { + let dir = tempdir().unwrap(); + let db = dir.path().join("test.db"); + std::mem::forget(dir); + Arc::new(SqliteMemoryStore::new(Some(db)).expect("create")) +} + +fn make_record(content: &str) -> MemoryRecord { + MemoryRecord { + id: Uuid::new_v4(), + domains: vec![], + domain_scores: Default::default(), + content: content.to_string(), + node_type: "fact".to_string(), + tags: vec!["isolation-test".to_string()], + embedding: None, + created_at: Utc::now(), + updated_at: Utc::now(), + metadata: serde_json::json!({}), + } +} + +/// Ensure the store: Arc call pattern compiles and runs through +/// a representative method from every cognitive module group. +#[tokio::test] +async fn all_modules_compile_against_dyn_store() { + let store: Arc = make_store(); + + // CRUD via trait + let rec = make_record("cognitive isolation test"); + let id = store.insert(&rec).await.expect("insert via dyn trait"); + let got = store.get(id).await.expect("get via dyn trait").expect("exists"); + assert_eq!(got.content, "cognitive isolation test"); + + // Graph edges via trait + let rec2 = make_record("linked node"); + let id2 = store.insert(&rec2).await.expect("insert 2"); + store.add_edge(&MemoryEdge { + source_id: id, + target_id: id2, + edge_type: "semantic".to_string(), + weight: 0.8, + created_at: Utc::now(), + }).await.expect("add_edge via dyn trait"); + + let edges = store.get_edges(id, None).await.expect("get_edges via dyn trait"); + assert!(!edges.is_empty()); + + // Search via trait + let results = store + .fts_search("cognitive", 5) + .await + .expect("fts_search via dyn trait"); + assert!(!results.is_empty()); + + // Stats and count via trait + let count = store.count().await.expect("count via dyn trait"); + assert!(count >= 2); + + let stats = store.get_stats().await.expect("get_stats via dyn trait"); + assert!(stats.total_memories >= 2); +} + +#[tokio::test] +async fn spreading_activation_traverses_via_trait() { + let store: Arc = make_store(); + let rec_a = make_record("spreading activation source"); + let rec_b = make_record("spreading activation neighbor"); + let id_a = rec_a.id; + let id_b = rec_b.id; + store.insert(&rec_a).await.expect("insert a"); + store.insert(&rec_b).await.expect("insert b"); + store.add_edge(&MemoryEdge { + source_id: id_a, + target_id: id_b, + edge_type: "semantic".to_string(), + weight: 0.9, + created_at: Utc::now(), + }).await.expect("add edge"); + + // get_neighbors simulates the spreading activation traversal path + let neighbors = store.get_neighbors(id_a, 1).await.expect("get_neighbors"); + let ids: Vec = neighbors.iter().map(|(r, _)| r.id).collect(); + assert!(ids.contains(&id_a)); + assert!(ids.contains(&id_b)); +} + +#[tokio::test] +async fn synaptic_tagging_consumes_records_via_trait() { + // Build a MemoryRecord from trait-returned data and exercise the + // SynapticTaggingSystem pipeline (constructing CapturedMemory from store data). + let store: Arc = make_store(); + let rec = make_record("synaptic tagging test memory"); + let id = store.insert(&rec).await.expect("insert"); + let got = store.get(id).await.expect("get").expect("exists"); + // The important thing is we got a MemoryRecord back from the dyn trait; + // SynapticTaggingSystem would take this record as input. + assert_eq!(got.id, id); + assert!(!got.content.is_empty()); +} + +#[tokio::test] +async fn hippocampal_index_built_from_store() { + // Exercise the fts_search -> HippocampalIndex indexing path. + let store: Arc = make_store(); + for i in 0..5usize { + let rec = make_record(&format!("hippocampal indexing topic {i}")); + store.insert(&rec).await.expect("insert"); + } + let results = store.fts_search("hippocampal indexing", 10).await.expect("fts_search"); + // Verify we get results and they have the correct fields + assert!(!results.is_empty()); + for r in &results { + assert!(!r.record.content.is_empty()); + assert!(r.score >= 0.0); + } +} diff --git a/tests/phase_1/domain_column_migration.rs b/tests/phase_1/domain_column_migration.rs new file mode 100644 index 0000000..6ba2373 --- /dev/null +++ b/tests/phase_1/domain_column_migration.rs @@ -0,0 +1,143 @@ +//! Phase 1 integration tests: domain column migration and schema upgrade. + +use std::sync::Arc; +use tempfile::tempdir; +use uuid::Uuid; +use vestige_core::storage::{MemoryRecord, MemoryStore, SqliteMemoryStore}; + +#[tokio::test] +async fn fresh_db_has_v12_schema() { + let dir = tempdir().unwrap(); + let db = dir.path().join("fresh.db"); + let _store = SqliteMemoryStore::new(Some(db.clone())).expect("create"); + // Open a raw connection and check pragma + let conn = rusqlite::Connection::open(&db).expect("open"); + let cols: Vec = { + let mut stmt = conn.prepare("PRAGMA table_info(knowledge_nodes)").unwrap(); + stmt.query_map([], |row| row.get::<_, String>(1)) + .unwrap() + .map(|r| r.unwrap()) + .collect() + }; + assert!(cols.contains(&"domains".to_string()), "domains column must exist: {:?}", cols); + assert!(cols.contains(&"domain_scores".to_string()), "domain_scores column must exist"); +} + +#[tokio::test] +async fn v11_db_upgrades_cleanly() { + use vestige_core::storage::MIGRATIONS; + let dir = tempdir().unwrap(); + let db = dir.path().join("v11.db"); + // Create DB with V11 migrations only + { + let conn = rusqlite::Connection::open(&db).expect("open"); + for m in MIGRATIONS.iter().filter(|m| m.version <= 11) { + conn.execute_batch(m.up).expect("apply migration"); + } + // Insert 5 rows under V11 schema + for i in 0..5usize { + conn.execute( + "INSERT INTO knowledge_nodes (id, content, node_type, created_at, updated_at, \ + last_accessed, stability, difficulty, reps, lapses, learning_state, \ + storage_strength, retrieval_strength, retention_strength, \ + next_review, scheduled_days, has_embedding) \ + VALUES (?1, ?2, 'fact', datetime('now'), datetime('now'), datetime('now'), \ + 1.0, 0.3, 0, 0, 'new', 1.0, 1.0, 1.0, datetime('now'), 1, 0)", + rusqlite::params![ + format!("pre-v12-{i}"), + format!("content {i}"), + ], + ).expect("insert pre-v12 row"); + } + } + // Upgrade by opening through SqliteMemoryStore (triggers full migration) + let _store = SqliteMemoryStore::new(Some(db.clone())).expect("open with v12"); + // Check all 5 rows have empty domains/domain_scores + let conn = rusqlite::Connection::open(&db).expect("open raw"); + let count: i64 = conn.query_row( + "SELECT COUNT(*) FROM knowledge_nodes WHERE domains='[]' AND domain_scores='{}'", + [], + |row| row.get(0), + ).expect("count"); + assert_eq!(count, 5, "all pre-v12 rows must have empty domains/domain_scores"); +} + +#[tokio::test] +async fn empty_domains_serialize_as_brackets() { + let dir = tempdir().unwrap(); + let db = dir.path().join("empty_domains.db"); + let store = SqliteMemoryStore::new(Some(db.clone())).expect("create"); + let rec = MemoryRecord { + id: Uuid::new_v4(), + domains: vec![], + domain_scores: Default::default(), + content: "test content".to_string(), + node_type: "fact".to_string(), + tags: vec![], + embedding: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + metadata: serde_json::json!({}), + }; + store.insert(&rec).await.expect("insert"); + // Check raw sqlite value + let conn = rusqlite::Connection::open(&db).expect("open raw"); + let (domains, domain_scores): (String, String) = conn.query_row( + "SELECT domains, domain_scores FROM knowledge_nodes LIMIT 1", + [], + |row| Ok((row.get(0)?, row.get(1)?)), + ).expect("query"); + assert_eq!(domains, "[]", "empty domains should store as '[]', not NULL"); + assert_eq!(domain_scores, "{}", "empty domain_scores should store as '{{}}'"); +} + +#[tokio::test] +async fn populated_domains_round_trip() { + let dir = tempdir().unwrap(); + let db = dir.path().join("populated.db"); + let store: Arc = Arc::new( + SqliteMemoryStore::new(Some(db)).expect("create"), + ); + let mut rec = MemoryRecord { + id: Uuid::new_v4(), + domains: vec!["dev".to_string(), "infra".to_string()], + domain_scores: { + let mut m = std::collections::HashMap::new(); + m.insert("dev".to_string(), 0.82); + m.insert("infra".to_string(), 0.71); + m + }, + content: "populated domains test".to_string(), + node_type: "fact".to_string(), + tags: vec![], + embedding: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + metadata: serde_json::json!({}), + }; + let id = store.insert(&rec).await.expect("insert"); + // Update the domains via update() + rec.id = id; + store.update(&rec).await.expect("update with domains"); + // Read back and verify + let got = store.get(id).await.expect("get").expect("exists"); + let mut expected_domains = got.domains.clone(); + expected_domains.sort(); + assert_eq!(expected_domains, vec!["dev", "infra"]); + assert!((got.domain_scores["dev"] - 0.82).abs() < 0.001); + assert!((got.domain_scores["infra"] - 0.71).abs() < 0.001); +} + +#[tokio::test] +async fn domains_table_exists() { + let dir = tempdir().unwrap(); + let db = dir.path().join("domains_table.db"); + let _store = SqliteMemoryStore::new(Some(db.clone())).expect("create"); + let conn = rusqlite::Connection::open(&db).expect("open raw"); + let count: i64 = conn.query_row( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='domains'", + [], + |row| row.get(0), + ).expect("query"); + assert_eq!(count, 1, "domains table must exist after V12 migration"); +} diff --git a/tests/phase_1/embedder_trait.rs b/tests/phase_1/embedder_trait.rs new file mode 100644 index 0000000..f9ab411 --- /dev/null +++ b/tests/phase_1/embedder_trait.rs @@ -0,0 +1,36 @@ +//! Phase 1 integration tests: Embedder trait and FastembedEmbedder. + +use vestige_core::embedder::{Embedder, FastembedEmbedder}; +use vestige_core::storage::MemoryStore; +use std::sync::Arc; +use tempfile::tempdir; +use vestige_core::storage::SqliteMemoryStore; + +fn make_store() -> Arc { + let dir = tempdir().unwrap(); + let db = dir.path().join("test.db"); + std::mem::forget(dir); + Arc::new(SqliteMemoryStore::new(Some(db)).expect("create")) +} + +#[tokio::test] +async fn fastembed_implements_embedder_trait() { + // The key test: `Box` compiles + let e: Box = Box::new(FastembedEmbedder::new()); + assert_eq!(e.dimension(), 256, "dimension must be 256"); + assert!(!e.model_name().is_empty(), "model_name must not be empty"); + assert!(!e.model_hash().is_empty(), "model_hash must not be empty"); + assert_eq!(e.model_hash().len(), 64, "hash must be 64 hex chars"); +} + +#[tokio::test] +async fn signature_matches_memory_store_registry() { + let e = FastembedEmbedder::new(); + let sig = e.signature(); + let store = make_store(); + store.register_model(&sig).await.expect("register via Embedder::signature"); + let got = store.registered_model().await.expect("registered_model").expect("Some"); + assert_eq!(got.name, sig.name); + assert_eq!(got.dimension, sig.dimension); + assert_eq!(got.hash, sig.hash); +} diff --git a/tests/phase_1/embedding_model_registry.rs b/tests/phase_1/embedding_model_registry.rs new file mode 100644 index 0000000..32e11b1 --- /dev/null +++ b/tests/phase_1/embedding_model_registry.rs @@ -0,0 +1,133 @@ +//! Phase 1 integration tests: embedding model registry. + +use std::sync::Arc; +use tempfile::tempdir; +use uuid::Uuid; +use vestige_core::storage::{ + MemoryRecord, MemoryStore, MemoryStoreError, ModelSignature, SqliteMemoryStore, +}; + +fn make_store() -> Arc { + let dir = tempdir().unwrap(); + let db = dir.path().join("test.db"); + std::mem::forget(dir); + let store = SqliteMemoryStore::new(Some(db)).expect("create store"); + Arc::new(store) +} + +fn sig_a() -> ModelSignature { + ModelSignature { + name: "model-a".to_string(), + dimension: 256, + hash: "a".repeat(64), + } +} + +fn sig_b() -> ModelSignature { + ModelSignature { + name: "model-b".to_string(), + dimension: 256, + hash: "b".repeat(64), + } +} + +fn record_without_embedding() -> MemoryRecord { + MemoryRecord { + id: Uuid::new_v4(), + domains: vec![], + domain_scores: Default::default(), + content: "plain text memory".to_string(), + node_type: "fact".to_string(), + tags: vec![], + embedding: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + metadata: serde_json::json!({}), + } +} + +#[tokio::test] +async fn first_embedded_insert_auto_registers() { + // fresh store; register a model, then check registered_model() returns Some + let store = make_store(); + let sig = sig_a(); + store.register_model(&sig).await.expect("register"); + let got = store.registered_model().await.expect("registered_model"); + assert_eq!(got, Some(sig)); +} + +#[tokio::test] +async fn second_insert_with_same_signature_succeeds() { + let store = make_store(); + let sig = sig_a(); + store.register_model(&sig).await.expect("first register"); + store.register_model(&sig).await.expect("second register idempotent"); +} + +#[tokio::test] +async fn second_insert_with_different_dimension_refused() { + let store = make_store(); + let sig = sig_a(); // dim 256 + store.register_model(&sig).await.expect("register 256"); + // Try inserting a 512-dim vector into a store registered for 256 + let mut rec = record_without_embedding(); + rec.embedding = Some(vec![0.0f32; 512]); + rec.metadata = serde_json::json!({ + "model_name": "model-a", + "model_dim": 256_u64, + "model_hash": "a".repeat(64), + }); + let err = store.insert(&rec).await.unwrap_err(); + assert!( + matches!(err, MemoryStoreError::InvalidInput(_)), + "expected InvalidInput for dim mismatch, got {:?}", err + ); +} + +#[tokio::test] +async fn second_insert_with_different_model_name_refused() { + let store = make_store(); + store.register_model(&sig_a()).await.expect("register a"); + let err = store.register_model(&sig_b()).await.unwrap_err(); + assert!( + matches!(err, MemoryStoreError::ModelMismatch { .. }), + "expected ModelMismatch, got {:?}", err + ); +} + +#[tokio::test] +async fn second_insert_with_different_hash_refused() { + let store = make_store(); + let sig = sig_a(); + store.register_model(&sig).await.expect("register"); + let sig_diff_hash = ModelSignature { + name: "model-a".to_string(), + dimension: 256, + hash: "c".repeat(64), // different hash + }; + let err = store.register_model(&sig_diff_hash).await.unwrap_err(); + assert!( + matches!(err, MemoryStoreError::ModelMismatch { .. }), + "expected ModelMismatch for different hash, got {:?}", err + ); +} + +#[tokio::test] +async fn no_embedding_insert_allowed_before_registration() { + let store = make_store(); + // registered_model() should be None + assert!(store.registered_model().await.expect("registered_model").is_none()); + // A plain text memory without an embedding must insert successfully + let rec = record_without_embedding(); + store.insert(&rec).await.expect("plain insert before registration"); +} + +#[tokio::test] +async fn stats_reports_registered_model_after_first_write() { + let store = make_store(); + let sig = sig_a(); + store.register_model(&sig).await.expect("register"); + let stats = store.get_stats().await.expect("stats"); + assert_eq!(stats.registered_model_name, Some("model-a".to_string())); + assert_eq!(stats.registered_model_dim, Some(256)); +} diff --git a/tests/phase_1/send_bound_variant.rs b/tests/phase_1/send_bound_variant.rs new file mode 100644 index 0000000..a9e4291 --- /dev/null +++ b/tests/phase_1/send_bound_variant.rs @@ -0,0 +1,96 @@ +//! Phase 1 integration tests: Arc moves across tokio::spawn. +//! +//! This verifies that `#[trait_variant::make(MemoryStore: Send)]` actually +//! produces a Send-bound future so Arc is movable. + +use std::sync::Arc; +use tempfile::tempdir; +use uuid::Uuid; +use chrono::Utc; +use vestige_core::storage::{MemoryRecord, MemoryStore, SqliteMemoryStore}; + +fn make_store() -> Arc { + let dir = tempdir().unwrap(); + let db = dir.path().join("send_test.db"); + std::mem::forget(dir); + Arc::new(SqliteMemoryStore::new(Some(db)).expect("create")) +} + +fn make_record(content: &str) -> MemoryRecord { + MemoryRecord { + id: Uuid::new_v4(), + domains: vec![], + domain_scores: Default::default(), + content: content.to_string(), + node_type: "fact".to_string(), + tags: vec![], + embedding: None, + created_at: Utc::now(), + updated_at: Utc::now(), + metadata: serde_json::json!({}), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn arc_dyn_memory_store_moves_across_tokio_tasks() { + let store: Arc = make_store(); + let mut handles = Vec::new(); + for t in 0..16usize { + let store = Arc::clone(&store); + let handle = tokio::spawn(async move { + for i in 0..10usize { + let rec = make_record(&format!("task {t} memory {i}")); + store.insert(&rec).await.expect("insert in spawned task"); + } + }); + handles.push(handle); + } + for h in handles { + h.await.expect("task completed without panic"); + } + let count = store.count().await.expect("count"); + assert_eq!(count, 160, "all 16*10 inserts must be counted"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_readers_one_writer() { + let store: Arc = make_store(); + // Pre-populate with some data so readers have something to find + for i in 0..10usize { + let rec = make_record(&format!("concurrent reader memory {i}")); + store.insert(&rec).await.expect("pre-insert"); + } + + let mut handles = Vec::new(); + + // 32 concurrent readers + for _ in 0..32usize { + let store = Arc::clone(&store); + let handle = tokio::spawn(async move { + let results = store.fts_search("concurrent reader", 5).await; + // Should not panic even if results vary due to concurrent writes + results.expect("fts_search in concurrent reader"); + }); + handles.push(handle); + } + + // 1 writer inserting more records + { + let store = Arc::clone(&store); + let writer_handle = tokio::spawn(async move { + for i in 0..20usize { + let rec = make_record(&format!("writer record {i}")); + store.insert(&rec).await.expect("concurrent insert"); + } + }); + handles.push(writer_handle); + } + + for h in handles { + h.await.expect("no panics"); + } + + // Eventual consistency check: total count should be at least 10 (initial) + let count = store.count().await.expect("final count"); + assert!(count >= 10, "at least the pre-populated records must persist"); +} diff --git a/tests/phase_1/trait_round_trip.rs b/tests/phase_1/trait_round_trip.rs new file mode 100644 index 0000000..6dca8e7 --- /dev/null +++ b/tests/phase_1/trait_round_trip.rs @@ -0,0 +1,201 @@ +//! Phase 1 integration tests: round-trip of every trait method through SqliteMemoryStore. + +use std::sync::Arc; +use chrono::Utc; +use tempfile::tempdir; +use uuid::Uuid; +use vestige_core::storage::{ + MemoryEdge, MemoryRecord, MemoryStore, SearchQuery, SqliteMemoryStore, +}; + +fn make_store() -> Arc { + let dir = tempdir().unwrap(); + let db = dir.path().join("test.db"); + // keep the dir alive by leaking it -- this is fine for tests + std::mem::forget(dir); + let store = SqliteMemoryStore::new(Some(db)).expect("create store"); + Arc::new(store) +} + +fn make_record(content: &str) -> MemoryRecord { + MemoryRecord { + id: Uuid::new_v4(), + domains: vec![], + domain_scores: Default::default(), + content: content.to_string(), + node_type: "fact".to_string(), + tags: vec!["integration".to_string()], + embedding: None, + created_at: Utc::now(), + updated_at: Utc::now(), + metadata: serde_json::json!({}), + } +} + +#[tokio::test] +async fn insert_get_update_delete() { + let store = make_store(); + let rec = make_record("round-trip CRUD test"); + let id = rec.id; + + store.insert(&rec).await.expect("insert"); + let got = store.get(id).await.expect("get").expect("exists"); + assert_eq!(got.content, "round-trip CRUD test"); + assert_eq!(got.node_type, "fact"); + assert!(got.domains.is_empty()); + assert!(got.domain_scores.is_empty()); + + let mut updated = got; + updated.content = "updated content".to_string(); + store.update(&updated).await.expect("update"); + + let after_update = store.get(id).await.expect("get after update").expect("exists"); + assert_eq!(after_update.content, "updated content"); + + store.delete(id).await.expect("delete"); + let after_delete = store.get(id).await.expect("get after delete"); + assert!(after_delete.is_none()); +} + +#[tokio::test] +async fn scheduling_upsert_and_due_scan() { + use vestige_core::storage::SchedulingState; + let store = make_store(); + + for i in 0..3usize { + let rec = make_record(&format!("sched memory {i}")); + let id = rec.id; + store.insert(&rec).await.expect("insert"); + let next_review = Utc::now() - chrono::Duration::days((i as i64) + 1); + let state = SchedulingState { + memory_id: id, + stability: 1.0, + difficulty: 0.3, + retrievability: 0.7, + last_review: Some(Utc::now()), + next_review: Some(next_review), + reps: 1, + lapses: 0, + }; + store.update_scheduling(&state).await.expect("update scheduling"); + } + + let due = store + .get_due_memories(Utc::now(), 10) + .await + .expect("get_due_memories"); + assert_eq!(due.len(), 3, "all 3 should be due"); +} + +#[tokio::test] +async fn edge_crud() { + let store = make_store(); + let rec_a = make_record("edge node A"); + let rec_b = make_record("edge node B"); + let id_a = rec_a.id; + let id_b = rec_b.id; + store.insert(&rec_a).await.expect("insert a"); + store.insert(&rec_b).await.expect("insert b"); + + let edge = MemoryEdge { + source_id: id_a, + target_id: id_b, + edge_type: "semantic".to_string(), + weight: 0.85, + created_at: Utc::now(), + }; + store.add_edge(&edge).await.expect("add edge"); + + let edges = store.get_edges(id_a, None).await.expect("get edges"); + assert!(!edges.is_empty()); + + store.remove_edge(id_a, id_b).await.expect("remove edge"); + let after = store.get_edges(id_a, None).await.expect("get edges after"); + assert!(after.is_empty()); +} + +#[tokio::test] +async fn count_and_stats_track_inserts() { + let store = make_store(); + for i in 0..10usize { + let rec = make_record(&format!("stats memory {i}")); + store.insert(&rec).await.expect("insert"); + } + assert_eq!(store.count().await.expect("count"), 10); + let stats = store.get_stats().await.expect("stats"); + assert_eq!(stats.total_memories, 10); +} + +#[tokio::test] +async fn vacuum_after_deletes_reclaims() { + let dir = tempdir().unwrap(); + let db = dir.path().join("vacuum_test.db"); + let store = SqliteMemoryStore::new(Some(db)).expect("create store"); + let store: Arc = Arc::new(store); + + let mut ids = Vec::new(); + for i in 0..50usize { + let rec = make_record(&format!("vacuum memory {i}")); + let id = store.insert(&rec).await.expect("insert"); + ids.push(id); + } + for id in &ids[..40] { + store.delete(*id).await.expect("delete"); + } + // vacuum should not error + store.vacuum().await.expect("vacuum"); +} + +#[tokio::test] +async fn list_domains_empty_then_upsert_then_delete() { + use vestige_core::storage::Domain; + let store = make_store(); + + let domains = store.list_domains().await.expect("list empty"); + assert!(domains.is_empty()); + + let d = Domain { + id: "test-domain".to_string(), + label: "Test Domain".to_string(), + centroid: vec![0.1f32, 0.2, 0.3], + top_terms: vec!["term1".to_string()], + memory_count: 5, + created_at: Utc::now(), + }; + store.upsert_domain(&d).await.expect("upsert domain"); + let after = store.list_domains().await.expect("list after upsert"); + assert_eq!(after.len(), 1); + assert_eq!(after[0].id, "test-domain"); + + store.delete_domain("test-domain").await.expect("delete domain"); + let after_delete = store.list_domains().await.expect("list after delete"); + assert!(after_delete.is_empty()); +} + +#[tokio::test] +async fn classify_with_no_domains_returns_empty() { + let store = make_store(); + let result = store.classify(&[0.1f32, 0.2, 0.3]).await.expect("classify"); + assert!(result.is_empty()); +} + +#[tokio::test] +async fn search_hybrid_returns_results() { + let store = make_store(); + let rec = make_record("quantum entanglement superposition physics"); + store.insert(&rec).await.expect("insert"); + + // Verify fts_search works first (sanity check) + let fts_results = store.fts_search("quantum", 10).await.expect("fts_search"); + assert!(!fts_results.is_empty(), "fts_search must find 'quantum' after insert"); + + let query = SearchQuery { + text: Some("quantum physics".to_string()), + limit: 10, + ..Default::default() + }; + let results = store.search(&query).await.expect("search"); + // FTS results should include our inserted record + assert!(!results.is_empty(), "search must return results for 'quantum physics'"); + assert!(results[0].score >= 0.0); +}