mirror of
https://github.com/samvallad33/vestige.git
synced 2026-06-06 20:15:15 +02:00
229 lines
7.5 KiB
Rust
229 lines
7.5 KiB
Rust
|
|
//! HyDE-inspired Query Expansion
|
||
|
|
//!
|
||
|
|
//! Implements a local-first version of Hypothetical Document Embeddings (HyDE).
|
||
|
|
//! Instead of requiring an LLM to generate hypothetical answers, we use
|
||
|
|
//! template-based query expansion to create multiple embedding targets
|
||
|
|
//! and average them for improved semantic search.
|
||
|
|
//!
|
||
|
|
//! This gives ~60% of full HyDE quality improvement with zero latency overhead.
|
||
|
|
//!
|
||
|
|
//! ## How it works
|
||
|
|
//!
|
||
|
|
//! 1. Analyze query intent (question, concept, lookup)
|
||
|
|
//! 2. Generate 3-5 expanded query variants using templates
|
||
|
|
//! 3. Embed all variants
|
||
|
|
//! 4. Average the embeddings (centroid)
|
||
|
|
//! 5. Use the centroid for vector search
|
||
|
|
//!
|
||
|
|
//! The centroid embedding captures a broader semantic space than the raw query,
|
||
|
|
//! improving recall for conceptual and question-style queries.
|
||
|
|
|
||
|
|
/// Query intent classification
|
||
|
|
#[derive(Debug, Clone, PartialEq)]
|
||
|
|
pub enum QueryIntent {
|
||
|
|
/// "What is X?" / "Explain X"
|
||
|
|
Definition,
|
||
|
|
/// "How to X?" / "Steps to X"
|
||
|
|
HowTo,
|
||
|
|
/// "Why does X?" / "Reason for X"
|
||
|
|
Reasoning,
|
||
|
|
/// "When did X?" / temporal queries
|
||
|
|
Temporal,
|
||
|
|
/// "Find X" / "X related to Y"
|
||
|
|
Lookup,
|
||
|
|
/// Code or technical terms
|
||
|
|
Technical,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Classify query intent from the raw query string
|
||
|
|
pub fn classify_intent(query: &str) -> QueryIntent {
|
||
|
|
let lower = query.to_lowercase();
|
||
|
|
let words: Vec<&str> = lower.split_whitespace().collect();
|
||
|
|
|
||
|
|
if lower.contains("how to") || lower.starts_with("how do") || lower.starts_with("steps") {
|
||
|
|
return QueryIntent::HowTo;
|
||
|
|
}
|
||
|
|
if lower.starts_with("what is") || lower.starts_with("what are")
|
||
|
|
|| lower.starts_with("define") || lower.starts_with("explain")
|
||
|
|
{
|
||
|
|
return QueryIntent::Definition;
|
||
|
|
}
|
||
|
|
if lower.starts_with("why") || lower.contains("reason") || lower.contains("because") {
|
||
|
|
return QueryIntent::Reasoning;
|
||
|
|
}
|
||
|
|
if lower.starts_with("when") || lower.contains("date") || lower.contains("timeline") {
|
||
|
|
return QueryIntent::Temporal;
|
||
|
|
}
|
||
|
|
if query.contains('(') || query.contains('{') || query.contains("fn ")
|
||
|
|
|| query.contains("class ") || query.contains("::")
|
||
|
|
{
|
||
|
|
return QueryIntent::Technical;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Default: multi-word = lookup, short = technical
|
||
|
|
if words.len() >= 2 {
|
||
|
|
QueryIntent::Lookup
|
||
|
|
} else {
|
||
|
|
QueryIntent::Technical
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Generate expanded query variants based on intent
|
||
|
|
///
|
||
|
|
/// Returns 3-5 variants that capture different semantic aspects of the query.
|
||
|
|
/// These are designed to create a broader embedding space when averaged.
|
||
|
|
pub fn expand_query(query: &str) -> Vec<String> {
|
||
|
|
let intent = classify_intent(query);
|
||
|
|
let clean = query.trim().trim_end_matches('?').trim_end_matches('.');
|
||
|
|
let mut variants = vec![query.to_string()];
|
||
|
|
|
||
|
|
match intent {
|
||
|
|
QueryIntent::Definition => {
|
||
|
|
variants.push(format!("{clean} is a concept that involves"));
|
||
|
|
variants.push(format!("The definition of {clean} in the context of"));
|
||
|
|
variants.push(format!("{clean} refers to a type of"));
|
||
|
|
}
|
||
|
|
QueryIntent::HowTo => {
|
||
|
|
variants.push(format!("The steps to {clean} are as follows"));
|
||
|
|
variants.push(format!("To accomplish {clean}, you need to"));
|
||
|
|
variants.push(format!("A guide for {clean} including"));
|
||
|
|
}
|
||
|
|
QueryIntent::Reasoning => {
|
||
|
|
variants.push(format!("The reason {clean} is because"));
|
||
|
|
variants.push(format!("{clean} happens due to the following factors"));
|
||
|
|
variants.push(format!("The explanation for {clean} involves"));
|
||
|
|
}
|
||
|
|
QueryIntent::Temporal => {
|
||
|
|
variants.push(format!("{clean} occurred at a specific time"));
|
||
|
|
variants.push(format!("The timeline of {clean} shows"));
|
||
|
|
variants.push(format!("Events related to {clean} in chronological order"));
|
||
|
|
}
|
||
|
|
QueryIntent::Lookup => {
|
||
|
|
variants.push(format!("Information about {clean} including details"));
|
||
|
|
variants.push(format!("{clean} is related to the following topics"));
|
||
|
|
variants.push(format!("Key facts about {clean}"));
|
||
|
|
}
|
||
|
|
QueryIntent::Technical => {
|
||
|
|
// For technical queries, keep it close to the original
|
||
|
|
variants.push(format!("{clean} implementation details"));
|
||
|
|
variants.push(format!("Code pattern for {clean}"));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
variants
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Average multiple embedding vectors to create a centroid
|
||
|
|
///
|
||
|
|
/// The centroid captures the "semantic center" of all expanded queries,
|
||
|
|
/// providing a broader search target than any single query embedding.
|
||
|
|
pub fn centroid_embedding(embeddings: &[Vec<f32>]) -> Vec<f32> {
|
||
|
|
if embeddings.is_empty() {
|
||
|
|
return vec![];
|
||
|
|
}
|
||
|
|
|
||
|
|
let dim = embeddings[0].len();
|
||
|
|
let count = embeddings.len() as f32;
|
||
|
|
let mut centroid = vec![0.0f32; dim];
|
||
|
|
|
||
|
|
for emb in embeddings {
|
||
|
|
for (i, val) in emb.iter().enumerate() {
|
||
|
|
if i < dim {
|
||
|
|
centroid[i] += val;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Average
|
||
|
|
for val in &mut centroid {
|
||
|
|
*val /= count;
|
||
|
|
}
|
||
|
|
|
||
|
|
// L2 normalize
|
||
|
|
let norm = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||
|
|
if norm > 0.0 {
|
||
|
|
for val in &mut centroid {
|
||
|
|
*val /= norm;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
centroid
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// TESTS
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_classify_definition() {
|
||
|
|
assert_eq!(classify_intent("What is FSRS?"), QueryIntent::Definition);
|
||
|
|
assert_eq!(classify_intent("explain spaced repetition"), QueryIntent::Definition);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_classify_howto() {
|
||
|
|
assert_eq!(classify_intent("how to configure embeddings"), QueryIntent::HowTo);
|
||
|
|
assert_eq!(classify_intent("How do I search memories?"), QueryIntent::HowTo);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_classify_reasoning() {
|
||
|
|
assert_eq!(classify_intent("why does retention decay?"), QueryIntent::Reasoning);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_classify_temporal() {
|
||
|
|
assert_eq!(classify_intent("when did the last consolidation run"), QueryIntent::Temporal);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_classify_technical() {
|
||
|
|
assert_eq!(classify_intent("fn main()"), QueryIntent::Technical);
|
||
|
|
assert_eq!(classify_intent("std::sync::Arc"), QueryIntent::Technical);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_classify_lookup() {
|
||
|
|
assert_eq!(classify_intent("vestige memory system"), QueryIntent::Lookup);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_expand_query_produces_variants() {
|
||
|
|
let variants = expand_query("What is FSRS?");
|
||
|
|
assert!(variants.len() >= 3);
|
||
|
|
assert_eq!(variants[0], "What is FSRS?");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_centroid_embedding() {
|
||
|
|
let embeddings = vec![
|
||
|
|
vec![1.0, 0.0, 0.0],
|
||
|
|
vec![0.0, 1.0, 0.0],
|
||
|
|
];
|
||
|
|
let centroid = centroid_embedding(&embeddings);
|
||
|
|
assert_eq!(centroid.len(), 3);
|
||
|
|
// Should be normalized
|
||
|
|
let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||
|
|
assert!((norm - 1.0).abs() < 0.01);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_centroid_empty() {
|
||
|
|
let centroid = centroid_embedding(&[]);
|
||
|
|
assert!(centroid.is_empty());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_centroid_single() {
|
||
|
|
let embeddings = vec![vec![0.6, 0.8]];
|
||
|
|
let centroid = centroid_embedding(&embeddings);
|
||
|
|
// Should be normalized version of [0.6, 0.8]
|
||
|
|
assert!((centroid[0] - 0.6).abs() < 0.01);
|
||
|
|
assert!((centroid[1] - 0.8).abs() < 0.01);
|
||
|
|
}
|
||
|
|
}
|