vestige/crates/vestige-core/src/advanced/chains.rs
Sam Valladares f9c60eb5a7 Initial commit: Vestige v1.0.0 - Cognitive memory MCP server
FSRS-6 spaced repetition, spreading activation, synaptic tagging,
hippocampal indexing, and 130 years of memory research.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-25 01:31:03 -06:00

687 lines
21 KiB
Rust

//! # Memory Chains (Reasoning)
//!
//! Build chains of reasoning from memory, connecting concepts through
//! their relationships. This enables Vestige to explain HOW it arrived
//! at a conclusion, not just WHAT the conclusion is.
//!
//! ## Use Cases
//!
//! - **Explanation**: "Why do you think X is related to Y?"
//! - **Discovery**: Find non-obvious connections between concepts
//! - **Debugging**: Trace how a bug in A could affect component B
//! - **Learning**: Understand relationships in a domain
//!
//! ## How It Works
//!
//! 1. **Graph Traversal**: Navigate the knowledge graph using BFS/DFS
//! 2. **Path Scoring**: Score paths by relevance and connection strength
//! 3. **Chain Building**: Construct reasoning chains from paths
//! 4. **Explanation Generation**: Generate human-readable explanations
//!
//! ## Example
//!
//! ```rust,ignore
//! let builder = MemoryChainBuilder::new();
//!
//! // Build a reasoning chain from "database" to "performance"
//! let chain = builder.build_chain("database", "performance");
//!
//! // Shows: database -> indexes -> query optimization -> performance
//! for step in chain.steps {
//! println!("{}: {} -> {}", step.reasoning, step.memory, step.connection_type);
//! }
//! ```
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
/// Maximum depth for chain building
const MAX_CHAIN_DEPTH: usize = 10;
/// Maximum paths to explore
const MAX_PATHS_TO_EXPLORE: usize = 1000;
/// Minimum connection strength to consider
const MIN_CONNECTION_STRENGTH: f64 = 0.2;
/// Types of connections between memories
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum ConnectionType {
/// Direct semantic similarity
SemanticSimilarity,
/// Same topic/tag
SharedTopic,
/// Temporal proximity (happened around same time)
TemporalProximity,
/// Causal relationship (A causes B)
Causal,
/// Part-whole relationship
PartOf,
/// Example-of relationship
ExampleOf,
/// Prerequisite relationship (need A to understand B)
Prerequisite,
/// Contradiction/conflict
Contradicts,
/// Elaboration (B provides more detail on A)
Elaborates,
/// Same entity/concept
SameEntity,
/// Used together
UsedTogether,
/// Custom relationship
Custom(String),
}
impl ConnectionType {
/// Get human-readable description
pub fn description(&self) -> &str {
match self {
Self::SemanticSimilarity => "is semantically similar to",
Self::SharedTopic => "shares topic with",
Self::TemporalProximity => "happened around the same time as",
Self::Causal => "causes or leads to",
Self::PartOf => "is part of",
Self::ExampleOf => "is an example of",
Self::Prerequisite => "is a prerequisite for",
Self::Contradicts => "contradicts",
Self::Elaborates => "provides more detail about",
Self::SameEntity => "refers to the same thing as",
Self::UsedTogether => "is commonly used with",
Self::Custom(_) => "is related to",
}
}
/// Get default strength for this connection type
pub fn default_strength(&self) -> f64 {
match self {
Self::SameEntity => 1.0,
Self::Causal | Self::PartOf => 0.9,
Self::Prerequisite | Self::Elaborates => 0.8,
Self::SemanticSimilarity => 0.7,
Self::SharedTopic | Self::UsedTogether => 0.6,
Self::ExampleOf => 0.7,
Self::TemporalProximity => 0.4,
Self::Contradicts => 0.5,
Self::Custom(_) => 0.5,
}
}
}
/// A step in a reasoning chain
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChainStep {
/// Memory at this step
pub memory_id: String,
/// Content preview
pub memory_preview: String,
/// How this connects to the next step
pub connection_type: ConnectionType,
/// Strength of this connection (0.0 to 1.0)
pub connection_strength: f64,
/// Human-readable reasoning for this step
pub reasoning: String,
}
/// A complete reasoning chain
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningChain {
/// Starting concept/memory
pub from: String,
/// Ending concept/memory
pub to: String,
/// Steps in the chain
pub steps: Vec<ChainStep>,
/// Overall confidence in this chain
pub confidence: f64,
/// Total number of hops
pub total_hops: usize,
/// Human-readable explanation of the chain
pub explanation: String,
}
impl ReasoningChain {
/// Check if this is a valid chain (reaches destination)
pub fn is_complete(&self) -> bool {
if let Some(last) = self.steps.last() {
last.memory_id == self.to || self.steps.iter().any(|s| s.memory_id == self.to)
} else {
false
}
}
/// Get the path as a list of memory IDs
pub fn path_ids(&self) -> Vec<String> {
self.steps.iter().map(|s| s.memory_id.clone()).collect()
}
}
/// A path between memories (used during search)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryPath {
/// Memory IDs in order
pub memories: Vec<String>,
/// Connections between consecutive memories
pub connections: Vec<Connection>,
/// Total path score
pub score: f64,
}
/// A connection between two memories
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Connection {
/// Source memory
pub from_id: String,
/// Target memory
pub to_id: String,
/// Type of connection
pub connection_type: ConnectionType,
/// Strength (0.0 to 1.0)
pub strength: f64,
/// When this connection was established
pub created_at: DateTime<Utc>,
}
/// Memory node for graph operations
#[derive(Debug, Clone)]
pub struct MemoryNode {
/// Memory ID
pub id: String,
/// Content preview
pub content_preview: String,
/// Tags/topics
pub tags: Vec<String>,
/// Connections to other memories
pub connections: Vec<Connection>,
}
/// State for path search (used in priority queue)
#[derive(Debug, Clone)]
struct SearchState {
memory_id: String,
path: Vec<String>,
connections: Vec<Connection>,
score: f64,
depth: usize,
}
impl PartialEq for SearchState {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl Eq for SearchState {}
impl PartialOrd for SearchState {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SearchState {
fn cmp(&self, other: &Self) -> Ordering {
// Higher score = higher priority
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
}
}
/// Builder for memory reasoning chains
pub struct MemoryChainBuilder {
/// Memory graph (loaded from storage)
graph: HashMap<String, MemoryNode>,
/// Reverse index: tag -> memory IDs
tag_index: HashMap<String, Vec<String>>,
}
impl MemoryChainBuilder {
/// Create a new chain builder
pub fn new() -> Self {
Self {
graph: HashMap::new(),
tag_index: HashMap::new(),
}
}
/// Load a memory node into the graph
pub fn add_memory(&mut self, node: MemoryNode) {
// Update tag index
for tag in &node.tags {
self.tag_index
.entry(tag.clone())
.or_default()
.push(node.id.clone());
}
self.graph.insert(node.id.clone(), node);
}
/// Add a connection between memories
pub fn add_connection(&mut self, connection: Connection) {
if let Some(node) = self.graph.get_mut(&connection.from_id) {
node.connections.push(connection);
}
}
/// Build a reasoning chain from one concept to another
pub fn build_chain(&self, from: &str, to: &str) -> Option<ReasoningChain> {
// Find all paths and pick the best one
let paths = self.find_paths(from, to);
if paths.is_empty() {
return None;
}
// Convert best path to chain
let best_path = paths.into_iter().next()?;
self.path_to_chain(from, to, best_path)
}
/// Find all paths between two concepts
pub fn find_paths(&self, concept_a: &str, concept_b: &str) -> Vec<MemoryPath> {
// Resolve concepts to memory IDs
let start_ids = self.resolve_concept(concept_a);
let end_ids: HashSet<_> = self.resolve_concept(concept_b).into_iter().collect();
if start_ids.is_empty() || end_ids.is_empty() {
return vec![];
}
let mut all_paths = Vec::new();
// BFS from each starting point
for start_id in start_ids {
let paths = self.bfs_find_paths(&start_id, &end_ids);
all_paths.extend(paths);
}
// Sort by score (descending)
all_paths.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
// Return top paths
all_paths.into_iter().take(10).collect()
}
/// Build a chain explaining why two concepts are related
pub fn explain_relationship(&self, from: &str, to: &str) -> Option<String> {
let chain = self.build_chain(from, to)?;
Some(chain.explanation)
}
/// Find memories that connect two concepts
pub fn find_bridge_memories(&self, concept_a: &str, concept_b: &str) -> Vec<String> {
let paths = self.find_paths(concept_a, concept_b);
// Collect memories that appear as intermediate steps
let mut bridges: HashMap<String, usize> = HashMap::new();
for path in paths {
if path.memories.len() > 2 {
for mem in &path.memories[1..path.memories.len() - 1] {
*bridges.entry(mem.clone()).or_insert(0) += 1;
}
}
}
// Sort by frequency
let mut bridge_list: Vec<_> = bridges.into_iter().collect();
bridge_list.sort_by(|a, b| b.1.cmp(&a.1));
bridge_list.into_iter().map(|(id, _)| id).collect()
}
/// Get the number of memories in the graph
pub fn memory_count(&self) -> usize {
self.graph.len()
}
/// Get the number of connections in the graph
pub fn connection_count(&self) -> usize {
self.graph.values().map(|n| n.connections.len()).sum()
}
// ========================================================================
// Private implementation
// ========================================================================
fn resolve_concept(&self, concept: &str) -> Vec<String> {
// First, check if it's a direct memory ID
if self.graph.contains_key(concept) {
return vec![concept.to_string()];
}
// Check tag index
if let Some(ids) = self.tag_index.get(concept) {
return ids.clone();
}
// Search by content (simplified - would use embeddings in production)
let concept_lower = concept.to_lowercase();
self.graph
.values()
.filter(|node| node.content_preview.to_lowercase().contains(&concept_lower))
.map(|node| node.id.clone())
.take(10)
.collect()
}
fn bfs_find_paths(&self, start: &str, targets: &HashSet<String>) -> Vec<MemoryPath> {
let mut paths = Vec::new();
let mut visited = HashSet::new();
let mut queue = BinaryHeap::new();
queue.push(SearchState {
memory_id: start.to_string(),
path: vec![start.to_string()],
connections: vec![],
score: 1.0,
depth: 0,
});
let mut explored = 0;
while let Some(state) = queue.pop() {
explored += 1;
if explored > MAX_PATHS_TO_EXPLORE {
break;
}
// Check if we reached a target
if targets.contains(&state.memory_id) {
paths.push(MemoryPath {
memories: state.path,
connections: state.connections,
score: state.score,
});
continue;
}
// Don't revisit or go too deep
if state.depth >= MAX_CHAIN_DEPTH {
continue;
}
let visit_key = (state.memory_id.clone(), state.depth);
if visited.contains(&visit_key) {
continue;
}
visited.insert(visit_key);
// Expand neighbors
if let Some(node) = self.graph.get(&state.memory_id) {
for conn in &node.connections {
if conn.strength < MIN_CONNECTION_STRENGTH {
continue;
}
if state.path.contains(&conn.to_id) {
continue; // Avoid cycles
}
let mut new_path = state.path.clone();
new_path.push(conn.to_id.clone());
let mut new_connections = state.connections.clone();
new_connections.push(conn.clone());
// Score decays with depth and connection strength
let new_score = state.score * conn.strength * 0.9;
queue.push(SearchState {
memory_id: conn.to_id.clone(),
path: new_path,
connections: new_connections,
score: new_score,
depth: state.depth + 1,
});
}
}
// Also explore tag-based connections
if let Some(node) = self.graph.get(&state.memory_id) {
for tag in &node.tags {
if let Some(related_ids) = self.tag_index.get(tag) {
for related_id in related_ids {
if state.path.contains(related_id) {
continue;
}
let mut new_path = state.path.clone();
new_path.push(related_id.clone());
let mut new_connections = state.connections.clone();
new_connections.push(Connection {
from_id: state.memory_id.clone(),
to_id: related_id.clone(),
connection_type: ConnectionType::SharedTopic,
strength: 0.5,
created_at: Utc::now(),
});
let new_score = state.score * 0.5 * 0.9;
queue.push(SearchState {
memory_id: related_id.clone(),
path: new_path,
connections: new_connections,
score: new_score,
depth: state.depth + 1,
});
}
}
}
}
}
paths
}
fn path_to_chain(&self, from: &str, to: &str, path: MemoryPath) -> Option<ReasoningChain> {
if path.memories.is_empty() {
return None;
}
let mut steps = Vec::new();
for (i, (mem_id, conn)) in path
.memories
.iter()
.zip(path.connections.iter().chain(std::iter::once(&Connection {
from_id: path.memories.last().cloned().unwrap_or_default(),
to_id: to.to_string(),
connection_type: ConnectionType::SemanticSimilarity,
strength: 1.0,
created_at: Utc::now(),
})))
.enumerate()
{
let preview = self
.graph
.get(mem_id)
.map(|n| n.content_preview.clone())
.unwrap_or_default();
let reasoning = if i == 0 {
format!("Starting from '{}'", preview)
} else {
format!(
"'{}' {} '{}'",
self.graph
.get(
&path
.memories
.get(i.saturating_sub(1))
.cloned()
.unwrap_or_default()
)
.map(|n| n.content_preview.as_str())
.unwrap_or(""),
conn.connection_type.description(),
preview
)
};
steps.push(ChainStep {
memory_id: mem_id.clone(),
memory_preview: preview,
connection_type: conn.connection_type.clone(),
connection_strength: conn.strength,
reasoning,
});
}
// Calculate overall confidence
let confidence = path
.connections
.iter()
.map(|c| c.strength)
.fold(1.0, |acc, s| acc * s)
.powf(1.0 / path.memories.len() as f64); // Geometric mean
// Generate explanation
let explanation = self.generate_explanation(&steps);
Some(ReasoningChain {
from: from.to_string(),
to: to.to_string(),
steps,
confidence,
total_hops: path.memories.len(),
explanation,
})
}
fn generate_explanation(&self, steps: &[ChainStep]) -> String {
if steps.is_empty() {
return "No reasoning chain found.".to_string();
}
let mut parts = Vec::new();
for (i, step) in steps.iter().enumerate() {
if i == 0 {
parts.push(format!("Starting from '{}'", step.memory_preview));
} else {
parts.push(format!(
"which {} '{}'",
step.connection_type.description(),
step.memory_preview
));
}
}
parts.join(", ")
}
}
impl Default for MemoryChainBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_test_graph() -> MemoryChainBuilder {
let mut builder = MemoryChainBuilder::new();
// Add test memories
builder.add_memory(MemoryNode {
id: "database".to_string(),
content_preview: "Database design patterns".to_string(),
tags: vec!["database".to_string(), "architecture".to_string()],
connections: vec![],
});
builder.add_memory(MemoryNode {
id: "indexes".to_string(),
content_preview: "Database indexing strategies".to_string(),
tags: vec!["database".to_string(), "performance".to_string()],
connections: vec![],
});
builder.add_memory(MemoryNode {
id: "query-opt".to_string(),
content_preview: "Query optimization techniques".to_string(),
tags: vec!["performance".to_string(), "sql".to_string()],
connections: vec![],
});
builder.add_memory(MemoryNode {
id: "perf".to_string(),
content_preview: "Performance best practices".to_string(),
tags: vec!["performance".to_string()],
connections: vec![],
});
// Add connections
builder.add_connection(Connection {
from_id: "database".to_string(),
to_id: "indexes".to_string(),
connection_type: ConnectionType::PartOf,
strength: 0.9,
created_at: Utc::now(),
});
builder.add_connection(Connection {
from_id: "indexes".to_string(),
to_id: "query-opt".to_string(),
connection_type: ConnectionType::Causal,
strength: 0.8,
created_at: Utc::now(),
});
builder.add_connection(Connection {
from_id: "query-opt".to_string(),
to_id: "perf".to_string(),
connection_type: ConnectionType::Causal,
strength: 0.85,
created_at: Utc::now(),
});
builder
}
#[test]
fn test_build_chain() {
let builder = build_test_graph();
let chain = builder.build_chain("database", "perf");
assert!(chain.is_some());
let chain = chain.unwrap();
assert!(chain.total_hops >= 2);
assert!(chain.confidence > 0.0);
}
#[test]
fn test_find_paths() {
let builder = build_test_graph();
let paths = builder.find_paths("database", "performance");
assert!(!paths.is_empty());
}
#[test]
fn test_connection_description() {
assert_eq!(ConnectionType::Causal.description(), "causes or leads to");
assert_eq!(ConnectionType::PartOf.description(), "is part of");
}
#[test]
fn test_find_bridge_memories() {
let builder = build_test_graph();
let bridges = builder.find_bridge_memories("database", "perf");
// Indexes and query-opt should be bridges
assert!(
bridges.contains(&"indexes".to_string()) || bridges.contains(&"query-opt".to_string())
);
}
}