feat: implement background memory extraction and editing capabilities for user and team memory management, enhancing long-term memory persistence and user interaction

This commit is contained in:
Anish Sarkar 2026-04-10 00:21:55 +05:30
parent cd72fa9a48
commit 84fc72e596
9 changed files with 534 additions and 224 deletions

View file

@ -6,12 +6,10 @@ always sees the current memory in <user_memory> / <team_memory> tags injected
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
Overflow handling:
- Soft limit (18K chars): an automatic LLM-driven consolidation is attempted
to proactively keep memory lean. The save always succeeds.
- Hard limit (25K chars): save rejected if memory still exceeds this after
consolidation.
- Pinned sections: headings containing ``(pinned)`` are protected the system
rejects any update that drops them and auto-restores them during consolidation.
- Soft limit (18K chars): a warning is returned telling the agent to
consolidate on the next update.
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
If it still exceeds the limit after rewriting, the save is rejected.
- Diff validation: warns when entire ``##`` sections are dropped or when the
document shrinks by more than 60%.
"""
@ -35,74 +33,9 @@ logger = logging.getLogger(__name__)
MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_PINNED_RE = re.compile(r"^##\s+.+\(pinned\)", re.MULTILINE)
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
# ---------------------------------------------------------------------------
# Pinned-section helpers
# ---------------------------------------------------------------------------
def _extract_pinned_headings(memory: str) -> set[str]:
"""Return the set of ``## …`` headings that contain ``(pinned)``."""
return set(_PINNED_RE.findall(memory))
def _extract_section_map(memory: str) -> dict[str, str]:
"""Split *memory* into ``{heading_text: full_section_content}``."""
sections: dict[str, str] = {}
parts = _SECTION_HEADING_RE.split(memory)
# parts: [preamble, heading1, body1, heading2, body2, …]
for i in range(1, len(parts) - 1, 2):
heading = parts[i].strip()
body = parts[i + 1]
sections[heading] = f"## {heading}\n{body}"
return sections
def _validate_pinned_preserved(old_memory: str | None, new_memory: str) -> str | None:
"""Return an error message if pinned headings from *old_memory* are missing
in *new_memory*, else ``None``."""
if not old_memory:
return None
old_pinned = _extract_pinned_headings(old_memory)
if not old_pinned:
return None
new_pinned = _extract_pinned_headings(new_memory)
dropped = old_pinned - new_pinned
if dropped:
names = ", ".join(sorted(dropped))
return (
f"Cannot remove pinned sections: {names}. "
"These sections are protected and must be preserved. "
"Re-include them and call update_memory again."
)
return None
def _restore_missing_pinned(old_memory: str, consolidated: str) -> str:
"""Prepend any pinned sections from *old_memory* that are absent in
*consolidated*."""
old_pinned = _extract_pinned_headings(old_memory)
if not old_pinned:
return consolidated
new_pinned = _extract_pinned_headings(consolidated)
dropped = old_pinned - new_pinned
if not dropped:
return consolidated
old_sections = _extract_section_map(old_memory)
restored_parts: list[str] = []
for heading in sorted(dropped):
raw_heading = heading.removeprefix("## ").strip()
if raw_heading in old_sections:
restored_parts.append(old_sections[raw_heading].rstrip())
if restored_parts:
return "\n\n".join(restored_parts) + "\n\n" + consolidated
return consolidated
# ---------------------------------------------------------------------------
# Diff validation
# ---------------------------------------------------------------------------
@ -173,37 +106,35 @@ def _soft_warning(content: str) -> str | None:
# ---------------------------------------------------------------------------
# Auto-consolidation via a separate LLM call
# Forced rewrite when memory exceeds the hard limit
# ---------------------------------------------------------------------------
_CONSOLIDATION_PROMPT = """\
_FORCED_REWRITE_PROMPT = """\
You are a memory curator. The following memory document exceeds the character \
limit and must be shortened.
RULES:
1. Rewrite the document to be under {target} characters.
2. Sections whose headings contain "(pinned)" MUST be preserved EXACTLY as-is \
do not modify, shorten, or remove them.
3. Only consolidate non-pinned content.
4. Priority for keeping content: pinned sections > identity/instructions > \
preferences > current context.
5. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
6. Each entry must be a single bullet point.
7. Every bullet MUST keep its (YYYY-MM-DD) date prefix.
8. Output ONLY the consolidated markdown no explanations, no wrapping.
2. Preserve all ## section headings.
3. Priority for keeping content: identity/instructions > preferences > \
current context.
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
5. Each entry must be a single bullet point.
6. Every bullet MUST keep its (YYYY-MM-DD) date prefix.
7. Output ONLY the consolidated markdown no explanations, no wrapping.
<memory_document>
{content}
</memory_document>"""
async def _auto_consolidate(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to consolidate *content* under the soft limit.
async def _forced_rewrite(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to compress *content* under the hard limit.
Returns the consolidated string, or ``None`` if consolidation fails.
Returns the rewritten string, or ``None`` if the call fails.
"""
try:
prompt = _CONSOLIDATION_PROMPT.format(target=MEMORY_SOFT_LIMIT, content=content)
prompt = _FORCED_REWRITE_PROMPT.format(target=MEMORY_HARD_LIMIT, content=content)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
@ -215,7 +146,7 @@ async def _auto_consolidate(content: str, llm: Any) -> str | None:
)
return text.strip()
except Exception:
logger.exception("Auto-consolidation LLM call failed")
logger.exception("Forced rewrite LLM call failed")
return None
@ -234,16 +165,17 @@ async def _save_memory(
rollback_fn,
label: str,
) -> dict[str, Any]:
"""Validate, optionally auto-consolidate, save, and return a response dict.
"""Validate, optionally force-rewrite if over the hard limit, save, and
return a response dict.
Parameters
----------
updated_memory : str
The new document the agent submitted.
old_memory : str | None
The previously persisted document (for diff / pinned checks).
The previously persisted document (for diff checks).
llm : Any | None
LLM instance for auto-consolidation (may be ``None``).
LLM instance for forced rewrite (may be ``None``).
apply_fn : callable(str) -> None
Callback that sets the new memory on the ORM object.
commit_fn : coroutine
@ -255,21 +187,13 @@ async def _save_memory(
"""
content = updated_memory
# --- pinned-section gate (before any size check) ---
pinned_err = _validate_pinned_preserved(old_memory, content)
if pinned_err:
return {"status": "error", "message": pinned_err}
# --- forced rewrite if over the hard limit ---
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
rewritten = await _forced_rewrite(content, llm)
if rewritten is not None and len(rewritten) < len(content):
content = rewritten
# --- auto-consolidate proactively at the soft limit ---
if len(content) > MEMORY_SOFT_LIMIT and llm is not None:
consolidated = await _auto_consolidate(content, llm)
if consolidated is not None:
if old_memory:
consolidated = _restore_missing_pinned(old_memory, consolidated)
if len(consolidated) < len(content):
content = consolidated
# --- hard-limit gate (reject if still too large after consolidation) ---
# --- hard-limit gate (reject if still too large after rewrite) ---
size_err = _validate_memory_size(content)
if size_err:
return size_err
@ -290,7 +214,7 @@ async def _save_memory(
}
if content is not updated_memory:
resp["notice"] = "Memory was automatically consolidated to fit within limits."
resp["notice"] = "Memory was automatically rewritten to fit within limits."
diff_warnings = _validate_diff(old_memory, content)
if diff_warnings: