mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
refactor: extract shared memory service
This commit is contained in:
parent
d66295aedd
commit
ceedd02353
10 changed files with 946 additions and 874 deletions
|
|
@ -1,280 +1,23 @@
|
|||
"""Overwrite one markdown memory document per user or team, with size and shrink guards."""
|
||||
"""Memory update tools backed by the canonical memory service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSpace, User
|
||||
from app.services.memory import (
|
||||
MEMORY_HARD_LIMIT,
|
||||
MEMORY_SOFT_LIMIT,
|
||||
MemoryScope,
|
||||
save_memory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_SOFT_LIMIT = 18_000
|
||||
MEMORY_HARD_LIMIT = 25_000
|
||||
|
||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
||||
|
||||
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
|
||||
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
|
||||
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diff validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_headings(memory: str) -> set[str]:
|
||||
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
|
||||
return set(_SECTION_HEADING_RE.findall(memory))
|
||||
|
||||
|
||||
def _normalize_heading(heading: str) -> str:
|
||||
"""Normalize heading text for robust scope checks."""
|
||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
||||
|
||||
|
||||
def _validate_memory_scope(
|
||||
content: str, scope: Literal["user", "team"]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Reject personal-only markers ([pref], [instr]) in team memory."""
|
||||
if scope != "team":
|
||||
return None
|
||||
|
||||
markers = set(_MARKER_RE.findall(content))
|
||||
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
|
||||
if leaked:
|
||||
tags = ", ".join(f"[{m}]" for m in leaked)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Team memory cannot include personal markers: {tags}. "
|
||||
"Use [fact] only in team memory."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _validate_bullet_format(content: str) -> list[str]:
|
||||
"""Return warnings for bullet lines that don't match the required format.
|
||||
|
||||
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
|
||||
"""
|
||||
warnings: list[str] = []
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped.startswith("- "):
|
||||
continue
|
||||
if not _BULLET_FORMAT_RE.match(stripped):
|
||||
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
||||
warnings.append(f"Malformed bullet: {short}")
|
||||
return warnings
|
||||
|
||||
|
||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||
"""Return a list of warning strings about suspicious changes."""
|
||||
if not old_memory:
|
||||
return []
|
||||
|
||||
warnings: list[str] = []
|
||||
old_headings = _extract_headings(old_memory)
|
||||
new_headings = _extract_headings(new_memory)
|
||||
dropped = old_headings - new_headings
|
||||
if dropped:
|
||||
names = ", ".join(sorted(dropped))
|
||||
warnings.append(
|
||||
f"Sections removed: {names}. "
|
||||
"If unintentional, the user can restore from the settings page."
|
||||
)
|
||||
|
||||
old_len = len(old_memory)
|
||||
new_len = len(new_memory)
|
||||
if old_len > 0 and new_len < old_len * 0.4:
|
||||
warnings.append(
|
||||
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
|
||||
"Possible data loss."
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Size validation & soft warning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_memory_size(content: str) -> dict[str, Any] | None:
|
||||
"""Return an error/warning dict if *content* is too large, else None."""
|
||||
length = len(content)
|
||||
if length > MEMORY_HARD_LIMIT:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
||||
f"({length:,} chars). Consolidate by merging related items, "
|
||||
"removing outdated entries, and shortening descriptions. "
|
||||
"Then call update_memory again."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _soft_warning(content: str) -> str | None:
|
||||
"""Return a warning string if content exceeds the soft limit."""
|
||||
length = len(content)
|
||||
if length > MEMORY_SOFT_LIMIT:
|
||||
return (
|
||||
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
||||
"Consolidate by merging related items and removing less important "
|
||||
"entries on your next update."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forced rewrite when memory exceeds the hard limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FORCED_REWRITE_PROMPT = """\
|
||||
You are a memory curator. The following memory document exceeds the character \
|
||||
limit and must be shortened.
|
||||
|
||||
RULES:
|
||||
1. Rewrite the document to be under {target} characters.
|
||||
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
|
||||
or rename headings to consolidate, but keep names personal and descriptive.
|
||||
3. Priority for keeping content: [instr] > [pref] > [fact].
|
||||
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
||||
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||
6. Preserve the user's first name in entries — do not replace it with "the user".
|
||||
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||
|
||||
<memory_document>
|
||||
{content}
|
||||
</memory_document>"""
|
||||
|
||||
|
||||
async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
||||
"""Use a focused LLM call to compress *content* under the hard limit.
|
||||
|
||||
Returns the rewritten string, or ``None`` if the call fails.
|
||||
"""
|
||||
try:
|
||||
prompt = _FORCED_REWRITE_PROMPT.format(
|
||||
target=MEMORY_HARD_LIMIT, content=content
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
text = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
)
|
||||
return text.strip()
|
||||
except Exception:
|
||||
logger.exception("Forced rewrite LLM call failed")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared save-and-respond logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _save_memory(
|
||||
*,
|
||||
updated_memory: str,
|
||||
old_memory: str | None,
|
||||
llm: Any | None,
|
||||
apply_fn,
|
||||
commit_fn,
|
||||
rollback_fn,
|
||||
label: str,
|
||||
scope: Literal["user", "team"],
|
||||
) -> dict[str, Any]:
|
||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||
return a response dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
updated_memory : str
|
||||
The new document the agent submitted.
|
||||
old_memory : str | None
|
||||
The previously persisted document (for diff checks).
|
||||
llm : Any | None
|
||||
LLM instance for forced rewrite (may be ``None``).
|
||||
apply_fn : callable(str) -> None
|
||||
Callback that sets the new memory on the ORM object.
|
||||
commit_fn : coroutine
|
||||
``session.commit``.
|
||||
rollback_fn : coroutine
|
||||
``session.rollback``.
|
||||
label : str
|
||||
Human label for log messages (e.g. "user memory", "team memory").
|
||||
"""
|
||||
content = updated_memory
|
||||
|
||||
# --- forced rewrite if over the hard limit ---
|
||||
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
|
||||
rewritten = await _forced_rewrite(content, llm)
|
||||
if rewritten is not None and len(rewritten) < len(content):
|
||||
content = rewritten
|
||||
|
||||
# --- hard-limit gate (reject if still too large after rewrite) ---
|
||||
size_err = _validate_memory_size(content)
|
||||
if size_err:
|
||||
return size_err
|
||||
|
||||
scope_err = _validate_memory_scope(content, scope)
|
||||
if scope_err:
|
||||
return scope_err
|
||||
|
||||
# --- persist ---
|
||||
try:
|
||||
apply_fn(content)
|
||||
await commit_fn()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update %s: %s", label, e)
|
||||
await rollback_fn()
|
||||
return {"status": "error", "message": f"Failed to update {label}: {e}"}
|
||||
|
||||
# --- build response ---
|
||||
resp: dict[str, Any] = {
|
||||
"status": "saved",
|
||||
"message": f"{label.capitalize()} updated.",
|
||||
}
|
||||
|
||||
if content is not updated_memory:
|
||||
resp["notice"] = "Memory was automatically rewritten to fit within limits."
|
||||
|
||||
diff_warnings = _validate_diff(old_memory, content)
|
||||
if diff_warnings:
|
||||
resp["diff_warnings"] = diff_warnings
|
||||
|
||||
format_warnings = _validate_bullet_format(content)
|
||||
if format_warnings:
|
||||
resp["format_warnings"] = format_warnings
|
||||
|
||||
warning = _soft_warning(content)
|
||||
if warning:
|
||||
resp["warning"] = warning
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_update_memory_tool(
|
||||
user_id: str | UUID,
|
||||
|
|
@ -287,40 +30,22 @@ def create_update_memory_tool(
|
|||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the user's personal memory document.
|
||||
|
||||
Your current memory is shown in <user_memory> in the system prompt.
|
||||
When the user shares important long-term information (preferences,
|
||||
facts, instructions, context), rewrite the memory document to include
|
||||
the new information. Merge new facts with existing ones, update
|
||||
contradictions, remove outdated entries, and keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
The current memory is shown in <user_memory>. Pass the FULL updated
|
||||
markdown document, not a diff.
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return {"status": "error", "message": "User not found."}
|
||||
|
||||
old_memory = user.memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id=uid,
|
||||
content=updated_memory,
|
||||
session=db_session,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
return result.to_dict()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update user memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update memory: {e}",
|
||||
}
|
||||
return {"status": "error", "message": f"Failed to update memory: {e}"}
|
||||
|
||||
return update_memory
|
||||
|
||||
|
|
@ -334,36 +59,18 @@ def create_update_team_memory_tool(
|
|||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the team's shared memory document for this search space.
|
||||
|
||||
Your current team memory is shown in <team_memory> in the system
|
||||
prompt. When the team shares important long-term information
|
||||
(decisions, conventions, key facts, priorities), rewrite the memory
|
||||
document to include the new information. Merge new facts with
|
||||
existing ones, update contradictions, remove outdated entries, and
|
||||
keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
The current team memory is shown in <team_memory>. Pass the FULL updated
|
||||
markdown document, not a diff.
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return {"status": "error", "message": "Search space not found."}
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.TEAM,
|
||||
target_id=search_space_id,
|
||||
content=updated_memory,
|
||||
session=db_session,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
return result.to_dict()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update team memory: %s", e)
|
||||
await db_session.rollback()
|
||||
|
|
@ -373,3 +80,11 @@ def create_update_team_memory_tool(
|
|||
}
|
||||
|
||||
return update_memory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MEMORY_HARD_LIMIT",
|
||||
"MEMORY_SOFT_LIMIT",
|
||||
"create_update_memory_tool",
|
||||
"create_update_team_memory_tool",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,4 @@
|
|||
"""Background memory extraction for the SurfSense agent.
|
||||
|
||||
After each agent response, if the agent did not call ``update_memory`` during
|
||||
the turn, this module can run a lightweight LLM call to decide whether the
|
||||
latest message contains long-term information worth persisting.
|
||||
"""
|
||||
"""Background memory extraction for the SurfSense agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -11,102 +6,11 @@ import logging
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.new_chat.tools.update_memory import _save_memory
|
||||
from app.db import SearchSpace, User, shielded_async_session
|
||||
from app.utils.content_utils import extract_text_content
|
||||
from app.db import User, shielded_async_session
|
||||
from app.services.memory import MemoryScope, extract_and_save
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MEMORY_EXTRACT_PROMPT = """\
|
||||
You are a memory extraction assistant. Analyze the user's message and decide \
|
||||
if it contains any long-term information worth persisting to memory.
|
||||
|
||||
Worth remembering: preferences, background/identity, goals, projects, \
|
||||
instructions, tools/languages they use, decisions, expertise, workplace — \
|
||||
durable facts that will matter in future conversations.
|
||||
|
||||
NOT worth remembering: greetings, one-off factual questions, session \
|
||||
logistics, ephemeral requests, follow-up clarifications with no new personal \
|
||||
info, things that only matter for the current task.
|
||||
|
||||
If the message contains memorizable information, output the FULL updated \
|
||||
memory document with the new facts merged into the existing content. Follow \
|
||||
these rules:
|
||||
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
||||
freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's
|
||||
name in headings.
|
||||
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||
[fact] = durable facts, [pref] = preferences, [instr] = standing instructions.
|
||||
- Use the user's first name (from <user_name>) in entry text, not "the user".
|
||||
- If a new fact contradicts an existing entry, update the existing entry.
|
||||
- Do not duplicate information that is already present.
|
||||
|
||||
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||
|
||||
<user_name>{user_name}</user_name>
|
||||
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
<user_message>
|
||||
{user_message}
|
||||
</user_message>"""
|
||||
|
||||
_TEAM_MEMORY_EXTRACT_PROMPT = """\
|
||||
You are a team-memory extraction assistant. Analyze the latest message and \
|
||||
decide if it contains durable TEAM-level information worth persisting.
|
||||
|
||||
Decision policy:
|
||||
- Prioritize recall for durable team context, while avoiding personal-only facts.
|
||||
- Do NOT require explicit consensus language. A direct team-level statement can
|
||||
be stored if it is stable and broadly useful for future team chats.
|
||||
- If evidence is weak or clearly tentative, output NO_UPDATE.
|
||||
|
||||
Worth remembering (team-level only):
|
||||
- Decisions and defaults that guide future team work
|
||||
- Team conventions/standards (naming, review policy, coding norms)
|
||||
- Stable org/project facts (locations, ownership, constraints)
|
||||
- Long-lived architecture/process facts
|
||||
- Ongoing priorities that are likely relevant beyond this turn
|
||||
|
||||
NOT worth remembering:
|
||||
- Personal preferences or biography of one person
|
||||
- Questions, brainstorming, tentative ideas, or speculation
|
||||
- One-off requests, status updates, TODOs, logistics for this session
|
||||
- Information scoped only to a single ephemeral task
|
||||
|
||||
If the message contains memorizable team information, output the FULL updated \
|
||||
team memory document with new facts merged into existing content. Follow rules:
|
||||
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
||||
freely. Keep heading names short (2-3 words) and natural.
|
||||
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text
|
||||
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr].
|
||||
- If a new fact contradicts an existing entry, update the existing entry.
|
||||
- Do not duplicate existing information.
|
||||
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
|
||||
|
||||
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||
|
||||
<current_team_memory>
|
||||
{current_memory}
|
||||
</current_team_memory>
|
||||
|
||||
<latest_message_author>
|
||||
{author}
|
||||
</latest_message_author>
|
||||
|
||||
<latest_message>
|
||||
{user_message}
|
||||
</latest_message>"""
|
||||
|
||||
|
||||
async def extract_and_save_memory(
|
||||
*,
|
||||
|
|
@ -114,57 +18,31 @@ async def extract_and_save_memory(
|
|||
user_id: str | None,
|
||||
llm: Any,
|
||||
) -> None:
|
||||
"""Background task: extract memorizable info and persist it.
|
||||
"""Fire-and-forget personal memory extraction.
|
||||
|
||||
Designed to be fire-and-forget — catches all exceptions internally.
|
||||
The service uses structured output, so free-form ``NO_UPDATE`` text can no
|
||||
longer be accidentally persisted as memory.
|
||||
"""
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
try:
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return
|
||||
|
||||
old_memory = user.memory_md
|
||||
first_name = (
|
||||
user.display_name.strip().split()[0]
|
||||
if user.display_name and user.display_name.strip()
|
||||
else "The user"
|
||||
)
|
||||
prompt = _MEMORY_EXTRACT_PROMPT.format(
|
||||
current_memory=old_memory or "(empty)",
|
||||
user = await session.get(User, uid)
|
||||
actor_display_name = user.display_name if user else None
|
||||
result = await extract_and_save(
|
||||
scope=MemoryScope.USER,
|
||||
target_id=uid,
|
||||
user_message=user_message,
|
||||
user_name=first_name,
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "memory-extraction"]},
|
||||
)
|
||||
text = extract_text_content(response.content).strip()
|
||||
|
||||
if text == "NO_UPDATE" or not text:
|
||||
logger.debug("Memory extraction: no update needed (user %s)", uid)
|
||||
return
|
||||
|
||||
save_result = await _save_memory(
|
||||
updated_memory=text,
|
||||
old_memory=old_memory,
|
||||
actor_display_name=actor_display_name,
|
||||
session=session,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
logger.info(
|
||||
"Background memory extraction for user %s: %s",
|
||||
uid,
|
||||
save_result.get("status"),
|
||||
result.status,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Background user memory extraction failed")
|
||||
|
|
@ -177,56 +55,24 @@ async def extract_and_save_team_memory(
|
|||
llm: Any,
|
||||
author_display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Background task: extract team-level memory and persist it.
|
||||
|
||||
Runs only for shared threads. Designed to be fire-and-forget and catches
|
||||
exceptions internally.
|
||||
"""
|
||||
"""Fire-and-forget team-level memory extraction."""
|
||||
if not search_space_id:
|
||||
return
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
|
||||
current_memory=old_memory or "(empty)",
|
||||
author=author_display_name or "Unknown team member",
|
||||
result = await extract_and_save(
|
||||
scope=MemoryScope.TEAM,
|
||||
target_id=search_space_id,
|
||||
user_message=user_message,
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
|
||||
)
|
||||
text = extract_text_content(response.content).strip()
|
||||
|
||||
if text == "NO_UPDATE" or not text:
|
||||
logger.debug(
|
||||
"Team memory extraction: no update needed (space %s)",
|
||||
search_space_id,
|
||||
)
|
||||
return
|
||||
|
||||
save_result = await _save_memory(
|
||||
updated_memory=text,
|
||||
old_memory=old_memory,
|
||||
actor_display_name=author_display_name,
|
||||
session=session,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
logger.info(
|
||||
"Background team memory extraction for space %s: %s",
|
||||
search_space_id,
|
||||
save_result.get("status"),
|
||||
result.status,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Background team memory extraction failed")
|
||||
|
|
|
|||
|
|
@ -1,369 +1,53 @@
|
|||
"""Markdown-document memory tool for the SurfSense agent.
|
||||
|
||||
Replaces the old row-per-fact save_memory / recall_memory tools with a single
|
||||
update_memory tool that overwrites a freeform markdown TEXT column. The LLM
|
||||
always sees the current memory in <user_memory> / <team_memory> tags injected
|
||||
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
|
||||
|
||||
Overflow handling:
|
||||
- Soft limit (18K chars): a warning is returned telling the agent to
|
||||
consolidate on the next update.
|
||||
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
|
||||
If it still exceeds the limit after rewriting, the save is rejected.
|
||||
- Diff validation: warns when entire ``##`` sections are dropped or when the
|
||||
document shrinks by more than 60%.
|
||||
"""
|
||||
"""Memory update tools backed by the canonical memory service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSpace, User, async_session_maker
|
||||
from app.utils.content_utils import extract_text_content
|
||||
from app.db import async_session_maker
|
||||
from app.services.memory import MemoryScope, save_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_SOFT_LIMIT = 18_000
|
||||
MEMORY_HARD_LIMIT = 25_000
|
||||
|
||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
||||
|
||||
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
|
||||
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
|
||||
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diff validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_headings(memory: str) -> set[str]:
|
||||
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
|
||||
return set(_SECTION_HEADING_RE.findall(memory))
|
||||
|
||||
|
||||
def _normalize_heading(heading: str) -> str:
|
||||
"""Normalize heading text for robust scope checks."""
|
||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
||||
|
||||
|
||||
def _validate_memory_scope(
|
||||
content: str, scope: Literal["user", "team"]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Reject personal-only markers ([pref], [instr]) in team memory."""
|
||||
if scope != "team":
|
||||
return None
|
||||
|
||||
markers = set(_MARKER_RE.findall(content))
|
||||
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
|
||||
if leaked:
|
||||
tags = ", ".join(f"[{m}]" for m in leaked)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Team memory cannot include personal markers: {tags}. "
|
||||
"Use [fact] only in team memory."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _validate_bullet_format(content: str) -> list[str]:
|
||||
"""Return warnings for bullet lines that don't match the required format.
|
||||
|
||||
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
|
||||
"""
|
||||
warnings: list[str] = []
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped.startswith("- "):
|
||||
continue
|
||||
if not _BULLET_FORMAT_RE.match(stripped):
|
||||
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
||||
warnings.append(f"Malformed bullet: {short}")
|
||||
return warnings
|
||||
|
||||
|
||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||
"""Return a list of warning strings about suspicious changes."""
|
||||
if not old_memory:
|
||||
return []
|
||||
|
||||
warnings: list[str] = []
|
||||
old_headings = _extract_headings(old_memory)
|
||||
new_headings = _extract_headings(new_memory)
|
||||
dropped = old_headings - new_headings
|
||||
if dropped:
|
||||
names = ", ".join(sorted(dropped))
|
||||
warnings.append(
|
||||
f"Sections removed: {names}. "
|
||||
"If unintentional, the user can restore from the settings page."
|
||||
)
|
||||
|
||||
old_len = len(old_memory)
|
||||
new_len = len(new_memory)
|
||||
if old_len > 0 and new_len < old_len * 0.4:
|
||||
warnings.append(
|
||||
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
|
||||
"Possible data loss."
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Size validation & soft warning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_memory_size(content: str) -> dict[str, Any] | None:
|
||||
"""Return an error/warning dict if *content* is too large, else None."""
|
||||
length = len(content)
|
||||
if length > MEMORY_HARD_LIMIT:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
||||
f"({length:,} chars). Consolidate by merging related items, "
|
||||
"removing outdated entries, and shortening descriptions. "
|
||||
"Then call update_memory again."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _soft_warning(content: str) -> str | None:
|
||||
"""Return a warning string if content exceeds the soft limit."""
|
||||
length = len(content)
|
||||
if length > MEMORY_SOFT_LIMIT:
|
||||
return (
|
||||
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
||||
"Consolidate by merging related items and removing less important "
|
||||
"entries on your next update."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forced rewrite when memory exceeds the hard limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FORCED_REWRITE_PROMPT = """\
|
||||
You are a memory curator. The following memory document exceeds the character \
|
||||
limit and must be shortened.
|
||||
|
||||
RULES:
|
||||
1. Rewrite the document to be under {target} characters.
|
||||
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
|
||||
or rename headings to consolidate, but keep names personal and descriptive.
|
||||
3. Priority for keeping content: [instr] > [pref] > [fact].
|
||||
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
||||
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||
6. Preserve the user's first name in entries — do not replace it with "the user".
|
||||
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||
|
||||
<memory_document>
|
||||
{content}
|
||||
</memory_document>"""
|
||||
|
||||
|
||||
async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
||||
"""Use a focused LLM call to compress *content* under the hard limit.
|
||||
|
||||
Returns the rewritten string, or ``None`` if the call fails.
|
||||
"""
|
||||
try:
|
||||
prompt = _FORCED_REWRITE_PROMPT.format(
|
||||
target=MEMORY_HARD_LIMIT, content=content
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
text = extract_text_content(response.content).strip()
|
||||
if not text:
|
||||
logger.warning("Forced rewrite returned empty text; aborting rewrite")
|
||||
return None
|
||||
return text
|
||||
except Exception:
|
||||
logger.exception("Forced rewrite LLM call failed")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared save-and-respond logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _save_memory(
|
||||
*,
|
||||
updated_memory: str,
|
||||
old_memory: str | None,
|
||||
llm: Any | None,
|
||||
apply_fn,
|
||||
commit_fn,
|
||||
rollback_fn,
|
||||
label: str,
|
||||
scope: Literal["user", "team"],
|
||||
) -> dict[str, Any]:
|
||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||
return a response dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
updated_memory : str
|
||||
The new document the agent submitted.
|
||||
old_memory : str | None
|
||||
The previously persisted document (for diff checks).
|
||||
llm : Any | None
|
||||
LLM instance for forced rewrite (may be ``None``).
|
||||
apply_fn : callable(str) -> None
|
||||
Callback that sets the new memory on the ORM object.
|
||||
commit_fn : coroutine
|
||||
``session.commit``.
|
||||
rollback_fn : coroutine
|
||||
``session.rollback``.
|
||||
label : str
|
||||
Human label for log messages (e.g. "user memory", "team memory").
|
||||
"""
|
||||
if not isinstance(updated_memory, str):
|
||||
logger.warning(
|
||||
"Refusing non-string memory payload (type=%s)",
|
||||
type(updated_memory).__name__,
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Internal error: memory payload must be a string.",
|
||||
}
|
||||
|
||||
content = updated_memory
|
||||
|
||||
# --- forced rewrite if over the hard limit ---
|
||||
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
|
||||
rewritten = await _forced_rewrite(content, llm)
|
||||
if rewritten is not None and len(rewritten) < len(content):
|
||||
content = rewritten
|
||||
|
||||
# --- hard-limit gate (reject if still too large after rewrite) ---
|
||||
size_err = _validate_memory_size(content)
|
||||
if size_err:
|
||||
return size_err
|
||||
|
||||
scope_err = _validate_memory_scope(content, scope)
|
||||
if scope_err:
|
||||
return scope_err
|
||||
|
||||
# --- persist ---
|
||||
try:
|
||||
apply_fn(content)
|
||||
await commit_fn()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update %s: %s", label, e)
|
||||
await rollback_fn()
|
||||
return {"status": "error", "message": f"Failed to update {label}: {e}"}
|
||||
|
||||
# --- build response ---
|
||||
resp: dict[str, Any] = {
|
||||
"status": "saved",
|
||||
"message": f"{label.capitalize()} updated.",
|
||||
}
|
||||
|
||||
if content is not updated_memory:
|
||||
resp["notice"] = "Memory was automatically rewritten to fit within limits."
|
||||
|
||||
diff_warnings = _validate_diff(old_memory, content)
|
||||
if diff_warnings:
|
||||
resp["diff_warnings"] = diff_warnings
|
||||
|
||||
format_warnings = _validate_bullet_format(content)
|
||||
if format_warnings:
|
||||
resp["format_warnings"] = format_warnings
|
||||
|
||||
warning = _soft_warning(content)
|
||||
if warning:
|
||||
resp["warning"] = warning
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_update_memory_tool(
|
||||
user_id: str | UUID,
|
||||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
"""Factory function to create the user-memory update tool.
|
||||
"""Factory for the user-memory update tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
The session's bound ``commit``/``rollback`` methods are captured at
|
||||
call time, after ``async with`` has bound ``db_session`` locally.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user whose memory document is being updated.
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
llm: Optional LLM for the forced-rewrite path.
|
||||
|
||||
Returns:
|
||||
Configured update_memory tool for the user-memory scope.
|
||||
Uses a fresh short-lived session per call so compiled-agent caches never
|
||||
retain a stale request-scoped session.
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
del db_session
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
@tool
|
||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the user's personal memory document.
|
||||
|
||||
Your current memory is shown in <user_memory> in the system prompt.
|
||||
When the user shares important long-term information (preferences,
|
||||
facts, instructions, context), rewrite the memory document to include
|
||||
the new information. Merge new facts with existing ones, update
|
||||
contradictions, remove outdated entries, and keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
The current memory is shown in <user_memory>. Pass the FULL updated
|
||||
markdown document, not a diff.
|
||||
"""
|
||||
try:
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return {"status": "error", "message": "User not found."}
|
||||
|
||||
old_memory = user.memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id=uid,
|
||||
content=updated_memory,
|
||||
session=db_session,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
return result.to_dict()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update user memory: %s", e)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update memory: {e}",
|
||||
}
|
||||
return {"status": "error", "message": f"Failed to update memory: {e}"}
|
||||
|
||||
return update_memory
|
||||
|
||||
|
|
@ -373,64 +57,26 @@ def create_update_team_memory_tool(
|
|||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
"""Factory function to create the team-memory update tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
The session's bound ``commit``/``rollback`` methods are captured at
|
||||
call time, after ``async with`` has bound ``db_session`` locally.
|
||||
|
||||
Args:
|
||||
search_space_id: ID of the search space whose team memory is being
|
||||
updated.
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
llm: Optional LLM for the forced-rewrite path.
|
||||
|
||||
Returns:
|
||||
Configured update_memory tool for the team-memory scope.
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
"""Factory for the team-memory update tool."""
|
||||
del db_session
|
||||
|
||||
@tool
|
||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the team's shared memory document for this search space.
|
||||
|
||||
Your current team memory is shown in <team_memory> in the system
|
||||
prompt. When the team shares important long-term information
|
||||
(decisions, conventions, key facts, priorities), rewrite the memory
|
||||
document to include the new information. Merge new facts with
|
||||
existing ones, update contradictions, remove outdated entries, and
|
||||
keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
The current team memory is shown in <team_memory>. Pass the FULL updated
|
||||
markdown document, not a diff.
|
||||
"""
|
||||
try:
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return {"status": "error", "message": "Search space not found."}
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.TEAM,
|
||||
target_id=search_space_id,
|
||||
content=updated_memory,
|
||||
session=db_session,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(
|
||||
space, "shared_memory_md", content
|
||||
),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
return result.to_dict()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update team memory: %s", e)
|
||||
return {
|
||||
|
|
@ -439,3 +85,9 @@ def create_update_team_memory_tool(
|
|||
}
|
||||
|
||||
return update_memory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_update_memory_tool",
|
||||
"create_update_team_memory_tool",
|
||||
]
|
||||
|
|
|
|||
29
surfsense_backend/app/services/memory/__init__.py
Normal file
29
surfsense_backend/app/services/memory/__init__.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""First-class memory service for user and team markdown memory."""
|
||||
|
||||
from .service import (
|
||||
MemoryScope,
|
||||
SaveResult,
|
||||
extract_and_save,
|
||||
read_memory,
|
||||
reset_memory,
|
||||
save_memory,
|
||||
)
|
||||
from .validation import (
|
||||
MEMORY_HARD_LIMIT,
|
||||
MEMORY_SOFT_LIMIT,
|
||||
validate_bullet_format,
|
||||
validate_memory_scope,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MEMORY_HARD_LIMIT",
|
||||
"MEMORY_SOFT_LIMIT",
|
||||
"MemoryScope",
|
||||
"SaveResult",
|
||||
"extract_and_save",
|
||||
"read_memory",
|
||||
"reset_memory",
|
||||
"save_memory",
|
||||
"validate_bullet_format",
|
||||
"validate_memory_scope",
|
||||
]
|
||||
110
surfsense_backend/app/services/memory/prompts.py
Normal file
110
surfsense_backend/app/services/memory/prompts.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Prompts used by the memory service."""
|
||||
|
||||
FORCED_REWRITE_PROMPT = """\
|
||||
You are a memory curator. The following memory document exceeds the character \
|
||||
limit and must be shortened.
|
||||
|
||||
RULES:
|
||||
1. Rewrite the document to be under {target} characters.
|
||||
2. Output Markdown only. Use clear `##` headings and concise bullet points.
|
||||
3. New-format bullets should look like: `- YYYY-MM-DD: memory text`.
|
||||
4. If the input contains legacy markers like `(YYYY-MM-DD) [fact]`, preserve the
|
||||
information but remove the inline marker in the output.
|
||||
5. Preserve durable instructions and preferences before generic facts when
|
||||
compressing personal memory.
|
||||
6. Preserve existing headings when useful; merge duplicate headings and bullets.
|
||||
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||
|
||||
<memory_document>
|
||||
{content}
|
||||
</memory_document>"""
|
||||
|
||||
USER_MEMORY_EXTRACT_PROMPT = """\
|
||||
You are a memory extraction assistant. Analyze the user's message and decide \
|
||||
if it contains any long-term information worth persisting to personal memory.
|
||||
|
||||
Worth remembering: preferences, background/identity, goals, projects, \
|
||||
instructions, tools/languages they use, decisions, expertise, workplace — \
|
||||
durable facts that will matter in future conversations.
|
||||
|
||||
NOT worth remembering: greetings, one-off factual questions, session \
|
||||
logistics, ephemeral requests, follow-up clarifications with no new personal \
|
||||
info, things that only matter for the current task.
|
||||
|
||||
If there is nothing durable to remember, choose `action = no_update`.
|
||||
|
||||
If the message contains memorizable information, choose `action = save` and \
|
||||
return the FULL updated memory document with the new information merged into \
|
||||
existing content.
|
||||
|
||||
FORMAT RULES FOR `updated_memory`:
|
||||
- Markdown only.
|
||||
- Every entry should be under a `##` heading.
|
||||
- Recommended headings: `## Facts`, `## Preferences`, `## Instructions`.
|
||||
- New bullets should use: `- YYYY-MM-DD: memory text`.
|
||||
- If current memory uses legacy `(YYYY-MM-DD) [fact|pref|instr]` markers,
|
||||
preserve the information but write the updated document in the new
|
||||
heading-based format.
|
||||
- Use the user's first name from `<user_name>` when helpful, not "the user".
|
||||
- Do not duplicate existing information.
|
||||
|
||||
<user_name>{user_name}</user_name>
|
||||
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
<user_message>
|
||||
{user_message}
|
||||
</user_message>"""
|
||||
|
||||
TEAM_MEMORY_EXTRACT_PROMPT = """\
|
||||
You are a team-memory extraction assistant. Analyze the latest message and \
|
||||
decide if it contains durable TEAM-level information worth persisting.
|
||||
|
||||
Decision policy:
|
||||
- Prioritize recall for durable team context, while avoiding personal-only facts.
|
||||
- Do NOT require explicit consensus language. A direct team-level statement can
|
||||
be stored if it is stable and broadly useful for future team chats.
|
||||
- If evidence is weak or clearly tentative, choose `action = no_update`.
|
||||
|
||||
Worth remembering (team-level only):
|
||||
- Decisions and defaults that guide future team work
|
||||
- Team conventions/standards (naming, review policy, coding norms)
|
||||
- Stable org/project facts (locations, ownership, constraints)
|
||||
- Long-lived architecture/process facts
|
||||
- Ongoing priorities that are likely relevant beyond this turn
|
||||
|
||||
NOT worth remembering:
|
||||
- Personal preferences or biography of one person
|
||||
- Questions, brainstorming, tentative ideas, or speculation
|
||||
- One-off requests, status updates, TODOs, logistics for this session
|
||||
- Information scoped only to a single ephemeral task
|
||||
|
||||
If the message contains memorizable team information, choose `action = save` \
|
||||
and return the FULL updated team memory document with new facts merged into \
|
||||
existing content.
|
||||
|
||||
FORMAT RULES FOR `updated_memory`:
|
||||
- Markdown only.
|
||||
- Every entry should be under a `##` heading.
|
||||
- Recommended headings: `## Product Decisions`, `## Engineering Conventions`,
|
||||
`## Project Facts`, `## Open Questions`.
|
||||
- New bullets should use: `- YYYY-MM-DD: memory text`.
|
||||
- If current memory uses legacy `(YYYY-MM-DD) [fact]` markers, preserve the
|
||||
information but write the updated document in the new heading-based format.
|
||||
- Do not create personal headings such as `## Preferences`, `## Instructions`,
|
||||
or `## Personal Notes`.
|
||||
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
|
||||
|
||||
<current_team_memory>
|
||||
{current_memory}
|
||||
</current_team_memory>
|
||||
|
||||
<latest_message_author>
|
||||
{author}
|
||||
</latest_message_author>
|
||||
|
||||
<latest_message>
|
||||
{user_message}
|
||||
</latest_message>"""
|
||||
35
surfsense_backend/app/services/memory/rewrite.py
Normal file
35
surfsense_backend/app/services/memory/rewrite.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""LLM-backed memory rewrite helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.services.memory.prompts import FORCED_REWRITE_PROMPT
|
||||
from app.services.memory.validation import MEMORY_HARD_LIMIT
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def forced_rewrite(content: str, llm: Any) -> str | None:
|
||||
"""Use a focused LLM call to compress memory under the hard limit."""
|
||||
try:
|
||||
prompt = FORCED_REWRITE_PROMPT.format(
|
||||
target=MEMORY_HARD_LIMIT,
|
||||
content=content,
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "memory-rewrite"]},
|
||||
)
|
||||
text = extract_text_content(response.content).strip()
|
||||
if not text:
|
||||
logger.warning("Forced memory rewrite returned empty text")
|
||||
return None
|
||||
return text
|
||||
except Exception:
|
||||
logger.exception("Forced memory rewrite LLM call failed")
|
||||
return None
|
||||
23
surfsense_backend/app/services/memory/schemas.py
Normal file
23
surfsense_backend/app/services/memory/schemas.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""Structured output schemas for memory extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MemoryExtractionDecision(BaseModel):
|
||||
"""Structured extraction result; avoids string sentinel parsing."""
|
||||
|
||||
action: Literal["no_update", "save"] = Field(
|
||||
description="Choose no_update when nothing durable should be saved; choose save otherwise."
|
||||
)
|
||||
reason: str | None = Field(
|
||||
default=None,
|
||||
description="Short reason for no_update, or brief summary of the memory update.",
|
||||
)
|
||||
updated_memory: str | None = Field(
|
||||
default=None,
|
||||
description="The full updated markdown memory document when action is save.",
|
||||
)
|
||||
300
surfsense_backend/app/services/memory/service.py
Normal file
300
surfsense_backend/app/services/memory/service.py
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
"""Canonical read/write/reset/extract service for markdown memory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSpace, User
|
||||
from app.services.memory.prompts import (
|
||||
TEAM_MEMORY_EXTRACT_PROMPT,
|
||||
USER_MEMORY_EXTRACT_PROMPT,
|
||||
)
|
||||
from app.services.memory.rewrite import forced_rewrite
|
||||
from app.services.memory.schemas import MemoryExtractionDecision
|
||||
from app.services.memory.validation import (
|
||||
MEMORY_HARD_LIMIT,
|
||||
soft_limit_warning,
|
||||
strip_preamble_to_first_heading,
|
||||
validate_bullet_format,
|
||||
validate_diff,
|
||||
validate_heading_sanity,
|
||||
validate_memory_scope,
|
||||
validate_memory_size,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryScope(StrEnum):
|
||||
USER = "user"
|
||||
TEAM = "team"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SaveResult:
|
||||
status: Literal["saved", "error", "no_op"]
|
||||
message: str
|
||||
memory_md: str = ""
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
diff_warnings: list[str] = field(default_factory=list)
|
||||
format_warnings: list[str] = field(default_factory=list)
|
||||
notice: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {
|
||||
"status": self.status,
|
||||
"message": self.message,
|
||||
"memory_md": self.memory_md,
|
||||
}
|
||||
if self.notice:
|
||||
data["notice"] = self.notice
|
||||
if self.warnings:
|
||||
data["warnings"] = self.warnings
|
||||
if len(self.warnings) == 1:
|
||||
data["warning"] = self.warnings[0]
|
||||
if self.diff_warnings:
|
||||
data["diff_warnings"] = self.diff_warnings
|
||||
if self.format_warnings:
|
||||
data["format_warnings"] = self.format_warnings
|
||||
return data
|
||||
|
||||
|
||||
class MemoryRead(BaseModel):
|
||||
memory_md: str
|
||||
|
||||
|
||||
def _normalize_scope(scope: MemoryScope | str) -> MemoryScope:
|
||||
return scope if isinstance(scope, MemoryScope) else MemoryScope(scope)
|
||||
|
||||
|
||||
def _normalize_user_id(target_id: str | UUID) -> UUID:
|
||||
return UUID(target_id) if isinstance(target_id, str) else target_id
|
||||
|
||||
|
||||
async def _load_target(
|
||||
*,
|
||||
scope: MemoryScope | str,
|
||||
target_id: str | int | UUID,
|
||||
session: AsyncSession,
|
||||
) -> User | SearchSpace | None:
|
||||
normalized = _normalize_scope(scope)
|
||||
if normalized is MemoryScope.USER:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == _normalize_user_id(target_id)) # type: ignore[arg-type]
|
||||
)
|
||||
return result.scalars().first()
|
||||
result = await session.execute(select(SearchSpace).where(SearchSpace.id == int(target_id)))
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def _get_memory(target: User | SearchSpace, scope: MemoryScope) -> str:
|
||||
if scope is MemoryScope.USER:
|
||||
return getattr(target, "memory_md", None) or ""
|
||||
return getattr(target, "shared_memory_md", None) or ""
|
||||
|
||||
|
||||
def _set_memory(target: User | SearchSpace, scope: MemoryScope, content: str) -> None:
|
||||
if scope is MemoryScope.USER:
|
||||
target.memory_md = content
|
||||
else:
|
||||
target.shared_memory_md = content
|
||||
|
||||
|
||||
async def read_memory(
|
||||
*,
|
||||
scope: MemoryScope | str,
|
||||
target_id: str | int | UUID,
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
normalized = _normalize_scope(scope)
|
||||
target = await _load_target(scope=normalized, target_id=target_id, session=session)
|
||||
if target is None:
|
||||
return ""
|
||||
return _get_memory(target, normalized)
|
||||
|
||||
|
||||
async def save_memory(
|
||||
*,
|
||||
scope: MemoryScope | str,
|
||||
target_id: str | int | UUID,
|
||||
content: str,
|
||||
session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
) -> SaveResult:
|
||||
normalized = _normalize_scope(scope)
|
||||
if not isinstance(content, str):
|
||||
return SaveResult(
|
||||
status="error",
|
||||
message="Internal error: memory payload must be a string.",
|
||||
)
|
||||
|
||||
target = await _load_target(scope=normalized, target_id=target_id, session=session)
|
||||
if target is None:
|
||||
return SaveResult(
|
||||
status="error",
|
||||
message="User not found." if normalized is MemoryScope.USER else "Search space not found.",
|
||||
)
|
||||
|
||||
old_memory = _get_memory(target, normalized)
|
||||
next_content = strip_preamble_to_first_heading(content.strip())
|
||||
notice: str | None = None
|
||||
warnings: list[str] = []
|
||||
|
||||
if len(next_content) > MEMORY_HARD_LIMIT and llm is not None:
|
||||
rewritten = await forced_rewrite(next_content, llm)
|
||||
if rewritten is not None and len(rewritten) < len(next_content):
|
||||
next_content = strip_preamble_to_first_heading(rewritten)
|
||||
notice = "Memory was automatically rewritten to fit within limits."
|
||||
|
||||
for validation in (
|
||||
validate_memory_size(next_content),
|
||||
validate_heading_sanity(next_content),
|
||||
):
|
||||
if validation:
|
||||
return SaveResult(
|
||||
status="error",
|
||||
message=validation["message"],
|
||||
memory_md=old_memory,
|
||||
)
|
||||
|
||||
scope_error, scope_warnings = validate_memory_scope(
|
||||
next_content,
|
||||
normalized.value,
|
||||
old_memory=old_memory,
|
||||
)
|
||||
warnings.extend(scope_warnings)
|
||||
if scope_error:
|
||||
return SaveResult(
|
||||
status="error",
|
||||
message=scope_error["message"],
|
||||
memory_md=old_memory,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
try:
|
||||
_set_memory(target, normalized, next_content)
|
||||
session.add(target)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update %s memory: %s", normalized.value, e)
|
||||
await session.rollback()
|
||||
return SaveResult(
|
||||
status="error",
|
||||
message=f"Failed to update {normalized.value} memory: {e}",
|
||||
memory_md=old_memory,
|
||||
)
|
||||
|
||||
diff_warnings = validate_diff(old_memory, next_content)
|
||||
format_warnings = validate_bullet_format(next_content)
|
||||
warning = soft_limit_warning(next_content)
|
||||
if warning:
|
||||
warnings.append(warning)
|
||||
|
||||
return SaveResult(
|
||||
status="saved",
|
||||
message=(
|
||||
"Memory updated."
|
||||
if normalized is MemoryScope.USER
|
||||
else "Team memory updated."
|
||||
),
|
||||
memory_md=next_content,
|
||||
warnings=warnings,
|
||||
diff_warnings=diff_warnings,
|
||||
format_warnings=format_warnings,
|
||||
notice=notice,
|
||||
)
|
||||
|
||||
|
||||
async def reset_memory(
|
||||
*,
|
||||
scope: MemoryScope | str,
|
||||
target_id: str | int | UUID,
|
||||
session: AsyncSession,
|
||||
) -> SaveResult:
|
||||
return await save_memory(
|
||||
scope=scope,
|
||||
target_id=target_id,
|
||||
content="",
|
||||
session=session,
|
||||
llm=None,
|
||||
)
|
||||
|
||||
|
||||
async def extract_and_save(
|
||||
*,
|
||||
scope: MemoryScope | str,
|
||||
target_id: str | int | UUID,
|
||||
user_message: str,
|
||||
actor_display_name: str | None,
|
||||
session: AsyncSession,
|
||||
llm: Any,
|
||||
) -> SaveResult:
|
||||
normalized = _normalize_scope(scope)
|
||||
current_memory = await read_memory(
|
||||
scope=normalized,
|
||||
target_id=target_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
if normalized is MemoryScope.USER:
|
||||
first_name = (
|
||||
actor_display_name.strip().split()[0]
|
||||
if actor_display_name and actor_display_name.strip()
|
||||
else "The user"
|
||||
)
|
||||
prompt = USER_MEMORY_EXTRACT_PROMPT.format(
|
||||
current_memory=current_memory or "(empty)",
|
||||
user_message=user_message,
|
||||
user_name=first_name,
|
||||
)
|
||||
else:
|
||||
prompt = TEAM_MEMORY_EXTRACT_PROMPT.format(
|
||||
current_memory=current_memory or "(empty)",
|
||||
author=actor_display_name or "Unknown team member",
|
||||
user_message=user_message,
|
||||
)
|
||||
|
||||
try:
|
||||
structured = llm.with_structured_output(MemoryExtractionDecision)
|
||||
decision = await structured.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "memory-extraction"]},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Structured memory extraction failed")
|
||||
return SaveResult(
|
||||
status="error",
|
||||
message="Structured memory extraction failed.",
|
||||
memory_md=current_memory,
|
||||
)
|
||||
|
||||
if decision.action == "no_update":
|
||||
return SaveResult(
|
||||
status="no_op",
|
||||
message=decision.reason or "No durable memory to persist.",
|
||||
memory_md=current_memory,
|
||||
)
|
||||
|
||||
if not decision.updated_memory:
|
||||
return SaveResult(
|
||||
status="error",
|
||||
message="Structured memory extraction chose save without updated_memory.",
|
||||
memory_md=current_memory,
|
||||
)
|
||||
|
||||
return await save_memory(
|
||||
scope=normalized,
|
||||
target_id=target_id,
|
||||
content=decision.updated_memory,
|
||||
session=session,
|
||||
llm=llm,
|
||||
)
|
||||
158
surfsense_backend/app/services/memory/validation.py
Normal file
158
surfsense_backend/app/services/memory/validation.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""Validation helpers for markdown-backed memory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
MEMORY_SOFT_LIMIT = 18_000
|
||||
MEMORY_HARD_LIMIT = 25_000
|
||||
|
||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||
_HEADING_LINE_RE = re.compile(r"^##\s+\S+", re.MULTILINE)
|
||||
_HEADING_NORMALIZE_RE = re.compile(r"[^a-z0-9]+")
|
||||
_LEGACY_BULLET_RE = re.compile(r"^-\s+\(\d{4}-\d{2}-\d{2}\)\s+\[(fact|pref|instr)\]\s+.+$")
|
||||
_NEW_BULLET_RE = re.compile(r"^-\s+\d{4}-\d{2}-\d{2}:\s+.+$")
|
||||
|
||||
_FORBIDDEN_TEAM_HEADINGS = {
|
||||
"preferences",
|
||||
"instructions",
|
||||
"personal notes",
|
||||
"personal instructions",
|
||||
}
|
||||
|
||||
|
||||
def has_markdown_heading(content: str) -> bool:
|
||||
return bool(_HEADING_LINE_RE.search(content))
|
||||
|
||||
|
||||
def strip_preamble_to_first_heading(content: str) -> str:
|
||||
"""Drop model preamble before the first ``##`` heading, if one exists."""
|
||||
match = _HEADING_LINE_RE.search(content)
|
||||
if not match:
|
||||
return content.strip()
|
||||
return content[match.start() :].strip()
|
||||
|
||||
|
||||
def extract_headings(memory: str | None) -> set[str]:
|
||||
if not memory:
|
||||
return set()
|
||||
return {_normalize_heading(h) for h in _SECTION_HEADING_RE.findall(memory)}
|
||||
|
||||
|
||||
def _normalize_heading(heading: str) -> str:
|
||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower()).strip()
|
||||
|
||||
|
||||
def validate_memory_size(content: str) -> dict[str, str] | None:
|
||||
length = len(content)
|
||||
if length > MEMORY_HARD_LIMIT:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
||||
f"({length:,} chars). Consolidate by merging related items, "
|
||||
"removing outdated entries, and shortening descriptions."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def validate_heading_sanity(content: str) -> dict[str, str] | None:
|
||||
"""Block long prose blobs without headings unless they are legacy bullets."""
|
||||
stripped = content.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
if has_markdown_heading(stripped):
|
||||
return None
|
||||
if len(stripped) <= 40:
|
||||
return None
|
||||
if any(_LEGACY_BULLET_RE.match(line.strip()) for line in stripped.splitlines()):
|
||||
return None
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Memory must be markdown with at least one ## heading.",
|
||||
}
|
||||
|
||||
|
||||
def validate_memory_scope(
|
||||
content: str,
|
||||
scope: Literal["user", "team"],
|
||||
*,
|
||||
old_memory: str | None = None,
|
||||
) -> tuple[dict[str, str] | None, list[str]]:
|
||||
"""Reject new personal headings in team memory, grandfather existing ones."""
|
||||
if scope != "team":
|
||||
return None, []
|
||||
|
||||
old_forbidden = extract_headings(old_memory) & _FORBIDDEN_TEAM_HEADINGS
|
||||
new_forbidden = extract_headings(content) & _FORBIDDEN_TEAM_HEADINGS
|
||||
introduced = sorted(new_forbidden - old_forbidden)
|
||||
grandfathered = sorted(new_forbidden & old_forbidden)
|
||||
|
||||
warnings: list[str] = []
|
||||
if grandfathered:
|
||||
warnings.append(
|
||||
"Team memory contains legacy personal headings: "
|
||||
+ ", ".join(grandfathered)
|
||||
+ ". Please consolidate them into team-safe headings."
|
||||
)
|
||||
if introduced:
|
||||
return (
|
||||
{
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Team memory cannot introduce personal headings: "
|
||||
+ ", ".join(introduced)
|
||||
+ ". Use team-safe headings instead."
|
||||
),
|
||||
},
|
||||
warnings,
|
||||
)
|
||||
return None, warnings
|
||||
|
||||
|
||||
def validate_bullet_format(content: str) -> list[str]:
|
||||
warnings: list[str] = []
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped.startswith("- "):
|
||||
continue
|
||||
if _NEW_BULLET_RE.match(stripped) or _LEGACY_BULLET_RE.match(stripped):
|
||||
continue
|
||||
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
||||
warnings.append(f"Non-standard memory bullet: {short}")
|
||||
return warnings
|
||||
|
||||
|
||||
def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||
if not old_memory:
|
||||
return []
|
||||
|
||||
warnings: list[str] = []
|
||||
old_headings = extract_headings(old_memory)
|
||||
new_headings = extract_headings(new_memory)
|
||||
dropped = old_headings - new_headings
|
||||
if dropped:
|
||||
names = ", ".join(sorted(dropped))
|
||||
warnings.append(
|
||||
f"Sections removed: {names}. If unintentional, restore from the settings page."
|
||||
)
|
||||
|
||||
old_len = len(old_memory)
|
||||
new_len = len(new_memory)
|
||||
if old_len > 0 and new_len < old_len * 0.4:
|
||||
warnings.append(
|
||||
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). Possible data loss."
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
def soft_limit_warning(content: str) -> str | None:
|
||||
length = len(content)
|
||||
if length > MEMORY_SOFT_LIMIT:
|
||||
return (
|
||||
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
||||
"Consolidate by merging related items and removing less important entries."
|
||||
)
|
||||
return None
|
||||
204
surfsense_backend/tests/unit/services/test_memory_service.py
Normal file
204
surfsense_backend/tests/unit/services/test_memory_service.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
"""Unit tests for the first-class memory service."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory import (
|
||||
MemoryScope,
|
||||
extract_and_save,
|
||||
reset_memory,
|
||||
save_memory,
|
||||
)
|
||||
from app.services.memory.schemas import MemoryExtractionDecision
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.commit_calls = 0
|
||||
self.rollback_calls = 0
|
||||
self.added = []
|
||||
|
||||
def add(self, obj) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commit_calls += 1
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollback_calls += 1
|
||||
|
||||
|
||||
class _StructuredLLM:
|
||||
def __init__(self, decision: MemoryExtractionDecision) -> None:
|
||||
self.decision = decision
|
||||
|
||||
def with_structured_output(self, _schema):
|
||||
return self
|
||||
|
||||
async def ainvoke(self, *_args, **_kwargs):
|
||||
return self.decision
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_saves_heading_based_memory(monkeypatch) -> None:
|
||||
target = SimpleNamespace(memory_md="")
|
||||
session = _FakeSession()
|
||||
|
||||
async def fake_load_target(**_kwargs):
|
||||
return target
|
||||
|
||||
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id="00000000-0000-0000-0000-000000000000",
|
||||
content="## Facts\n- 2026-05-19: Anish works on SurfSense\n",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert result.status == "saved"
|
||||
assert target.memory_md.startswith("## Facts")
|
||||
assert session.commit_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_accepts_legacy_marker_payload(monkeypatch) -> None:
|
||||
target = SimpleNamespace(memory_md="")
|
||||
session = _FakeSession()
|
||||
|
||||
async def fake_load_target(**_kwargs):
|
||||
return target
|
||||
|
||||
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id="00000000-0000-0000-0000-000000000000",
|
||||
content="- (2026-05-19) [fact] Legacy marker memory\n",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert result.status == "saved"
|
||||
assert "[fact]" in target.memory_md
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_rejects_long_no_heading_payload(monkeypatch) -> None:
|
||||
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
|
||||
session = _FakeSession()
|
||||
|
||||
async def fake_load_target(**_kwargs):
|
||||
return target
|
||||
|
||||
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id="00000000-0000-0000-0000-000000000000",
|
||||
content="reasoning text before NO_UPDATE should not become saved memory",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert result.status == "error"
|
||||
assert session.commit_calls == 0
|
||||
assert target.memory_md.startswith("## Facts")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_grandfathers_existing_team_personal_heading(monkeypatch) -> None:
|
||||
content = "## Preferences\n- 2026-05-19: Existing legacy heading\n"
|
||||
target = SimpleNamespace(shared_memory_md=content)
|
||||
session = _FakeSession()
|
||||
|
||||
async def fake_load_target(**_kwargs):
|
||||
return target
|
||||
|
||||
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.TEAM,
|
||||
target_id=1,
|
||||
content=content,
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert result.status == "saved"
|
||||
assert result.warnings
|
||||
assert session.commit_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_memory_clears_memory(monkeypatch) -> None:
|
||||
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
|
||||
session = _FakeSession()
|
||||
|
||||
async def fake_load_target(**_kwargs):
|
||||
return target
|
||||
|
||||
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||
|
||||
result = await reset_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id="00000000-0000-0000-0000-000000000000",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert result.status == "saved"
|
||||
assert target.memory_md == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_and_save_no_update_does_not_commit(monkeypatch) -> None:
|
||||
target = SimpleNamespace(memory_md="## Facts\n- 2026-05-19: Existing\n")
|
||||
session = _FakeSession()
|
||||
|
||||
async def fake_load_target(**_kwargs):
|
||||
return target
|
||||
|
||||
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||
|
||||
result = await extract_and_save(
|
||||
scope=MemoryScope.USER,
|
||||
target_id="00000000-0000-0000-0000-000000000000",
|
||||
user_message="hello",
|
||||
actor_display_name="Anish",
|
||||
session=session,
|
||||
llm=_StructuredLLM(
|
||||
MemoryExtractionDecision(action="no_update", reason="Greeting only")
|
||||
),
|
||||
)
|
||||
|
||||
assert result.status == "no_op"
|
||||
assert session.commit_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_and_save_persists_structured_update(monkeypatch) -> None:
|
||||
target = SimpleNamespace(memory_md="")
|
||||
session = _FakeSession()
|
||||
|
||||
async def fake_load_target(**_kwargs):
|
||||
return target
|
||||
|
||||
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||
|
||||
result = await extract_and_save(
|
||||
scope=MemoryScope.USER,
|
||||
target_id="00000000-0000-0000-0000-000000000000",
|
||||
user_message="I work on SurfSense",
|
||||
actor_display_name="Anish",
|
||||
session=session,
|
||||
llm=_StructuredLLM(
|
||||
MemoryExtractionDecision(
|
||||
action="save",
|
||||
updated_memory="## Facts\n- 2026-05-19: Anish works on SurfSense\n",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
assert result.status == "saved"
|
||||
assert "SurfSense" in target.memory_md
|
||||
assert session.commit_calls == 1
|
||||
Loading…
Add table
Add a link
Reference in a new issue