mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-30 19:36:25 +02:00
feat: implement background memory extraction and editing capabilities for user and team memory management, enhancing long-term memory persistence and user interaction
This commit is contained in:
parent
cd72fa9a48
commit
84fc72e596
9 changed files with 534 additions and 224 deletions
|
|
@ -6,12 +6,10 @@ always sees the current memory in <user_memory> / <team_memory> tags injected
|
|||
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
|
||||
|
||||
Overflow handling:
|
||||
- Soft limit (18K chars): an automatic LLM-driven consolidation is attempted
|
||||
to proactively keep memory lean. The save always succeeds.
|
||||
- Hard limit (25K chars): save rejected if memory still exceeds this after
|
||||
consolidation.
|
||||
- Pinned sections: headings containing ``(pinned)`` are protected — the system
|
||||
rejects any update that drops them and auto-restores them during consolidation.
|
||||
- Soft limit (18K chars): a warning is returned telling the agent to
|
||||
consolidate on the next update.
|
||||
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
|
||||
If it still exceeds the limit after rewriting, the save is rejected.
|
||||
- Diff validation: warns when entire ``##`` sections are dropped or when the
|
||||
document shrinks by more than 60%.
|
||||
"""
|
||||
|
|
@ -35,74 +33,9 @@ logger = logging.getLogger(__name__)
|
|||
MEMORY_SOFT_LIMIT = 18_000
|
||||
MEMORY_HARD_LIMIT = 25_000
|
||||
|
||||
_PINNED_RE = re.compile(r"^##\s+.+\(pinned\)", re.MULTILINE)
|
||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pinned-section helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_pinned_headings(memory: str) -> set[str]:
|
||||
"""Return the set of ``## …`` headings that contain ``(pinned)``."""
|
||||
return set(_PINNED_RE.findall(memory))
|
||||
|
||||
|
||||
def _extract_section_map(memory: str) -> dict[str, str]:
|
||||
"""Split *memory* into ``{heading_text: full_section_content}``."""
|
||||
sections: dict[str, str] = {}
|
||||
parts = _SECTION_HEADING_RE.split(memory)
|
||||
# parts: [preamble, heading1, body1, heading2, body2, …]
|
||||
for i in range(1, len(parts) - 1, 2):
|
||||
heading = parts[i].strip()
|
||||
body = parts[i + 1]
|
||||
sections[heading] = f"## {heading}\n{body}"
|
||||
return sections
|
||||
|
||||
|
||||
def _validate_pinned_preserved(old_memory: str | None, new_memory: str) -> str | None:
|
||||
"""Return an error message if pinned headings from *old_memory* are missing
|
||||
in *new_memory*, else ``None``."""
|
||||
if not old_memory:
|
||||
return None
|
||||
old_pinned = _extract_pinned_headings(old_memory)
|
||||
if not old_pinned:
|
||||
return None
|
||||
new_pinned = _extract_pinned_headings(new_memory)
|
||||
dropped = old_pinned - new_pinned
|
||||
if dropped:
|
||||
names = ", ".join(sorted(dropped))
|
||||
return (
|
||||
f"Cannot remove pinned sections: {names}. "
|
||||
"These sections are protected and must be preserved. "
|
||||
"Re-include them and call update_memory again."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _restore_missing_pinned(old_memory: str, consolidated: str) -> str:
|
||||
"""Prepend any pinned sections from *old_memory* that are absent in
|
||||
*consolidated*."""
|
||||
old_pinned = _extract_pinned_headings(old_memory)
|
||||
if not old_pinned:
|
||||
return consolidated
|
||||
new_pinned = _extract_pinned_headings(consolidated)
|
||||
dropped = old_pinned - new_pinned
|
||||
if not dropped:
|
||||
return consolidated
|
||||
|
||||
old_sections = _extract_section_map(old_memory)
|
||||
restored_parts: list[str] = []
|
||||
for heading in sorted(dropped):
|
||||
raw_heading = heading.removeprefix("## ").strip()
|
||||
if raw_heading in old_sections:
|
||||
restored_parts.append(old_sections[raw_heading].rstrip())
|
||||
if restored_parts:
|
||||
return "\n\n".join(restored_parts) + "\n\n" + consolidated
|
||||
return consolidated
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diff validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -173,37 +106,35 @@ def _soft_warning(content: str) -> str | None:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-consolidation via a separate LLM call
|
||||
# Forced rewrite when memory exceeds the hard limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONSOLIDATION_PROMPT = """\
|
||||
_FORCED_REWRITE_PROMPT = """\
|
||||
You are a memory curator. The following memory document exceeds the character \
|
||||
limit and must be shortened.
|
||||
|
||||
RULES:
|
||||
1. Rewrite the document to be under {target} characters.
|
||||
2. Sections whose headings contain "(pinned)" MUST be preserved EXACTLY as-is \
|
||||
— do not modify, shorten, or remove them.
|
||||
3. Only consolidate non-pinned content.
|
||||
4. Priority for keeping content: pinned sections > identity/instructions > \
|
||||
preferences > current context.
|
||||
5. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
||||
6. Each entry must be a single bullet point.
|
||||
7. Every bullet MUST keep its (YYYY-MM-DD) date prefix.
|
||||
8. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||
2. Preserve all ## section headings.
|
||||
3. Priority for keeping content: identity/instructions > preferences > \
|
||||
current context.
|
||||
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
||||
5. Each entry must be a single bullet point.
|
||||
6. Every bullet MUST keep its (YYYY-MM-DD) date prefix.
|
||||
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||
|
||||
<memory_document>
|
||||
{content}
|
||||
</memory_document>"""
|
||||
|
||||
|
||||
async def _auto_consolidate(content: str, llm: Any) -> str | None:
|
||||
"""Use a focused LLM call to consolidate *content* under the soft limit.
|
||||
async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
||||
"""Use a focused LLM call to compress *content* under the hard limit.
|
||||
|
||||
Returns the consolidated string, or ``None`` if consolidation fails.
|
||||
Returns the rewritten string, or ``None`` if the call fails.
|
||||
"""
|
||||
try:
|
||||
prompt = _CONSOLIDATION_PROMPT.format(target=MEMORY_SOFT_LIMIT, content=content)
|
||||
prompt = _FORCED_REWRITE_PROMPT.format(target=MEMORY_HARD_LIMIT, content=content)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
|
|
@ -215,7 +146,7 @@ async def _auto_consolidate(content: str, llm: Any) -> str | None:
|
|||
)
|
||||
return text.strip()
|
||||
except Exception:
|
||||
logger.exception("Auto-consolidation LLM call failed")
|
||||
logger.exception("Forced rewrite LLM call failed")
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -234,16 +165,17 @@ async def _save_memory(
|
|||
rollback_fn,
|
||||
label: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Validate, optionally auto-consolidate, save, and return a response dict.
|
||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||
return a response dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
updated_memory : str
|
||||
The new document the agent submitted.
|
||||
old_memory : str | None
|
||||
The previously persisted document (for diff / pinned checks).
|
||||
The previously persisted document (for diff checks).
|
||||
llm : Any | None
|
||||
LLM instance for auto-consolidation (may be ``None``).
|
||||
LLM instance for forced rewrite (may be ``None``).
|
||||
apply_fn : callable(str) -> None
|
||||
Callback that sets the new memory on the ORM object.
|
||||
commit_fn : coroutine
|
||||
|
|
@ -255,21 +187,13 @@ async def _save_memory(
|
|||
"""
|
||||
content = updated_memory
|
||||
|
||||
# --- pinned-section gate (before any size check) ---
|
||||
pinned_err = _validate_pinned_preserved(old_memory, content)
|
||||
if pinned_err:
|
||||
return {"status": "error", "message": pinned_err}
|
||||
# --- forced rewrite if over the hard limit ---
|
||||
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
|
||||
rewritten = await _forced_rewrite(content, llm)
|
||||
if rewritten is not None and len(rewritten) < len(content):
|
||||
content = rewritten
|
||||
|
||||
# --- auto-consolidate proactively at the soft limit ---
|
||||
if len(content) > MEMORY_SOFT_LIMIT and llm is not None:
|
||||
consolidated = await _auto_consolidate(content, llm)
|
||||
if consolidated is not None:
|
||||
if old_memory:
|
||||
consolidated = _restore_missing_pinned(old_memory, consolidated)
|
||||
if len(consolidated) < len(content):
|
||||
content = consolidated
|
||||
|
||||
# --- hard-limit gate (reject if still too large after consolidation) ---
|
||||
# --- hard-limit gate (reject if still too large after rewrite) ---
|
||||
size_err = _validate_memory_size(content)
|
||||
if size_err:
|
||||
return size_err
|
||||
|
|
@ -290,7 +214,7 @@ async def _save_memory(
|
|||
}
|
||||
|
||||
if content is not updated_memory:
|
||||
resp["notice"] = "Memory was automatically consolidated to fit within limits."
|
||||
resp["notice"] = "Memory was automatically rewritten to fit within limits."
|
||||
|
||||
diff_warnings = _validate_diff(old_memory, content)
|
||||
if diff_warnings:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue