mirror of
https://github.com/samvallad33/vestige.git
synced 2026-04-24 16:26:22 +02:00
Merge pull request #33 from Unforgettable-93/unforgettable-fixes
fix: keyword-first search (Stage 0) + type filter SQL pushdown
This commit is contained in:
commit
27b562d64d
3 changed files with 674 additions and 328 deletions
605
Cargo.lock
generated
605
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1392,7 +1392,7 @@ impl Storage {
|
|||
Ok(similarity_results)
|
||||
}
|
||||
|
||||
/// Hybrid search
|
||||
/// Hybrid search (delegates to hybrid_search_filtered with no type filters)
|
||||
#[cfg(all(feature = "embeddings", feature = "vector-search"))]
|
||||
pub fn hybrid_search(
|
||||
&self,
|
||||
|
|
@ -1401,10 +1401,40 @@ impl Storage {
|
|||
keyword_weight: f32,
|
||||
semantic_weight: f32,
|
||||
) -> Result<Vec<SearchResult>> {
|
||||
let keyword_results = self.keyword_search_with_scores(query, limit * 2)?;
|
||||
self.hybrid_search_filtered(query, limit, keyword_weight, semantic_weight, None, None)
|
||||
}
|
||||
|
||||
/// Hybrid search with optional type filtering pushed into the storage layer.
|
||||
///
|
||||
/// When `include_types` is `Some`, only nodes whose `node_type` matches one of
|
||||
/// the given strings are returned. When `exclude_types` is `Some`, nodes whose
|
||||
/// `node_type` matches are excluded. `include_types` takes precedence over
|
||||
/// `exclude_types`. Both are case-sensitive and compared against the stored
|
||||
/// `node_type` value.
|
||||
#[cfg(all(feature = "embeddings", feature = "vector-search"))]
|
||||
pub fn hybrid_search_filtered(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: i32,
|
||||
keyword_weight: f32,
|
||||
semantic_weight: f32,
|
||||
include_types: Option<&[String]>,
|
||||
exclude_types: Option<&[String]>,
|
||||
) -> Result<Vec<SearchResult>> {
|
||||
let has_type_filter = include_types.is_some() || exclude_types.is_some();
|
||||
// Over-fetch more aggressively when type filters are active so that
|
||||
// after filtering we still have enough candidates to fill `limit`.
|
||||
let overfetch_factor = if has_type_filter { 4 } else { 2 };
|
||||
|
||||
let keyword_results = self.keyword_search_with_scores(
|
||||
query,
|
||||
limit * overfetch_factor,
|
||||
include_types,
|
||||
exclude_types,
|
||||
)?;
|
||||
|
||||
let semantic_results = if self.embedding_service.is_ready() {
|
||||
self.semantic_search_raw(query, limit * 2)?
|
||||
self.semantic_search_raw(query, limit * overfetch_factor)?
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
|
@ -1417,8 +1447,22 @@ impl Storage {
|
|||
|
||||
let mut results = Vec::with_capacity(limit as usize);
|
||||
|
||||
for (node_id, combined_score) in combined.into_iter().take(limit as usize) {
|
||||
for (node_id, combined_score) in combined.into_iter() {
|
||||
if results.len() >= limit as usize {
|
||||
break;
|
||||
}
|
||||
if let Some(node) = self.get_node(&node_id)? {
|
||||
// Apply type filtering for results that came from semantic search
|
||||
// (keyword search already filters in SQL, but semantic search cannot)
|
||||
if let Some(includes) = include_types {
|
||||
if !includes.iter().any(|t| t == &node.node_type) {
|
||||
continue;
|
||||
}
|
||||
} else if let Some(excludes) = exclude_types {
|
||||
if excludes.iter().any(|t| t == &node.node_type) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let keyword_score = keyword_results
|
||||
.iter()
|
||||
.find(|(id, _)| id == &node_id)
|
||||
|
|
@ -1486,23 +1530,71 @@ impl Storage {
|
|||
Ok(results)
|
||||
}
|
||||
|
||||
/// Keyword search returning scores
|
||||
/// Keyword search returning scores, with optional type filtering in the SQL query.
|
||||
#[cfg(all(feature = "embeddings", feature = "vector-search"))]
|
||||
fn keyword_search_with_scores(&self, query: &str, limit: i32) -> Result<Vec<(String, f32)>> {
|
||||
fn keyword_search_with_scores(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: i32,
|
||||
include_types: Option<&[String]>,
|
||||
exclude_types: Option<&[String]>,
|
||||
) -> Result<Vec<(String, f32)>> {
|
||||
let sanitized_query = sanitize_fts5_query(query);
|
||||
|
||||
// Build the type filter clause and collect parameter values.
|
||||
// We use numbered parameters: ?1 = query, ?2 = limit, ?3.. = type strings.
|
||||
let mut type_clause = String::new();
|
||||
let type_values: Vec<&str>;
|
||||
|
||||
if let Some(includes) = include_types {
|
||||
if !includes.is_empty() {
|
||||
let placeholders: Vec<String> = (0..includes.len())
|
||||
.map(|i| format!("?{}", i + 3))
|
||||
.collect();
|
||||
type_clause = format!(" AND n.node_type IN ({})", placeholders.join(","));
|
||||
type_values = includes.iter().map(|s| s.as_str()).collect();
|
||||
} else {
|
||||
type_values = vec![];
|
||||
}
|
||||
} else if let Some(excludes) = exclude_types {
|
||||
if !excludes.is_empty() {
|
||||
let placeholders: Vec<String> = (0..excludes.len())
|
||||
.map(|i| format!("?{}", i + 3))
|
||||
.collect();
|
||||
type_clause = format!(" AND n.node_type NOT IN ({})", placeholders.join(","));
|
||||
type_values = excludes.iter().map(|s| s.as_str()).collect();
|
||||
} else {
|
||||
type_values = vec![];
|
||||
}
|
||||
} else {
|
||||
type_values = vec![];
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
"SELECT n.id, rank FROM knowledge_nodes n
|
||||
JOIN knowledge_fts fts ON n.id = fts.id
|
||||
WHERE knowledge_fts MATCH ?1{}
|
||||
ORDER BY rank
|
||||
LIMIT ?2",
|
||||
type_clause
|
||||
);
|
||||
|
||||
let reader = self.reader.lock()
|
||||
.map_err(|_| StorageError::Init("Reader lock poisoned".into()))?;
|
||||
let mut stmt = reader.prepare(
|
||||
"SELECT n.id, rank FROM knowledge_nodes n
|
||||
JOIN knowledge_fts fts ON n.id = fts.id
|
||||
WHERE knowledge_fts MATCH ?1
|
||||
ORDER BY rank
|
||||
LIMIT ?2",
|
||||
)?;
|
||||
let mut stmt = reader.prepare(&sql)?;
|
||||
|
||||
// Build the parameter list: [query, limit, ...type_values]
|
||||
let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
|
||||
param_values.push(Box::new(sanitized_query.clone()));
|
||||
param_values.push(Box::new(limit));
|
||||
for tv in &type_values {
|
||||
param_values.push(Box::new(tv.to_string()));
|
||||
}
|
||||
let params_ref: Vec<&dyn rusqlite::ToSql> =
|
||||
param_values.iter().map(|p| p.as_ref()).collect();
|
||||
|
||||
let results: Vec<(String, f32)> = stmt
|
||||
.query_map(params![sanitized_query, limit], |row| {
|
||||
.query_map(params_ref.as_slice(), |row| {
|
||||
Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)? as f32))
|
||||
})?
|
||||
.filter_map(|r| r.ok())
|
||||
|
|
@ -3836,4 +3928,135 @@ mod tests {
|
|||
// Static method should not panic even if no backups exist
|
||||
let _ = Storage::get_last_backup_timestamp();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keyword_search_with_include_types() {
|
||||
let storage = create_test_storage();
|
||||
|
||||
// Ingest nodes of different types all containing the word "quantum"
|
||||
storage.ingest(IngestInput {
|
||||
content: "Quantum mechanics is fundamental to physics".to_string(),
|
||||
node_type: "fact".to_string(),
|
||||
..Default::default()
|
||||
}).unwrap();
|
||||
storage.ingest(IngestInput {
|
||||
content: "Quantum computing uses qubits for calculation".to_string(),
|
||||
node_type: "concept".to_string(),
|
||||
..Default::default()
|
||||
}).unwrap();
|
||||
storage.ingest(IngestInput {
|
||||
content: "Quantum entanglement was demonstrated in the lab".to_string(),
|
||||
node_type: "event".to_string(),
|
||||
..Default::default()
|
||||
}).unwrap();
|
||||
|
||||
// Search with include_types = ["fact"] — should only return the fact
|
||||
let include = vec!["fact".to_string()];
|
||||
let results = storage.hybrid_search_filtered(
|
||||
"quantum", 10, 0.3, 0.7,
|
||||
Some(&include), None,
|
||||
).unwrap();
|
||||
|
||||
assert!(!results.is_empty(), "should return at least one result");
|
||||
for r in &results {
|
||||
assert_eq!(r.node.node_type, "fact",
|
||||
"include_types=[fact] should only return facts, got: {}", r.node.node_type);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keyword_search_with_exclude_types() {
|
||||
let storage = create_test_storage();
|
||||
|
||||
storage.ingest(IngestInput {
|
||||
content: "Photosynthesis converts sunlight to energy".to_string(),
|
||||
node_type: "fact".to_string(),
|
||||
..Default::default()
|
||||
}).unwrap();
|
||||
storage.ingest(IngestInput {
|
||||
content: "Photosynthesis is a complex biochemical process".to_string(),
|
||||
node_type: "reflection".to_string(),
|
||||
..Default::default()
|
||||
}).unwrap();
|
||||
|
||||
// Search with exclude_types = ["reflection"] — should skip the reflection
|
||||
let exclude = vec!["reflection".to_string()];
|
||||
let results = storage.hybrid_search_filtered(
|
||||
"photosynthesis", 10, 0.3, 0.7,
|
||||
None, Some(&exclude),
|
||||
).unwrap();
|
||||
|
||||
assert!(!results.is_empty(), "should return at least one result");
|
||||
for r in &results {
|
||||
assert_ne!(r.node.node_type, "reflection",
|
||||
"exclude_types=[reflection] should not return reflections");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_include_types_takes_precedence_over_exclude() {
|
||||
let storage = create_test_storage();
|
||||
|
||||
storage.ingest(IngestInput {
|
||||
content: "Gravity holds planets in orbit around stars".to_string(),
|
||||
node_type: "fact".to_string(),
|
||||
..Default::default()
|
||||
}).unwrap();
|
||||
storage.ingest(IngestInput {
|
||||
content: "Gravity waves were first detected by LIGO".to_string(),
|
||||
node_type: "event".to_string(),
|
||||
..Default::default()
|
||||
}).unwrap();
|
||||
|
||||
// When both are provided, include_types wins
|
||||
let include = vec!["fact".to_string()];
|
||||
let exclude = vec!["fact".to_string()];
|
||||
let results = storage.hybrid_search_filtered(
|
||||
"gravity", 10, 0.3, 0.7,
|
||||
Some(&include), Some(&exclude),
|
||||
).unwrap();
|
||||
|
||||
// include_types takes precedence — facts should be returned
|
||||
assert!(!results.is_empty());
|
||||
for r in &results {
|
||||
assert_eq!(r.node.node_type, "fact");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_type_filter_with_no_matches_returns_empty() {
|
||||
let storage = create_test_storage();
|
||||
|
||||
storage.ingest(IngestInput {
|
||||
content: "DNA carries genetic information in cells".to_string(),
|
||||
node_type: "fact".to_string(),
|
||||
..Default::default()
|
||||
}).unwrap();
|
||||
|
||||
// Search for a type that doesn't exist among matches
|
||||
let include = vec!["person".to_string()];
|
||||
let results = storage.hybrid_search_filtered(
|
||||
"DNA", 10, 0.3, 0.7,
|
||||
Some(&include), None,
|
||||
).unwrap();
|
||||
|
||||
assert!(results.is_empty(),
|
||||
"filtering for a non-matching type should return empty results");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_search_backward_compat() {
|
||||
// Ensure the original hybrid_search (no type filters) still works
|
||||
let storage = create_test_storage();
|
||||
|
||||
storage.ingest(IngestInput {
|
||||
content: "Neurons transmit electrical signals in the brain".to_string(),
|
||||
node_type: "fact".to_string(),
|
||||
..Default::default()
|
||||
}).unwrap();
|
||||
|
||||
let results = storage.hybrid_search("neurons", 10, 0.3, 0.7).unwrap();
|
||||
assert!(!results.is_empty());
|
||||
assert!(results[0].node.content.contains("Neurons"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,6 +66,16 @@ pub fn schema() -> Value {
|
|||
"items": { "type": "string" },
|
||||
"description": "Optional topics for context-dependent retrieval boosting"
|
||||
},
|
||||
"exclude_types": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Node types to exclude from results (e.g., ['reflection']). Reflections are excluded by default to prevent polluting factual queries."
|
||||
},
|
||||
"include_types": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "If set, only return nodes of these types. Overrides exclude_types."
|
||||
},
|
||||
"token_budget": {
|
||||
"type": "integer",
|
||||
"description": "Max tokens for response. Server truncates content to fit budget. Use memory(action='get') for full content of specific IDs. With 1M context models, budgets up to 100K are practical.",
|
||||
|
|
@ -96,6 +106,10 @@ struct SearchArgs {
|
|||
detail_level: Option<String>,
|
||||
#[serde(alias = "context_topics")]
|
||||
context_topics: Option<Vec<String>>,
|
||||
#[serde(alias = "exclude_types")]
|
||||
exclude_types: Option<Vec<String>>,
|
||||
#[serde(alias = "include_types")]
|
||||
include_types: Option<Vec<String>>,
|
||||
#[serde(alias = "token_budget")]
|
||||
token_budget: Option<i32>,
|
||||
#[serde(alias = "retrieval_mode")]
|
||||
|
|
@ -163,6 +177,39 @@ pub async fn execute(
|
|||
let keyword_weight = 0.3_f32;
|
||||
let semantic_weight = 0.7_f32;
|
||||
|
||||
// ====================================================================
|
||||
// STAGE 0: Keyword-first search (dedicated keyword-only pass)
|
||||
// ====================================================================
|
||||
// Run a small keyword-only search to guarantee strong keyword matches
|
||||
// survive into the candidate pool, even with small limits/overfetch.
|
||||
// Without this, exact keyword matches (e.g. unique proper nouns) get
|
||||
// buried by semantic scoring in the hybrid search.
|
||||
let keyword_first_limit = 10_i32;
|
||||
let keyword_priority_threshold: f32 = 0.8;
|
||||
|
||||
let keyword_first_results = storage
|
||||
.hybrid_search_filtered(
|
||||
&args.query,
|
||||
keyword_first_limit,
|
||||
1.0, // keyword_weight = 1.0 (keyword-only)
|
||||
0.0, // semantic_weight = 0.0
|
||||
args.include_types.as_deref(),
|
||||
args.exclude_types.as_deref(),
|
||||
)
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
// Collect keyword-priority results (keyword_score >= threshold)
|
||||
let mut keyword_priority_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||||
let mut keyword_priority_results: Vec<vestige_core::SearchResult> = Vec::new();
|
||||
for r in keyword_first_results {
|
||||
if r.keyword_score.unwrap_or(0.0) >= keyword_priority_threshold
|
||||
&& r.node.retention_strength >= min_retention
|
||||
{
|
||||
keyword_priority_ids.insert(r.node.id.clone());
|
||||
keyword_priority_results.push(r);
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// STAGE 1: Hybrid search with Nx over-fetch for reranking pool
|
||||
// ====================================================================
|
||||
|
|
@ -174,7 +221,14 @@ pub async fn execute(
|
|||
let overfetch_limit = (limit * overfetch_multiplier).min(100); // Cap at 100 to avoid excessive DB load
|
||||
|
||||
let results = storage
|
||||
.hybrid_search(&args.query, overfetch_limit, keyword_weight, semantic_weight)
|
||||
.hybrid_search_filtered(
|
||||
&args.query,
|
||||
overfetch_limit,
|
||||
keyword_weight,
|
||||
semantic_weight,
|
||||
args.include_types.as_deref(),
|
||||
args.exclude_types.as_deref(),
|
||||
)
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
// Filter by min_retention and min_similarity first (cheap filters)
|
||||
|
|
@ -193,25 +247,87 @@ pub async fn execute(
|
|||
})
|
||||
.collect();
|
||||
|
||||
// ====================================================================
|
||||
// Dedup: merge Stage 0 keyword-priority results into Stage 1 results
|
||||
// ====================================================================
|
||||
for kp in &keyword_priority_results {
|
||||
if let Some(existing) = filtered_results.iter_mut().find(|r| r.node.id == kp.node.id) {
|
||||
// Preserve keyword_score from Stage 0 (keyword-only search is authoritative)
|
||||
if kp.keyword_score.unwrap_or(0.0) > existing.keyword_score.unwrap_or(0.0) {
|
||||
existing.keyword_score = kp.keyword_score;
|
||||
}
|
||||
if kp.combined_score > existing.combined_score {
|
||||
existing.combined_score = kp.combined_score;
|
||||
}
|
||||
} else {
|
||||
// New result from Stage 0 not in Stage 1 — add it
|
||||
filtered_results.push(kp.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// STAGE 2: Reranker (BM25-like rescoring, trim to requested limit)
|
||||
// ====================================================================
|
||||
if let Ok(mut cog) = cognitive.try_lock() {
|
||||
let candidates: Vec<_> = filtered_results
|
||||
.iter()
|
||||
.map(|r| (r.clone(), r.node.content.clone()))
|
||||
.collect();
|
||||
// Keyword bypass: results with strong keyword matches (>= 0.8) skip the
|
||||
// cross-encoder entirely and are placed above reranked results. This
|
||||
// prevents the cross-encoder from burying exact/near-exact keyword hits
|
||||
// (e.g. unique proper nouns) beneath semantically-similar but unrelated
|
||||
// results.
|
||||
{
|
||||
let keyword_bypass_threshold: f32 = 0.8;
|
||||
let limit_usize = limit as usize;
|
||||
|
||||
if let Ok(reranked) = cog.reranker.rerank(&args.query, candidates, Some(limit as usize)) {
|
||||
// Replace filtered_results with reranked items (preserves original SearchResult)
|
||||
filtered_results = reranked.into_iter().map(|rr| rr.item).collect();
|
||||
} else {
|
||||
// Reranker failed — fall back to original order, just truncate
|
||||
filtered_results.truncate(limit as usize);
|
||||
// Partition: keyword bypass vs. candidates for reranking
|
||||
let mut bypass_results: Vec<vestige_core::SearchResult> = Vec::new();
|
||||
let mut rerank_candidates: Vec<(vestige_core::SearchResult, String)> = Vec::new();
|
||||
|
||||
for r in filtered_results.iter() {
|
||||
if r.keyword_score.unwrap_or(0.0) >= keyword_bypass_threshold {
|
||||
bypass_results.push(r.clone());
|
||||
} else {
|
||||
rerank_candidates.push((r.clone(), r.node.content.clone()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Couldn't acquire cognitive lock — truncate to limit
|
||||
filtered_results.truncate(limit as usize);
|
||||
|
||||
// Boost bypass results so they survive later pipeline stages
|
||||
// (temporal, FSRS, utility, competition) and the final re-sort.
|
||||
for r in bypass_results.iter_mut() {
|
||||
r.combined_score *= 2.0;
|
||||
}
|
||||
|
||||
bypass_results.sort_by(|a, b| {
|
||||
b.combined_score
|
||||
.partial_cmp(&a.combined_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Rerank the remaining candidates
|
||||
let reranked_results: Vec<vestige_core::SearchResult> = if rerank_candidates.is_empty() {
|
||||
Vec::new()
|
||||
} else if let Ok(mut cog) = cognitive.try_lock() {
|
||||
if let Ok(reranked) = cog.reranker.rerank(&args.query, rerank_candidates, Some(limit_usize)) {
|
||||
reranked.into_iter().map(|rr| rr.item).collect()
|
||||
} else {
|
||||
// Reranker failed — fall back to original order for non-bypass candidates
|
||||
filtered_results
|
||||
.iter()
|
||||
.filter(|r| r.keyword_score.unwrap_or(0.0) < keyword_bypass_threshold)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
} else {
|
||||
// Couldn't acquire cognitive lock — use original order
|
||||
filtered_results
|
||||
.iter()
|
||||
.filter(|r| r.keyword_score.unwrap_or(0.0) < keyword_bypass_threshold)
|
||||
.cloned()
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Merge: bypass first, then reranked, trim to limit
|
||||
filtered_results = bypass_results;
|
||||
filtered_results.extend(reranked_results);
|
||||
filtered_results.truncate(limit_usize);
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue