Merge pull request #33 from Unforgettable-93/unforgettable-fixes

fix: keyword-first search (Stage 0) + type filter SQL pushdown
This commit is contained in:
Sam Valladares 2026-04-14 14:23:35 -05:00 committed by GitHub
commit 27b562d64d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 674 additions and 328 deletions

605
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1392,7 +1392,7 @@ impl Storage {
Ok(similarity_results) Ok(similarity_results)
} }
/// Hybrid search /// Hybrid search (delegates to hybrid_search_filtered with no type filters)
#[cfg(all(feature = "embeddings", feature = "vector-search"))] #[cfg(all(feature = "embeddings", feature = "vector-search"))]
pub fn hybrid_search( pub fn hybrid_search(
&self, &self,
@ -1401,10 +1401,40 @@ impl Storage {
keyword_weight: f32, keyword_weight: f32,
semantic_weight: f32, semantic_weight: f32,
) -> Result<Vec<SearchResult>> { ) -> Result<Vec<SearchResult>> {
let keyword_results = self.keyword_search_with_scores(query, limit * 2)?; self.hybrid_search_filtered(query, limit, keyword_weight, semantic_weight, None, None)
}
/// Hybrid search with optional type filtering pushed into the storage layer.
///
/// When `include_types` is `Some`, only nodes whose `node_type` matches one of
/// the given strings are returned. When `exclude_types` is `Some`, nodes whose
/// `node_type` matches are excluded. `include_types` takes precedence over
/// `exclude_types`. Both are case-sensitive and compared against the stored
/// `node_type` value.
#[cfg(all(feature = "embeddings", feature = "vector-search"))]
pub fn hybrid_search_filtered(
&self,
query: &str,
limit: i32,
keyword_weight: f32,
semantic_weight: f32,
include_types: Option<&[String]>,
exclude_types: Option<&[String]>,
) -> Result<Vec<SearchResult>> {
let has_type_filter = include_types.is_some() || exclude_types.is_some();
// Over-fetch more aggressively when type filters are active so that
// after filtering we still have enough candidates to fill `limit`.
let overfetch_factor = if has_type_filter { 4 } else { 2 };
let keyword_results = self.keyword_search_with_scores(
query,
limit * overfetch_factor,
include_types,
exclude_types,
)?;
let semantic_results = if self.embedding_service.is_ready() { let semantic_results = if self.embedding_service.is_ready() {
self.semantic_search_raw(query, limit * 2)? self.semantic_search_raw(query, limit * overfetch_factor)?
} else { } else {
vec![] vec![]
}; };
@ -1417,8 +1447,22 @@ impl Storage {
let mut results = Vec::with_capacity(limit as usize); let mut results = Vec::with_capacity(limit as usize);
for (node_id, combined_score) in combined.into_iter().take(limit as usize) { for (node_id, combined_score) in combined.into_iter() {
if results.len() >= limit as usize {
break;
}
if let Some(node) = self.get_node(&node_id)? { if let Some(node) = self.get_node(&node_id)? {
// Apply type filtering for results that came from semantic search
// (keyword search already filters in SQL, but semantic search cannot)
if let Some(includes) = include_types {
if !includes.iter().any(|t| t == &node.node_type) {
continue;
}
} else if let Some(excludes) = exclude_types {
if excludes.iter().any(|t| t == &node.node_type) {
continue;
}
}
let keyword_score = keyword_results let keyword_score = keyword_results
.iter() .iter()
.find(|(id, _)| id == &node_id) .find(|(id, _)| id == &node_id)
@ -1486,23 +1530,71 @@ impl Storage {
Ok(results) Ok(results)
} }
/// Keyword search returning scores /// Keyword search returning scores, with optional type filtering in the SQL query.
#[cfg(all(feature = "embeddings", feature = "vector-search"))] #[cfg(all(feature = "embeddings", feature = "vector-search"))]
fn keyword_search_with_scores(&self, query: &str, limit: i32) -> Result<Vec<(String, f32)>> { fn keyword_search_with_scores(
&self,
query: &str,
limit: i32,
include_types: Option<&[String]>,
exclude_types: Option<&[String]>,
) -> Result<Vec<(String, f32)>> {
let sanitized_query = sanitize_fts5_query(query); let sanitized_query = sanitize_fts5_query(query);
// Build the type filter clause and collect parameter values.
// We use numbered parameters: ?1 = query, ?2 = limit, ?3.. = type strings.
let mut type_clause = String::new();
let type_values: Vec<&str>;
if let Some(includes) = include_types {
if !includes.is_empty() {
let placeholders: Vec<String> = (0..includes.len())
.map(|i| format!("?{}", i + 3))
.collect();
type_clause = format!(" AND n.node_type IN ({})", placeholders.join(","));
type_values = includes.iter().map(|s| s.as_str()).collect();
} else {
type_values = vec![];
}
} else if let Some(excludes) = exclude_types {
if !excludes.is_empty() {
let placeholders: Vec<String> = (0..excludes.len())
.map(|i| format!("?{}", i + 3))
.collect();
type_clause = format!(" AND n.node_type NOT IN ({})", placeholders.join(","));
type_values = excludes.iter().map(|s| s.as_str()).collect();
} else {
type_values = vec![];
}
} else {
type_values = vec![];
}
let sql = format!(
"SELECT n.id, rank FROM knowledge_nodes n
JOIN knowledge_fts fts ON n.id = fts.id
WHERE knowledge_fts MATCH ?1{}
ORDER BY rank
LIMIT ?2",
type_clause
);
let reader = self.reader.lock() let reader = self.reader.lock()
.map_err(|_| StorageError::Init("Reader lock poisoned".into()))?; .map_err(|_| StorageError::Init("Reader lock poisoned".into()))?;
let mut stmt = reader.prepare( let mut stmt = reader.prepare(&sql)?;
"SELECT n.id, rank FROM knowledge_nodes n
JOIN knowledge_fts fts ON n.id = fts.id // Build the parameter list: [query, limit, ...type_values]
WHERE knowledge_fts MATCH ?1 let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
ORDER BY rank param_values.push(Box::new(sanitized_query.clone()));
LIMIT ?2", param_values.push(Box::new(limit));
)?; for tv in &type_values {
param_values.push(Box::new(tv.to_string()));
}
let params_ref: Vec<&dyn rusqlite::ToSql> =
param_values.iter().map(|p| p.as_ref()).collect();
let results: Vec<(String, f32)> = stmt let results: Vec<(String, f32)> = stmt
.query_map(params![sanitized_query, limit], |row| { .query_map(params_ref.as_slice(), |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)? as f32)) Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)? as f32))
})? })?
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
@ -3836,4 +3928,135 @@ mod tests {
// Static method should not panic even if no backups exist // Static method should not panic even if no backups exist
let _ = Storage::get_last_backup_timestamp(); let _ = Storage::get_last_backup_timestamp();
} }
#[test]
fn test_keyword_search_with_include_types() {
let storage = create_test_storage();
// Ingest nodes of different types all containing the word "quantum"
storage.ingest(IngestInput {
content: "Quantum mechanics is fundamental to physics".to_string(),
node_type: "fact".to_string(),
..Default::default()
}).unwrap();
storage.ingest(IngestInput {
content: "Quantum computing uses qubits for calculation".to_string(),
node_type: "concept".to_string(),
..Default::default()
}).unwrap();
storage.ingest(IngestInput {
content: "Quantum entanglement was demonstrated in the lab".to_string(),
node_type: "event".to_string(),
..Default::default()
}).unwrap();
// Search with include_types = ["fact"] — should only return the fact
let include = vec!["fact".to_string()];
let results = storage.hybrid_search_filtered(
"quantum", 10, 0.3, 0.7,
Some(&include), None,
).unwrap();
assert!(!results.is_empty(), "should return at least one result");
for r in &results {
assert_eq!(r.node.node_type, "fact",
"include_types=[fact] should only return facts, got: {}", r.node.node_type);
}
}
#[test]
fn test_keyword_search_with_exclude_types() {
let storage = create_test_storage();
storage.ingest(IngestInput {
content: "Photosynthesis converts sunlight to energy".to_string(),
node_type: "fact".to_string(),
..Default::default()
}).unwrap();
storage.ingest(IngestInput {
content: "Photosynthesis is a complex biochemical process".to_string(),
node_type: "reflection".to_string(),
..Default::default()
}).unwrap();
// Search with exclude_types = ["reflection"] — should skip the reflection
let exclude = vec!["reflection".to_string()];
let results = storage.hybrid_search_filtered(
"photosynthesis", 10, 0.3, 0.7,
None, Some(&exclude),
).unwrap();
assert!(!results.is_empty(), "should return at least one result");
for r in &results {
assert_ne!(r.node.node_type, "reflection",
"exclude_types=[reflection] should not return reflections");
}
}
#[test]
fn test_include_types_takes_precedence_over_exclude() {
let storage = create_test_storage();
storage.ingest(IngestInput {
content: "Gravity holds planets in orbit around stars".to_string(),
node_type: "fact".to_string(),
..Default::default()
}).unwrap();
storage.ingest(IngestInput {
content: "Gravity waves were first detected by LIGO".to_string(),
node_type: "event".to_string(),
..Default::default()
}).unwrap();
// When both are provided, include_types wins
let include = vec!["fact".to_string()];
let exclude = vec!["fact".to_string()];
let results = storage.hybrid_search_filtered(
"gravity", 10, 0.3, 0.7,
Some(&include), Some(&exclude),
).unwrap();
// include_types takes precedence — facts should be returned
assert!(!results.is_empty());
for r in &results {
assert_eq!(r.node.node_type, "fact");
}
}
#[test]
fn test_type_filter_with_no_matches_returns_empty() {
let storage = create_test_storage();
storage.ingest(IngestInput {
content: "DNA carries genetic information in cells".to_string(),
node_type: "fact".to_string(),
..Default::default()
}).unwrap();
// Search for a type that doesn't exist among matches
let include = vec!["person".to_string()];
let results = storage.hybrid_search_filtered(
"DNA", 10, 0.3, 0.7,
Some(&include), None,
).unwrap();
assert!(results.is_empty(),
"filtering for a non-matching type should return empty results");
}
#[test]
fn test_hybrid_search_backward_compat() {
// Ensure the original hybrid_search (no type filters) still works
let storage = create_test_storage();
storage.ingest(IngestInput {
content: "Neurons transmit electrical signals in the brain".to_string(),
node_type: "fact".to_string(),
..Default::default()
}).unwrap();
let results = storage.hybrid_search("neurons", 10, 0.3, 0.7).unwrap();
assert!(!results.is_empty());
assert!(results[0].node.content.contains("Neurons"));
}
} }

View file

@ -66,6 +66,16 @@ pub fn schema() -> Value {
"items": { "type": "string" }, "items": { "type": "string" },
"description": "Optional topics for context-dependent retrieval boosting" "description": "Optional topics for context-dependent retrieval boosting"
}, },
"exclude_types": {
"type": "array",
"items": { "type": "string" },
"description": "Node types to exclude from results (e.g., ['reflection']). Reflections are excluded by default to prevent polluting factual queries."
},
"include_types": {
"type": "array",
"items": { "type": "string" },
"description": "If set, only return nodes of these types. Overrides exclude_types."
},
"token_budget": { "token_budget": {
"type": "integer", "type": "integer",
"description": "Max tokens for response. Server truncates content to fit budget. Use memory(action='get') for full content of specific IDs. With 1M context models, budgets up to 100K are practical.", "description": "Max tokens for response. Server truncates content to fit budget. Use memory(action='get') for full content of specific IDs. With 1M context models, budgets up to 100K are practical.",
@ -96,6 +106,10 @@ struct SearchArgs {
detail_level: Option<String>, detail_level: Option<String>,
#[serde(alias = "context_topics")] #[serde(alias = "context_topics")]
context_topics: Option<Vec<String>>, context_topics: Option<Vec<String>>,
#[serde(alias = "exclude_types")]
exclude_types: Option<Vec<String>>,
#[serde(alias = "include_types")]
include_types: Option<Vec<String>>,
#[serde(alias = "token_budget")] #[serde(alias = "token_budget")]
token_budget: Option<i32>, token_budget: Option<i32>,
#[serde(alias = "retrieval_mode")] #[serde(alias = "retrieval_mode")]
@ -163,6 +177,39 @@ pub async fn execute(
let keyword_weight = 0.3_f32; let keyword_weight = 0.3_f32;
let semantic_weight = 0.7_f32; let semantic_weight = 0.7_f32;
// ====================================================================
// STAGE 0: Keyword-first search (dedicated keyword-only pass)
// ====================================================================
// Run a small keyword-only search to guarantee strong keyword matches
// survive into the candidate pool, even with small limits/overfetch.
// Without this, exact keyword matches (e.g. unique proper nouns) get
// buried by semantic scoring in the hybrid search.
let keyword_first_limit = 10_i32;
let keyword_priority_threshold: f32 = 0.8;
let keyword_first_results = storage
.hybrid_search_filtered(
&args.query,
keyword_first_limit,
1.0, // keyword_weight = 1.0 (keyword-only)
0.0, // semantic_weight = 0.0
args.include_types.as_deref(),
args.exclude_types.as_deref(),
)
.map_err(|e| e.to_string())?;
// Collect keyword-priority results (keyword_score >= threshold)
let mut keyword_priority_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut keyword_priority_results: Vec<vestige_core::SearchResult> = Vec::new();
for r in keyword_first_results {
if r.keyword_score.unwrap_or(0.0) >= keyword_priority_threshold
&& r.node.retention_strength >= min_retention
{
keyword_priority_ids.insert(r.node.id.clone());
keyword_priority_results.push(r);
}
}
// ==================================================================== // ====================================================================
// STAGE 1: Hybrid search with Nx over-fetch for reranking pool // STAGE 1: Hybrid search with Nx over-fetch for reranking pool
// ==================================================================== // ====================================================================
@ -174,7 +221,14 @@ pub async fn execute(
let overfetch_limit = (limit * overfetch_multiplier).min(100); // Cap at 100 to avoid excessive DB load let overfetch_limit = (limit * overfetch_multiplier).min(100); // Cap at 100 to avoid excessive DB load
let results = storage let results = storage
.hybrid_search(&args.query, overfetch_limit, keyword_weight, semantic_weight) .hybrid_search_filtered(
&args.query,
overfetch_limit,
keyword_weight,
semantic_weight,
args.include_types.as_deref(),
args.exclude_types.as_deref(),
)
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;
// Filter by min_retention and min_similarity first (cheap filters) // Filter by min_retention and min_similarity first (cheap filters)
@ -194,24 +248,86 @@ pub async fn execute(
.collect(); .collect();
// ==================================================================== // ====================================================================
// STAGE 2: Reranker (BM25-like rescoring, trim to requested limit) // Dedup: merge Stage 0 keyword-priority results into Stage 1 results
// ==================================================================== // ====================================================================
if let Ok(mut cog) = cognitive.try_lock() { for kp in &keyword_priority_results {
let candidates: Vec<_> = filtered_results if let Some(existing) = filtered_results.iter_mut().find(|r| r.node.id == kp.node.id) {
.iter() // Preserve keyword_score from Stage 0 (keyword-only search is authoritative)
.map(|r| (r.clone(), r.node.content.clone())) if kp.keyword_score.unwrap_or(0.0) > existing.keyword_score.unwrap_or(0.0) {
.collect(); existing.keyword_score = kp.keyword_score;
}
if let Ok(reranked) = cog.reranker.rerank(&args.query, candidates, Some(limit as usize)) { if kp.combined_score > existing.combined_score {
// Replace filtered_results with reranked items (preserves original SearchResult) existing.combined_score = kp.combined_score;
filtered_results = reranked.into_iter().map(|rr| rr.item).collect();
} else {
// Reranker failed — fall back to original order, just truncate
filtered_results.truncate(limit as usize);
} }
} else { } else {
// Couldn't acquire cognitive lock — truncate to limit // New result from Stage 0 not in Stage 1 — add it
filtered_results.truncate(limit as usize); filtered_results.push(kp.clone());
}
}
// ====================================================================
// STAGE 2: Reranker (BM25-like rescoring, trim to requested limit)
// ====================================================================
// Keyword bypass: results with strong keyword matches (>= 0.8) skip the
// cross-encoder entirely and are placed above reranked results. This
// prevents the cross-encoder from burying exact/near-exact keyword hits
// (e.g. unique proper nouns) beneath semantically-similar but unrelated
// results.
{
let keyword_bypass_threshold: f32 = 0.8;
let limit_usize = limit as usize;
// Partition: keyword bypass vs. candidates for reranking
let mut bypass_results: Vec<vestige_core::SearchResult> = Vec::new();
let mut rerank_candidates: Vec<(vestige_core::SearchResult, String)> = Vec::new();
for r in filtered_results.iter() {
if r.keyword_score.unwrap_or(0.0) >= keyword_bypass_threshold {
bypass_results.push(r.clone());
} else {
rerank_candidates.push((r.clone(), r.node.content.clone()));
}
}
// Boost bypass results so they survive later pipeline stages
// (temporal, FSRS, utility, competition) and the final re-sort.
for r in bypass_results.iter_mut() {
r.combined_score *= 2.0;
}
bypass_results.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
// Rerank the remaining candidates
let reranked_results: Vec<vestige_core::SearchResult> = if rerank_candidates.is_empty() {
Vec::new()
} else if let Ok(mut cog) = cognitive.try_lock() {
if let Ok(reranked) = cog.reranker.rerank(&args.query, rerank_candidates, Some(limit_usize)) {
reranked.into_iter().map(|rr| rr.item).collect()
} else {
// Reranker failed — fall back to original order for non-bypass candidates
filtered_results
.iter()
.filter(|r| r.keyword_score.unwrap_or(0.0) < keyword_bypass_threshold)
.cloned()
.collect()
}
} else {
// Couldn't acquire cognitive lock — use original order
filtered_results
.iter()
.filter(|r| r.keyword_score.unwrap_or(0.0) < keyword_bypass_threshold)
.cloned()
.collect()
};
// Merge: bypass first, then reranked, trim to limit
filtered_results = bypass_results;
filtered_results.extend(reranked_results);
filtered_results.truncate(limit_usize);
} }
// ==================================================================== // ====================================================================