From ceedd02353c74f063188fc103500fb848909a810 Mon Sep 17 00:00:00 2001
From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com>
Date: Wed, 20 May 2026 02:01:36 +0530
Subject: [PATCH] refactor: extract shared memory service
---
.../builtins/memory/tools/update_memory.py | 351 ++-------------
.../app/agents/new_chat/memory_extraction.py | 196 +--------
.../agents/new_chat/tools/update_memory.py | 414 ++----------------
.../app/services/memory/__init__.py | 29 ++
.../app/services/memory/prompts.py | 110 +++++
.../app/services/memory/rewrite.py | 35 ++
.../app/services/memory/schemas.py | 23 +
.../app/services/memory/service.py | 300 +++++++++++++
.../app/services/memory/validation.py | 158 +++++++
.../unit/services/test_memory_service.py | 204 +++++++++
10 files changed, 946 insertions(+), 874 deletions(-)
create mode 100644 surfsense_backend/app/services/memory/__init__.py
create mode 100644 surfsense_backend/app/services/memory/prompts.py
create mode 100644 surfsense_backend/app/services/memory/rewrite.py
create mode 100644 surfsense_backend/app/services/memory/schemas.py
create mode 100644 surfsense_backend/app/services/memory/service.py
create mode 100644 surfsense_backend/app/services/memory/validation.py
create mode 100644 surfsense_backend/tests/unit/services/test_memory_service.py
diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py
index 23375a081..67bcc3e06 100644
--- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py
+++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py
@@ -1,280 +1,23 @@
-"""Overwrite one markdown memory document per user or team, with size and shrink guards."""
+"""Memory update tools backed by the canonical memory service."""
from __future__ import annotations
import logging
-import re
-from typing import Any, Literal
+from typing import Any
from uuid import UUID
-from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
-from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.db import SearchSpace, User
+from app.services.memory import (
+ MEMORY_HARD_LIMIT,
+ MEMORY_SOFT_LIMIT,
+ MemoryScope,
+ save_memory,
+)
logger = logging.getLogger(__name__)
-MEMORY_SOFT_LIMIT = 18_000
-MEMORY_HARD_LIMIT = 25_000
-
-_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
-_HEADING_NORMALIZE_RE = re.compile(r"\s+")
-
-_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
-_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
-_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
-
-
-# ---------------------------------------------------------------------------
-# Diff validation
-# ---------------------------------------------------------------------------
-
-
-def _extract_headings(memory: str) -> set[str]:
- """Return all ``## …`` heading texts (without the ``## `` prefix)."""
- return set(_SECTION_HEADING_RE.findall(memory))
-
-
-def _normalize_heading(heading: str) -> str:
- """Normalize heading text for robust scope checks."""
- return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
-
-
-def _validate_memory_scope(
- content: str, scope: Literal["user", "team"]
-) -> dict[str, Any] | None:
- """Reject personal-only markers ([pref], [instr]) in team memory."""
- if scope != "team":
- return None
-
- markers = set(_MARKER_RE.findall(content))
- leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
- if leaked:
- tags = ", ".join(f"[{m}]" for m in leaked)
- return {
- "status": "error",
- "message": (
- f"Team memory cannot include personal markers: {tags}. "
- "Use [fact] only in team memory."
- ),
- }
- return None
-
-
-def _validate_bullet_format(content: str) -> list[str]:
- """Return warnings for bullet lines that don't match the required format.
-
- Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
- """
- warnings: list[str] = []
- for line in content.splitlines():
- stripped = line.strip()
- if not stripped.startswith("- "):
- continue
- if not _BULLET_FORMAT_RE.match(stripped):
- short = stripped[:80] + ("..." if len(stripped) > 80 else "")
- warnings.append(f"Malformed bullet: {short}")
- return warnings
-
-
-def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
- """Return a list of warning strings about suspicious changes."""
- if not old_memory:
- return []
-
- warnings: list[str] = []
- old_headings = _extract_headings(old_memory)
- new_headings = _extract_headings(new_memory)
- dropped = old_headings - new_headings
- if dropped:
- names = ", ".join(sorted(dropped))
- warnings.append(
- f"Sections removed: {names}. "
- "If unintentional, the user can restore from the settings page."
- )
-
- old_len = len(old_memory)
- new_len = len(new_memory)
- if old_len > 0 and new_len < old_len * 0.4:
- warnings.append(
- f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
- "Possible data loss."
- )
- return warnings
-
-
-# ---------------------------------------------------------------------------
-# Size validation & soft warning
-# ---------------------------------------------------------------------------
-
-
-def _validate_memory_size(content: str) -> dict[str, Any] | None:
- """Return an error/warning dict if *content* is too large, else None."""
- length = len(content)
- if length > MEMORY_HARD_LIMIT:
- return {
- "status": "error",
- "message": (
- f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
- f"({length:,} chars). Consolidate by merging related items, "
- "removing outdated entries, and shortening descriptions. "
- "Then call update_memory again."
- ),
- }
- return None
-
-
-def _soft_warning(content: str) -> str | None:
- """Return a warning string if content exceeds the soft limit."""
- length = len(content)
- if length > MEMORY_SOFT_LIMIT:
- return (
- f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
- "Consolidate by merging related items and removing less important "
- "entries on your next update."
- )
- return None
-
-
-# ---------------------------------------------------------------------------
-# Forced rewrite when memory exceeds the hard limit
-# ---------------------------------------------------------------------------
-
-_FORCED_REWRITE_PROMPT = """\
-You are a memory curator. The following memory document exceeds the character \
-limit and must be shortened.
-
-RULES:
-1. Rewrite the document to be under {target} characters.
-2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
- or rename headings to consolidate, but keep names personal and descriptive.
-3. Priority for keeping content: [instr] > [pref] > [fact].
-4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
-5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
-6. Preserve the user's first name in entries — do not replace it with "the user".
-7. Output ONLY the consolidated markdown — no explanations, no wrapping.
-
-
-{content}
-"""
-
-
-async def _forced_rewrite(content: str, llm: Any) -> str | None:
- """Use a focused LLM call to compress *content* under the hard limit.
-
- Returns the rewritten string, or ``None`` if the call fails.
- """
- try:
- prompt = _FORCED_REWRITE_PROMPT.format(
- target=MEMORY_HARD_LIMIT, content=content
- )
- response = await llm.ainvoke(
- [HumanMessage(content=prompt)],
- config={"tags": ["surfsense:internal"]},
- )
- text = (
- response.content
- if isinstance(response.content, str)
- else str(response.content)
- )
- return text.strip()
- except Exception:
- logger.exception("Forced rewrite LLM call failed")
- return None
-
-
-# ---------------------------------------------------------------------------
-# Shared save-and-respond logic
-# ---------------------------------------------------------------------------
-
-
-async def _save_memory(
- *,
- updated_memory: str,
- old_memory: str | None,
- llm: Any | None,
- apply_fn,
- commit_fn,
- rollback_fn,
- label: str,
- scope: Literal["user", "team"],
-) -> dict[str, Any]:
- """Validate, optionally force-rewrite if over the hard limit, save, and
- return a response dict.
-
- Parameters
- ----------
- updated_memory : str
- The new document the agent submitted.
- old_memory : str | None
- The previously persisted document (for diff checks).
- llm : Any | None
- LLM instance for forced rewrite (may be ``None``).
- apply_fn : callable(str) -> None
- Callback that sets the new memory on the ORM object.
- commit_fn : coroutine
- ``session.commit``.
- rollback_fn : coroutine
- ``session.rollback``.
- label : str
- Human label for log messages (e.g. "user memory", "team memory").
- """
- content = updated_memory
-
- # --- forced rewrite if over the hard limit ---
- if len(content) > MEMORY_HARD_LIMIT and llm is not None:
- rewritten = await _forced_rewrite(content, llm)
- if rewritten is not None and len(rewritten) < len(content):
- content = rewritten
-
- # --- hard-limit gate (reject if still too large after rewrite) ---
- size_err = _validate_memory_size(content)
- if size_err:
- return size_err
-
- scope_err = _validate_memory_scope(content, scope)
- if scope_err:
- return scope_err
-
- # --- persist ---
- try:
- apply_fn(content)
- await commit_fn()
- except Exception as e:
- logger.exception("Failed to update %s: %s", label, e)
- await rollback_fn()
- return {"status": "error", "message": f"Failed to update {label}: {e}"}
-
- # --- build response ---
- resp: dict[str, Any] = {
- "status": "saved",
- "message": f"{label.capitalize()} updated.",
- }
-
- if content is not updated_memory:
- resp["notice"] = "Memory was automatically rewritten to fit within limits."
-
- diff_warnings = _validate_diff(old_memory, content)
- if diff_warnings:
- resp["diff_warnings"] = diff_warnings
-
- format_warnings = _validate_bullet_format(content)
- if format_warnings:
- resp["format_warnings"] = format_warnings
-
- warning = _soft_warning(content)
- if warning:
- resp["warning"] = warning
-
- return resp
-
-
-# ---------------------------------------------------------------------------
-# Tool factories
-# ---------------------------------------------------------------------------
-
def create_update_memory_tool(
user_id: str | UUID,
@@ -287,40 +30,22 @@ def create_update_memory_tool(
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the user's personal memory document.
- Your current memory is shown in in the system prompt.
- When the user shares important long-term information (preferences,
- facts, instructions, context), rewrite the memory document to include
- the new information. Merge new facts with existing ones, update
- contradictions, remove outdated entries, and keep it concise.
-
- Args:
- updated_memory: The FULL updated markdown document (not a diff).
+ The current memory is shown in . Pass the FULL updated
+ markdown document, not a diff.
"""
try:
- result = await db_session.execute(select(User).where(User.id == uid))
- user = result.scalars().first()
- if not user:
- return {"status": "error", "message": "User not found."}
-
- old_memory = user.memory_md
-
- return await _save_memory(
- updated_memory=updated_memory,
- old_memory=old_memory,
+ result = await save_memory(
+ scope=MemoryScope.USER,
+ target_id=uid,
+ content=updated_memory,
+ session=db_session,
llm=llm,
- apply_fn=lambda content: setattr(user, "memory_md", content),
- commit_fn=db_session.commit,
- rollback_fn=db_session.rollback,
- label="memory",
- scope="user",
)
+ return result.to_dict()
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
await db_session.rollback()
- return {
- "status": "error",
- "message": f"Failed to update memory: {e}",
- }
+ return {"status": "error", "message": f"Failed to update memory: {e}"}
return update_memory
@@ -334,36 +59,18 @@ def create_update_team_memory_tool(
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
- Your current team memory is shown in in the system
- prompt. When the team shares important long-term information
- (decisions, conventions, key facts, priorities), rewrite the memory
- document to include the new information. Merge new facts with
- existing ones, update contradictions, remove outdated entries, and
- keep it concise.
-
- Args:
- updated_memory: The FULL updated markdown document (not a diff).
+ The current team memory is shown in . Pass the FULL updated
+ markdown document, not a diff.
"""
try:
- result = await db_session.execute(
- select(SearchSpace).where(SearchSpace.id == search_space_id)
- )
- space = result.scalars().first()
- if not space:
- return {"status": "error", "message": "Search space not found."}
-
- old_memory = space.shared_memory_md
-
- return await _save_memory(
- updated_memory=updated_memory,
- old_memory=old_memory,
+ result = await save_memory(
+ scope=MemoryScope.TEAM,
+ target_id=search_space_id,
+ content=updated_memory,
+ session=db_session,
llm=llm,
- apply_fn=lambda content: setattr(space, "shared_memory_md", content),
- commit_fn=db_session.commit,
- rollback_fn=db_session.rollback,
- label="team memory",
- scope="team",
)
+ return result.to_dict()
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
await db_session.rollback()
@@ -373,3 +80,11 @@ def create_update_team_memory_tool(
}
return update_memory
+
+
+__all__ = [
+ "MEMORY_HARD_LIMIT",
+ "MEMORY_SOFT_LIMIT",
+ "create_update_memory_tool",
+ "create_update_team_memory_tool",
+]
diff --git a/surfsense_backend/app/agents/new_chat/memory_extraction.py b/surfsense_backend/app/agents/new_chat/memory_extraction.py
index e31774a7c..d44b58f7b 100644
--- a/surfsense_backend/app/agents/new_chat/memory_extraction.py
+++ b/surfsense_backend/app/agents/new_chat/memory_extraction.py
@@ -1,9 +1,4 @@
-"""Background memory extraction for the SurfSense agent.
-
-After each agent response, if the agent did not call ``update_memory`` during
-the turn, this module can run a lightweight LLM call to decide whether the
-latest message contains long-term information worth persisting.
-"""
+"""Background memory extraction for the SurfSense agent."""
from __future__ import annotations
@@ -11,102 +6,11 @@ import logging
from typing import Any
from uuid import UUID
-from langchain_core.messages import HumanMessage
-from sqlalchemy import select
-
-from app.agents.new_chat.tools.update_memory import _save_memory
-from app.db import SearchSpace, User, shielded_async_session
-from app.utils.content_utils import extract_text_content
+from app.db import User, shielded_async_session
+from app.services.memory import MemoryScope, extract_and_save
logger = logging.getLogger(__name__)
-_MEMORY_EXTRACT_PROMPT = """\
-You are a memory extraction assistant. Analyze the user's message and decide \
-if it contains any long-term information worth persisting to memory.
-
-Worth remembering: preferences, background/identity, goals, projects, \
-instructions, tools/languages they use, decisions, expertise, workplace — \
-durable facts that will matter in future conversations.
-
-NOT worth remembering: greetings, one-off factual questions, session \
-logistics, ephemeral requests, follow-up clarifications with no new personal \
-info, things that only matter for the current task.
-
-If the message contains memorizable information, output the FULL updated \
-memory document with the new facts merged into the existing content. Follow \
-these rules:
-- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
- freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's
- name in headings.
-- Keep entries as single bullet points. Be descriptive but concise — include relevant
- details and context rather than just a few words.
-- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text
- [fact] = durable facts, [pref] = preferences, [instr] = standing instructions.
-- Use the user's first name (from ) in entry text, not "the user".
-- If a new fact contradicts an existing entry, update the existing entry.
-- Do not duplicate information that is already present.
-
-If nothing is worth remembering, output exactly: NO_UPDATE
-
-{user_name}
-
-
-{current_memory}
-
-
-
-{user_message}
-"""
-
-_TEAM_MEMORY_EXTRACT_PROMPT = """\
-You are a team-memory extraction assistant. Analyze the latest message and \
-decide if it contains durable TEAM-level information worth persisting.
-
-Decision policy:
-- Prioritize recall for durable team context, while avoiding personal-only facts.
-- Do NOT require explicit consensus language. A direct team-level statement can
- be stored if it is stable and broadly useful for future team chats.
-- If evidence is weak or clearly tentative, output NO_UPDATE.
-
-Worth remembering (team-level only):
-- Decisions and defaults that guide future team work
-- Team conventions/standards (naming, review policy, coding norms)
-- Stable org/project facts (locations, ownership, constraints)
-- Long-lived architecture/process facts
-- Ongoing priorities that are likely relevant beyond this turn
-
-NOT worth remembering:
-- Personal preferences or biography of one person
-- Questions, brainstorming, tentative ideas, or speculation
-- One-off requests, status updates, TODOs, logistics for this session
-- Information scoped only to a single ephemeral task
-
-If the message contains memorizable team information, output the FULL updated \
-team memory document with new facts merged into existing content. Follow rules:
-- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
- freely. Keep heading names short (2-3 words) and natural.
-- Keep entries as single bullet points. Be descriptive but concise — include relevant
- details and context rather than just a few words.
-- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text
- Team memory uses ONLY the [fact] marker. Never use [pref] or [instr].
-- If a new fact contradicts an existing entry, update the existing entry.
-- Do not duplicate existing information.
-- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
-
-If nothing is worth remembering, output exactly: NO_UPDATE
-
-
-{current_memory}
-
-
-
-{author}
-
-
-
-{user_message}
-"""
-
async def extract_and_save_memory(
*,
@@ -114,57 +18,31 @@ async def extract_and_save_memory(
user_id: str | None,
llm: Any,
) -> None:
- """Background task: extract memorizable info and persist it.
+ """Fire-and-forget personal memory extraction.
- Designed to be fire-and-forget — catches all exceptions internally.
+ The service uses structured output, so free-form ``NO_UPDATE`` text can no
+ longer be accidentally persisted as memory.
"""
if not user_id:
return
try:
uid = UUID(user_id) if isinstance(user_id, str) else user_id
-
async with shielded_async_session() as session:
- result = await session.execute(select(User).where(User.id == uid))
- user = result.scalars().first()
- if not user:
- return
-
- old_memory = user.memory_md
- first_name = (
- user.display_name.strip().split()[0]
- if user.display_name and user.display_name.strip()
- else "The user"
- )
- prompt = _MEMORY_EXTRACT_PROMPT.format(
- current_memory=old_memory or "(empty)",
+ user = await session.get(User, uid)
+ actor_display_name = user.display_name if user else None
+ result = await extract_and_save(
+ scope=MemoryScope.USER,
+ target_id=uid,
user_message=user_message,
- user_name=first_name,
- )
- response = await llm.ainvoke(
- [HumanMessage(content=prompt)],
- config={"tags": ["surfsense:internal", "memory-extraction"]},
- )
- text = extract_text_content(response.content).strip()
-
- if text == "NO_UPDATE" or not text:
- logger.debug("Memory extraction: no update needed (user %s)", uid)
- return
-
- save_result = await _save_memory(
- updated_memory=text,
- old_memory=old_memory,
+ actor_display_name=actor_display_name,
+ session=session,
llm=llm,
- apply_fn=lambda content: setattr(user, "memory_md", content),
- commit_fn=session.commit,
- rollback_fn=session.rollback,
- label="memory",
- scope="user",
)
logger.info(
"Background memory extraction for user %s: %s",
uid,
- save_result.get("status"),
+ result.status,
)
except Exception:
logger.exception("Background user memory extraction failed")
@@ -177,56 +55,24 @@ async def extract_and_save_team_memory(
llm: Any,
author_display_name: str | None = None,
) -> None:
- """Background task: extract team-level memory and persist it.
-
- Runs only for shared threads. Designed to be fire-and-forget and catches
- exceptions internally.
- """
+ """Fire-and-forget team-level memory extraction."""
if not search_space_id:
return
try:
async with shielded_async_session() as session:
- result = await session.execute(
- select(SearchSpace).where(SearchSpace.id == search_space_id)
- )
- space = result.scalars().first()
- if not space:
- return
-
- old_memory = space.shared_memory_md
- prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
- current_memory=old_memory or "(empty)",
- author=author_display_name or "Unknown team member",
+ result = await extract_and_save(
+ scope=MemoryScope.TEAM,
+ target_id=search_space_id,
user_message=user_message,
- )
- response = await llm.ainvoke(
- [HumanMessage(content=prompt)],
- config={"tags": ["surfsense:internal", "team-memory-extraction"]},
- )
- text = extract_text_content(response.content).strip()
-
- if text == "NO_UPDATE" or not text:
- logger.debug(
- "Team memory extraction: no update needed (space %s)",
- search_space_id,
- )
- return
-
- save_result = await _save_memory(
- updated_memory=text,
- old_memory=old_memory,
+ actor_display_name=author_display_name,
+ session=session,
llm=llm,
- apply_fn=lambda content: setattr(space, "shared_memory_md", content),
- commit_fn=session.commit,
- rollback_fn=session.rollback,
- label="team memory",
- scope="team",
)
logger.info(
"Background team memory extraction for space %s: %s",
search_space_id,
- save_result.get("status"),
+ result.status,
)
except Exception:
logger.exception("Background team memory extraction failed")
diff --git a/surfsense_backend/app/agents/new_chat/tools/update_memory.py b/surfsense_backend/app/agents/new_chat/tools/update_memory.py
index 062668aac..78a65201b 100644
--- a/surfsense_backend/app/agents/new_chat/tools/update_memory.py
+++ b/surfsense_backend/app/agents/new_chat/tools/update_memory.py
@@ -1,369 +1,53 @@
-"""Markdown-document memory tool for the SurfSense agent.
-
-Replaces the old row-per-fact save_memory / recall_memory tools with a single
-update_memory tool that overwrites a freeform markdown TEXT column. The LLM
-always sees the current memory in / tags injected
-by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
-
-Overflow handling:
- - Soft limit (18K chars): a warning is returned telling the agent to
- consolidate on the next update.
- - Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
- If it still exceeds the limit after rewriting, the save is rejected.
- - Diff validation: warns when entire ``##`` sections are dropped or when the
- document shrinks by more than 60%.
-"""
+"""Memory update tools backed by the canonical memory service."""
from __future__ import annotations
import logging
-import re
-from typing import Any, Literal
+from typing import Any
from uuid import UUID
-from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
-from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.db import SearchSpace, User, async_session_maker
-from app.utils.content_utils import extract_text_content
+from app.db import async_session_maker
+from app.services.memory import MemoryScope, save_memory
logger = logging.getLogger(__name__)
-MEMORY_SOFT_LIMIT = 18_000
-MEMORY_HARD_LIMIT = 25_000
-
-_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
-_HEADING_NORMALIZE_RE = re.compile(r"\s+")
-
-_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
-_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
-_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
-
-
-# ---------------------------------------------------------------------------
-# Diff validation
-# ---------------------------------------------------------------------------
-
-
-def _extract_headings(memory: str) -> set[str]:
- """Return all ``## …`` heading texts (without the ``## `` prefix)."""
- return set(_SECTION_HEADING_RE.findall(memory))
-
-
-def _normalize_heading(heading: str) -> str:
- """Normalize heading text for robust scope checks."""
- return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
-
-
-def _validate_memory_scope(
- content: str, scope: Literal["user", "team"]
-) -> dict[str, Any] | None:
- """Reject personal-only markers ([pref], [instr]) in team memory."""
- if scope != "team":
- return None
-
- markers = set(_MARKER_RE.findall(content))
- leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
- if leaked:
- tags = ", ".join(f"[{m}]" for m in leaked)
- return {
- "status": "error",
- "message": (
- f"Team memory cannot include personal markers: {tags}. "
- "Use [fact] only in team memory."
- ),
- }
- return None
-
-
-def _validate_bullet_format(content: str) -> list[str]:
- """Return warnings for bullet lines that don't match the required format.
-
- Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
- """
- warnings: list[str] = []
- for line in content.splitlines():
- stripped = line.strip()
- if not stripped.startswith("- "):
- continue
- if not _BULLET_FORMAT_RE.match(stripped):
- short = stripped[:80] + ("..." if len(stripped) > 80 else "")
- warnings.append(f"Malformed bullet: {short}")
- return warnings
-
-
-def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
- """Return a list of warning strings about suspicious changes."""
- if not old_memory:
- return []
-
- warnings: list[str] = []
- old_headings = _extract_headings(old_memory)
- new_headings = _extract_headings(new_memory)
- dropped = old_headings - new_headings
- if dropped:
- names = ", ".join(sorted(dropped))
- warnings.append(
- f"Sections removed: {names}. "
- "If unintentional, the user can restore from the settings page."
- )
-
- old_len = len(old_memory)
- new_len = len(new_memory)
- if old_len > 0 and new_len < old_len * 0.4:
- warnings.append(
- f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
- "Possible data loss."
- )
- return warnings
-
-
-# ---------------------------------------------------------------------------
-# Size validation & soft warning
-# ---------------------------------------------------------------------------
-
-
-def _validate_memory_size(content: str) -> dict[str, Any] | None:
- """Return an error/warning dict if *content* is too large, else None."""
- length = len(content)
- if length > MEMORY_HARD_LIMIT:
- return {
- "status": "error",
- "message": (
- f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
- f"({length:,} chars). Consolidate by merging related items, "
- "removing outdated entries, and shortening descriptions. "
- "Then call update_memory again."
- ),
- }
- return None
-
-
-def _soft_warning(content: str) -> str | None:
- """Return a warning string if content exceeds the soft limit."""
- length = len(content)
- if length > MEMORY_SOFT_LIMIT:
- return (
- f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
- "Consolidate by merging related items and removing less important "
- "entries on your next update."
- )
- return None
-
-
-# ---------------------------------------------------------------------------
-# Forced rewrite when memory exceeds the hard limit
-# ---------------------------------------------------------------------------
-
-_FORCED_REWRITE_PROMPT = """\
-You are a memory curator. The following memory document exceeds the character \
-limit and must be shortened.
-
-RULES:
-1. Rewrite the document to be under {target} characters.
-2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
- or rename headings to consolidate, but keep names personal and descriptive.
-3. Priority for keeping content: [instr] > [pref] > [fact].
-4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
-5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
-6. Preserve the user's first name in entries — do not replace it with "the user".
-7. Output ONLY the consolidated markdown — no explanations, no wrapping.
-
-
-{content}
-"""
-
-
-async def _forced_rewrite(content: str, llm: Any) -> str | None:
- """Use a focused LLM call to compress *content* under the hard limit.
-
- Returns the rewritten string, or ``None`` if the call fails.
- """
- try:
- prompt = _FORCED_REWRITE_PROMPT.format(
- target=MEMORY_HARD_LIMIT, content=content
- )
- response = await llm.ainvoke(
- [HumanMessage(content=prompt)],
- config={"tags": ["surfsense:internal"]},
- )
- text = extract_text_content(response.content).strip()
- if not text:
- logger.warning("Forced rewrite returned empty text; aborting rewrite")
- return None
- return text
- except Exception:
- logger.exception("Forced rewrite LLM call failed")
- return None
-
-
-# ---------------------------------------------------------------------------
-# Shared save-and-respond logic
-# ---------------------------------------------------------------------------
-
-
-async def _save_memory(
- *,
- updated_memory: str,
- old_memory: str | None,
- llm: Any | None,
- apply_fn,
- commit_fn,
- rollback_fn,
- label: str,
- scope: Literal["user", "team"],
-) -> dict[str, Any]:
- """Validate, optionally force-rewrite if over the hard limit, save, and
- return a response dict.
-
- Parameters
- ----------
- updated_memory : str
- The new document the agent submitted.
- old_memory : str | None
- The previously persisted document (for diff checks).
- llm : Any | None
- LLM instance for forced rewrite (may be ``None``).
- apply_fn : callable(str) -> None
- Callback that sets the new memory on the ORM object.
- commit_fn : coroutine
- ``session.commit``.
- rollback_fn : coroutine
- ``session.rollback``.
- label : str
- Human label for log messages (e.g. "user memory", "team memory").
- """
- if not isinstance(updated_memory, str):
- logger.warning(
- "Refusing non-string memory payload (type=%s)",
- type(updated_memory).__name__,
- )
- return {
- "status": "error",
- "message": "Internal error: memory payload must be a string.",
- }
-
- content = updated_memory
-
- # --- forced rewrite if over the hard limit ---
- if len(content) > MEMORY_HARD_LIMIT and llm is not None:
- rewritten = await _forced_rewrite(content, llm)
- if rewritten is not None and len(rewritten) < len(content):
- content = rewritten
-
- # --- hard-limit gate (reject if still too large after rewrite) ---
- size_err = _validate_memory_size(content)
- if size_err:
- return size_err
-
- scope_err = _validate_memory_scope(content, scope)
- if scope_err:
- return scope_err
-
- # --- persist ---
- try:
- apply_fn(content)
- await commit_fn()
- except Exception as e:
- logger.exception("Failed to update %s: %s", label, e)
- await rollback_fn()
- return {"status": "error", "message": f"Failed to update {label}: {e}"}
-
- # --- build response ---
- resp: dict[str, Any] = {
- "status": "saved",
- "message": f"{label.capitalize()} updated.",
- }
-
- if content is not updated_memory:
- resp["notice"] = "Memory was automatically rewritten to fit within limits."
-
- diff_warnings = _validate_diff(old_memory, content)
- if diff_warnings:
- resp["diff_warnings"] = diff_warnings
-
- format_warnings = _validate_bullet_format(content)
- if format_warnings:
- resp["format_warnings"] = format_warnings
-
- warning = _soft_warning(content)
- if warning:
- resp["warning"] = warning
-
- return resp
-
-
-# ---------------------------------------------------------------------------
-# Tool factories
-# ---------------------------------------------------------------------------
-
def create_update_memory_tool(
user_id: str | UUID,
db_session: AsyncSession,
llm: Any | None = None,
):
- """Factory function to create the user-memory update tool.
+ """Factory for the user-memory update tool.
- The tool acquires its own short-lived ``AsyncSession`` per call via
- :data:`async_session_maker` so the closure is safe to share across
- HTTP requests by the compiled-agent cache. Capturing a per-request
- session here would surface stale/closed sessions on cache hits.
- The session's bound ``commit``/``rollback`` methods are captured at
- call time, after ``async with`` has bound ``db_session`` locally.
-
- Args:
- user_id: ID of the user whose memory document is being updated.
- db_session: Reserved for registry compatibility. Per-call sessions
- are opened via :data:`async_session_maker` inside the tool body.
- llm: Optional LLM for the forced-rewrite path.
-
- Returns:
- Configured update_memory tool for the user-memory scope.
+ Uses a fresh short-lived session per call so compiled-agent caches never
+ retain a stale request-scoped session.
"""
- del db_session # per-call session — see docstring
+ del db_session
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the user's personal memory document.
- Your current memory is shown in in the system prompt.
- When the user shares important long-term information (preferences,
- facts, instructions, context), rewrite the memory document to include
- the new information. Merge new facts with existing ones, update
- contradictions, remove outdated entries, and keep it concise.
-
- Args:
- updated_memory: The FULL updated markdown document (not a diff).
+ The current memory is shown in . Pass the FULL updated
+ markdown document, not a diff.
"""
try:
async with async_session_maker() as db_session:
- result = await db_session.execute(select(User).where(User.id == uid))
- user = result.scalars().first()
- if not user:
- return {"status": "error", "message": "User not found."}
-
- old_memory = user.memory_md
-
- return await _save_memory(
- updated_memory=updated_memory,
- old_memory=old_memory,
+ result = await save_memory(
+ scope=MemoryScope.USER,
+ target_id=uid,
+ content=updated_memory,
+ session=db_session,
llm=llm,
- apply_fn=lambda content: setattr(user, "memory_md", content),
- commit_fn=db_session.commit,
- rollback_fn=db_session.rollback,
- label="memory",
- scope="user",
)
+ return result.to_dict()
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
- return {
- "status": "error",
- "message": f"Failed to update memory: {e}",
- }
+ return {"status": "error", "message": f"Failed to update memory: {e}"}
return update_memory
@@ -373,64 +57,26 @@ def create_update_team_memory_tool(
db_session: AsyncSession,
llm: Any | None = None,
):
- """Factory function to create the team-memory update tool.
-
- The tool acquires its own short-lived ``AsyncSession`` per call via
- :data:`async_session_maker` so the closure is safe to share across
- HTTP requests by the compiled-agent cache. Capturing a per-request
- session here would surface stale/closed sessions on cache hits.
- The session's bound ``commit``/``rollback`` methods are captured at
- call time, after ``async with`` has bound ``db_session`` locally.
-
- Args:
- search_space_id: ID of the search space whose team memory is being
- updated.
- db_session: Reserved for registry compatibility. Per-call sessions
- are opened via :data:`async_session_maker` inside the tool body.
- llm: Optional LLM for the forced-rewrite path.
-
- Returns:
- Configured update_memory tool for the team-memory scope.
- """
- del db_session # per-call session — see docstring
+ """Factory for the team-memory update tool."""
+ del db_session
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
- Your current team memory is shown in in the system
- prompt. When the team shares important long-term information
- (decisions, conventions, key facts, priorities), rewrite the memory
- document to include the new information. Merge new facts with
- existing ones, update contradictions, remove outdated entries, and
- keep it concise.
-
- Args:
- updated_memory: The FULL updated markdown document (not a diff).
+ The current team memory is shown in . Pass the FULL updated
+ markdown document, not a diff.
"""
try:
async with async_session_maker() as db_session:
- result = await db_session.execute(
- select(SearchSpace).where(SearchSpace.id == search_space_id)
- )
- space = result.scalars().first()
- if not space:
- return {"status": "error", "message": "Search space not found."}
-
- old_memory = space.shared_memory_md
-
- return await _save_memory(
- updated_memory=updated_memory,
- old_memory=old_memory,
+ result = await save_memory(
+ scope=MemoryScope.TEAM,
+ target_id=search_space_id,
+ content=updated_memory,
+ session=db_session,
llm=llm,
- apply_fn=lambda content: setattr(
- space, "shared_memory_md", content
- ),
- commit_fn=db_session.commit,
- rollback_fn=db_session.rollback,
- label="team memory",
- scope="team",
)
+ return result.to_dict()
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
return {
@@ -439,3 +85,9 @@ def create_update_team_memory_tool(
}
return update_memory
+
+
+__all__ = [
+ "create_update_memory_tool",
+ "create_update_team_memory_tool",
+]
diff --git a/surfsense_backend/app/services/memory/__init__.py b/surfsense_backend/app/services/memory/__init__.py
new file mode 100644
index 000000000..d72f45e1f
--- /dev/null
+++ b/surfsense_backend/app/services/memory/__init__.py
@@ -0,0 +1,29 @@
+"""First-class memory service for user and team markdown memory."""
+
+from .service import (
+ MemoryScope,
+ SaveResult,
+ extract_and_save,
+ read_memory,
+ reset_memory,
+ save_memory,
+)
+from .validation import (
+ MEMORY_HARD_LIMIT,
+ MEMORY_SOFT_LIMIT,
+ validate_bullet_format,
+ validate_memory_scope,
+)
+
+__all__ = [
+ "MEMORY_HARD_LIMIT",
+ "MEMORY_SOFT_LIMIT",
+ "MemoryScope",
+ "SaveResult",
+ "extract_and_save",
+ "read_memory",
+ "reset_memory",
+ "save_memory",
+ "validate_bullet_format",
+ "validate_memory_scope",
+]
diff --git a/surfsense_backend/app/services/memory/prompts.py b/surfsense_backend/app/services/memory/prompts.py
new file mode 100644
index 000000000..fbf27fd08
--- /dev/null
+++ b/surfsense_backend/app/services/memory/prompts.py
@@ -0,0 +1,110 @@
+"""Prompts used by the memory service."""
+
+FORCED_REWRITE_PROMPT = """\
+You are a memory curator. The following memory document exceeds the character \
+limit and must be shortened.
+
+RULES:
+1. Rewrite the document to be under {target} characters.
+2. Output Markdown only. Use clear `##` headings and concise bullet points.
+3. New-format bullets should look like: `- YYYY-MM-DD: memory text`.
+4. If the input contains legacy markers like `(YYYY-MM-DD) [fact]`, preserve the
+ information but remove the inline marker in the output.
+5. Preserve durable instructions and preferences before generic facts when
+ compressing personal memory.
+6. Preserve existing headings when useful; merge duplicate headings and bullets.
+7. Output ONLY the consolidated markdown — no explanations, no wrapping.
+
+
+{content}
+"""
+
+USER_MEMORY_EXTRACT_PROMPT = """\
+You are a memory extraction assistant. Analyze the user's message and decide \
+if it contains any long-term information worth persisting to personal memory.
+
+Worth remembering: preferences, background/identity, goals, projects, \
+instructions, tools/languages they use, decisions, expertise, workplace — \
+durable facts that will matter in future conversations.
+
+NOT worth remembering: greetings, one-off factual questions, session \
+logistics, ephemeral requests, follow-up clarifications with no new personal \
+info, things that only matter for the current task.
+
+If there is nothing durable to remember, choose `action = no_update`.
+
+If the message contains memorizable information, choose `action = save` and \
+return the FULL updated memory document with the new information merged into \
+existing content.
+
+FORMAT RULES FOR `updated_memory`:
+- Markdown only.
+- Every entry should be under a `##` heading.
+- Recommended headings: `## Facts`, `## Preferences`, `## Instructions`.
+- New bullets should use: `- YYYY-MM-DD: memory text`.
+- If current memory uses legacy `(YYYY-MM-DD) [fact|pref|instr]` markers,
+ preserve the information but write the updated document in the new
+ heading-based format.
+- Use the user's first name from `` when helpful, not "the user".
+- Do not duplicate existing information.
+
+{user_name}
+
+
+{current_memory}
+
+
+
+{user_message}
+"""
+
+TEAM_MEMORY_EXTRACT_PROMPT = """\
+You are a team-memory extraction assistant. Analyze the latest message and \
+decide if it contains durable TEAM-level information worth persisting.
+
+Decision policy:
+- Prioritize recall for durable team context, while avoiding personal-only facts.
+- Do NOT require explicit consensus language. A direct team-level statement can
+ be stored if it is stable and broadly useful for future team chats.
+- If evidence is weak or clearly tentative, choose `action = no_update`.
+
+Worth remembering (team-level only):
+- Decisions and defaults that guide future team work
+- Team conventions/standards (naming, review policy, coding norms)
+- Stable org/project facts (locations, ownership, constraints)
+- Long-lived architecture/process facts
+- Ongoing priorities that are likely relevant beyond this turn
+
+NOT worth remembering:
+- Personal preferences or biography of one person
+- Questions, brainstorming, tentative ideas, or speculation
+- One-off requests, status updates, TODOs, logistics for this session
+- Information scoped only to a single ephemeral task
+
+If the message contains memorizable team information, choose `action = save` \
+and return the FULL updated team memory document with new facts merged into \
+existing content.
+
+FORMAT RULES FOR `updated_memory`:
+- Markdown only.
+- Every entry should be under a `##` heading.
+- Recommended headings: `## Product Decisions`, `## Engineering Conventions`,
+ `## Project Facts`, `## Open Questions`.
+- New bullets should use: `- YYYY-MM-DD: memory text`.
+- If current memory uses legacy `(YYYY-MM-DD) [fact]` markers, preserve the
+ information but write the updated document in the new heading-based format.
+- Do not create personal headings such as `## Preferences`, `## Instructions`,
+ or `## Personal Notes`.
+- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
+
+
+{current_memory}
+
+
+
+{author}
+
+
+
+{user_message}
+"""
diff --git a/surfsense_backend/app/services/memory/rewrite.py b/surfsense_backend/app/services/memory/rewrite.py
new file mode 100644
index 000000000..270904ce7
--- /dev/null
+++ b/surfsense_backend/app/services/memory/rewrite.py
@@ -0,0 +1,35 @@
+"""LLM-backed memory rewrite helpers."""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+from langchain_core.messages import HumanMessage
+
+from app.services.memory.prompts import FORCED_REWRITE_PROMPT
+from app.services.memory.validation import MEMORY_HARD_LIMIT
+from app.utils.content_utils import extract_text_content
+
+logger = logging.getLogger(__name__)
+
+
+async def forced_rewrite(content: str, llm: Any) -> str | None:
+ """Use a focused LLM call to compress memory under the hard limit."""
+ try:
+ prompt = FORCED_REWRITE_PROMPT.format(
+ target=MEMORY_HARD_LIMIT,
+ content=content,
+ )
+ response = await llm.ainvoke(
+ [HumanMessage(content=prompt)],
+ config={"tags": ["surfsense:internal", "memory-rewrite"]},
+ )
+ text = extract_text_content(response.content).strip()
+ if not text:
+ logger.warning("Forced memory rewrite returned empty text")
+ return None
+ return text
+ except Exception:
+ logger.exception("Forced memory rewrite LLM call failed")
+ return None
diff --git a/surfsense_backend/app/services/memory/schemas.py b/surfsense_backend/app/services/memory/schemas.py
new file mode 100644
index 000000000..9b40ee5b1
--- /dev/null
+++ b/surfsense_backend/app/services/memory/schemas.py
@@ -0,0 +1,23 @@
+"""Structured output schemas for memory extraction."""
+
+from __future__ import annotations
+
+from typing import Literal
+
+from pydantic import BaseModel, Field
+
+
+class MemoryExtractionDecision(BaseModel):
+ """Structured extraction result; avoids string sentinel parsing."""
+
+ action: Literal["no_update", "save"] = Field(
+ description="Choose no_update when nothing durable should be saved; choose save otherwise."
+ )
+ reason: str | None = Field(
+ default=None,
+ description="Short reason for no_update, or brief summary of the memory update.",
+ )
+ updated_memory: str | None = Field(
+ default=None,
+ description="The full updated markdown memory document when action is save.",
+ )
diff --git a/surfsense_backend/app/services/memory/service.py b/surfsense_backend/app/services/memory/service.py
new file mode 100644
index 000000000..85459c28c
--- /dev/null
+++ b/surfsense_backend/app/services/memory/service.py
@@ -0,0 +1,300 @@
+"""Canonical read/write/reset/extract service for markdown memory."""
+
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass, field
+from enum import StrEnum
+from typing import Any, Literal
+from uuid import UUID
+
+from langchain_core.messages import HumanMessage
+from pydantic import BaseModel
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db import SearchSpace, User
+from app.services.memory.prompts import (
+ TEAM_MEMORY_EXTRACT_PROMPT,
+ USER_MEMORY_EXTRACT_PROMPT,
+)
+from app.services.memory.rewrite import forced_rewrite
+from app.services.memory.schemas import MemoryExtractionDecision
+from app.services.memory.validation import (
+ MEMORY_HARD_LIMIT,
+ soft_limit_warning,
+ strip_preamble_to_first_heading,
+ validate_bullet_format,
+ validate_diff,
+ validate_heading_sanity,
+ validate_memory_scope,
+ validate_memory_size,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class MemoryScope(StrEnum):
+ USER = "user"
+ TEAM = "team"
+
+
+@dataclass(frozen=True)
+class SaveResult:
+ status: Literal["saved", "error", "no_op"]
+ message: str
+ memory_md: str = ""
+ warnings: list[str] = field(default_factory=list)
+ diff_warnings: list[str] = field(default_factory=list)
+ format_warnings: list[str] = field(default_factory=list)
+ notice: str | None = None
+
+ def to_dict(self) -> dict[str, Any]:
+ data: dict[str, Any] = {
+ "status": self.status,
+ "message": self.message,
+ "memory_md": self.memory_md,
+ }
+ if self.notice:
+ data["notice"] = self.notice
+ if self.warnings:
+ data["warnings"] = self.warnings
+ if len(self.warnings) == 1:
+ data["warning"] = self.warnings[0]
+ if self.diff_warnings:
+ data["diff_warnings"] = self.diff_warnings
+ if self.format_warnings:
+ data["format_warnings"] = self.format_warnings
+ return data
+
+
+class MemoryRead(BaseModel):
+ memory_md: str
+
+
+def _normalize_scope(scope: MemoryScope | str) -> MemoryScope:
+ return scope if isinstance(scope, MemoryScope) else MemoryScope(scope)
+
+
+def _normalize_user_id(target_id: str | UUID) -> UUID:
+ return UUID(target_id) if isinstance(target_id, str) else target_id
+
+
+async def _load_target(
+ *,
+ scope: MemoryScope | str,
+ target_id: str | int | UUID,
+ session: AsyncSession,
+) -> User | SearchSpace | None:
+ normalized = _normalize_scope(scope)
+ if normalized is MemoryScope.USER:
+ result = await session.execute(
+ select(User).where(User.id == _normalize_user_id(target_id)) # type: ignore[arg-type]
+ )
+ return result.scalars().first()
+ result = await session.execute(select(SearchSpace).where(SearchSpace.id == int(target_id)))
+ return result.scalars().first()
+
+
+def _get_memory(target: User | SearchSpace, scope: MemoryScope) -> str:
+ if scope is MemoryScope.USER:
+ return getattr(target, "memory_md", None) or ""
+ return getattr(target, "shared_memory_md", None) or ""
+
+
+def _set_memory(target: User | SearchSpace, scope: MemoryScope, content: str) -> None:
+ if scope is MemoryScope.USER:
+ target.memory_md = content
+ else:
+ target.shared_memory_md = content
+
+
+async def read_memory(
+ *,
+ scope: MemoryScope | str,
+ target_id: str | int | UUID,
+ session: AsyncSession,
+) -> str:
+ normalized = _normalize_scope(scope)
+ target = await _load_target(scope=normalized, target_id=target_id, session=session)
+ if target is None:
+ return ""
+ return _get_memory(target, normalized)
+
+
+async def save_memory(
+ *,
+ scope: MemoryScope | str,
+ target_id: str | int | UUID,
+ content: str,
+ session: AsyncSession,
+ llm: Any | None = None,
+) -> SaveResult:
+ normalized = _normalize_scope(scope)
+ if not isinstance(content, str):
+ return SaveResult(
+ status="error",
+ message="Internal error: memory payload must be a string.",
+ )
+
+ target = await _load_target(scope=normalized, target_id=target_id, session=session)
+ if target is None:
+ return SaveResult(
+ status="error",
+ message="User not found." if normalized is MemoryScope.USER else "Search space not found.",
+ )
+
+ old_memory = _get_memory(target, normalized)
+ next_content = strip_preamble_to_first_heading(content.strip())
+ notice: str | None = None
+ warnings: list[str] = []
+
+ if len(next_content) > MEMORY_HARD_LIMIT and llm is not None:
+ rewritten = await forced_rewrite(next_content, llm)
+ if rewritten is not None and len(rewritten) < len(next_content):
+ next_content = strip_preamble_to_first_heading(rewritten)
+ notice = "Memory was automatically rewritten to fit within limits."
+
+ for validation in (
+ validate_memory_size(next_content),
+ validate_heading_sanity(next_content),
+ ):
+ if validation:
+ return SaveResult(
+ status="error",
+ message=validation["message"],
+ memory_md=old_memory,
+ )
+
+ scope_error, scope_warnings = validate_memory_scope(
+ next_content,
+ normalized.value,
+ old_memory=old_memory,
+ )
+ warnings.extend(scope_warnings)
+ if scope_error:
+ return SaveResult(
+ status="error",
+ message=scope_error["message"],
+ memory_md=old_memory,
+ warnings=warnings,
+ )
+
+ try:
+ _set_memory(target, normalized, next_content)
+ session.add(target)
+ await session.commit()
+ except Exception as e:
+ logger.exception("Failed to update %s memory: %s", normalized.value, e)
+ await session.rollback()
+ return SaveResult(
+ status="error",
+ message=f"Failed to update {normalized.value} memory: {e}",
+ memory_md=old_memory,
+ )
+
+ diff_warnings = validate_diff(old_memory, next_content)
+ format_warnings = validate_bullet_format(next_content)
+ warning = soft_limit_warning(next_content)
+ if warning:
+ warnings.append(warning)
+
+ return SaveResult(
+ status="saved",
+ message=(
+ "Memory updated."
+ if normalized is MemoryScope.USER
+ else "Team memory updated."
+ ),
+ memory_md=next_content,
+ warnings=warnings,
+ diff_warnings=diff_warnings,
+ format_warnings=format_warnings,
+ notice=notice,
+ )
+
+
+async def reset_memory(
+ *,
+ scope: MemoryScope | str,
+ target_id: str | int | UUID,
+ session: AsyncSession,
+) -> SaveResult:
+ return await save_memory(
+ scope=scope,
+ target_id=target_id,
+ content="",
+ session=session,
+ llm=None,
+ )
+
+
+async def extract_and_save(
+ *,
+ scope: MemoryScope | str,
+ target_id: str | int | UUID,
+ user_message: str,
+ actor_display_name: str | None,
+ session: AsyncSession,
+ llm: Any,
+) -> SaveResult:
+ normalized = _normalize_scope(scope)
+ current_memory = await read_memory(
+ scope=normalized,
+ target_id=target_id,
+ session=session,
+ )
+
+ if normalized is MemoryScope.USER:
+ first_name = (
+ actor_display_name.strip().split()[0]
+ if actor_display_name and actor_display_name.strip()
+ else "The user"
+ )
+ prompt = USER_MEMORY_EXTRACT_PROMPT.format(
+ current_memory=current_memory or "(empty)",
+ user_message=user_message,
+ user_name=first_name,
+ )
+ else:
+ prompt = TEAM_MEMORY_EXTRACT_PROMPT.format(
+ current_memory=current_memory or "(empty)",
+ author=actor_display_name or "Unknown team member",
+ user_message=user_message,
+ )
+
+ try:
+ structured = llm.with_structured_output(MemoryExtractionDecision)
+ decision = await structured.ainvoke(
+ [HumanMessage(content=prompt)],
+ config={"tags": ["surfsense:internal", "memory-extraction"]},
+ )
+ except Exception:
+ logger.exception("Structured memory extraction failed")
+ return SaveResult(
+ status="error",
+ message="Structured memory extraction failed.",
+ memory_md=current_memory,
+ )
+
+ if decision.action == "no_update":
+ return SaveResult(
+ status="no_op",
+ message=decision.reason or "No durable memory to persist.",
+ memory_md=current_memory,
+ )
+
+ if not decision.updated_memory:
+ return SaveResult(
+ status="error",
+ message="Structured memory extraction chose save without updated_memory.",
+ memory_md=current_memory,
+ )
+
+ return await save_memory(
+ scope=normalized,
+ target_id=target_id,
+ content=decision.updated_memory,
+ session=session,
+ llm=llm,
+ )
diff --git a/surfsense_backend/app/services/memory/validation.py b/surfsense_backend/app/services/memory/validation.py
new file mode 100644
index 000000000..0e856943b
--- /dev/null
+++ b/surfsense_backend/app/services/memory/validation.py
@@ -0,0 +1,158 @@
+"""Validation helpers for markdown-backed memory."""
+
+from __future__ import annotations
+
+import re
+from typing import Literal
+
+MEMORY_SOFT_LIMIT = 18_000
+MEMORY_HARD_LIMIT = 25_000
+
+_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
+_HEADING_LINE_RE = re.compile(r"^##\s+\S+", re.MULTILINE)
+_HEADING_NORMALIZE_RE = re.compile(r"[^a-z0-9]+")
+_LEGACY_BULLET_RE = re.compile(r"^-\s+\(\d{4}-\d{2}-\d{2}\)\s+\[(fact|pref|instr)\]\s+.+$")
+_NEW_BULLET_RE = re.compile(r"^-\s+\d{4}-\d{2}-\d{2}:\s+.+$")
+
+_FORBIDDEN_TEAM_HEADINGS = {
+ "preferences",
+ "instructions",
+ "personal notes",
+ "personal instructions",
+}
+
+
+def has_markdown_heading(content: str) -> bool:
+ return bool(_HEADING_LINE_RE.search(content))
+
+
+def strip_preamble_to_first_heading(content: str) -> str:
+ """Drop model preamble before the first ``##`` heading, if one exists."""
+ match = _HEADING_LINE_RE.search(content)
+ if not match:
+ return content.strip()
+ return content[match.start() :].strip()
+
+
+def extract_headings(memory: str | None) -> set[str]:
+ if not memory:
+ return set()
+ return {_normalize_heading(h) for h in _SECTION_HEADING_RE.findall(memory)}
+
+
+def _normalize_heading(heading: str) -> str:
+ return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower()).strip()
+
+
+def validate_memory_size(content: str) -> dict[str, str] | None:
+ length = len(content)
+ if length > MEMORY_HARD_LIMIT:
+ return {
+ "status": "error",
+ "message": (
+ f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
+ f"({length:,} chars). Consolidate by merging related items, "
+ "removing outdated entries, and shortening descriptions."
+ ),
+ }
+ return None
+
+
+def validate_heading_sanity(content: str) -> dict[str, str] | None:
+ """Block long prose blobs without headings unless they are legacy bullets."""
+ stripped = content.strip()
+ if not stripped:
+ return None
+ if has_markdown_heading(stripped):
+ return None
+ if len(stripped) <= 40:
+ return None
+ if any(_LEGACY_BULLET_RE.match(line.strip()) for line in stripped.splitlines()):
+ return None
+ return {
+ "status": "error",
+ "message": "Memory must be markdown with at least one ## heading.",
+ }
+
+
+def validate_memory_scope(
+ content: str,
+ scope: Literal["user", "team"],
+ *,
+ old_memory: str | None = None,
+) -> tuple[dict[str, str] | None, list[str]]:
+ """Reject new personal headings in team memory, grandfather existing ones."""
+ if scope != "team":
+ return None, []
+
+ old_forbidden = extract_headings(old_memory) & _FORBIDDEN_TEAM_HEADINGS
+ new_forbidden = extract_headings(content) & _FORBIDDEN_TEAM_HEADINGS
+ introduced = sorted(new_forbidden - old_forbidden)
+ grandfathered = sorted(new_forbidden & old_forbidden)
+
+ warnings: list[str] = []
+ if grandfathered:
+ warnings.append(
+ "Team memory contains legacy personal headings: "
+ + ", ".join(grandfathered)
+ + ". Please consolidate them into team-safe headings."
+ )
+ if introduced:
+ return (
+ {
+ "status": "error",
+ "message": (
+ "Team memory cannot introduce personal headings: "
+ + ", ".join(introduced)
+ + ". Use team-safe headings instead."
+ ),
+ },
+ warnings,
+ )
+ return None, warnings
+
+
+def validate_bullet_format(content: str) -> list[str]:
+ warnings: list[str] = []
+ for line in content.splitlines():
+ stripped = line.strip()
+ if not stripped.startswith("- "):
+ continue
+ if _NEW_BULLET_RE.match(stripped) or _LEGACY_BULLET_RE.match(stripped):
+ continue
+ short = stripped[:80] + ("..." if len(stripped) > 80 else "")
+ warnings.append(f"Non-standard memory bullet: {short}")
+ return warnings
+
+
+def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
+ if not old_memory:
+ return []
+
+ warnings: list[str] = []
+ old_headings = extract_headings(old_memory)
+ new_headings = extract_headings(new_memory)
+ dropped = old_headings - new_headings
+ if dropped:
+ names = ", ".join(sorted(dropped))
+ warnings.append(
+ f"Sections removed: {names}. If unintentional, restore from the settings page."
+ )
+
+ old_len = len(old_memory)
+ new_len = len(new_memory)
+ if old_len > 0 and new_len < old_len * 0.4:
+ warnings.append(
+ f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). Possible data loss."
+ )
+ return warnings
+
+
+def soft_limit_warning(content: str) -> str | None:
+ length = len(content)
+ if length > MEMORY_SOFT_LIMIT:
+ return (
+ f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
+ "Consolidate by merging related items and removing less important entries."
+ )
+ return None
diff --git a/surfsense_backend/tests/unit/services/test_memory_service.py b/surfsense_backend/tests/unit/services/test_memory_service.py
new file mode 100644
index 000000000..c16e34062
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_memory_service.py
@@ -0,0 +1,204 @@
+"""Unit tests for the first-class memory service."""
+
+from types import SimpleNamespace
+
+import pytest
+
+from app.services.memory import (
+ MemoryScope,
+ extract_and_save,
+ reset_memory,
+ save_memory,
+)
+from app.services.memory.schemas import MemoryExtractionDecision
+
+pytestmark = pytest.mark.unit
+
+
+class _FakeSession:
+ def __init__(self) -> None:
+ self.commit_calls = 0
+ self.rollback_calls = 0
+ self.added = []
+
+ def add(self, obj) -> None:
+ self.added.append(obj)
+
+ async def commit(self) -> None:
+ self.commit_calls += 1
+
+ async def rollback(self) -> None:
+ self.rollback_calls += 1
+
+
+class _StructuredLLM:
+ def __init__(self, decision: MemoryExtractionDecision) -> None:
+ self.decision = decision
+
+ def with_structured_output(self, _schema):
+ return self
+
+ async def ainvoke(self, *_args, **_kwargs):
+ return self.decision
+
+
+@pytest.mark.asyncio
+async def test_save_memory_saves_heading_based_memory(monkeypatch) -> None:
+ target = SimpleNamespace(memory_md="")
+ session = _FakeSession()
+
+ async def fake_load_target(**_kwargs):
+ return target
+
+ monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
+
+ result = await save_memory(
+ scope=MemoryScope.USER,
+ target_id="00000000-0000-0000-0000-000000000000",
+ content="## Facts\n- 2026-05-19: Anish works on SurfSense\n",
+ session=session,
+ )
+
+ assert result.status == "saved"
+ assert target.memory_md.startswith("## Facts")
+ assert session.commit_calls == 1
+
+
+@pytest.mark.asyncio
+async def test_save_memory_accepts_legacy_marker_payload(monkeypatch) -> None:
+ target = SimpleNamespace(memory_md="")
+ session = _FakeSession()
+
+ async def fake_load_target(**_kwargs):
+ return target
+
+ monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
+
+ result = await save_memory(
+ scope=MemoryScope.USER,
+ target_id="00000000-0000-0000-0000-000000000000",
+ content="- (2026-05-19) [fact] Legacy marker memory\n",
+ session=session,
+ )
+
+ assert result.status == "saved"
+ assert "[fact]" in target.memory_md
+
+
+@pytest.mark.asyncio
+async def test_save_memory_rejects_long_no_heading_payload(monkeypatch) -> None:
+ target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
+ session = _FakeSession()
+
+ async def fake_load_target(**_kwargs):
+ return target
+
+ monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
+
+ result = await save_memory(
+ scope=MemoryScope.USER,
+ target_id="00000000-0000-0000-0000-000000000000",
+ content="reasoning text before NO_UPDATE should not become saved memory",
+ session=session,
+ )
+
+ assert result.status == "error"
+ assert session.commit_calls == 0
+ assert target.memory_md.startswith("## Facts")
+
+
+@pytest.mark.asyncio
+async def test_save_memory_grandfathers_existing_team_personal_heading(monkeypatch) -> None:
+ content = "## Preferences\n- 2026-05-19: Existing legacy heading\n"
+ target = SimpleNamespace(shared_memory_md=content)
+ session = _FakeSession()
+
+ async def fake_load_target(**_kwargs):
+ return target
+
+ monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
+
+ result = await save_memory(
+ scope=MemoryScope.TEAM,
+ target_id=1,
+ content=content,
+ session=session,
+ )
+
+ assert result.status == "saved"
+ assert result.warnings
+ assert session.commit_calls == 1
+
+
+@pytest.mark.asyncio
+async def test_reset_memory_clears_memory(monkeypatch) -> None:
+ target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
+ session = _FakeSession()
+
+ async def fake_load_target(**_kwargs):
+ return target
+
+ monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
+
+ result = await reset_memory(
+ scope=MemoryScope.USER,
+ target_id="00000000-0000-0000-0000-000000000000",
+ session=session,
+ )
+
+ assert result.status == "saved"
+ assert target.memory_md == ""
+
+
+@pytest.mark.asyncio
+async def test_extract_and_save_no_update_does_not_commit(monkeypatch) -> None:
+ target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
+ session = _FakeSession()
+
+ async def fake_load_target(**_kwargs):
+ return target
+
+ monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
+
+ result = await extract_and_save(
+ scope=MemoryScope.USER,
+ target_id="00000000-0000-0000-0000-000000000000",
+ user_message="hello",
+ actor_display_name="Anish",
+ session=session,
+ llm=_StructuredLLM(
+ MemoryExtractionDecision(action="no_update", reason="Greeting only")
+ ),
+ )
+
+ assert result.status == "no_op"
+ assert session.commit_calls == 0
+
+
+@pytest.mark.asyncio
+async def test_extract_and_save_persists_structured_update(monkeypatch) -> None:
+ target = SimpleNamespace(memory_md="")
+ session = _FakeSession()
+
+ async def fake_load_target(**_kwargs):
+ return target
+
+ monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
+
+ result = await extract_and_save(
+ scope=MemoryScope.USER,
+ target_id="00000000-0000-0000-0000-000000000000",
+ user_message="I work on SurfSense",
+ actor_display_name="Anish",
+ session=session,
+ llm=_StructuredLLM(
+ MemoryExtractionDecision(
+ action="save",
+ updated_memory="## Facts\n- 2026-05-19: Anish works on SurfSense\n",
+ )
+ ),
+ )
+
+ assert result.status == "saved"
+ assert "SurfSense" in target.memory_md
+ assert session.commit_calls == 1