refactor: extract shared memory service

This commit is contained in:
Anish Sarkar 2026-05-20 02:01:36 +05:30
parent d66295aedd
commit ceedd02353
10 changed files with 946 additions and 874 deletions

View file

@ -1,280 +1,23 @@
"""Overwrite one markdown memory document per user or team, with size and shrink guards."""
"""Memory update tools backed by the canonical memory service."""
from __future__ import annotations
import logging
import re
from typing import Any, Literal
from typing import Any
from uuid import UUID
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User
from app.services.memory import (
MEMORY_HARD_LIMIT,
MEMORY_SOFT_LIMIT,
MemoryScope,
save_memory,
)
logger = logging.getLogger(__name__)
MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
# ---------------------------------------------------------------------------
# Diff validation
# ---------------------------------------------------------------------------
def _extract_headings(memory: str) -> set[str]:
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
return set(_SECTION_HEADING_RE.findall(memory))
def _normalize_heading(heading: str) -> str:
"""Normalize heading text for robust scope checks."""
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
def _validate_memory_scope(
content: str, scope: Literal["user", "team"]
) -> dict[str, Any] | None:
"""Reject personal-only markers ([pref], [instr]) in team memory."""
if scope != "team":
return None
markers = set(_MARKER_RE.findall(content))
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
if leaked:
tags = ", ".join(f"[{m}]" for m in leaked)
return {
"status": "error",
"message": (
f"Team memory cannot include personal markers: {tags}. "
"Use [fact] only in team memory."
),
}
return None
def _validate_bullet_format(content: str) -> list[str]:
"""Return warnings for bullet lines that don't match the required format.
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
"""
warnings: list[str] = []
for line in content.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
if not _BULLET_FORMAT_RE.match(stripped):
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
warnings.append(f"Malformed bullet: {short}")
return warnings
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
"""Return a list of warning strings about suspicious changes."""
if not old_memory:
return []
warnings: list[str] = []
old_headings = _extract_headings(old_memory)
new_headings = _extract_headings(new_memory)
dropped = old_headings - new_headings
if dropped:
names = ", ".join(sorted(dropped))
warnings.append(
f"Sections removed: {names}. "
"If unintentional, the user can restore from the settings page."
)
old_len = len(old_memory)
new_len = len(new_memory)
if old_len > 0 and new_len < old_len * 0.4:
warnings.append(
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
"Possible data loss."
)
return warnings
# ---------------------------------------------------------------------------
# Size validation & soft warning
# ---------------------------------------------------------------------------
def _validate_memory_size(content: str) -> dict[str, Any] | None:
"""Return an error/warning dict if *content* is too large, else None."""
length = len(content)
if length > MEMORY_HARD_LIMIT:
return {
"status": "error",
"message": (
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
f"({length:,} chars). Consolidate by merging related items, "
"removing outdated entries, and shortening descriptions. "
"Then call update_memory again."
),
}
return None
def _soft_warning(content: str) -> str | None:
"""Return a warning string if content exceeds the soft limit."""
length = len(content)
if length > MEMORY_SOFT_LIMIT:
return (
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
"Consolidate by merging related items and removing less important "
"entries on your next update."
)
return None
# ---------------------------------------------------------------------------
# Forced rewrite when memory exceeds the hard limit
# ---------------------------------------------------------------------------
_FORCED_REWRITE_PROMPT = """\
You are a memory curator. The following memory document exceeds the character \
limit and must be shortened.
RULES:
1. Rewrite the document to be under {target} characters.
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
or rename headings to consolidate, but keep names personal and descriptive.
3. Priority for keeping content: [instr] > [pref] > [fact].
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
6. Preserve the user's first name in entries — do not replace it with "the user".
7. Output ONLY the consolidated markdown no explanations, no wrapping.
<memory_document>
{content}
</memory_document>"""
async def _forced_rewrite(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to compress *content* under the hard limit.
Returns the rewritten string, or ``None`` if the call fails.
"""
try:
prompt = _FORCED_REWRITE_PROMPT.format(
target=MEMORY_HARD_LIMIT, content=content
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
)
return text.strip()
except Exception:
logger.exception("Forced rewrite LLM call failed")
return None
# ---------------------------------------------------------------------------
# Shared save-and-respond logic
# ---------------------------------------------------------------------------
async def _save_memory(
*,
updated_memory: str,
old_memory: str | None,
llm: Any | None,
apply_fn,
commit_fn,
rollback_fn,
label: str,
scope: Literal["user", "team"],
) -> dict[str, Any]:
"""Validate, optionally force-rewrite if over the hard limit, save, and
return a response dict.
Parameters
----------
updated_memory : str
The new document the agent submitted.
old_memory : str | None
The previously persisted document (for diff checks).
llm : Any | None
LLM instance for forced rewrite (may be ``None``).
apply_fn : callable(str) -> None
Callback that sets the new memory on the ORM object.
commit_fn : coroutine
``session.commit``.
rollback_fn : coroutine
``session.rollback``.
label : str
Human label for log messages (e.g. "user memory", "team memory").
"""
content = updated_memory
# --- forced rewrite if over the hard limit ---
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
rewritten = await _forced_rewrite(content, llm)
if rewritten is not None and len(rewritten) < len(content):
content = rewritten
# --- hard-limit gate (reject if still too large after rewrite) ---
size_err = _validate_memory_size(content)
if size_err:
return size_err
scope_err = _validate_memory_scope(content, scope)
if scope_err:
return scope_err
# --- persist ---
try:
apply_fn(content)
await commit_fn()
except Exception as e:
logger.exception("Failed to update %s: %s", label, e)
await rollback_fn()
return {"status": "error", "message": f"Failed to update {label}: {e}"}
# --- build response ---
resp: dict[str, Any] = {
"status": "saved",
"message": f"{label.capitalize()} updated.",
}
if content is not updated_memory:
resp["notice"] = "Memory was automatically rewritten to fit within limits."
diff_warnings = _validate_diff(old_memory, content)
if diff_warnings:
resp["diff_warnings"] = diff_warnings
format_warnings = _validate_bullet_format(content)
if format_warnings:
resp["format_warnings"] = format_warnings
warning = _soft_warning(content)
if warning:
resp["warning"] = warning
return resp
# ---------------------------------------------------------------------------
# Tool factories
# ---------------------------------------------------------------------------
def create_update_memory_tool(
user_id: str | UUID,
@ -287,40 +30,22 @@ def create_update_memory_tool(
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the user's personal memory document.
Your current memory is shown in <user_memory> in the system prompt.
When the user shares important long-term information (preferences,
facts, instructions, context), rewrite the memory document to include
the new information. Merge new facts with existing ones, update
contradictions, remove outdated entries, and keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
The current memory is shown in <user_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
result = await db_session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return {"status": "error", "message": "User not found."}
old_memory = user.memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
result = await save_memory(
scope=MemoryScope.USER,
target_id=uid,
content=updated_memory,
session=db_session,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
scope="user",
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update memory: {e}",
}
return {"status": "error", "message": f"Failed to update memory: {e}"}
return update_memory
@ -334,36 +59,18 @@ def create_update_team_memory_tool(
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
Your current team memory is shown in <team_memory> in the system
prompt. When the team shares important long-term information
(decisions, conventions, key facts, priorities), rewrite the memory
document to include the new information. Merge new facts with
existing ones, update contradictions, remove outdated entries, and
keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
The current team memory is shown in <team_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
result = await db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return {"status": "error", "message": "Search space not found."}
old_memory = space.shared_memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
content=updated_memory,
session=db_session,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
scope="team",
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
await db_session.rollback()
@ -373,3 +80,11 @@ def create_update_team_memory_tool(
}
return update_memory
__all__ = [
"MEMORY_HARD_LIMIT",
"MEMORY_SOFT_LIMIT",
"create_update_memory_tool",
"create_update_team_memory_tool",
]

View file

@ -1,9 +1,4 @@
"""Background memory extraction for the SurfSense agent.
After each agent response, if the agent did not call ``update_memory`` during
the turn, this module can run a lightweight LLM call to decide whether the
latest message contains long-term information worth persisting.
"""
"""Background memory extraction for the SurfSense agent."""
from __future__ import annotations
@ -11,102 +6,11 @@ import logging
from typing import Any
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy import select
from app.agents.new_chat.tools.update_memory import _save_memory
from app.db import SearchSpace, User, shielded_async_session
from app.utils.content_utils import extract_text_content
from app.db import User, shielded_async_session
from app.services.memory import MemoryScope, extract_and_save
logger = logging.getLogger(__name__)
_MEMORY_EXTRACT_PROMPT = """\
You are a memory extraction assistant. Analyze the user's message and decide \
if it contains any long-term information worth persisting to memory.
Worth remembering: preferences, background/identity, goals, projects, \
instructions, tools/languages they use, decisions, expertise, workplace \
durable facts that will matter in future conversations.
NOT worth remembering: greetings, one-off factual questions, session \
logistics, ephemeral requests, follow-up clarifications with no new personal \
info, things that only matter for the current task.
If the message contains memorizable information, output the FULL updated \
memory document with the new facts merged into the existing content. Follow \
these rules:
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's
name in headings.
- Keep entries as single bullet points. Be descriptive but concise include relevant
details and context rather than just a few words.
- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text
[fact] = durable facts, [pref] = preferences, [instr] = standing instructions.
- Use the user's first name (from <user_name>) in entry text, not "the user".
- If a new fact contradicts an existing entry, update the existing entry.
- Do not duplicate information that is already present.
If nothing is worth remembering, output exactly: NO_UPDATE
<user_name>{user_name}</user_name>
<current_memory>
{current_memory}
</current_memory>
<user_message>
{user_message}
</user_message>"""
_TEAM_MEMORY_EXTRACT_PROMPT = """\
You are a team-memory extraction assistant. Analyze the latest message and \
decide if it contains durable TEAM-level information worth persisting.
Decision policy:
- Prioritize recall for durable team context, while avoiding personal-only facts.
- Do NOT require explicit consensus language. A direct team-level statement can
be stored if it is stable and broadly useful for future team chats.
- If evidence is weak or clearly tentative, output NO_UPDATE.
Worth remembering (team-level only):
- Decisions and defaults that guide future team work
- Team conventions/standards (naming, review policy, coding norms)
- Stable org/project facts (locations, ownership, constraints)
- Long-lived architecture/process facts
- Ongoing priorities that are likely relevant beyond this turn
NOT worth remembering:
- Personal preferences or biography of one person
- Questions, brainstorming, tentative ideas, or speculation
- One-off requests, status updates, TODOs, logistics for this session
- Information scoped only to a single ephemeral task
If the message contains memorizable team information, output the FULL updated \
team memory document with new facts merged into existing content. Follow rules:
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
freely. Keep heading names short (2-3 words) and natural.
- Keep entries as single bullet points. Be descriptive but concise include relevant
details and context rather than just a few words.
- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr].
- If a new fact contradicts an existing entry, update the existing entry.
- Do not duplicate existing information.
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
If nothing is worth remembering, output exactly: NO_UPDATE
<current_team_memory>
{current_memory}
</current_team_memory>
<latest_message_author>
{author}
</latest_message_author>
<latest_message>
{user_message}
</latest_message>"""
async def extract_and_save_memory(
*,
@ -114,57 +18,31 @@ async def extract_and_save_memory(
user_id: str | None,
llm: Any,
) -> None:
"""Background task: extract memorizable info and persist it.
"""Fire-and-forget personal memory extraction.
Designed to be fire-and-forget catches all exceptions internally.
The service uses structured output, so free-form ``NO_UPDATE`` text can no
longer be accidentally persisted as memory.
"""
if not user_id:
return
try:
uid = UUID(user_id) if isinstance(user_id, str) else user_id
async with shielded_async_session() as session:
result = await session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return
old_memory = user.memory_md
first_name = (
user.display_name.strip().split()[0]
if user.display_name and user.display_name.strip()
else "The user"
)
prompt = _MEMORY_EXTRACT_PROMPT.format(
current_memory=old_memory or "(empty)",
user = await session.get(User, uid)
actor_display_name = user.display_name if user else None
result = await extract_and_save(
scope=MemoryScope.USER,
target_id=uid,
user_message=user_message,
user_name=first_name,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-extraction"]},
)
text = extract_text_content(response.content).strip()
if text == "NO_UPDATE" or not text:
logger.debug("Memory extraction: no update needed (user %s)", uid)
return
save_result = await _save_memory(
updated_memory=text,
old_memory=old_memory,
actor_display_name=actor_display_name,
session=session,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="memory",
scope="user",
)
logger.info(
"Background memory extraction for user %s: %s",
uid,
save_result.get("status"),
result.status,
)
except Exception:
logger.exception("Background user memory extraction failed")
@ -177,56 +55,24 @@ async def extract_and_save_team_memory(
llm: Any,
author_display_name: str | None = None,
) -> None:
"""Background task: extract team-level memory and persist it.
Runs only for shared threads. Designed to be fire-and-forget and catches
exceptions internally.
"""
"""Fire-and-forget team-level memory extraction."""
if not search_space_id:
return
try:
async with shielded_async_session() as session:
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return
old_memory = space.shared_memory_md
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
current_memory=old_memory or "(empty)",
author=author_display_name or "Unknown team member",
result = await extract_and_save(
scope=MemoryScope.TEAM,
target_id=search_space_id,
user_message=user_message,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
)
text = extract_text_content(response.content).strip()
if text == "NO_UPDATE" or not text:
logger.debug(
"Team memory extraction: no update needed (space %s)",
search_space_id,
)
return
save_result = await _save_memory(
updated_memory=text,
old_memory=old_memory,
actor_display_name=author_display_name,
session=session,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="team memory",
scope="team",
)
logger.info(
"Background team memory extraction for space %s: %s",
search_space_id,
save_result.get("status"),
result.status,
)
except Exception:
logger.exception("Background team memory extraction failed")

View file

@ -1,369 +1,53 @@
"""Markdown-document memory tool for the SurfSense agent.
Replaces the old row-per-fact save_memory / recall_memory tools with a single
update_memory tool that overwrites a freeform markdown TEXT column. The LLM
always sees the current memory in <user_memory> / <team_memory> tags injected
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
Overflow handling:
- Soft limit (18K chars): a warning is returned telling the agent to
consolidate on the next update.
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
If it still exceeds the limit after rewriting, the save is rejected.
- Diff validation: warns when entire ``##`` sections are dropped or when the
document shrinks by more than 60%.
"""
"""Memory update tools backed by the canonical memory service."""
from __future__ import annotations
import logging
import re
from typing import Any, Literal
from typing import Any
from uuid import UUID
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User, async_session_maker
from app.utils.content_utils import extract_text_content
from app.db import async_session_maker
from app.services.memory import MemoryScope, save_memory
logger = logging.getLogger(__name__)
MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
# ---------------------------------------------------------------------------
# Diff validation
# ---------------------------------------------------------------------------
def _extract_headings(memory: str) -> set[str]:
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
return set(_SECTION_HEADING_RE.findall(memory))
def _normalize_heading(heading: str) -> str:
"""Normalize heading text for robust scope checks."""
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
def _validate_memory_scope(
content: str, scope: Literal["user", "team"]
) -> dict[str, Any] | None:
"""Reject personal-only markers ([pref], [instr]) in team memory."""
if scope != "team":
return None
markers = set(_MARKER_RE.findall(content))
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
if leaked:
tags = ", ".join(f"[{m}]" for m in leaked)
return {
"status": "error",
"message": (
f"Team memory cannot include personal markers: {tags}. "
"Use [fact] only in team memory."
),
}
return None
def _validate_bullet_format(content: str) -> list[str]:
"""Return warnings for bullet lines that don't match the required format.
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
"""
warnings: list[str] = []
for line in content.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
if not _BULLET_FORMAT_RE.match(stripped):
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
warnings.append(f"Malformed bullet: {short}")
return warnings
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
"""Return a list of warning strings about suspicious changes."""
if not old_memory:
return []
warnings: list[str] = []
old_headings = _extract_headings(old_memory)
new_headings = _extract_headings(new_memory)
dropped = old_headings - new_headings
if dropped:
names = ", ".join(sorted(dropped))
warnings.append(
f"Sections removed: {names}. "
"If unintentional, the user can restore from the settings page."
)
old_len = len(old_memory)
new_len = len(new_memory)
if old_len > 0 and new_len < old_len * 0.4:
warnings.append(
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
"Possible data loss."
)
return warnings
# ---------------------------------------------------------------------------
# Size validation & soft warning
# ---------------------------------------------------------------------------
def _validate_memory_size(content: str) -> dict[str, Any] | None:
"""Return an error/warning dict if *content* is too large, else None."""
length = len(content)
if length > MEMORY_HARD_LIMIT:
return {
"status": "error",
"message": (
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
f"({length:,} chars). Consolidate by merging related items, "
"removing outdated entries, and shortening descriptions. "
"Then call update_memory again."
),
}
return None
def _soft_warning(content: str) -> str | None:
"""Return a warning string if content exceeds the soft limit."""
length = len(content)
if length > MEMORY_SOFT_LIMIT:
return (
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
"Consolidate by merging related items and removing less important "
"entries on your next update."
)
return None
# ---------------------------------------------------------------------------
# Forced rewrite when memory exceeds the hard limit
# ---------------------------------------------------------------------------
_FORCED_REWRITE_PROMPT = """\
You are a memory curator. The following memory document exceeds the character \
limit and must be shortened.
RULES:
1. Rewrite the document to be under {target} characters.
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
or rename headings to consolidate, but keep names personal and descriptive.
3. Priority for keeping content: [instr] > [pref] > [fact].
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
6. Preserve the user's first name in entries — do not replace it with "the user".
7. Output ONLY the consolidated markdown no explanations, no wrapping.
<memory_document>
{content}
</memory_document>"""
async def _forced_rewrite(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to compress *content* under the hard limit.
Returns the rewritten string, or ``None`` if the call fails.
"""
try:
prompt = _FORCED_REWRITE_PROMPT.format(
target=MEMORY_HARD_LIMIT, content=content
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
text = extract_text_content(response.content).strip()
if not text:
logger.warning("Forced rewrite returned empty text; aborting rewrite")
return None
return text
except Exception:
logger.exception("Forced rewrite LLM call failed")
return None
# ---------------------------------------------------------------------------
# Shared save-and-respond logic
# ---------------------------------------------------------------------------
async def _save_memory(
*,
updated_memory: str,
old_memory: str | None,
llm: Any | None,
apply_fn,
commit_fn,
rollback_fn,
label: str,
scope: Literal["user", "team"],
) -> dict[str, Any]:
"""Validate, optionally force-rewrite if over the hard limit, save, and
return a response dict.
Parameters
----------
updated_memory : str
The new document the agent submitted.
old_memory : str | None
The previously persisted document (for diff checks).
llm : Any | None
LLM instance for forced rewrite (may be ``None``).
apply_fn : callable(str) -> None
Callback that sets the new memory on the ORM object.
commit_fn : coroutine
``session.commit``.
rollback_fn : coroutine
``session.rollback``.
label : str
Human label for log messages (e.g. "user memory", "team memory").
"""
if not isinstance(updated_memory, str):
logger.warning(
"Refusing non-string memory payload (type=%s)",
type(updated_memory).__name__,
)
return {
"status": "error",
"message": "Internal error: memory payload must be a string.",
}
content = updated_memory
# --- forced rewrite if over the hard limit ---
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
rewritten = await _forced_rewrite(content, llm)
if rewritten is not None and len(rewritten) < len(content):
content = rewritten
# --- hard-limit gate (reject if still too large after rewrite) ---
size_err = _validate_memory_size(content)
if size_err:
return size_err
scope_err = _validate_memory_scope(content, scope)
if scope_err:
return scope_err
# --- persist ---
try:
apply_fn(content)
await commit_fn()
except Exception as e:
logger.exception("Failed to update %s: %s", label, e)
await rollback_fn()
return {"status": "error", "message": f"Failed to update {label}: {e}"}
# --- build response ---
resp: dict[str, Any] = {
"status": "saved",
"message": f"{label.capitalize()} updated.",
}
if content is not updated_memory:
resp["notice"] = "Memory was automatically rewritten to fit within limits."
diff_warnings = _validate_diff(old_memory, content)
if diff_warnings:
resp["diff_warnings"] = diff_warnings
format_warnings = _validate_bullet_format(content)
if format_warnings:
resp["format_warnings"] = format_warnings
warning = _soft_warning(content)
if warning:
resp["warning"] = warning
return resp
# ---------------------------------------------------------------------------
# Tool factories
# ---------------------------------------------------------------------------
def create_update_memory_tool(
user_id: str | UUID,
db_session: AsyncSession,
llm: Any | None = None,
):
"""Factory function to create the user-memory update tool.
"""Factory for the user-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
user_id: ID of the user whose memory document is being updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the user-memory scope.
Uses a fresh short-lived session per call so compiled-agent caches never
retain a stale request-scoped session.
"""
del db_session # per-call session — see docstring
del db_session
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the user's personal memory document.
Your current memory is shown in <user_memory> in the system prompt.
When the user shares important long-term information (preferences,
facts, instructions, context), rewrite the memory document to include
the new information. Merge new facts with existing ones, update
contradictions, remove outdated entries, and keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
The current memory is shown in <user_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
async with async_session_maker() as db_session:
result = await db_session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return {"status": "error", "message": "User not found."}
old_memory = user.memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
result = await save_memory(
scope=MemoryScope.USER,
target_id=uid,
content=updated_memory,
session=db_session,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
scope="user",
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
return {
"status": "error",
"message": f"Failed to update memory: {e}",
}
return {"status": "error", "message": f"Failed to update memory: {e}"}
return update_memory
@ -373,64 +57,26 @@ def create_update_team_memory_tool(
db_session: AsyncSession,
llm: Any | None = None,
):
"""Factory function to create the team-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
search_space_id: ID of the search space whose team memory is being
updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the team-memory scope.
"""
del db_session # per-call session — see docstring
"""Factory for the team-memory update tool."""
del db_session
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
Your current team memory is shown in <team_memory> in the system
prompt. When the team shares important long-term information
(decisions, conventions, key facts, priorities), rewrite the memory
document to include the new information. Merge new facts with
existing ones, update contradictions, remove outdated entries, and
keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
The current team memory is shown in <team_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return {"status": "error", "message": "Search space not found."}
old_memory = space.shared_memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
content=updated_memory,
session=db_session,
llm=llm,
apply_fn=lambda content: setattr(
space, "shared_memory_md", content
),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
scope="team",
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
return {
@ -439,3 +85,9 @@ def create_update_team_memory_tool(
}
return update_memory
__all__ = [
"create_update_memory_tool",
"create_update_team_memory_tool",
]

View file

@ -0,0 +1,29 @@
"""First-class memory service for user and team markdown memory."""
from .service import (
MemoryScope,
SaveResult,
extract_and_save,
read_memory,
reset_memory,
save_memory,
)
from .validation import (
MEMORY_HARD_LIMIT,
MEMORY_SOFT_LIMIT,
validate_bullet_format,
validate_memory_scope,
)
__all__ = [
"MEMORY_HARD_LIMIT",
"MEMORY_SOFT_LIMIT",
"MemoryScope",
"SaveResult",
"extract_and_save",
"read_memory",
"reset_memory",
"save_memory",
"validate_bullet_format",
"validate_memory_scope",
]

View file

@ -0,0 +1,110 @@
"""Prompts used by the memory service."""
FORCED_REWRITE_PROMPT = """\
You are a memory curator. The following memory document exceeds the character \
limit and must be shortened.
RULES:
1. Rewrite the document to be under {target} characters.
2. Output Markdown only. Use clear `##` headings and concise bullet points.
3. New-format bullets should look like: `- YYYY-MM-DD: memory text`.
4. If the input contains legacy markers like `(YYYY-MM-DD) [fact]`, preserve the
information but remove the inline marker in the output.
5. Preserve durable instructions and preferences before generic facts when
compressing personal memory.
6. Preserve existing headings when useful; merge duplicate headings and bullets.
7. Output ONLY the consolidated markdown no explanations, no wrapping.
<memory_document>
{content}
</memory_document>"""
USER_MEMORY_EXTRACT_PROMPT = """\
You are a memory extraction assistant. Analyze the user's message and decide \
if it contains any long-term information worth persisting to personal memory.
Worth remembering: preferences, background/identity, goals, projects, \
instructions, tools/languages they use, decisions, expertise, workplace \
durable facts that will matter in future conversations.
NOT worth remembering: greetings, one-off factual questions, session \
logistics, ephemeral requests, follow-up clarifications with no new personal \
info, things that only matter for the current task.
If there is nothing durable to remember, choose `action = no_update`.
If the message contains memorizable information, choose `action = save` and \
return the FULL updated memory document with the new information merged into \
existing content.
FORMAT RULES FOR `updated_memory`:
- Markdown only.
- Every entry should be under a `##` heading.
- Recommended headings: `## Facts`, `## Preferences`, `## Instructions`.
- New bullets should use: `- YYYY-MM-DD: memory text`.
- If current memory uses legacy `(YYYY-MM-DD) [fact|pref|instr]` markers,
preserve the information but write the updated document in the new
heading-based format.
- Use the user's first name from `<user_name>` when helpful, not "the user".
- Do not duplicate existing information.
<user_name>{user_name}</user_name>
<current_memory>
{current_memory}
</current_memory>
<user_message>
{user_message}
</user_message>"""
TEAM_MEMORY_EXTRACT_PROMPT = """\
You are a team-memory extraction assistant. Analyze the latest message and \
decide if it contains durable TEAM-level information worth persisting.
Decision policy:
- Prioritize recall for durable team context, while avoiding personal-only facts.
- Do NOT require explicit consensus language. A direct team-level statement can
be stored if it is stable and broadly useful for future team chats.
- If evidence is weak or clearly tentative, choose `action = no_update`.
Worth remembering (team-level only):
- Decisions and defaults that guide future team work
- Team conventions/standards (naming, review policy, coding norms)
- Stable org/project facts (locations, ownership, constraints)
- Long-lived architecture/process facts
- Ongoing priorities that are likely relevant beyond this turn
NOT worth remembering:
- Personal preferences or biography of one person
- Questions, brainstorming, tentative ideas, or speculation
- One-off requests, status updates, TODOs, logistics for this session
- Information scoped only to a single ephemeral task
If the message contains memorizable team information, choose `action = save` \
and return the FULL updated team memory document with new facts merged into \
existing content.
FORMAT RULES FOR `updated_memory`:
- Markdown only.
- Every entry should be under a `##` heading.
- Recommended headings: `## Product Decisions`, `## Engineering Conventions`,
`## Project Facts`, `## Open Questions`.
- New bullets should use: `- YYYY-MM-DD: memory text`.
- If current memory uses legacy `(YYYY-MM-DD) [fact]` markers, preserve the
information but write the updated document in the new heading-based format.
- Do not create personal headings such as `## Preferences`, `## Instructions`,
or `## Personal Notes`.
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
<current_team_memory>
{current_memory}
</current_team_memory>
<latest_message_author>
{author}
</latest_message_author>
<latest_message>
{user_message}
</latest_message>"""

View file

@ -0,0 +1,35 @@
"""LLM-backed memory rewrite helpers."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.messages import HumanMessage
from app.services.memory.prompts import FORCED_REWRITE_PROMPT
from app.services.memory.validation import MEMORY_HARD_LIMIT
from app.utils.content_utils import extract_text_content
logger = logging.getLogger(__name__)
async def forced_rewrite(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to compress memory under the hard limit."""
try:
prompt = FORCED_REWRITE_PROMPT.format(
target=MEMORY_HARD_LIMIT,
content=content,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-rewrite"]},
)
text = extract_text_content(response.content).strip()
if not text:
logger.warning("Forced memory rewrite returned empty text")
return None
return text
except Exception:
logger.exception("Forced memory rewrite LLM call failed")
return None

View file

@ -0,0 +1,23 @@
"""Structured output schemas for memory extraction."""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
class MemoryExtractionDecision(BaseModel):
"""Structured extraction result; avoids string sentinel parsing."""
action: Literal["no_update", "save"] = Field(
description="Choose no_update when nothing durable should be saved; choose save otherwise."
)
reason: str | None = Field(
default=None,
description="Short reason for no_update, or brief summary of the memory update.",
)
updated_memory: str | None = Field(
default=None,
description="The full updated markdown memory document when action is save.",
)

View file

@ -0,0 +1,300 @@
"""Canonical read/write/reset/extract service for markdown memory."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any, Literal
from uuid import UUID
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User
from app.services.memory.prompts import (
TEAM_MEMORY_EXTRACT_PROMPT,
USER_MEMORY_EXTRACT_PROMPT,
)
from app.services.memory.rewrite import forced_rewrite
from app.services.memory.schemas import MemoryExtractionDecision
from app.services.memory.validation import (
MEMORY_HARD_LIMIT,
soft_limit_warning,
strip_preamble_to_first_heading,
validate_bullet_format,
validate_diff,
validate_heading_sanity,
validate_memory_scope,
validate_memory_size,
)
logger = logging.getLogger(__name__)
class MemoryScope(StrEnum):
USER = "user"
TEAM = "team"
@dataclass(frozen=True)
class SaveResult:
status: Literal["saved", "error", "no_op"]
message: str
memory_md: str = ""
warnings: list[str] = field(default_factory=list)
diff_warnings: list[str] = field(default_factory=list)
format_warnings: list[str] = field(default_factory=list)
notice: str | None = None
def to_dict(self) -> dict[str, Any]:
data: dict[str, Any] = {
"status": self.status,
"message": self.message,
"memory_md": self.memory_md,
}
if self.notice:
data["notice"] = self.notice
if self.warnings:
data["warnings"] = self.warnings
if len(self.warnings) == 1:
data["warning"] = self.warnings[0]
if self.diff_warnings:
data["diff_warnings"] = self.diff_warnings
if self.format_warnings:
data["format_warnings"] = self.format_warnings
return data
class MemoryRead(BaseModel):
memory_md: str
def _normalize_scope(scope: MemoryScope | str) -> MemoryScope:
return scope if isinstance(scope, MemoryScope) else MemoryScope(scope)
def _normalize_user_id(target_id: str | UUID) -> UUID:
return UUID(target_id) if isinstance(target_id, str) else target_id
async def _load_target(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
session: AsyncSession,
) -> User | SearchSpace | None:
normalized = _normalize_scope(scope)
if normalized is MemoryScope.USER:
result = await session.execute(
select(User).where(User.id == _normalize_user_id(target_id)) # type: ignore[arg-type]
)
return result.scalars().first()
result = await session.execute(select(SearchSpace).where(SearchSpace.id == int(target_id)))
return result.scalars().first()
def _get_memory(target: User | SearchSpace, scope: MemoryScope) -> str:
if scope is MemoryScope.USER:
return getattr(target, "memory_md", None) or ""
return getattr(target, "shared_memory_md", None) or ""
def _set_memory(target: User | SearchSpace, scope: MemoryScope, content: str) -> None:
if scope is MemoryScope.USER:
target.memory_md = content
else:
target.shared_memory_md = content
async def read_memory(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
session: AsyncSession,
) -> str:
normalized = _normalize_scope(scope)
target = await _load_target(scope=normalized, target_id=target_id, session=session)
if target is None:
return ""
return _get_memory(target, normalized)
async def save_memory(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
content: str,
session: AsyncSession,
llm: Any | None = None,
) -> SaveResult:
normalized = _normalize_scope(scope)
if not isinstance(content, str):
return SaveResult(
status="error",
message="Internal error: memory payload must be a string.",
)
target = await _load_target(scope=normalized, target_id=target_id, session=session)
if target is None:
return SaveResult(
status="error",
message="User not found." if normalized is MemoryScope.USER else "Search space not found.",
)
old_memory = _get_memory(target, normalized)
next_content = strip_preamble_to_first_heading(content.strip())
notice: str | None = None
warnings: list[str] = []
if len(next_content) > MEMORY_HARD_LIMIT and llm is not None:
rewritten = await forced_rewrite(next_content, llm)
if rewritten is not None and len(rewritten) < len(next_content):
next_content = strip_preamble_to_first_heading(rewritten)
notice = "Memory was automatically rewritten to fit within limits."
for validation in (
validate_memory_size(next_content),
validate_heading_sanity(next_content),
):
if validation:
return SaveResult(
status="error",
message=validation["message"],
memory_md=old_memory,
)
scope_error, scope_warnings = validate_memory_scope(
next_content,
normalized.value,
old_memory=old_memory,
)
warnings.extend(scope_warnings)
if scope_error:
return SaveResult(
status="error",
message=scope_error["message"],
memory_md=old_memory,
warnings=warnings,
)
try:
_set_memory(target, normalized, next_content)
session.add(target)
await session.commit()
except Exception as e:
logger.exception("Failed to update %s memory: %s", normalized.value, e)
await session.rollback()
return SaveResult(
status="error",
message=f"Failed to update {normalized.value} memory: {e}",
memory_md=old_memory,
)
diff_warnings = validate_diff(old_memory, next_content)
format_warnings = validate_bullet_format(next_content)
warning = soft_limit_warning(next_content)
if warning:
warnings.append(warning)
return SaveResult(
status="saved",
message=(
"Memory updated."
if normalized is MemoryScope.USER
else "Team memory updated."
),
memory_md=next_content,
warnings=warnings,
diff_warnings=diff_warnings,
format_warnings=format_warnings,
notice=notice,
)
async def reset_memory(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
session: AsyncSession,
) -> SaveResult:
return await save_memory(
scope=scope,
target_id=target_id,
content="",
session=session,
llm=None,
)
async def extract_and_save(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
user_message: str,
actor_display_name: str | None,
session: AsyncSession,
llm: Any,
) -> SaveResult:
normalized = _normalize_scope(scope)
current_memory = await read_memory(
scope=normalized,
target_id=target_id,
session=session,
)
if normalized is MemoryScope.USER:
first_name = (
actor_display_name.strip().split()[0]
if actor_display_name and actor_display_name.strip()
else "The user"
)
prompt = USER_MEMORY_EXTRACT_PROMPT.format(
current_memory=current_memory or "(empty)",
user_message=user_message,
user_name=first_name,
)
else:
prompt = TEAM_MEMORY_EXTRACT_PROMPT.format(
current_memory=current_memory or "(empty)",
author=actor_display_name or "Unknown team member",
user_message=user_message,
)
try:
structured = llm.with_structured_output(MemoryExtractionDecision)
decision = await structured.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-extraction"]},
)
except Exception:
logger.exception("Structured memory extraction failed")
return SaveResult(
status="error",
message="Structured memory extraction failed.",
memory_md=current_memory,
)
if decision.action == "no_update":
return SaveResult(
status="no_op",
message=decision.reason or "No durable memory to persist.",
memory_md=current_memory,
)
if not decision.updated_memory:
return SaveResult(
status="error",
message="Structured memory extraction chose save without updated_memory.",
memory_md=current_memory,
)
return await save_memory(
scope=normalized,
target_id=target_id,
content=decision.updated_memory,
session=session,
llm=llm,
)

View file

@ -0,0 +1,158 @@
"""Validation helpers for markdown-backed memory."""
from __future__ import annotations
import re
from typing import Literal
MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_LINE_RE = re.compile(r"^##\s+\S+", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"[^a-z0-9]+")
_LEGACY_BULLET_RE = re.compile(r"^-\s+\(\d{4}-\d{2}-\d{2}\)\s+\[(fact|pref|instr)\]\s+.+$")
_NEW_BULLET_RE = re.compile(r"^-\s+\d{4}-\d{2}-\d{2}:\s+.+$")
_FORBIDDEN_TEAM_HEADINGS = {
"preferences",
"instructions",
"personal notes",
"personal instructions",
}
def has_markdown_heading(content: str) -> bool:
return bool(_HEADING_LINE_RE.search(content))
def strip_preamble_to_first_heading(content: str) -> str:
"""Drop model preamble before the first ``##`` heading, if one exists."""
match = _HEADING_LINE_RE.search(content)
if not match:
return content.strip()
return content[match.start() :].strip()
def extract_headings(memory: str | None) -> set[str]:
if not memory:
return set()
return {_normalize_heading(h) for h in _SECTION_HEADING_RE.findall(memory)}
def _normalize_heading(heading: str) -> str:
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower()).strip()
def validate_memory_size(content: str) -> dict[str, str] | None:
length = len(content)
if length > MEMORY_HARD_LIMIT:
return {
"status": "error",
"message": (
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
f"({length:,} chars). Consolidate by merging related items, "
"removing outdated entries, and shortening descriptions."
),
}
return None
def validate_heading_sanity(content: str) -> dict[str, str] | None:
"""Block long prose blobs without headings unless they are legacy bullets."""
stripped = content.strip()
if not stripped:
return None
if has_markdown_heading(stripped):
return None
if len(stripped) <= 40:
return None
if any(_LEGACY_BULLET_RE.match(line.strip()) for line in stripped.splitlines()):
return None
return {
"status": "error",
"message": "Memory must be markdown with at least one ## heading.",
}
def validate_memory_scope(
content: str,
scope: Literal["user", "team"],
*,
old_memory: str | None = None,
) -> tuple[dict[str, str] | None, list[str]]:
"""Reject new personal headings in team memory, grandfather existing ones."""
if scope != "team":
return None, []
old_forbidden = extract_headings(old_memory) & _FORBIDDEN_TEAM_HEADINGS
new_forbidden = extract_headings(content) & _FORBIDDEN_TEAM_HEADINGS
introduced = sorted(new_forbidden - old_forbidden)
grandfathered = sorted(new_forbidden & old_forbidden)
warnings: list[str] = []
if grandfathered:
warnings.append(
"Team memory contains legacy personal headings: "
+ ", ".join(grandfathered)
+ ". Please consolidate them into team-safe headings."
)
if introduced:
return (
{
"status": "error",
"message": (
"Team memory cannot introduce personal headings: "
+ ", ".join(introduced)
+ ". Use team-safe headings instead."
),
},
warnings,
)
return None, warnings
def validate_bullet_format(content: str) -> list[str]:
warnings: list[str] = []
for line in content.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
if _NEW_BULLET_RE.match(stripped) or _LEGACY_BULLET_RE.match(stripped):
continue
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
warnings.append(f"Non-standard memory bullet: {short}")
return warnings
def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
if not old_memory:
return []
warnings: list[str] = []
old_headings = extract_headings(old_memory)
new_headings = extract_headings(new_memory)
dropped = old_headings - new_headings
if dropped:
names = ", ".join(sorted(dropped))
warnings.append(
f"Sections removed: {names}. If unintentional, restore from the settings page."
)
old_len = len(old_memory)
new_len = len(new_memory)
if old_len > 0 and new_len < old_len * 0.4:
warnings.append(
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). Possible data loss."
)
return warnings
def soft_limit_warning(content: str) -> str | None:
length = len(content)
if length > MEMORY_SOFT_LIMIT:
return (
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
"Consolidate by merging related items and removing less important entries."
)
return None

View file

@ -0,0 +1,204 @@
"""Unit tests for the first-class memory service."""
from types import SimpleNamespace
import pytest
from app.services.memory import (
MemoryScope,
extract_and_save,
reset_memory,
save_memory,
)
from app.services.memory.schemas import MemoryExtractionDecision
pytestmark = pytest.mark.unit
class _FakeSession:
def __init__(self) -> None:
self.commit_calls = 0
self.rollback_calls = 0
self.added = []
def add(self, obj) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.commit_calls += 1
async def rollback(self) -> None:
self.rollback_calls += 1
class _StructuredLLM:
def __init__(self, decision: MemoryExtractionDecision) -> None:
self.decision = decision
def with_structured_output(self, _schema):
return self
async def ainvoke(self, *_args, **_kwargs):
return self.decision
@pytest.mark.asyncio
async def test_save_memory_saves_heading_based_memory(monkeypatch) -> None:
target = SimpleNamespace(memory_md="")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="## Facts\n- 2026-05-19: Anish works on SurfSense\n",
session=session,
)
assert result.status == "saved"
assert target.memory_md.startswith("## Facts")
assert session.commit_calls == 1
@pytest.mark.asyncio
async def test_save_memory_accepts_legacy_marker_payload(monkeypatch) -> None:
target = SimpleNamespace(memory_md="")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="- (2026-05-19) [fact] Legacy marker memory\n",
session=session,
)
assert result.status == "saved"
assert "[fact]" in target.memory_md
@pytest.mark.asyncio
async def test_save_memory_rejects_long_no_heading_payload(monkeypatch) -> None:
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="reasoning text before NO_UPDATE should not become saved memory",
session=session,
)
assert result.status == "error"
assert session.commit_calls == 0
assert target.memory_md.startswith("## Facts")
@pytest.mark.asyncio
async def test_save_memory_grandfathers_existing_team_personal_heading(monkeypatch) -> None:
content = "## Preferences\n- 2026-05-19: Existing legacy heading\n"
target = SimpleNamespace(shared_memory_md=content)
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=1,
content=content,
session=session,
)
assert result.status == "saved"
assert result.warnings
assert session.commit_calls == 1
@pytest.mark.asyncio
async def test_reset_memory_clears_memory(monkeypatch) -> None:
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await reset_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
session=session,
)
assert result.status == "saved"
assert target.memory_md == ""
@pytest.mark.asyncio
async def test_extract_and_save_no_update_does_not_commit(monkeypatch) -> None:
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await extract_and_save(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
user_message="hello",
actor_display_name="Anish",
session=session,
llm=_StructuredLLM(
MemoryExtractionDecision(action="no_update", reason="Greeting only")
),
)
assert result.status == "no_op"
assert session.commit_calls == 0
@pytest.mark.asyncio
async def test_extract_and_save_persists_structured_update(monkeypatch) -> None:
target = SimpleNamespace(memory_md="")
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await extract_and_save(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
user_message="I work on SurfSense",
actor_display_name="Anish",
session=session,
llm=_StructuredLLM(
MemoryExtractionDecision(
action="save",
updated_memory="## Facts\n- 2026-05-19: Anish works on SurfSense\n",
)
),
)
assert result.status == "saved"
assert "SurfSense" in target.memory_md
assert session.commit_calls == 1