mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-30 01:35:14 +02:00
Phase 3: Memory
This commit is contained in:
parent
91b646cf66
commit
63e285c815
1 changed files with 282 additions and 0 deletions
282
trustgraph-flow/trustgraph/agent/confidence/memory.py
Normal file
282
trustgraph-flow/trustgraph/agent/confidence/memory.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
"""
|
||||
Memory Manager Module
|
||||
|
||||
Handles inter-step data flow and context preservation for the confidence agent.
|
||||
Manages execution context, result caching, and dependency resolution.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Set
|
||||
from .types import ContextEntry, ExecutionStep, StepResult
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Manages execution context and inter-step data flow.
|
||||
|
||||
Responsibilities:
|
||||
- Store and retrieve execution context between steps
|
||||
- Manage step dependencies and result passing
|
||||
- Handle context window management
|
||||
- Provide result caching with TTL
|
||||
"""
|
||||
|
||||
def __init__(self, max_context_size: int = 8192, cache_ttl_seconds: int = 300):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.max_context_size = max_context_size
|
||||
self.cache_ttl_seconds = cache_ttl_seconds
|
||||
|
||||
# In-memory storage for Phase 1 (could be Redis/external in Phase 2)
|
||||
self._context: Dict[str, ContextEntry] = {}
|
||||
self._step_results: Dict[str, StepResult] = {}
|
||||
self._dependency_graph: Dict[str, Set[str]] = {} # step_id -> dependent_step_ids
|
||||
|
||||
def store_context(self, key: str, value: Any, step_id: str, ttl_seconds: Optional[int] = None) -> None:
|
||||
"""
|
||||
Store a context entry.
|
||||
|
||||
Args:
|
||||
key: Context key
|
||||
value: Value to store
|
||||
step_id: ID of step that created this entry
|
||||
ttl_seconds: Time to live (defaults to cache_ttl_seconds)
|
||||
"""
|
||||
ttl = ttl_seconds or self.cache_ttl_seconds
|
||||
|
||||
entry = ContextEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
step_id=step_id,
|
||||
timestamp=int(time.time()),
|
||||
ttl_seconds=ttl
|
||||
)
|
||||
|
||||
self._context[key] = entry
|
||||
self.logger.debug(f"Stored context key '{key}' from step '{step_id}'")
|
||||
|
||||
# Clean up expired entries
|
||||
self._cleanup_expired()
|
||||
|
||||
# Manage context size
|
||||
self._manage_context_size()
|
||||
|
||||
def get_context(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieve a context value.
|
||||
|
||||
Args:
|
||||
key: Context key to retrieve
|
||||
|
||||
Returns:
|
||||
Context value or None if not found/expired
|
||||
"""
|
||||
entry = self._context.get(key)
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
# Check if expired
|
||||
if self._is_expired(entry):
|
||||
del self._context[key]
|
||||
return None
|
||||
|
||||
return entry.value
|
||||
|
||||
def get_context_for_step(self, step: ExecutionStep) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all relevant context for a step based on its dependencies.
|
||||
|
||||
Args:
|
||||
step: Execution step needing context
|
||||
|
||||
Returns:
|
||||
Dictionary of relevant context entries
|
||||
"""
|
||||
context = {}
|
||||
|
||||
# Include results from dependency steps
|
||||
for dep_step_id in step.dependencies:
|
||||
result = self._step_results.get(dep_step_id)
|
||||
if result and result.success:
|
||||
context[f"step_{dep_step_id}_output"] = result.output
|
||||
context[f"step_{dep_step_id}_confidence"] = result.confidence.score
|
||||
|
||||
# Include global context entries (filter by relevance if needed)
|
||||
for key, entry in self._context.items():
|
||||
if not self._is_expired(entry):
|
||||
context[key] = entry.value
|
||||
|
||||
self.logger.debug(f"Retrieved context for step '{step.id}': {len(context)} entries")
|
||||
return context
|
||||
|
||||
def store_step_result(self, result: StepResult) -> None:
|
||||
"""
|
||||
Store result from step execution.
|
||||
|
||||
Args:
|
||||
result: Step execution result
|
||||
"""
|
||||
self._step_results[result.step_id] = result
|
||||
|
||||
# Store key outputs in context for easy access
|
||||
if result.success:
|
||||
self.store_context(
|
||||
key=f"result_{result.step_id}",
|
||||
value=result.output,
|
||||
step_id=result.step_id
|
||||
)
|
||||
|
||||
self.logger.debug(f"Stored result for step '{result.step_id}' (success: {result.success})")
|
||||
|
||||
def get_step_result(self, step_id: str) -> Optional[StepResult]:
|
||||
"""
|
||||
Get stored result for a step.
|
||||
|
||||
Args:
|
||||
step_id: ID of step
|
||||
|
||||
Returns:
|
||||
StepResult or None if not found
|
||||
"""
|
||||
return self._step_results.get(step_id)
|
||||
|
||||
def register_dependency(self, step_id: str, depends_on: str) -> None:
|
||||
"""
|
||||
Register a dependency relationship between steps.
|
||||
|
||||
Args:
|
||||
step_id: Step that depends on another
|
||||
depends_on: Step that must complete first
|
||||
"""
|
||||
if depends_on not in self._dependency_graph:
|
||||
self._dependency_graph[depends_on] = set()
|
||||
|
||||
self._dependency_graph[depends_on].add(step_id)
|
||||
|
||||
def get_ready_steps(self, all_steps: List[ExecutionStep], completed_steps: Set[str]) -> List[ExecutionStep]:
|
||||
"""
|
||||
Get steps that are ready to execute (all dependencies completed).
|
||||
|
||||
Args:
|
||||
all_steps: All steps in the plan
|
||||
completed_steps: Set of completed step IDs
|
||||
|
||||
Returns:
|
||||
List of steps ready for execution
|
||||
"""
|
||||
ready = []
|
||||
|
||||
for step in all_steps:
|
||||
if step.id in completed_steps:
|
||||
continue
|
||||
|
||||
# Check if all dependencies are completed
|
||||
deps_completed = all(dep_id in completed_steps for dep_id in step.dependencies)
|
||||
|
||||
if deps_completed:
|
||||
ready.append(step)
|
||||
|
||||
return ready
|
||||
|
||||
def serialize_context(self) -> str:
|
||||
"""
|
||||
Serialize current context for debugging/audit.
|
||||
|
||||
Returns:
|
||||
JSON string of current context
|
||||
"""
|
||||
serializable = {}
|
||||
|
||||
for key, entry in self._context.items():
|
||||
if not self._is_expired(entry):
|
||||
# Convert complex objects to strings for serialization
|
||||
try:
|
||||
value = entry.value
|
||||
if not isinstance(value, (str, int, float, bool, list, dict)):
|
||||
value = str(value)
|
||||
|
||||
serializable[key] = {
|
||||
"value": value,
|
||||
"step_id": entry.step_id,
|
||||
"timestamp": entry.timestamp
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not serialize context key '{key}': {e}")
|
||||
|
||||
return json.dumps(serializable, indent=2)
|
||||
|
||||
def clear_context(self) -> None:
|
||||
"""Clear all stored context (for cleanup between requests)."""
|
||||
self._context.clear()
|
||||
self._step_results.clear()
|
||||
self._dependency_graph.clear()
|
||||
self.logger.debug("Cleared all context")
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get memory usage statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with usage statistics
|
||||
"""
|
||||
return {
|
||||
"context_entries": len(self._context),
|
||||
"step_results": len(self._step_results),
|
||||
"dependencies": sum(len(deps) for deps in self._dependency_graph.values()),
|
||||
"estimated_size_bytes": self._estimate_memory_size()
|
||||
}
|
||||
|
||||
def _is_expired(self, entry: ContextEntry) -> bool:
|
||||
"""Check if a context entry has expired."""
|
||||
if entry.ttl_seconds is None:
|
||||
return False
|
||||
|
||||
age_seconds = int(time.time()) - entry.timestamp
|
||||
return age_seconds > entry.ttl_seconds
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired context entries."""
|
||||
expired_keys = [
|
||||
key for key, entry in self._context.items()
|
||||
if self._is_expired(entry)
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self._context[key]
|
||||
|
||||
if expired_keys:
|
||||
self.logger.debug(f"Cleaned up {len(expired_keys)} expired context entries")
|
||||
|
||||
def _manage_context_size(self) -> None:
|
||||
"""Manage context size by removing oldest entries if needed."""
|
||||
current_size = self._estimate_memory_size()
|
||||
|
||||
if current_size > self.max_context_size:
|
||||
# Remove oldest entries first
|
||||
sorted_entries = sorted(
|
||||
self._context.items(),
|
||||
key=lambda x: x[1].timestamp
|
||||
)
|
||||
|
||||
removed_count = 0
|
||||
for key, entry in sorted_entries:
|
||||
del self._context[key]
|
||||
removed_count += 1
|
||||
|
||||
# Check size again
|
||||
if self._estimate_memory_size() <= self.max_context_size * 0.8:
|
||||
break
|
||||
|
||||
self.logger.debug(f"Removed {removed_count} context entries to manage size")
|
||||
|
||||
def _estimate_memory_size(self) -> int:
|
||||
"""Rough estimate of memory usage in bytes."""
|
||||
total = 0
|
||||
|
||||
for key, entry in self._context.items():
|
||||
total += len(key) * 2 # Unicode chars
|
||||
total += len(str(entry.value)) * 2 # Rough estimate
|
||||
total += 100 # Overhead
|
||||
|
||||
return total
|
||||
Loading…
Add table
Add a link
Reference in a new issue