From ceedd02353c74f063188fc103500fb848909a810 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 20 May 2026 02:01:36 +0530 Subject: [PATCH] refactor: extract shared memory service --- .../builtins/memory/tools/update_memory.py | 351 ++------------- .../app/agents/new_chat/memory_extraction.py | 196 +-------- .../agents/new_chat/tools/update_memory.py | 414 ++---------------- .../app/services/memory/__init__.py | 29 ++ .../app/services/memory/prompts.py | 110 +++++ .../app/services/memory/rewrite.py | 35 ++ .../app/services/memory/schemas.py | 23 + .../app/services/memory/service.py | 300 +++++++++++++ .../app/services/memory/validation.py | 158 +++++++ .../unit/services/test_memory_service.py | 204 +++++++++ 10 files changed, 946 insertions(+), 874 deletions(-) create mode 100644 surfsense_backend/app/services/memory/__init__.py create mode 100644 surfsense_backend/app/services/memory/prompts.py create mode 100644 surfsense_backend/app/services/memory/rewrite.py create mode 100644 surfsense_backend/app/services/memory/schemas.py create mode 100644 surfsense_backend/app/services/memory/service.py create mode 100644 surfsense_backend/app/services/memory/validation.py create mode 100644 surfsense_backend/tests/unit/services/test_memory_service.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py index 23375a081..67bcc3e06 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/update_memory.py @@ -1,280 +1,23 @@ -"""Overwrite one markdown memory document per user or team, with size and shrink guards.""" +"""Memory update tools backed by the canonical memory service.""" from __future__ import annotations import logging -import re -from typing import Any, Literal +from typing import Any from uuid import UUID -from langchain_core.messages import HumanMessage from langchain_core.tools import tool -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db import SearchSpace, User +from app.services.memory import ( + MEMORY_HARD_LIMIT, + MEMORY_SOFT_LIMIT, + MemoryScope, + save_memory, +) logger = logging.getLogger(__name__) -MEMORY_SOFT_LIMIT = 18_000 -MEMORY_HARD_LIMIT = 25_000 - -_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE) -_HEADING_NORMALIZE_RE = re.compile(r"\s+") - -_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]") -_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$") -_PERSONAL_ONLY_MARKERS = {"pref", "instr"} - - -# --------------------------------------------------------------------------- -# Diff validation -# --------------------------------------------------------------------------- - - -def _extract_headings(memory: str) -> set[str]: - """Return all ``## …`` heading texts (without the ``## `` prefix).""" - return set(_SECTION_HEADING_RE.findall(memory)) - - -def _normalize_heading(heading: str) -> str: - """Normalize heading text for robust scope checks.""" - return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower()) - - -def _validate_memory_scope( - content: str, scope: Literal["user", "team"] -) -> dict[str, Any] | None: - """Reject personal-only markers ([pref], [instr]) in team memory.""" - if scope != "team": - return None - - markers = set(_MARKER_RE.findall(content)) - leaked = sorted(markers & _PERSONAL_ONLY_MARKERS) - if leaked: - tags = ", ".join(f"[{m}]" for m in leaked) - return { - "status": "error", - "message": ( - f"Team memory cannot include personal markers: {tags}. " - "Use [fact] only in team memory." - ), - } - return None - - -def _validate_bullet_format(content: str) -> list[str]: - """Return warnings for bullet lines that don't match the required format. - - Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text`` - """ - warnings: list[str] = [] - for line in content.splitlines(): - stripped = line.strip() - if not stripped.startswith("- "): - continue - if not _BULLET_FORMAT_RE.match(stripped): - short = stripped[:80] + ("..." if len(stripped) > 80 else "") - warnings.append(f"Malformed bullet: {short}") - return warnings - - -def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]: - """Return a list of warning strings about suspicious changes.""" - if not old_memory: - return [] - - warnings: list[str] = [] - old_headings = _extract_headings(old_memory) - new_headings = _extract_headings(new_memory) - dropped = old_headings - new_headings - if dropped: - names = ", ".join(sorted(dropped)) - warnings.append( - f"Sections removed: {names}. " - "If unintentional, the user can restore from the settings page." - ) - - old_len = len(old_memory) - new_len = len(new_memory) - if old_len > 0 and new_len < old_len * 0.4: - warnings.append( - f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). " - "Possible data loss." - ) - return warnings - - -# --------------------------------------------------------------------------- -# Size validation & soft warning -# --------------------------------------------------------------------------- - - -def _validate_memory_size(content: str) -> dict[str, Any] | None: - """Return an error/warning dict if *content* is too large, else None.""" - length = len(content) - if length > MEMORY_HARD_LIMIT: - return { - "status": "error", - "message": ( - f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit " - f"({length:,} chars). Consolidate by merging related items, " - "removing outdated entries, and shortening descriptions. " - "Then call update_memory again." - ), - } - return None - - -def _soft_warning(content: str) -> str | None: - """Return a warning string if content exceeds the soft limit.""" - length = len(content) - if length > MEMORY_SOFT_LIMIT: - return ( - f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. " - "Consolidate by merging related items and removing less important " - "entries on your next update." - ) - return None - - -# --------------------------------------------------------------------------- -# Forced rewrite when memory exceeds the hard limit -# --------------------------------------------------------------------------- - -_FORCED_REWRITE_PROMPT = """\ -You are a memory curator. The following memory document exceeds the character \ -limit and must be shortened. - -RULES: -1. Rewrite the document to be under {target} characters. -2. Preserve existing ## headings. Every entry must remain under a heading. You may merge - or rename headings to consolidate, but keep names personal and descriptive. -3. Priority for keeping content: [instr] > [pref] > [fact]. -4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions. -5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text -6. Preserve the user's first name in entries — do not replace it with "the user". -7. Output ONLY the consolidated markdown — no explanations, no wrapping. - - -{content} -""" - - -async def _forced_rewrite(content: str, llm: Any) -> str | None: - """Use a focused LLM call to compress *content* under the hard limit. - - Returns the rewritten string, or ``None`` if the call fails. - """ - try: - prompt = _FORCED_REWRITE_PROMPT.format( - target=MEMORY_HARD_LIMIT, content=content - ) - response = await llm.ainvoke( - [HumanMessage(content=prompt)], - config={"tags": ["surfsense:internal"]}, - ) - text = ( - response.content - if isinstance(response.content, str) - else str(response.content) - ) - return text.strip() - except Exception: - logger.exception("Forced rewrite LLM call failed") - return None - - -# --------------------------------------------------------------------------- -# Shared save-and-respond logic -# --------------------------------------------------------------------------- - - -async def _save_memory( - *, - updated_memory: str, - old_memory: str | None, - llm: Any | None, - apply_fn, - commit_fn, - rollback_fn, - label: str, - scope: Literal["user", "team"], -) -> dict[str, Any]: - """Validate, optionally force-rewrite if over the hard limit, save, and - return a response dict. - - Parameters - ---------- - updated_memory : str - The new document the agent submitted. - old_memory : str | None - The previously persisted document (for diff checks). - llm : Any | None - LLM instance for forced rewrite (may be ``None``). - apply_fn : callable(str) -> None - Callback that sets the new memory on the ORM object. - commit_fn : coroutine - ``session.commit``. - rollback_fn : coroutine - ``session.rollback``. - label : str - Human label for log messages (e.g. "user memory", "team memory"). - """ - content = updated_memory - - # --- forced rewrite if over the hard limit --- - if len(content) > MEMORY_HARD_LIMIT and llm is not None: - rewritten = await _forced_rewrite(content, llm) - if rewritten is not None and len(rewritten) < len(content): - content = rewritten - - # --- hard-limit gate (reject if still too large after rewrite) --- - size_err = _validate_memory_size(content) - if size_err: - return size_err - - scope_err = _validate_memory_scope(content, scope) - if scope_err: - return scope_err - - # --- persist --- - try: - apply_fn(content) - await commit_fn() - except Exception as e: - logger.exception("Failed to update %s: %s", label, e) - await rollback_fn() - return {"status": "error", "message": f"Failed to update {label}: {e}"} - - # --- build response --- - resp: dict[str, Any] = { - "status": "saved", - "message": f"{label.capitalize()} updated.", - } - - if content is not updated_memory: - resp["notice"] = "Memory was automatically rewritten to fit within limits." - - diff_warnings = _validate_diff(old_memory, content) - if diff_warnings: - resp["diff_warnings"] = diff_warnings - - format_warnings = _validate_bullet_format(content) - if format_warnings: - resp["format_warnings"] = format_warnings - - warning = _soft_warning(content) - if warning: - resp["warning"] = warning - - return resp - - -# --------------------------------------------------------------------------- -# Tool factories -# --------------------------------------------------------------------------- - def create_update_memory_tool( user_id: str | UUID, @@ -287,40 +30,22 @@ def create_update_memory_tool( async def update_memory(updated_memory: str) -> dict[str, Any]: """Update the user's personal memory document. - Your current memory is shown in in the system prompt. - When the user shares important long-term information (preferences, - facts, instructions, context), rewrite the memory document to include - the new information. Merge new facts with existing ones, update - contradictions, remove outdated entries, and keep it concise. - - Args: - updated_memory: The FULL updated markdown document (not a diff). + The current memory is shown in . Pass the FULL updated + markdown document, not a diff. """ try: - result = await db_session.execute(select(User).where(User.id == uid)) - user = result.scalars().first() - if not user: - return {"status": "error", "message": "User not found."} - - old_memory = user.memory_md - - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, + result = await save_memory( + scope=MemoryScope.USER, + target_id=uid, + content=updated_memory, + session=db_session, llm=llm, - apply_fn=lambda content: setattr(user, "memory_md", content), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="memory", - scope="user", ) + return result.to_dict() except Exception as e: logger.exception("Failed to update user memory: %s", e) await db_session.rollback() - return { - "status": "error", - "message": f"Failed to update memory: {e}", - } + return {"status": "error", "message": f"Failed to update memory: {e}"} return update_memory @@ -334,36 +59,18 @@ def create_update_team_memory_tool( async def update_memory(updated_memory: str) -> dict[str, Any]: """Update the team's shared memory document for this search space. - Your current team memory is shown in in the system - prompt. When the team shares important long-term information - (decisions, conventions, key facts, priorities), rewrite the memory - document to include the new information. Merge new facts with - existing ones, update contradictions, remove outdated entries, and - keep it concise. - - Args: - updated_memory: The FULL updated markdown document (not a diff). + The current team memory is shown in . Pass the FULL updated + markdown document, not a diff. """ try: - result = await db_session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - space = result.scalars().first() - if not space: - return {"status": "error", "message": "Search space not found."} - - old_memory = space.shared_memory_md - - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, + result = await save_memory( + scope=MemoryScope.TEAM, + target_id=search_space_id, + content=updated_memory, + session=db_session, llm=llm, - apply_fn=lambda content: setattr(space, "shared_memory_md", content), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="team memory", - scope="team", ) + return result.to_dict() except Exception as e: logger.exception("Failed to update team memory: %s", e) await db_session.rollback() @@ -373,3 +80,11 @@ def create_update_team_memory_tool( } return update_memory + + +__all__ = [ + "MEMORY_HARD_LIMIT", + "MEMORY_SOFT_LIMIT", + "create_update_memory_tool", + "create_update_team_memory_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/memory_extraction.py b/surfsense_backend/app/agents/new_chat/memory_extraction.py index e31774a7c..d44b58f7b 100644 --- a/surfsense_backend/app/agents/new_chat/memory_extraction.py +++ b/surfsense_backend/app/agents/new_chat/memory_extraction.py @@ -1,9 +1,4 @@ -"""Background memory extraction for the SurfSense agent. - -After each agent response, if the agent did not call ``update_memory`` during -the turn, this module can run a lightweight LLM call to decide whether the -latest message contains long-term information worth persisting. -""" +"""Background memory extraction for the SurfSense agent.""" from __future__ import annotations @@ -11,102 +6,11 @@ import logging from typing import Any from uuid import UUID -from langchain_core.messages import HumanMessage -from sqlalchemy import select - -from app.agents.new_chat.tools.update_memory import _save_memory -from app.db import SearchSpace, User, shielded_async_session -from app.utils.content_utils import extract_text_content +from app.db import User, shielded_async_session +from app.services.memory import MemoryScope, extract_and_save logger = logging.getLogger(__name__) -_MEMORY_EXTRACT_PROMPT = """\ -You are a memory extraction assistant. Analyze the user's message and decide \ -if it contains any long-term information worth persisting to memory. - -Worth remembering: preferences, background/identity, goals, projects, \ -instructions, tools/languages they use, decisions, expertise, workplace — \ -durable facts that will matter in future conversations. - -NOT worth remembering: greetings, one-off factual questions, session \ -logistics, ephemeral requests, follow-up clarifications with no new personal \ -info, things that only matter for the current task. - -If the message contains memorizable information, output the FULL updated \ -memory document with the new facts merged into the existing content. Follow \ -these rules: -- Every entry MUST be under a ## heading. Preserve existing headings; create new ones - freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's - name in headings. -- Keep entries as single bullet points. Be descriptive but concise — include relevant - details and context rather than just a few words. -- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text - [fact] = durable facts, [pref] = preferences, [instr] = standing instructions. -- Use the user's first name (from ) in entry text, not "the user". -- If a new fact contradicts an existing entry, update the existing entry. -- Do not duplicate information that is already present. - -If nothing is worth remembering, output exactly: NO_UPDATE - -{user_name} - - -{current_memory} - - - -{user_message} -""" - -_TEAM_MEMORY_EXTRACT_PROMPT = """\ -You are a team-memory extraction assistant. Analyze the latest message and \ -decide if it contains durable TEAM-level information worth persisting. - -Decision policy: -- Prioritize recall for durable team context, while avoiding personal-only facts. -- Do NOT require explicit consensus language. A direct team-level statement can - be stored if it is stable and broadly useful for future team chats. -- If evidence is weak or clearly tentative, output NO_UPDATE. - -Worth remembering (team-level only): -- Decisions and defaults that guide future team work -- Team conventions/standards (naming, review policy, coding norms) -- Stable org/project facts (locations, ownership, constraints) -- Long-lived architecture/process facts -- Ongoing priorities that are likely relevant beyond this turn - -NOT worth remembering: -- Personal preferences or biography of one person -- Questions, brainstorming, tentative ideas, or speculation -- One-off requests, status updates, TODOs, logistics for this session -- Information scoped only to a single ephemeral task - -If the message contains memorizable team information, output the FULL updated \ -team memory document with new facts merged into existing content. Follow rules: -- Every entry MUST be under a ## heading. Preserve existing headings; create new ones - freely. Keep heading names short (2-3 words) and natural. -- Keep entries as single bullet points. Be descriptive but concise — include relevant - details and context rather than just a few words. -- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text - Team memory uses ONLY the [fact] marker. Never use [pref] or [instr]. -- If a new fact contradicts an existing entry, update the existing entry. -- Do not duplicate existing information. -- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored. - -If nothing is worth remembering, output exactly: NO_UPDATE - - -{current_memory} - - - -{author} - - - -{user_message} -""" - async def extract_and_save_memory( *, @@ -114,57 +18,31 @@ async def extract_and_save_memory( user_id: str | None, llm: Any, ) -> None: - """Background task: extract memorizable info and persist it. + """Fire-and-forget personal memory extraction. - Designed to be fire-and-forget — catches all exceptions internally. + The service uses structured output, so free-form ``NO_UPDATE`` text can no + longer be accidentally persisted as memory. """ if not user_id: return try: uid = UUID(user_id) if isinstance(user_id, str) else user_id - async with shielded_async_session() as session: - result = await session.execute(select(User).where(User.id == uid)) - user = result.scalars().first() - if not user: - return - - old_memory = user.memory_md - first_name = ( - user.display_name.strip().split()[0] - if user.display_name and user.display_name.strip() - else "The user" - ) - prompt = _MEMORY_EXTRACT_PROMPT.format( - current_memory=old_memory or "(empty)", + user = await session.get(User, uid) + actor_display_name = user.display_name if user else None + result = await extract_and_save( + scope=MemoryScope.USER, + target_id=uid, user_message=user_message, - user_name=first_name, - ) - response = await llm.ainvoke( - [HumanMessage(content=prompt)], - config={"tags": ["surfsense:internal", "memory-extraction"]}, - ) - text = extract_text_content(response.content).strip() - - if text == "NO_UPDATE" or not text: - logger.debug("Memory extraction: no update needed (user %s)", uid) - return - - save_result = await _save_memory( - updated_memory=text, - old_memory=old_memory, + actor_display_name=actor_display_name, + session=session, llm=llm, - apply_fn=lambda content: setattr(user, "memory_md", content), - commit_fn=session.commit, - rollback_fn=session.rollback, - label="memory", - scope="user", ) logger.info( "Background memory extraction for user %s: %s", uid, - save_result.get("status"), + result.status, ) except Exception: logger.exception("Background user memory extraction failed") @@ -177,56 +55,24 @@ async def extract_and_save_team_memory( llm: Any, author_display_name: str | None = None, ) -> None: - """Background task: extract team-level memory and persist it. - - Runs only for shared threads. Designed to be fire-and-forget and catches - exceptions internally. - """ + """Fire-and-forget team-level memory extraction.""" if not search_space_id: return try: async with shielded_async_session() as session: - result = await session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - space = result.scalars().first() - if not space: - return - - old_memory = space.shared_memory_md - prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format( - current_memory=old_memory or "(empty)", - author=author_display_name or "Unknown team member", + result = await extract_and_save( + scope=MemoryScope.TEAM, + target_id=search_space_id, user_message=user_message, - ) - response = await llm.ainvoke( - [HumanMessage(content=prompt)], - config={"tags": ["surfsense:internal", "team-memory-extraction"]}, - ) - text = extract_text_content(response.content).strip() - - if text == "NO_UPDATE" or not text: - logger.debug( - "Team memory extraction: no update needed (space %s)", - search_space_id, - ) - return - - save_result = await _save_memory( - updated_memory=text, - old_memory=old_memory, + actor_display_name=author_display_name, + session=session, llm=llm, - apply_fn=lambda content: setattr(space, "shared_memory_md", content), - commit_fn=session.commit, - rollback_fn=session.rollback, - label="team memory", - scope="team", ) logger.info( "Background team memory extraction for space %s: %s", search_space_id, - save_result.get("status"), + result.status, ) except Exception: logger.exception("Background team memory extraction failed") diff --git a/surfsense_backend/app/agents/new_chat/tools/update_memory.py b/surfsense_backend/app/agents/new_chat/tools/update_memory.py index 062668aac..78a65201b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/update_memory.py +++ b/surfsense_backend/app/agents/new_chat/tools/update_memory.py @@ -1,369 +1,53 @@ -"""Markdown-document memory tool for the SurfSense agent. - -Replaces the old row-per-fact save_memory / recall_memory tools with a single -update_memory tool that overwrites a freeform markdown TEXT column. The LLM -always sees the current memory in / tags injected -by MemoryInjectionMiddleware, so it passes the FULL updated document each time. - -Overflow handling: - - Soft limit (18K chars): a warning is returned telling the agent to - consolidate on the next update. - - Hard limit (25K chars): a forced LLM-driven rewrite compresses the document. - If it still exceeds the limit after rewriting, the save is rejected. - - Diff validation: warns when entire ``##`` sections are dropped or when the - document shrinks by more than 60%. -""" +"""Memory update tools backed by the canonical memory service.""" from __future__ import annotations import logging -import re -from typing import Any, Literal +from typing import Any from uuid import UUID -from langchain_core.messages import HumanMessage from langchain_core.tools import tool -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db import SearchSpace, User, async_session_maker -from app.utils.content_utils import extract_text_content +from app.db import async_session_maker +from app.services.memory import MemoryScope, save_memory logger = logging.getLogger(__name__) -MEMORY_SOFT_LIMIT = 18_000 -MEMORY_HARD_LIMIT = 25_000 - -_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE) -_HEADING_NORMALIZE_RE = re.compile(r"\s+") - -_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]") -_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$") -_PERSONAL_ONLY_MARKERS = {"pref", "instr"} - - -# --------------------------------------------------------------------------- -# Diff validation -# --------------------------------------------------------------------------- - - -def _extract_headings(memory: str) -> set[str]: - """Return all ``## …`` heading texts (without the ``## `` prefix).""" - return set(_SECTION_HEADING_RE.findall(memory)) - - -def _normalize_heading(heading: str) -> str: - """Normalize heading text for robust scope checks.""" - return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower()) - - -def _validate_memory_scope( - content: str, scope: Literal["user", "team"] -) -> dict[str, Any] | None: - """Reject personal-only markers ([pref], [instr]) in team memory.""" - if scope != "team": - return None - - markers = set(_MARKER_RE.findall(content)) - leaked = sorted(markers & _PERSONAL_ONLY_MARKERS) - if leaked: - tags = ", ".join(f"[{m}]" for m in leaked) - return { - "status": "error", - "message": ( - f"Team memory cannot include personal markers: {tags}. " - "Use [fact] only in team memory." - ), - } - return None - - -def _validate_bullet_format(content: str) -> list[str]: - """Return warnings for bullet lines that don't match the required format. - - Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text`` - """ - warnings: list[str] = [] - for line in content.splitlines(): - stripped = line.strip() - if not stripped.startswith("- "): - continue - if not _BULLET_FORMAT_RE.match(stripped): - short = stripped[:80] + ("..." if len(stripped) > 80 else "") - warnings.append(f"Malformed bullet: {short}") - return warnings - - -def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]: - """Return a list of warning strings about suspicious changes.""" - if not old_memory: - return [] - - warnings: list[str] = [] - old_headings = _extract_headings(old_memory) - new_headings = _extract_headings(new_memory) - dropped = old_headings - new_headings - if dropped: - names = ", ".join(sorted(dropped)) - warnings.append( - f"Sections removed: {names}. " - "If unintentional, the user can restore from the settings page." - ) - - old_len = len(old_memory) - new_len = len(new_memory) - if old_len > 0 and new_len < old_len * 0.4: - warnings.append( - f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). " - "Possible data loss." - ) - return warnings - - -# --------------------------------------------------------------------------- -# Size validation & soft warning -# --------------------------------------------------------------------------- - - -def _validate_memory_size(content: str) -> dict[str, Any] | None: - """Return an error/warning dict if *content* is too large, else None.""" - length = len(content) - if length > MEMORY_HARD_LIMIT: - return { - "status": "error", - "message": ( - f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit " - f"({length:,} chars). Consolidate by merging related items, " - "removing outdated entries, and shortening descriptions. " - "Then call update_memory again." - ), - } - return None - - -def _soft_warning(content: str) -> str | None: - """Return a warning string if content exceeds the soft limit.""" - length = len(content) - if length > MEMORY_SOFT_LIMIT: - return ( - f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. " - "Consolidate by merging related items and removing less important " - "entries on your next update." - ) - return None - - -# --------------------------------------------------------------------------- -# Forced rewrite when memory exceeds the hard limit -# --------------------------------------------------------------------------- - -_FORCED_REWRITE_PROMPT = """\ -You are a memory curator. The following memory document exceeds the character \ -limit and must be shortened. - -RULES: -1. Rewrite the document to be under {target} characters. -2. Preserve existing ## headings. Every entry must remain under a heading. You may merge - or rename headings to consolidate, but keep names personal and descriptive. -3. Priority for keeping content: [instr] > [pref] > [fact]. -4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions. -5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text -6. Preserve the user's first name in entries — do not replace it with "the user". -7. Output ONLY the consolidated markdown — no explanations, no wrapping. - - -{content} -""" - - -async def _forced_rewrite(content: str, llm: Any) -> str | None: - """Use a focused LLM call to compress *content* under the hard limit. - - Returns the rewritten string, or ``None`` if the call fails. - """ - try: - prompt = _FORCED_REWRITE_PROMPT.format( - target=MEMORY_HARD_LIMIT, content=content - ) - response = await llm.ainvoke( - [HumanMessage(content=prompt)], - config={"tags": ["surfsense:internal"]}, - ) - text = extract_text_content(response.content).strip() - if not text: - logger.warning("Forced rewrite returned empty text; aborting rewrite") - return None - return text - except Exception: - logger.exception("Forced rewrite LLM call failed") - return None - - -# --------------------------------------------------------------------------- -# Shared save-and-respond logic -# --------------------------------------------------------------------------- - - -async def _save_memory( - *, - updated_memory: str, - old_memory: str | None, - llm: Any | None, - apply_fn, - commit_fn, - rollback_fn, - label: str, - scope: Literal["user", "team"], -) -> dict[str, Any]: - """Validate, optionally force-rewrite if over the hard limit, save, and - return a response dict. - - Parameters - ---------- - updated_memory : str - The new document the agent submitted. - old_memory : str | None - The previously persisted document (for diff checks). - llm : Any | None - LLM instance for forced rewrite (may be ``None``). - apply_fn : callable(str) -> None - Callback that sets the new memory on the ORM object. - commit_fn : coroutine - ``session.commit``. - rollback_fn : coroutine - ``session.rollback``. - label : str - Human label for log messages (e.g. "user memory", "team memory"). - """ - if not isinstance(updated_memory, str): - logger.warning( - "Refusing non-string memory payload (type=%s)", - type(updated_memory).__name__, - ) - return { - "status": "error", - "message": "Internal error: memory payload must be a string.", - } - - content = updated_memory - - # --- forced rewrite if over the hard limit --- - if len(content) > MEMORY_HARD_LIMIT and llm is not None: - rewritten = await _forced_rewrite(content, llm) - if rewritten is not None and len(rewritten) < len(content): - content = rewritten - - # --- hard-limit gate (reject if still too large after rewrite) --- - size_err = _validate_memory_size(content) - if size_err: - return size_err - - scope_err = _validate_memory_scope(content, scope) - if scope_err: - return scope_err - - # --- persist --- - try: - apply_fn(content) - await commit_fn() - except Exception as e: - logger.exception("Failed to update %s: %s", label, e) - await rollback_fn() - return {"status": "error", "message": f"Failed to update {label}: {e}"} - - # --- build response --- - resp: dict[str, Any] = { - "status": "saved", - "message": f"{label.capitalize()} updated.", - } - - if content is not updated_memory: - resp["notice"] = "Memory was automatically rewritten to fit within limits." - - diff_warnings = _validate_diff(old_memory, content) - if diff_warnings: - resp["diff_warnings"] = diff_warnings - - format_warnings = _validate_bullet_format(content) - if format_warnings: - resp["format_warnings"] = format_warnings - - warning = _soft_warning(content) - if warning: - resp["warning"] = warning - - return resp - - -# --------------------------------------------------------------------------- -# Tool factories -# --------------------------------------------------------------------------- - def create_update_memory_tool( user_id: str | UUID, db_session: AsyncSession, llm: Any | None = None, ): - """Factory function to create the user-memory update tool. + """Factory for the user-memory update tool. - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - The session's bound ``commit``/``rollback`` methods are captured at - call time, after ``async with`` has bound ``db_session`` locally. - - Args: - user_id: ID of the user whose memory document is being updated. - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - llm: Optional LLM for the forced-rewrite path. - - Returns: - Configured update_memory tool for the user-memory scope. + Uses a fresh short-lived session per call so compiled-agent caches never + retain a stale request-scoped session. """ - del db_session # per-call session — see docstring + del db_session uid = UUID(user_id) if isinstance(user_id, str) else user_id @tool async def update_memory(updated_memory: str) -> dict[str, Any]: """Update the user's personal memory document. - Your current memory is shown in in the system prompt. - When the user shares important long-term information (preferences, - facts, instructions, context), rewrite the memory document to include - the new information. Merge new facts with existing ones, update - contradictions, remove outdated entries, and keep it concise. - - Args: - updated_memory: The FULL updated markdown document (not a diff). + The current memory is shown in . Pass the FULL updated + markdown document, not a diff. """ try: async with async_session_maker() as db_session: - result = await db_session.execute(select(User).where(User.id == uid)) - user = result.scalars().first() - if not user: - return {"status": "error", "message": "User not found."} - - old_memory = user.memory_md - - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, + result = await save_memory( + scope=MemoryScope.USER, + target_id=uid, + content=updated_memory, + session=db_session, llm=llm, - apply_fn=lambda content: setattr(user, "memory_md", content), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="memory", - scope="user", ) + return result.to_dict() except Exception as e: logger.exception("Failed to update user memory: %s", e) - return { - "status": "error", - "message": f"Failed to update memory: {e}", - } + return {"status": "error", "message": f"Failed to update memory: {e}"} return update_memory @@ -373,64 +57,26 @@ def create_update_team_memory_tool( db_session: AsyncSession, llm: Any | None = None, ): - """Factory function to create the team-memory update tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - The session's bound ``commit``/``rollback`` methods are captured at - call time, after ``async with`` has bound ``db_session`` locally. - - Args: - search_space_id: ID of the search space whose team memory is being - updated. - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - llm: Optional LLM for the forced-rewrite path. - - Returns: - Configured update_memory tool for the team-memory scope. - """ - del db_session # per-call session — see docstring + """Factory for the team-memory update tool.""" + del db_session @tool async def update_memory(updated_memory: str) -> dict[str, Any]: """Update the team's shared memory document for this search space. - Your current team memory is shown in in the system - prompt. When the team shares important long-term information - (decisions, conventions, key facts, priorities), rewrite the memory - document to include the new information. Merge new facts with - existing ones, update contradictions, remove outdated entries, and - keep it concise. - - Args: - updated_memory: The FULL updated markdown document (not a diff). + The current team memory is shown in . Pass the FULL updated + markdown document, not a diff. """ try: async with async_session_maker() as db_session: - result = await db_session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - space = result.scalars().first() - if not space: - return {"status": "error", "message": "Search space not found."} - - old_memory = space.shared_memory_md - - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, + result = await save_memory( + scope=MemoryScope.TEAM, + target_id=search_space_id, + content=updated_memory, + session=db_session, llm=llm, - apply_fn=lambda content: setattr( - space, "shared_memory_md", content - ), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="team memory", - scope="team", ) + return result.to_dict() except Exception as e: logger.exception("Failed to update team memory: %s", e) return { @@ -439,3 +85,9 @@ def create_update_team_memory_tool( } return update_memory + + +__all__ = [ + "create_update_memory_tool", + "create_update_team_memory_tool", +] diff --git a/surfsense_backend/app/services/memory/__init__.py b/surfsense_backend/app/services/memory/__init__.py new file mode 100644 index 000000000..d72f45e1f --- /dev/null +++ b/surfsense_backend/app/services/memory/__init__.py @@ -0,0 +1,29 @@ +"""First-class memory service for user and team markdown memory.""" + +from .service import ( + MemoryScope, + SaveResult, + extract_and_save, + read_memory, + reset_memory, + save_memory, +) +from .validation import ( + MEMORY_HARD_LIMIT, + MEMORY_SOFT_LIMIT, + validate_bullet_format, + validate_memory_scope, +) + +__all__ = [ + "MEMORY_HARD_LIMIT", + "MEMORY_SOFT_LIMIT", + "MemoryScope", + "SaveResult", + "extract_and_save", + "read_memory", + "reset_memory", + "save_memory", + "validate_bullet_format", + "validate_memory_scope", +] diff --git a/surfsense_backend/app/services/memory/prompts.py b/surfsense_backend/app/services/memory/prompts.py new file mode 100644 index 000000000..fbf27fd08 --- /dev/null +++ b/surfsense_backend/app/services/memory/prompts.py @@ -0,0 +1,110 @@ +"""Prompts used by the memory service.""" + +FORCED_REWRITE_PROMPT = """\ +You are a memory curator. The following memory document exceeds the character \ +limit and must be shortened. + +RULES: +1. Rewrite the document to be under {target} characters. +2. Output Markdown only. Use clear `##` headings and concise bullet points. +3. New-format bullets should look like: `- YYYY-MM-DD: memory text`. +4. If the input contains legacy markers like `(YYYY-MM-DD) [fact]`, preserve the + information but remove the inline marker in the output. +5. Preserve durable instructions and preferences before generic facts when + compressing personal memory. +6. Preserve existing headings when useful; merge duplicate headings and bullets. +7. Output ONLY the consolidated markdown — no explanations, no wrapping. + + +{content} +""" + +USER_MEMORY_EXTRACT_PROMPT = """\ +You are a memory extraction assistant. Analyze the user's message and decide \ +if it contains any long-term information worth persisting to personal memory. + +Worth remembering: preferences, background/identity, goals, projects, \ +instructions, tools/languages they use, decisions, expertise, workplace — \ +durable facts that will matter in future conversations. + +NOT worth remembering: greetings, one-off factual questions, session \ +logistics, ephemeral requests, follow-up clarifications with no new personal \ +info, things that only matter for the current task. + +If there is nothing durable to remember, choose `action = no_update`. + +If the message contains memorizable information, choose `action = save` and \ +return the FULL updated memory document with the new information merged into \ +existing content. + +FORMAT RULES FOR `updated_memory`: +- Markdown only. +- Every entry should be under a `##` heading. +- Recommended headings: `## Facts`, `## Preferences`, `## Instructions`. +- New bullets should use: `- YYYY-MM-DD: memory text`. +- If current memory uses legacy `(YYYY-MM-DD) [fact|pref|instr]` markers, + preserve the information but write the updated document in the new + heading-based format. +- Use the user's first name from `` when helpful, not "the user". +- Do not duplicate existing information. + +{user_name} + + +{current_memory} + + + +{user_message} +""" + +TEAM_MEMORY_EXTRACT_PROMPT = """\ +You are a team-memory extraction assistant. Analyze the latest message and \ +decide if it contains durable TEAM-level information worth persisting. + +Decision policy: +- Prioritize recall for durable team context, while avoiding personal-only facts. +- Do NOT require explicit consensus language. A direct team-level statement can + be stored if it is stable and broadly useful for future team chats. +- If evidence is weak or clearly tentative, choose `action = no_update`. + +Worth remembering (team-level only): +- Decisions and defaults that guide future team work +- Team conventions/standards (naming, review policy, coding norms) +- Stable org/project facts (locations, ownership, constraints) +- Long-lived architecture/process facts +- Ongoing priorities that are likely relevant beyond this turn + +NOT worth remembering: +- Personal preferences or biography of one person +- Questions, brainstorming, tentative ideas, or speculation +- One-off requests, status updates, TODOs, logistics for this session +- Information scoped only to a single ephemeral task + +If the message contains memorizable team information, choose `action = save` \ +and return the FULL updated team memory document with new facts merged into \ +existing content. + +FORMAT RULES FOR `updated_memory`: +- Markdown only. +- Every entry should be under a `##` heading. +- Recommended headings: `## Product Decisions`, `## Engineering Conventions`, + `## Project Facts`, `## Open Questions`. +- New bullets should use: `- YYYY-MM-DD: memory text`. +- If current memory uses legacy `(YYYY-MM-DD) [fact]` markers, preserve the + information but write the updated document in the new heading-based format. +- Do not create personal headings such as `## Preferences`, `## Instructions`, + or `## Personal Notes`. +- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored. + + +{current_memory} + + + +{author} + + + +{user_message} +""" diff --git a/surfsense_backend/app/services/memory/rewrite.py b/surfsense_backend/app/services/memory/rewrite.py new file mode 100644 index 000000000..270904ce7 --- /dev/null +++ b/surfsense_backend/app/services/memory/rewrite.py @@ -0,0 +1,35 @@ +"""LLM-backed memory rewrite helpers.""" + +from __future__ import annotations + +import logging +from typing import Any + +from langchain_core.messages import HumanMessage + +from app.services.memory.prompts import FORCED_REWRITE_PROMPT +from app.services.memory.validation import MEMORY_HARD_LIMIT +from app.utils.content_utils import extract_text_content + +logger = logging.getLogger(__name__) + + +async def forced_rewrite(content: str, llm: Any) -> str | None: + """Use a focused LLM call to compress memory under the hard limit.""" + try: + prompt = FORCED_REWRITE_PROMPT.format( + target=MEMORY_HARD_LIMIT, + content=content, + ) + response = await llm.ainvoke( + [HumanMessage(content=prompt)], + config={"tags": ["surfsense:internal", "memory-rewrite"]}, + ) + text = extract_text_content(response.content).strip() + if not text: + logger.warning("Forced memory rewrite returned empty text") + return None + return text + except Exception: + logger.exception("Forced memory rewrite LLM call failed") + return None diff --git a/surfsense_backend/app/services/memory/schemas.py b/surfsense_backend/app/services/memory/schemas.py new file mode 100644 index 000000000..9b40ee5b1 --- /dev/null +++ b/surfsense_backend/app/services/memory/schemas.py @@ -0,0 +1,23 @@ +"""Structured output schemas for memory extraction.""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field + + +class MemoryExtractionDecision(BaseModel): + """Structured extraction result; avoids string sentinel parsing.""" + + action: Literal["no_update", "save"] = Field( + description="Choose no_update when nothing durable should be saved; choose save otherwise." + ) + reason: str | None = Field( + default=None, + description="Short reason for no_update, or brief summary of the memory update.", + ) + updated_memory: str | None = Field( + default=None, + description="The full updated markdown memory document when action is save.", + ) diff --git a/surfsense_backend/app/services/memory/service.py b/surfsense_backend/app/services/memory/service.py new file mode 100644 index 000000000..85459c28c --- /dev/null +++ b/surfsense_backend/app/services/memory/service.py @@ -0,0 +1,300 @@ +"""Canonical read/write/reset/extract service for markdown memory.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from enum import StrEnum +from typing import Any, Literal +from uuid import UUID + +from langchain_core.messages import HumanMessage +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import SearchSpace, User +from app.services.memory.prompts import ( + TEAM_MEMORY_EXTRACT_PROMPT, + USER_MEMORY_EXTRACT_PROMPT, +) +from app.services.memory.rewrite import forced_rewrite +from app.services.memory.schemas import MemoryExtractionDecision +from app.services.memory.validation import ( + MEMORY_HARD_LIMIT, + soft_limit_warning, + strip_preamble_to_first_heading, + validate_bullet_format, + validate_diff, + validate_heading_sanity, + validate_memory_scope, + validate_memory_size, +) + +logger = logging.getLogger(__name__) + + +class MemoryScope(StrEnum): + USER = "user" + TEAM = "team" + + +@dataclass(frozen=True) +class SaveResult: + status: Literal["saved", "error", "no_op"] + message: str + memory_md: str = "" + warnings: list[str] = field(default_factory=list) + diff_warnings: list[str] = field(default_factory=list) + format_warnings: list[str] = field(default_factory=list) + notice: str | None = None + + def to_dict(self) -> dict[str, Any]: + data: dict[str, Any] = { + "status": self.status, + "message": self.message, + "memory_md": self.memory_md, + } + if self.notice: + data["notice"] = self.notice + if self.warnings: + data["warnings"] = self.warnings + if len(self.warnings) == 1: + data["warning"] = self.warnings[0] + if self.diff_warnings: + data["diff_warnings"] = self.diff_warnings + if self.format_warnings: + data["format_warnings"] = self.format_warnings + return data + + +class MemoryRead(BaseModel): + memory_md: str + + +def _normalize_scope(scope: MemoryScope | str) -> MemoryScope: + return scope if isinstance(scope, MemoryScope) else MemoryScope(scope) + + +def _normalize_user_id(target_id: str | UUID) -> UUID: + return UUID(target_id) if isinstance(target_id, str) else target_id + + +async def _load_target( + *, + scope: MemoryScope | str, + target_id: str | int | UUID, + session: AsyncSession, +) -> User | SearchSpace | None: + normalized = _normalize_scope(scope) + if normalized is MemoryScope.USER: + result = await session.execute( + select(User).where(User.id == _normalize_user_id(target_id)) # type: ignore[arg-type] + ) + return result.scalars().first() + result = await session.execute(select(SearchSpace).where(SearchSpace.id == int(target_id))) + return result.scalars().first() + + +def _get_memory(target: User | SearchSpace, scope: MemoryScope) -> str: + if scope is MemoryScope.USER: + return getattr(target, "memory_md", None) or "" + return getattr(target, "shared_memory_md", None) or "" + + +def _set_memory(target: User | SearchSpace, scope: MemoryScope, content: str) -> None: + if scope is MemoryScope.USER: + target.memory_md = content + else: + target.shared_memory_md = content + + +async def read_memory( + *, + scope: MemoryScope | str, + target_id: str | int | UUID, + session: AsyncSession, +) -> str: + normalized = _normalize_scope(scope) + target = await _load_target(scope=normalized, target_id=target_id, session=session) + if target is None: + return "" + return _get_memory(target, normalized) + + +async def save_memory( + *, + scope: MemoryScope | str, + target_id: str | int | UUID, + content: str, + session: AsyncSession, + llm: Any | None = None, +) -> SaveResult: + normalized = _normalize_scope(scope) + if not isinstance(content, str): + return SaveResult( + status="error", + message="Internal error: memory payload must be a string.", + ) + + target = await _load_target(scope=normalized, target_id=target_id, session=session) + if target is None: + return SaveResult( + status="error", + message="User not found." if normalized is MemoryScope.USER else "Search space not found.", + ) + + old_memory = _get_memory(target, normalized) + next_content = strip_preamble_to_first_heading(content.strip()) + notice: str | None = None + warnings: list[str] = [] + + if len(next_content) > MEMORY_HARD_LIMIT and llm is not None: + rewritten = await forced_rewrite(next_content, llm) + if rewritten is not None and len(rewritten) < len(next_content): + next_content = strip_preamble_to_first_heading(rewritten) + notice = "Memory was automatically rewritten to fit within limits." + + for validation in ( + validate_memory_size(next_content), + validate_heading_sanity(next_content), + ): + if validation: + return SaveResult( + status="error", + message=validation["message"], + memory_md=old_memory, + ) + + scope_error, scope_warnings = validate_memory_scope( + next_content, + normalized.value, + old_memory=old_memory, + ) + warnings.extend(scope_warnings) + if scope_error: + return SaveResult( + status="error", + message=scope_error["message"], + memory_md=old_memory, + warnings=warnings, + ) + + try: + _set_memory(target, normalized, next_content) + session.add(target) + await session.commit() + except Exception as e: + logger.exception("Failed to update %s memory: %s", normalized.value, e) + await session.rollback() + return SaveResult( + status="error", + message=f"Failed to update {normalized.value} memory: {e}", + memory_md=old_memory, + ) + + diff_warnings = validate_diff(old_memory, next_content) + format_warnings = validate_bullet_format(next_content) + warning = soft_limit_warning(next_content) + if warning: + warnings.append(warning) + + return SaveResult( + status="saved", + message=( + "Memory updated." + if normalized is MemoryScope.USER + else "Team memory updated." + ), + memory_md=next_content, + warnings=warnings, + diff_warnings=diff_warnings, + format_warnings=format_warnings, + notice=notice, + ) + + +async def reset_memory( + *, + scope: MemoryScope | str, + target_id: str | int | UUID, + session: AsyncSession, +) -> SaveResult: + return await save_memory( + scope=scope, + target_id=target_id, + content="", + session=session, + llm=None, + ) + + +async def extract_and_save( + *, + scope: MemoryScope | str, + target_id: str | int | UUID, + user_message: str, + actor_display_name: str | None, + session: AsyncSession, + llm: Any, +) -> SaveResult: + normalized = _normalize_scope(scope) + current_memory = await read_memory( + scope=normalized, + target_id=target_id, + session=session, + ) + + if normalized is MemoryScope.USER: + first_name = ( + actor_display_name.strip().split()[0] + if actor_display_name and actor_display_name.strip() + else "The user" + ) + prompt = USER_MEMORY_EXTRACT_PROMPT.format( + current_memory=current_memory or "(empty)", + user_message=user_message, + user_name=first_name, + ) + else: + prompt = TEAM_MEMORY_EXTRACT_PROMPT.format( + current_memory=current_memory or "(empty)", + author=actor_display_name or "Unknown team member", + user_message=user_message, + ) + + try: + structured = llm.with_structured_output(MemoryExtractionDecision) + decision = await structured.ainvoke( + [HumanMessage(content=prompt)], + config={"tags": ["surfsense:internal", "memory-extraction"]}, + ) + except Exception: + logger.exception("Structured memory extraction failed") + return SaveResult( + status="error", + message="Structured memory extraction failed.", + memory_md=current_memory, + ) + + if decision.action == "no_update": + return SaveResult( + status="no_op", + message=decision.reason or "No durable memory to persist.", + memory_md=current_memory, + ) + + if not decision.updated_memory: + return SaveResult( + status="error", + message="Structured memory extraction chose save without updated_memory.", + memory_md=current_memory, + ) + + return await save_memory( + scope=normalized, + target_id=target_id, + content=decision.updated_memory, + session=session, + llm=llm, + ) diff --git a/surfsense_backend/app/services/memory/validation.py b/surfsense_backend/app/services/memory/validation.py new file mode 100644 index 000000000..0e856943b --- /dev/null +++ b/surfsense_backend/app/services/memory/validation.py @@ -0,0 +1,158 @@ +"""Validation helpers for markdown-backed memory.""" + +from __future__ import annotations + +import re +from typing import Literal + +MEMORY_SOFT_LIMIT = 18_000 +MEMORY_HARD_LIMIT = 25_000 + +_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE) +_HEADING_LINE_RE = re.compile(r"^##\s+\S+", re.MULTILINE) +_HEADING_NORMALIZE_RE = re.compile(r"[^a-z0-9]+") +_LEGACY_BULLET_RE = re.compile(r"^-\s+\(\d{4}-\d{2}-\d{2}\)\s+\[(fact|pref|instr)\]\s+.+$") +_NEW_BULLET_RE = re.compile(r"^-\s+\d{4}-\d{2}-\d{2}:\s+.+$") + +_FORBIDDEN_TEAM_HEADINGS = { + "preferences", + "instructions", + "personal notes", + "personal instructions", +} + + +def has_markdown_heading(content: str) -> bool: + return bool(_HEADING_LINE_RE.search(content)) + + +def strip_preamble_to_first_heading(content: str) -> str: + """Drop model preamble before the first ``##`` heading, if one exists.""" + match = _HEADING_LINE_RE.search(content) + if not match: + return content.strip() + return content[match.start() :].strip() + + +def extract_headings(memory: str | None) -> set[str]: + if not memory: + return set() + return {_normalize_heading(h) for h in _SECTION_HEADING_RE.findall(memory)} + + +def _normalize_heading(heading: str) -> str: + return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower()).strip() + + +def validate_memory_size(content: str) -> dict[str, str] | None: + length = len(content) + if length > MEMORY_HARD_LIMIT: + return { + "status": "error", + "message": ( + f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit " + f"({length:,} chars). Consolidate by merging related items, " + "removing outdated entries, and shortening descriptions." + ), + } + return None + + +def validate_heading_sanity(content: str) -> dict[str, str] | None: + """Block long prose blobs without headings unless they are legacy bullets.""" + stripped = content.strip() + if not stripped: + return None + if has_markdown_heading(stripped): + return None + if len(stripped) <= 40: + return None + if any(_LEGACY_BULLET_RE.match(line.strip()) for line in stripped.splitlines()): + return None + return { + "status": "error", + "message": "Memory must be markdown with at least one ## heading.", + } + + +def validate_memory_scope( + content: str, + scope: Literal["user", "team"], + *, + old_memory: str | None = None, +) -> tuple[dict[str, str] | None, list[str]]: + """Reject new personal headings in team memory, grandfather existing ones.""" + if scope != "team": + return None, [] + + old_forbidden = extract_headings(old_memory) & _FORBIDDEN_TEAM_HEADINGS + new_forbidden = extract_headings(content) & _FORBIDDEN_TEAM_HEADINGS + introduced = sorted(new_forbidden - old_forbidden) + grandfathered = sorted(new_forbidden & old_forbidden) + + warnings: list[str] = [] + if grandfathered: + warnings.append( + "Team memory contains legacy personal headings: " + + ", ".join(grandfathered) + + ". Please consolidate them into team-safe headings." + ) + if introduced: + return ( + { + "status": "error", + "message": ( + "Team memory cannot introduce personal headings: " + + ", ".join(introduced) + + ". Use team-safe headings instead." + ), + }, + warnings, + ) + return None, warnings + + +def validate_bullet_format(content: str) -> list[str]: + warnings: list[str] = [] + for line in content.splitlines(): + stripped = line.strip() + if not stripped.startswith("- "): + continue + if _NEW_BULLET_RE.match(stripped) or _LEGACY_BULLET_RE.match(stripped): + continue + short = stripped[:80] + ("..." if len(stripped) > 80 else "") + warnings.append(f"Non-standard memory bullet: {short}") + return warnings + + +def validate_diff(old_memory: str | None, new_memory: str) -> list[str]: + if not old_memory: + return [] + + warnings: list[str] = [] + old_headings = extract_headings(old_memory) + new_headings = extract_headings(new_memory) + dropped = old_headings - new_headings + if dropped: + names = ", ".join(sorted(dropped)) + warnings.append( + f"Sections removed: {names}. If unintentional, restore from the settings page." + ) + + old_len = len(old_memory) + new_len = len(new_memory) + if old_len > 0 and new_len < old_len * 0.4: + warnings.append( + f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). Possible data loss." + ) + return warnings + + +def soft_limit_warning(content: str) -> str | None: + length = len(content) + if length > MEMORY_SOFT_LIMIT: + return ( + f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. " + "Consolidate by merging related items and removing less important entries." + ) + return None diff --git a/surfsense_backend/tests/unit/services/test_memory_service.py b/surfsense_backend/tests/unit/services/test_memory_service.py new file mode 100644 index 000000000..c16e34062 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_memory_service.py @@ -0,0 +1,204 @@ +"""Unit tests for the first-class memory service.""" + +from types import SimpleNamespace + +import pytest + +from app.services.memory import ( + MemoryScope, + extract_and_save, + reset_memory, + save_memory, +) +from app.services.memory.schemas import MemoryExtractionDecision + +pytestmark = pytest.mark.unit + + +class _FakeSession: + def __init__(self) -> None: + self.commit_calls = 0 + self.rollback_calls = 0 + self.added = [] + + def add(self, obj) -> None: + self.added.append(obj) + + async def commit(self) -> None: + self.commit_calls += 1 + + async def rollback(self) -> None: + self.rollback_calls += 1 + + +class _StructuredLLM: + def __init__(self, decision: MemoryExtractionDecision) -> None: + self.decision = decision + + def with_structured_output(self, _schema): + return self + + async def ainvoke(self, *_args, **_kwargs): + return self.decision + + +@pytest.mark.asyncio +async def test_save_memory_saves_heading_based_memory(monkeypatch) -> None: + target = SimpleNamespace(memory_md="") + session = _FakeSession() + + async def fake_load_target(**_kwargs): + return target + + monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target) + + result = await save_memory( + scope=MemoryScope.USER, + target_id="00000000-0000-0000-0000-000000000000", + content="## Facts\n- 2026-05-19: Anish works on SurfSense\n", + session=session, + ) + + assert result.status == "saved" + assert target.memory_md.startswith("## Facts") + assert session.commit_calls == 1 + + +@pytest.mark.asyncio +async def test_save_memory_accepts_legacy_marker_payload(monkeypatch) -> None: + target = SimpleNamespace(memory_md="") + session = _FakeSession() + + async def fake_load_target(**_kwargs): + return target + + monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target) + + result = await save_memory( + scope=MemoryScope.USER, + target_id="00000000-0000-0000-0000-000000000000", + content="- (2026-05-19) [fact] Legacy marker memory\n", + session=session, + ) + + assert result.status == "saved" + assert "[fact]" in target.memory_md + + +@pytest.mark.asyncio +async def test_save_memory_rejects_long_no_heading_payload(monkeypatch) -> None: + target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n") + session = _FakeSession() + + async def fake_load_target(**_kwargs): + return target + + monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target) + + result = await save_memory( + scope=MemoryScope.USER, + target_id="00000000-0000-0000-0000-000000000000", + content="reasoning text before NO_UPDATE should not become saved memory", + session=session, + ) + + assert result.status == "error" + assert session.commit_calls == 0 + assert target.memory_md.startswith("## Facts") + + +@pytest.mark.asyncio +async def test_save_memory_grandfathers_existing_team_personal_heading(monkeypatch) -> None: + content = "## Preferences\n- 2026-05-19: Existing legacy heading\n" + target = SimpleNamespace(shared_memory_md=content) + session = _FakeSession() + + async def fake_load_target(**_kwargs): + return target + + monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target) + + result = await save_memory( + scope=MemoryScope.TEAM, + target_id=1, + content=content, + session=session, + ) + + assert result.status == "saved" + assert result.warnings + assert session.commit_calls == 1 + + +@pytest.mark.asyncio +async def test_reset_memory_clears_memory(monkeypatch) -> None: + target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n") + session = _FakeSession() + + async def fake_load_target(**_kwargs): + return target + + monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target) + + result = await reset_memory( + scope=MemoryScope.USER, + target_id="00000000-0000-0000-0000-000000000000", + session=session, + ) + + assert result.status == "saved" + assert target.memory_md == "" + + +@pytest.mark.asyncio +async def test_extract_and_save_no_update_does_not_commit(monkeypatch) -> None: + target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n") + session = _FakeSession() + + async def fake_load_target(**_kwargs): + return target + + monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target) + + result = await extract_and_save( + scope=MemoryScope.USER, + target_id="00000000-0000-0000-0000-000000000000", + user_message="hello", + actor_display_name="Anish", + session=session, + llm=_StructuredLLM( + MemoryExtractionDecision(action="no_update", reason="Greeting only") + ), + ) + + assert result.status == "no_op" + assert session.commit_calls == 0 + + +@pytest.mark.asyncio +async def test_extract_and_save_persists_structured_update(monkeypatch) -> None: + target = SimpleNamespace(memory_md="") + session = _FakeSession() + + async def fake_load_target(**_kwargs): + return target + + monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target) + + result = await extract_and_save( + scope=MemoryScope.USER, + target_id="00000000-0000-0000-0000-000000000000", + user_message="I work on SurfSense", + actor_display_name="Anish", + session=session, + llm=_StructuredLLM( + MemoryExtractionDecision( + action="save", + updated_memory="## Facts\n- 2026-05-19: Anish works on SurfSense\n", + ) + ), + ) + + assert result.status == "saved" + assert "SurfSense" in target.memory_md + assert session.commit_calls == 1