refactor: update memory management tools to enforce character limits, enhance pinned section handling, and improve user feedback in MemoryContent and TeamMemoryManager components

This commit is contained in:
Anish Sarkar 2026-04-09 17:24:31 +05:30
parent ab3cb0e1c5
commit a335f7621a
9 changed files with 324 additions and 52 deletions

View file

@ -282,6 +282,7 @@ async def create_surfsense_deep_agent(
"available_connectors": available_connectors,
"available_document_types": available_document_types,
"max_input_tokens": _max_input_tokens,
"llm": llm,
}
# Disable Notion action tools if no Notion connector is configured

View file

@ -117,7 +117,7 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
if display_name:
first_name = display_name.split()[0]
seed = f"## About the user\n- Name: {first_name}"
seed = f"## About the user (pinned)\n- Name: {first_name}"
await session.execute(
User.__table__.update()
.where(User.id == self.user_id)

View file

@ -276,11 +276,16 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
Include inline dates (YYYY-MM) on entries where temporal context matters (facts that
may change, decisions, context). Skip dates on timeless preferences and instructions.
- Keep it concise and well under the character limit shown in <user_memory>.
- Organize using markdown sections as appropriate (suggested but not required):
## About the user — name, role, background, company (with date if it may change)
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
## About the user (pinned) — name, role, background, company (with date if it may change)
## Preferences — languages, tools, frameworks, response style
## Instructions — standing instructions, things to always/never do
## Instructions (pinned) — standing instructions, things to always/never do
## Current context — ongoing projects, goals, deadlines (with date)
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
- Each time-sensitive entry MUST include a (YYYY-MM) date suffix.
- Sections with `(pinned)` in the heading are protected the system will reject any
update that removes them. Users can add `(pinned)` to any `##` heading to protect it.
- During consolidation, prioritize keeping: pinned sections > preferences > current context.
""",
"shared": """
- update_memory: Update the team's shared memory document for this search space.
@ -300,11 +305,16 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
Treat every update as a curation pass consolidate, don't just append.
Include inline dates (YYYY-MM) on decisions and time-sensitive entries.
- Keep it concise and well under the character limit shown in <team_memory>.
- Organize using markdown sections as appropriate (suggested but not required):
## Team decisions — agreed-upon choices with rationale and date
## Conventions — coding standards, tools, processes, naming patterns
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
## Team decisions (pinned) — agreed-upon choices with rationale and date
## Conventions (pinned) — coding standards, tools, processes, naming patterns
## Key facts — where things are, how things work, team structure
## Current priorities — active projects, deadlines, blockers
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
- Each time-sensitive entry MUST include a (YYYY-MM) date suffix.
- Sections with `(pinned)` in the heading are protected the system will reject any
update that removes them. Users can add `(pinned)` to any `##` heading to protect it.
- During consolidation, prioritize keeping: pinned sections > key facts > current priorities.
""",
},
}
@ -312,25 +322,25 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
_MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = {
"update_memory": {
"private": """
- <user_memory persisted="false"> contains "## About the user\\n- Name: Alex"
- <user_memory persisted="false"> contains "## About the user (pinned)\\n- Name: Alex"
User: "I'm a university student, explain astrophage to me"
- Memory is not yet persisted AND the user casually shared that they are a student.
You MUST call update_memory to persist the seed plus the new fact:
update_memory(updated_memory="## About the user\\n- Name: Alex\\n- University student\\n")
update_memory(updated_memory="## About the user (pinned)\\n- Name: Alex\\n- University student\\n")
- User: "Remember that I prefer TypeScript over JavaScript"
- Timeless preference, no date needed. You see the current <user_memory> and merge:
update_memory(updated_memory="## About the user\\n- Senior developer\\n\\n## Preferences\\n- Prefers TypeScript over JavaScript\\n...")
update_memory(updated_memory="## About the user (pinned)\\n- Senior developer\\n\\n## Preferences\\n- Prefers TypeScript over JavaScript\\n...")
- User: "I actually moved to Google last month"
- Fact that changes over time, include date:
update_memory(updated_memory="## About the user\\n- Senior developer at Google (since 2026-03, previously Acme Corp)\\n...")
update_memory(updated_memory="## About the user (pinned)\\n- Senior developer at Google (since 2026-03, previously Acme Corp)\\n...")
- User: "I'm building a SaaS app with Next.js and Supabase"
- Implicit project info shared as context. Save it:
update_memory(updated_memory="## About the user\\n- Name: Alex\\n\\n## Current context\\n- Building a SaaS app with Next.js and Supabase (2026-04)\\n")
update_memory(updated_memory="## About the user (pinned)\\n- Name: Alex\\n\\n## Current context\\n- Building a SaaS app with Next.js and Supabase (2026-04)\\n")
""",
"shared": """
- User: "Let's remember that we decided to use GraphQL"
- Decision with date:
update_memory(updated_memory="## Team decisions\\n- 2026-04: Adopted GraphQL over REST for new APIs\\n...")
update_memory(updated_memory="## Team decisions (pinned)\\n- 2026-04: Adopted GraphQL over REST for new APIs\\n...")
- User: "Our deploy process uses Railway auto-deploys"
- Key fact, no date needed:
update_memory(updated_memory="## Key facts\\n- Deploy pipeline: git push -> Railway auto-deploys in ~3min\\n...")

View file

@ -219,14 +219,16 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
create_update_team_memory_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
llm=deps.get("llm"),
)
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
else create_update_memory_tool(
user_id=deps["user_id"],
db_session=deps["db_session"],
llm=deps.get("llm"),
)
),
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
requires=["user_id", "search_space_id", "db_session", "thread_visibility", "llm"],
),
# =========================================================================
# LINEAR TOOLS - create, update, delete issues

View file

@ -4,14 +4,25 @@ Replaces the old row-per-fact save_memory / recall_memory tools with a single
update_memory tool that overwrites a freeform markdown TEXT column. The LLM
always sees the current memory in <user_memory> / <team_memory> tags injected
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
Overflow handling:
- Soft limit (15K chars): advisory warning returned alongside a successful save.
- Hard limit (25K chars): save rejected; an automatic LLM-driven consolidation
is attempted before falling back to the error.
- Pinned sections: headings containing ``(pinned)`` are protected the system
rejects any update that drops them and auto-restores them during consolidation.
- Diff validation: warns when entire ``##`` sections are dropped or when the
document shrinks by more than 60%.
"""
from __future__ import annotations
import logging
import re
from typing import Any
from uuid import UUID
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@ -20,9 +31,120 @@ from app.db import SearchSpace, User
logger = logging.getLogger(__name__)
MEMORY_SOFT_LIMIT = 20_000
MEMORY_SOFT_LIMIT = 15_000
MEMORY_HARD_LIMIT = 25_000
_PINNED_RE = re.compile(r"^##\s+.+\(pinned\)", re.MULTILINE)
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
# ---------------------------------------------------------------------------
# Pinned-section helpers
# ---------------------------------------------------------------------------
def _extract_pinned_headings(memory: str) -> set[str]:
"""Return the set of ``## …`` headings that contain ``(pinned)``."""
return set(_PINNED_RE.findall(memory))
def _extract_section_map(memory: str) -> dict[str, str]:
"""Split *memory* into ``{heading_text: full_section_content}``."""
sections: dict[str, str] = {}
parts = _SECTION_HEADING_RE.split(memory)
# parts: [preamble, heading1, body1, heading2, body2, …]
for i in range(1, len(parts) - 1, 2):
heading = parts[i].strip()
body = parts[i + 1]
sections[heading] = f"## {heading}\n{body}"
return sections
def _validate_pinned_preserved(
old_memory: str | None, new_memory: str
) -> str | None:
"""Return an error message if pinned headings from *old_memory* are missing
in *new_memory*, else ``None``."""
if not old_memory:
return None
old_pinned = _extract_pinned_headings(old_memory)
if not old_pinned:
return None
new_pinned = _extract_pinned_headings(new_memory)
dropped = old_pinned - new_pinned
if dropped:
names = ", ".join(sorted(dropped))
return (
f"Cannot remove pinned sections: {names}. "
"These sections are protected and must be preserved. "
"Re-include them and call update_memory again."
)
return None
def _restore_missing_pinned(
old_memory: str, consolidated: str
) -> str:
"""Prepend any pinned sections from *old_memory* that are absent in
*consolidated*."""
old_pinned = _extract_pinned_headings(old_memory)
if not old_pinned:
return consolidated
new_pinned = _extract_pinned_headings(consolidated)
dropped = old_pinned - new_pinned
if not dropped:
return consolidated
old_sections = _extract_section_map(old_memory)
restored_parts: list[str] = []
for heading in sorted(dropped):
raw_heading = heading.removeprefix("## ").strip()
if raw_heading in old_sections:
restored_parts.append(old_sections[raw_heading].rstrip())
if restored_parts:
return "\n\n".join(restored_parts) + "\n\n" + consolidated
return consolidated
# ---------------------------------------------------------------------------
# Diff validation
# ---------------------------------------------------------------------------
def _extract_headings(memory: str) -> set[str]:
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
return set(_SECTION_HEADING_RE.findall(memory))
def _validate_diff(
old_memory: str | None, new_memory: str
) -> list[str]:
"""Return a list of warning strings about suspicious changes."""
if not old_memory:
return []
warnings: list[str] = []
old_headings = _extract_headings(old_memory)
new_headings = _extract_headings(new_memory)
dropped = old_headings - new_headings
if dropped:
names = ", ".join(sorted(dropped))
warnings.append(
f"Sections removed: {names}. "
"If unintentional, the user can restore from the settings page."
)
old_len = len(old_memory)
new_len = len(new_memory)
if old_len > 0 and new_len < old_len * 0.4:
warnings.append(
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
"Possible data loss."
)
return warnings
# ---------------------------------------------------------------------------
# Size validation & soft warning
# ---------------------------------------------------------------------------
def _validate_memory_size(content: str) -> dict[str, Any] | None:
"""Return an error/warning dict if *content* is too large, else None."""
@ -52,9 +174,153 @@ def _soft_warning(content: str) -> str | None:
return None
# ---------------------------------------------------------------------------
# Auto-consolidation via a separate LLM call
# ---------------------------------------------------------------------------
_CONSOLIDATION_PROMPT = """\
You are a memory curator. The following memory document exceeds the character \
limit and must be shortened.
RULES:
1. Rewrite the document to be under {target} characters.
2. Sections whose headings contain "(pinned)" MUST be preserved EXACTLY as-is \
do not modify, shorten, or remove them.
3. Only consolidate non-pinned content.
4. Priority for keeping content: pinned sections > identity/instructions > \
preferences > current context.
5. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
6. Each entry must be a single bullet point.
7. Preserve (YYYY-MM) date suffixes on time-sensitive entries.
8. Output ONLY the consolidated markdown no explanations, no wrapping.
<memory_document>
{content}
</memory_document>"""
async def _auto_consolidate(
content: str, llm: Any
) -> str | None:
"""Use a focused LLM call to consolidate *content* under the soft limit.
Returns the consolidated string, or ``None`` if consolidation fails.
"""
try:
prompt = _CONSOLIDATION_PROMPT.format(
target=MEMORY_SOFT_LIMIT, content=content
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
)
return text.strip()
except Exception:
logger.exception("Auto-consolidation LLM call failed")
return None
# ---------------------------------------------------------------------------
# Shared save-and-respond logic
# ---------------------------------------------------------------------------
async def _save_memory(
*,
updated_memory: str,
old_memory: str | None,
llm: Any | None,
apply_fn,
commit_fn,
rollback_fn,
label: str,
) -> dict[str, Any]:
"""Validate, optionally auto-consolidate, save, and return a response dict.
Parameters
----------
updated_memory : str
The new document the agent submitted.
old_memory : str | None
The previously persisted document (for diff / pinned checks).
llm : Any | None
LLM instance for auto-consolidation (may be ``None``).
apply_fn : callable(str) -> None
Callback that sets the new memory on the ORM object.
commit_fn : coroutine
``session.commit``.
rollback_fn : coroutine
``session.rollback``.
label : str
Human label for log messages (e.g. "user memory", "team memory").
"""
content = updated_memory
# --- pinned-section gate (before any size check) ---
pinned_err = _validate_pinned_preserved(old_memory, content)
if pinned_err:
return {"status": "error", "message": pinned_err}
# --- hard-limit gate with auto-consolidation fallback ---
size_err = _validate_memory_size(content)
if size_err:
if llm is None:
return size_err
consolidated = await _auto_consolidate(content, llm)
if consolidated is None:
return size_err
# Restore any pinned sections the consolidation LLM may have dropped
if old_memory:
consolidated = _restore_missing_pinned(old_memory, consolidated)
recheck = _validate_memory_size(consolidated)
if recheck:
return recheck
content = consolidated
# --- persist ---
try:
apply_fn(content)
await commit_fn()
except Exception as e:
logger.exception("Failed to update %s: %s", label, e)
await rollback_fn()
return {"status": "error", "message": f"Failed to update {label}: {e}"}
# --- build response ---
resp: dict[str, Any] = {"status": "saved", "message": f"{label.capitalize()} updated."}
if content is not updated_memory:
resp["notice"] = (
"Memory was automatically consolidated to fit within limits."
)
diff_warnings = _validate_diff(old_memory, content)
if diff_warnings:
resp["diff_warnings"] = diff_warnings
warning = _soft_warning(content)
if warning:
resp["warning"] = warning
return resp
# ---------------------------------------------------------------------------
# Tool factories
# ---------------------------------------------------------------------------
def create_update_memory_tool(
user_id: str | UUID,
db_session: AsyncSession,
llm: Any | None = None,
):
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@ -71,10 +337,6 @@ def create_update_memory_tool(
Args:
updated_memory: The FULL updated markdown document (not a diff).
"""
error = _validate_memory_size(updated_memory)
if error:
return error
try:
result = await db_session.execute(
select(User).where(User.id == uid)
@ -83,17 +345,17 @@ def create_update_memory_tool(
if not user:
return {"status": "error", "message": "User not found."}
user.memory_md = updated_memory
await db_session.commit()
old_memory = user.memory_md
resp: dict[str, Any] = {
"status": "saved",
"message": "Memory updated.",
}
warning = _soft_warning(updated_memory)
if warning:
resp["warning"] = warning
return resp
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
)
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
await db_session.rollback()
@ -108,6 +370,7 @@ def create_update_memory_tool(
def create_update_team_memory_tool(
search_space_id: int,
db_session: AsyncSession,
llm: Any | None = None,
):
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
@ -123,10 +386,6 @@ def create_update_team_memory_tool(
Args:
updated_memory: The FULL updated markdown document (not a diff).
"""
error = _validate_memory_size(updated_memory)
if error:
return error
try:
result = await db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
@ -135,17 +394,17 @@ def create_update_team_memory_tool(
if not space:
return {"status": "error", "message": "Search space not found."}
space.shared_memory_md = updated_memory
await db_session.commit()
old_memory = space.shared_memory_md
resp: dict[str, Any] = {
"status": "saved",
"message": "Team memory updated.",
}
warning = _soft_warning(updated_memory)
if warning:
resp["warning"] = warning
return resp
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
)
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
await db_session.rollback()

View file

@ -81,8 +81,8 @@ export function MemoryContent() {
const getCounterColor = () => {
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
if (charCount > 20_000) return "text-orange-500";
if (charCount > 15_000) return "text-yellow-500";
if (charCount > 15_000) return "text-orange-500";
if (charCount > 10_000) return "text-yellow-500";
return "text-muted-foreground";
};
@ -119,7 +119,7 @@ export function MemoryContent() {
<div className="flex items-center justify-between">
<span className={`text-xs ${getCounterColor()}`}>
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()} characters
{charCount > 20_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
{isOverLimit && " - Exceeds limit"}
</span>
</div>

View file

@ -85,8 +85,8 @@ export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
const getCounterColor = () => {
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
if (charCount > 20_000) return "text-orange-500";
if (charCount > 15_000) return "text-yellow-500";
if (charCount > 15_000) return "text-orange-500";
if (charCount > 10_000) return "text-yellow-500";
return "text-muted-foreground";
};
@ -123,7 +123,7 @@ export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
<div className="flex items-center justify-between">
<span className={`text-xs ${getCounterColor()}`}>
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()} characters
{charCount > 20_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
{isOverLimit && " - Exceeds limit"}
</span>
</div>

View file

@ -54,8 +54,8 @@ const editorVariants = cva(
cn(
"group/editor",
"relative w-full cursor-text select-text overflow-x-hidden whitespace-pre-wrap break-words",
"rounded-md ring-offset-background focus-visible:outline-none",
"**:data-slate-placeholder:!top-1/2 **:data-slate-placeholder:-translate-y-1/2 placeholder:text-muted-foreground/80 **:data-slate-placeholder:text-muted-foreground/80 **:data-slate-placeholder:opacity-100!",
"rounded-none ring-offset-background focus-visible:outline-none",
"placeholder:text-muted-foreground/80 **:data-slate-placeholder:text-muted-foreground/80 **:data-slate-placeholder:py-1",
"[&_strong]:font-bold"
),
{

View file

@ -6,7 +6,7 @@ import type { PlateElementProps } from "platejs/react";
import { PlateElement } from "platejs/react";
import * as React from "react";
const headingVariants = cva("relative mb-1", {
const headingVariants = cva("relative mb-1 first:mt-0", {
variants: {
variant: {
h1: "mt-[1.6em] pb-1 font-bold font-heading text-4xl",