feat: add memory document model and parsing functionality for markdown handling

This commit is contained in:
Anish Sarkar 2026-05-20 13:20:05 +05:30
parent fe07de3f9c
commit a0ff86e0e8
5 changed files with 241 additions and 37 deletions

View file

@ -0,0 +1,200 @@
"""Memory-specific markdown document model and canonical renderer.
This intentionally parses only SurfSense memory's small markdown contract:
``##`` sections with dated bullet items. Unknown lines are preserved so user
edits are not lost, while legacy marker bullets are normalized on render.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import date
DEFAULT_LEGACY_SECTION = "Memory"
LEGACY_MARKERS = frozenset({"fact", "pref", "instr"})
@dataclass(frozen=True)
class MemoryBullet:
entry_date: date
text: str
@dataclass(frozen=True)
class MemoryRawLine:
text: str
MemoryLine = MemoryBullet | MemoryRawLine
@dataclass(frozen=True)
class MemorySection:
heading: str
lines: list[MemoryLine] = field(default_factory=list)
explicit_heading: bool = True
@dataclass(frozen=True)
class MemoryDocument:
sections: list[MemorySection] = field(default_factory=list)
@property
def has_explicit_heading(self) -> bool:
return any(section.explicit_heading for section in self.sections)
def is_section_heading(line: str) -> bool:
return line.startswith("## ") and bool(line[3:].strip())
def heading_text(line: str) -> str:
return line[3:].strip()
def normalize_heading(heading: str) -> str:
chars: list[str] = []
previous_was_space = True
for char in heading.strip().lower():
if char.isalnum():
chars.append(char)
previous_was_space = False
elif not previous_was_space:
chars.append(" ")
previous_was_space = True
return "".join(chars).strip()
def parse_bullet_line(line: str) -> MemoryBullet | None:
stripped = line.strip()
if not stripped.startswith("- "):
return None
body = stripped[2:]
parsed = _parse_canonical_bullet(body)
if parsed is not None:
return parsed
return _parse_legacy_bullet(body)
def _parse_canonical_bullet(body: str) -> MemoryBullet | None:
if len(body) < 13 or body[10:12] != ": ":
return None
try:
entry_date = date.fromisoformat(body[:10])
except ValueError:
return None
text = body[12:].strip()
if not text:
return None
return MemoryBullet(entry_date=entry_date, text=text)
def _parse_legacy_bullet(body: str) -> MemoryBullet | None:
if len(body) < 20 or not body.startswith("("):
return None
if len(body) < 14 or body[11:14] != ") [":
return None
try:
entry_date = date.fromisoformat(body[1:11])
except ValueError:
return None
marker_end = body.find("] ", 14)
if marker_end == -1:
return None
marker = body[14:marker_end]
if marker not in LEGACY_MARKERS:
return None
text = body[marker_end + 2 :].strip()
if not text:
return None
return MemoryBullet(entry_date=entry_date, text=text)
def parse_memory_document(content: str | None) -> MemoryDocument:
if not content:
return MemoryDocument()
sections: list[MemorySection] = []
current_heading: str | None = None
current_explicit = True
current_lines: list[MemoryLine] = []
def flush_current() -> None:
nonlocal current_heading, current_explicit, current_lines
if current_heading is None:
return
sections.append(
MemorySection(
heading=current_heading,
lines=current_lines,
explicit_heading=current_explicit,
)
)
current_heading = None
current_explicit = True
current_lines = []
for raw_line in content.strip().splitlines():
line = raw_line.rstrip()
if is_section_heading(line):
flush_current()
current_heading = heading_text(line)
current_explicit = True
current_lines = []
continue
bullet = parse_bullet_line(line)
if current_heading is None:
if bullet is None:
continue
current_heading = DEFAULT_LEGACY_SECTION
current_explicit = False
current_lines = [bullet]
continue
current_lines.append(bullet if bullet is not None else MemoryRawLine(text=line))
flush_current()
return MemoryDocument(sections=sections)
def render_memory_document(document: MemoryDocument) -> str:
rendered_sections: list[str] = []
for section in document.sections:
section_lines = [f"## {section.heading}"]
for line in section.lines:
if isinstance(line, MemoryBullet):
section_lines.append(f"- {line.entry_date.isoformat()}: {line.text}")
else:
section_lines.append(line.text)
rendered_sections.append("\n".join(section_lines).strip())
return "\n\n".join(section for section in rendered_sections if section).strip()
def extract_headings(memory: str | None) -> set[str]:
document = parse_memory_document(memory)
return {
normalize_heading(section.heading)
for section in document.sections
if section.explicit_heading
}
def has_explicit_heading(content: str) -> bool:
return parse_memory_document(content).has_explicit_heading
def nonstandard_bullets(content: str) -> list[str]:
warnings: list[str] = []
for line in content.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
if parse_bullet_line(stripped) is not None:
continue
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
warnings.append(f"Non-standard memory bullet: {short}")
return warnings

View file

@ -13,6 +13,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User from app.db import SearchSpace, User
from app.services.memory.document import parse_memory_document, render_memory_document
from app.services.memory.prompts import ( from app.services.memory.prompts import (
TEAM_MEMORY_EXTRACT_PROMPT, TEAM_MEMORY_EXTRACT_PROMPT,
USER_MEMORY_EXTRACT_PROMPT, USER_MEMORY_EXTRACT_PROMPT,
@ -184,6 +185,8 @@ async def save_memory(
warnings=warnings, warnings=warnings,
) )
next_content = render_memory_document(parse_memory_document(next_content))
try: try:
_set_memory(target, normalized, next_content) _set_memory(target, normalized, next_content)
session.add(target) session.add(target)

View file

@ -2,20 +2,18 @@
from __future__ import annotations from __future__ import annotations
import re
from typing import Literal from typing import Literal
from app.services.memory.document import (
extract_headings,
has_explicit_heading,
nonstandard_bullets,
parse_memory_document,
)
MEMORY_SOFT_LIMIT = 18_000 MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000 MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_LINE_RE = re.compile(r"^##\s+\S+", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"[^a-z0-9]+")
_LEGACY_BULLET_RE = re.compile(
r"^-\s+\(\d{4}-\d{2}-\d{2}\)\s+\[(fact|pref|instr)\]\s+.+$"
)
_NEW_BULLET_RE = re.compile(r"^-\s+\d{4}-\d{2}-\d{2}:\s+.+$")
_FORBIDDEN_TEAM_HEADINGS = { _FORBIDDEN_TEAM_HEADINGS = {
"preferences", "preferences",
"instructions", "instructions",
@ -25,25 +23,16 @@ _FORBIDDEN_TEAM_HEADINGS = {
def has_markdown_heading(content: str) -> bool: def has_markdown_heading(content: str) -> bool:
return bool(_HEADING_LINE_RE.search(content)) return has_explicit_heading(content)
def strip_preamble_to_first_heading(content: str) -> str: def strip_preamble_to_first_heading(content: str) -> str:
"""Drop model preamble before the first ``##`` heading, if one exists.""" """Drop model preamble before the first ``##`` heading, if one exists."""
match = _HEADING_LINE_RE.search(content) lines = content.splitlines()
if not match: for index, line in enumerate(lines):
if line.startswith("## ") and line[3:].strip():
return "\n".join(lines[index:]).strip()
return content.strip() return content.strip()
return content[match.start() :].strip()
def extract_headings(memory: str | None) -> set[str]:
if not memory:
return set()
return {_normalize_heading(h) for h in _SECTION_HEADING_RE.findall(memory)}
def _normalize_heading(heading: str) -> str:
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower()).strip()
def validate_memory_size(content: str) -> dict[str, str] | None: def validate_memory_size(content: str) -> dict[str, str] | None:
@ -69,7 +58,7 @@ def validate_heading_sanity(content: str) -> dict[str, str] | None:
return None return None
if len(stripped) <= 40: if len(stripped) <= 40:
return None return None
if any(_LEGACY_BULLET_RE.match(line.strip()) for line in stripped.splitlines()): if parse_memory_document(stripped).sections:
return None return None
return { return {
"status": "error", "status": "error",
@ -115,16 +104,7 @@ def validate_memory_scope(
def validate_bullet_format(content: str) -> list[str]: def validate_bullet_format(content: str) -> list[str]:
warnings: list[str] = [] return nonstandard_bullets(content)
for line in content.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
if _NEW_BULLET_RE.match(stripped) or _LEGACY_BULLET_RE.match(stripped):
continue
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
warnings.append(f"Non-standard memory bullet: {short}")
return warnings
def validate_diff(old_memory: str | None, new_memory: str) -> list[str]: def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
@ -138,7 +118,7 @@ def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
if dropped: if dropped:
names = ", ".join(sorted(dropped)) names = ", ".join(sorted(dropped))
warnings.append( warnings.append(
f"Sections removed: {names}. If unintentional, restore from the settings page." f"Sections removed: {names}. If unintentional, restore them from the memory document."
) )
old_len = len(old_memory) old_len = len(old_memory)

View file

@ -64,6 +64,27 @@ def test_validate_bullet_format_warns_on_nonstandard_bullet() -> None:
assert "Non-standard memory bullet" in warnings[0] assert "Non-standard memory bullet" in warnings[0]
@pytest.mark.asyncio
async def test_save_memory_normalizes_legacy_marker_bullets(monkeypatch) -> None:
target = type("Target", (), {"memory_md": ""})()
session = _FakeSession()
async def fake_load_target(**_kwargs):
return target
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
result = await save_memory(
scope=MemoryScope.USER,
target_id="00000000-0000-0000-0000-000000000000",
content="- (2026-04-10) [fact] Legacy fact is preserved\n",
session=session,
)
assert result.status == "saved"
assert target.memory_md == "## Memory\n- 2026-04-10: Legacy fact is preserved"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_save_memory_blocks_new_personal_heading_in_team_before_commit( async def test_save_memory_blocks_new_personal_heading_in_team_before_commit(
monkeypatch, monkeypatch,

View file

@ -82,7 +82,7 @@ async def test_save_memory_accepts_legacy_marker_payload(monkeypatch) -> None:
) )
assert result.status == "saved" assert result.status == "saved"
assert "[fact]" in target.memory_md assert target.memory_md == "## Memory\n- 2026-05-19: Legacy marker memory"
@pytest.mark.asyncio @pytest.mark.asyncio