mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-27 19:25:15 +02:00
feat: add memory document model and parsing functionality for markdown handling
This commit is contained in:
parent
fe07de3f9c
commit
a0ff86e0e8
5 changed files with 241 additions and 37 deletions
200
surfsense_backend/app/services/memory/document.py
Normal file
200
surfsense_backend/app/services/memory/document.py
Normal file
|
|
@ -0,0 +1,200 @@
|
||||||
|
"""Memory-specific markdown document model and canonical renderer.
|
||||||
|
|
||||||
|
This intentionally parses only SurfSense memory's small markdown contract:
|
||||||
|
``##`` sections with dated bullet items. Unknown lines are preserved so user
|
||||||
|
edits are not lost, while legacy marker bullets are normalized on render.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
DEFAULT_LEGACY_SECTION = "Memory"
|
||||||
|
LEGACY_MARKERS = frozenset({"fact", "pref", "instr"})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MemoryBullet:
|
||||||
|
entry_date: date
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MemoryRawLine:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
MemoryLine = MemoryBullet | MemoryRawLine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MemorySection:
|
||||||
|
heading: str
|
||||||
|
lines: list[MemoryLine] = field(default_factory=list)
|
||||||
|
explicit_heading: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MemoryDocument:
|
||||||
|
sections: list[MemorySection] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_explicit_heading(self) -> bool:
|
||||||
|
return any(section.explicit_heading for section in self.sections)
|
||||||
|
|
||||||
|
|
||||||
|
def is_section_heading(line: str) -> bool:
|
||||||
|
return line.startswith("## ") and bool(line[3:].strip())
|
||||||
|
|
||||||
|
|
||||||
|
def heading_text(line: str) -> str:
|
||||||
|
return line[3:].strip()
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_heading(heading: str) -> str:
|
||||||
|
chars: list[str] = []
|
||||||
|
previous_was_space = True
|
||||||
|
for char in heading.strip().lower():
|
||||||
|
if char.isalnum():
|
||||||
|
chars.append(char)
|
||||||
|
previous_was_space = False
|
||||||
|
elif not previous_was_space:
|
||||||
|
chars.append(" ")
|
||||||
|
previous_was_space = True
|
||||||
|
return "".join(chars).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bullet_line(line: str) -> MemoryBullet | None:
|
||||||
|
stripped = line.strip()
|
||||||
|
if not stripped.startswith("- "):
|
||||||
|
return None
|
||||||
|
|
||||||
|
body = stripped[2:]
|
||||||
|
parsed = _parse_canonical_bullet(body)
|
||||||
|
if parsed is not None:
|
||||||
|
return parsed
|
||||||
|
return _parse_legacy_bullet(body)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_canonical_bullet(body: str) -> MemoryBullet | None:
|
||||||
|
if len(body) < 13 or body[10:12] != ": ":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
entry_date = date.fromisoformat(body[:10])
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
text = body[12:].strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
return MemoryBullet(entry_date=entry_date, text=text)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_legacy_bullet(body: str) -> MemoryBullet | None:
|
||||||
|
if len(body) < 20 or not body.startswith("("):
|
||||||
|
return None
|
||||||
|
if len(body) < 14 or body[11:14] != ") [":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
entry_date = date.fromisoformat(body[1:11])
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
marker_end = body.find("] ", 14)
|
||||||
|
if marker_end == -1:
|
||||||
|
return None
|
||||||
|
marker = body[14:marker_end]
|
||||||
|
if marker not in LEGACY_MARKERS:
|
||||||
|
return None
|
||||||
|
|
||||||
|
text = body[marker_end + 2 :].strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
return MemoryBullet(entry_date=entry_date, text=text)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_memory_document(content: str | None) -> MemoryDocument:
|
||||||
|
if not content:
|
||||||
|
return MemoryDocument()
|
||||||
|
|
||||||
|
sections: list[MemorySection] = []
|
||||||
|
current_heading: str | None = None
|
||||||
|
current_explicit = True
|
||||||
|
current_lines: list[MemoryLine] = []
|
||||||
|
|
||||||
|
def flush_current() -> None:
|
||||||
|
nonlocal current_heading, current_explicit, current_lines
|
||||||
|
if current_heading is None:
|
||||||
|
return
|
||||||
|
sections.append(
|
||||||
|
MemorySection(
|
||||||
|
heading=current_heading,
|
||||||
|
lines=current_lines,
|
||||||
|
explicit_heading=current_explicit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_heading = None
|
||||||
|
current_explicit = True
|
||||||
|
current_lines = []
|
||||||
|
|
||||||
|
for raw_line in content.strip().splitlines():
|
||||||
|
line = raw_line.rstrip()
|
||||||
|
if is_section_heading(line):
|
||||||
|
flush_current()
|
||||||
|
current_heading = heading_text(line)
|
||||||
|
current_explicit = True
|
||||||
|
current_lines = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
bullet = parse_bullet_line(line)
|
||||||
|
if current_heading is None:
|
||||||
|
if bullet is None:
|
||||||
|
continue
|
||||||
|
current_heading = DEFAULT_LEGACY_SECTION
|
||||||
|
current_explicit = False
|
||||||
|
current_lines = [bullet]
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_lines.append(bullet if bullet is not None else MemoryRawLine(text=line))
|
||||||
|
|
||||||
|
flush_current()
|
||||||
|
return MemoryDocument(sections=sections)
|
||||||
|
|
||||||
|
|
||||||
|
def render_memory_document(document: MemoryDocument) -> str:
|
||||||
|
rendered_sections: list[str] = []
|
||||||
|
for section in document.sections:
|
||||||
|
section_lines = [f"## {section.heading}"]
|
||||||
|
for line in section.lines:
|
||||||
|
if isinstance(line, MemoryBullet):
|
||||||
|
section_lines.append(f"- {line.entry_date.isoformat()}: {line.text}")
|
||||||
|
else:
|
||||||
|
section_lines.append(line.text)
|
||||||
|
rendered_sections.append("\n".join(section_lines).strip())
|
||||||
|
return "\n\n".join(section for section in rendered_sections if section).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_headings(memory: str | None) -> set[str]:
|
||||||
|
document = parse_memory_document(memory)
|
||||||
|
return {
|
||||||
|
normalize_heading(section.heading)
|
||||||
|
for section in document.sections
|
||||||
|
if section.explicit_heading
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def has_explicit_heading(content: str) -> bool:
|
||||||
|
return parse_memory_document(content).has_explicit_heading
|
||||||
|
|
||||||
|
|
||||||
|
def nonstandard_bullets(content: str) -> list[str]:
|
||||||
|
warnings: list[str] = []
|
||||||
|
for line in content.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if not stripped.startswith("- "):
|
||||||
|
continue
|
||||||
|
if parse_bullet_line(stripped) is not None:
|
||||||
|
continue
|
||||||
|
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
||||||
|
warnings.append(f"Non-standard memory bullet: {short}")
|
||||||
|
return warnings
|
||||||
|
|
@ -13,6 +13,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import SearchSpace, User
|
from app.db import SearchSpace, User
|
||||||
|
from app.services.memory.document import parse_memory_document, render_memory_document
|
||||||
from app.services.memory.prompts import (
|
from app.services.memory.prompts import (
|
||||||
TEAM_MEMORY_EXTRACT_PROMPT,
|
TEAM_MEMORY_EXTRACT_PROMPT,
|
||||||
USER_MEMORY_EXTRACT_PROMPT,
|
USER_MEMORY_EXTRACT_PROMPT,
|
||||||
|
|
@ -184,6 +185,8 @@ async def save_memory(
|
||||||
warnings=warnings,
|
warnings=warnings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_content = render_memory_document(parse_memory_document(next_content))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_set_memory(target, normalized, next_content)
|
_set_memory(target, normalized, next_content)
|
||||||
session.add(target)
|
session.add(target)
|
||||||
|
|
|
||||||
|
|
@ -2,20 +2,18 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from app.services.memory.document import (
|
||||||
|
extract_headings,
|
||||||
|
has_explicit_heading,
|
||||||
|
nonstandard_bullets,
|
||||||
|
parse_memory_document,
|
||||||
|
)
|
||||||
|
|
||||||
MEMORY_SOFT_LIMIT = 18_000
|
MEMORY_SOFT_LIMIT = 18_000
|
||||||
MEMORY_HARD_LIMIT = 25_000
|
MEMORY_HARD_LIMIT = 25_000
|
||||||
|
|
||||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
|
||||||
_HEADING_LINE_RE = re.compile(r"^##\s+\S+", re.MULTILINE)
|
|
||||||
_HEADING_NORMALIZE_RE = re.compile(r"[^a-z0-9]+")
|
|
||||||
_LEGACY_BULLET_RE = re.compile(
|
|
||||||
r"^-\s+\(\d{4}-\d{2}-\d{2}\)\s+\[(fact|pref|instr)\]\s+.+$"
|
|
||||||
)
|
|
||||||
_NEW_BULLET_RE = re.compile(r"^-\s+\d{4}-\d{2}-\d{2}:\s+.+$")
|
|
||||||
|
|
||||||
_FORBIDDEN_TEAM_HEADINGS = {
|
_FORBIDDEN_TEAM_HEADINGS = {
|
||||||
"preferences",
|
"preferences",
|
||||||
"instructions",
|
"instructions",
|
||||||
|
|
@ -25,25 +23,16 @@ _FORBIDDEN_TEAM_HEADINGS = {
|
||||||
|
|
||||||
|
|
||||||
def has_markdown_heading(content: str) -> bool:
|
def has_markdown_heading(content: str) -> bool:
|
||||||
return bool(_HEADING_LINE_RE.search(content))
|
return has_explicit_heading(content)
|
||||||
|
|
||||||
|
|
||||||
def strip_preamble_to_first_heading(content: str) -> str:
|
def strip_preamble_to_first_heading(content: str) -> str:
|
||||||
"""Drop model preamble before the first ``##`` heading, if one exists."""
|
"""Drop model preamble before the first ``##`` heading, if one exists."""
|
||||||
match = _HEADING_LINE_RE.search(content)
|
lines = content.splitlines()
|
||||||
if not match:
|
for index, line in enumerate(lines):
|
||||||
return content.strip()
|
if line.startswith("## ") and line[3:].strip():
|
||||||
return content[match.start() :].strip()
|
return "\n".join(lines[index:]).strip()
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
def extract_headings(memory: str | None) -> set[str]:
|
|
||||||
if not memory:
|
|
||||||
return set()
|
|
||||||
return {_normalize_heading(h) for h in _SECTION_HEADING_RE.findall(memory)}
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_heading(heading: str) -> str:
|
|
||||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower()).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def validate_memory_size(content: str) -> dict[str, str] | None:
|
def validate_memory_size(content: str) -> dict[str, str] | None:
|
||||||
|
|
@ -69,7 +58,7 @@ def validate_heading_sanity(content: str) -> dict[str, str] | None:
|
||||||
return None
|
return None
|
||||||
if len(stripped) <= 40:
|
if len(stripped) <= 40:
|
||||||
return None
|
return None
|
||||||
if any(_LEGACY_BULLET_RE.match(line.strip()) for line in stripped.splitlines()):
|
if parse_memory_document(stripped).sections:
|
||||||
return None
|
return None
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
|
|
@ -115,16 +104,7 @@ def validate_memory_scope(
|
||||||
|
|
||||||
|
|
||||||
def validate_bullet_format(content: str) -> list[str]:
|
def validate_bullet_format(content: str) -> list[str]:
|
||||||
warnings: list[str] = []
|
return nonstandard_bullets(content)
|
||||||
for line in content.splitlines():
|
|
||||||
stripped = line.strip()
|
|
||||||
if not stripped.startswith("- "):
|
|
||||||
continue
|
|
||||||
if _NEW_BULLET_RE.match(stripped) or _LEGACY_BULLET_RE.match(stripped):
|
|
||||||
continue
|
|
||||||
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
|
||||||
warnings.append(f"Non-standard memory bullet: {short}")
|
|
||||||
return warnings
|
|
||||||
|
|
||||||
|
|
||||||
def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||||
|
|
@ -138,7 +118,7 @@ def validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||||
if dropped:
|
if dropped:
|
||||||
names = ", ".join(sorted(dropped))
|
names = ", ".join(sorted(dropped))
|
||||||
warnings.append(
|
warnings.append(
|
||||||
f"Sections removed: {names}. If unintentional, restore from the settings page."
|
f"Sections removed: {names}. If unintentional, restore them from the memory document."
|
||||||
)
|
)
|
||||||
|
|
||||||
old_len = len(old_memory)
|
old_len = len(old_memory)
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,27 @@ def test_validate_bullet_format_warns_on_nonstandard_bullet() -> None:
|
||||||
assert "Non-standard memory bullet" in warnings[0]
|
assert "Non-standard memory bullet" in warnings[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_normalizes_legacy_marker_bullets(monkeypatch) -> None:
|
||||||
|
target = type("Target", (), {"memory_md": ""})()
|
||||||
|
session = _FakeSession()
|
||||||
|
|
||||||
|
async def fake_load_target(**_kwargs):
|
||||||
|
return target
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.memory.service._load_target", fake_load_target)
|
||||||
|
|
||||||
|
result = await save_memory(
|
||||||
|
scope=MemoryScope.USER,
|
||||||
|
target_id="00000000-0000-0000-0000-000000000000",
|
||||||
|
content="- (2026-04-10) [fact] Legacy fact is preserved\n",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "saved"
|
||||||
|
assert target.memory_md == "## Memory\n- 2026-04-10: Legacy fact is preserved"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_save_memory_blocks_new_personal_heading_in_team_before_commit(
|
async def test_save_memory_blocks_new_personal_heading_in_team_before_commit(
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ async def test_save_memory_accepts_legacy_marker_payload(monkeypatch) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.status == "saved"
|
assert result.status == "saved"
|
||||||
assert "[fact]" in target.memory_md
|
assert target.memory_md == "## Memory\n- 2026-05-19: Legacy marker memory"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue