SurfSense/surfsense_backend/app/services/memory/service.py

247 lines
7.1 KiB
Python

"""Canonical read/write/reset/extract service for markdown memory."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any, Literal
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User
from app.services.memory.document import parse_memory_document, render_memory_document
from app.services.memory.rewrite import forced_rewrite
from app.services.memory.schemas import MemoryLimits
from app.services.memory.validation import (
MEMORY_HARD_LIMIT,
MEMORY_SOFT_LIMIT,
soft_limit_warning,
strip_preamble_to_first_heading,
validate_bullet_format,
validate_diff,
validate_heading_sanity,
validate_memory_scope,
validate_memory_size,
)
logger = logging.getLogger(__name__)
_NO_UPDATE_SENTINELS = frozenset(
{
"NO_UPDATE",
"NO UPDATE",
"NO_CHANGE",
"NO CHANGE",
}
)
class MemoryScope(StrEnum):
USER = "user"
TEAM = "team"
@dataclass(frozen=True)
class SaveResult:
status: Literal["saved", "error", "no_op"]
message: str
memory_md: str = ""
warnings: list[str] = field(default_factory=list)
diff_warnings: list[str] = field(default_factory=list)
format_warnings: list[str] = field(default_factory=list)
notice: str | None = None
def to_dict(self) -> dict[str, Any]:
data: dict[str, Any] = {
"status": self.status,
"message": self.message,
"memory_md": self.memory_md,
}
if self.notice:
data["notice"] = self.notice
if self.warnings:
data["warnings"] = self.warnings
if len(self.warnings) == 1:
data["warning"] = self.warnings[0]
if self.diff_warnings:
data["diff_warnings"] = self.diff_warnings
if self.format_warnings:
data["format_warnings"] = self.format_warnings
return data
def memory_limits() -> MemoryLimits:
return MemoryLimits(soft=MEMORY_SOFT_LIMIT, hard=MEMORY_HARD_LIMIT)
def _normalize_scope(scope: MemoryScope | str) -> MemoryScope:
return scope if isinstance(scope, MemoryScope) else MemoryScope(scope)
def _normalize_user_id(target_id: str | UUID) -> UUID:
return UUID(target_id) if isinstance(target_id, str) else target_id
async def _load_target(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
session: AsyncSession,
) -> User | SearchSpace | None:
normalized = _normalize_scope(scope)
if normalized is MemoryScope.USER:
result = await session.execute(
select(User).where(User.id == _normalize_user_id(target_id)) # type: ignore[arg-type]
)
return result.scalars().first()
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == int(target_id))
)
return result.scalars().first()
def _get_memory(target: User | SearchSpace, scope: MemoryScope) -> str:
if scope is MemoryScope.USER:
return getattr(target, "memory_md", None) or ""
return getattr(target, "shared_memory_md", None) or ""
def _set_memory(target: User | SearchSpace, scope: MemoryScope, content: str) -> None:
if scope is MemoryScope.USER:
target.memory_md = content
else:
target.shared_memory_md = content
async def read_memory(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
session: AsyncSession,
) -> str:
normalized = _normalize_scope(scope)
target = await _load_target(scope=normalized, target_id=target_id, session=session)
if target is None:
return ""
return _get_memory(target, normalized)
async def save_memory(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
content: str,
session: AsyncSession,
llm: Any | None = None,
) -> SaveResult:
normalized = _normalize_scope(scope)
if not isinstance(content, str):
return SaveResult(
status="error",
message="Internal error: memory payload must be a string.",
)
target = await _load_target(scope=normalized, target_id=target_id, session=session)
if target is None:
return SaveResult(
status="error",
message="User not found."
if normalized is MemoryScope.USER
else "Search space not found.",
)
old_memory = _get_memory(target, normalized)
next_content = strip_preamble_to_first_heading(content.strip())
notice: str | None = None
warnings: list[str] = []
if next_content.upper() in _NO_UPDATE_SENTINELS:
return SaveResult(
status="no_op",
message="No memory update requested.",
memory_md=old_memory,
)
if len(next_content) > MEMORY_HARD_LIMIT and llm is not None:
rewritten = await forced_rewrite(next_content, llm)
if rewritten is not None and len(rewritten) < len(next_content):
next_content = strip_preamble_to_first_heading(rewritten)
notice = "Memory was automatically rewritten to fit within limits."
for validation in (
validate_memory_size(next_content),
validate_heading_sanity(next_content),
):
if validation:
return SaveResult(
status="error",
message=validation["message"],
memory_md=old_memory,
)
scope_error, scope_warnings = validate_memory_scope(
next_content,
normalized.value,
old_memory=old_memory,
)
warnings.extend(scope_warnings)
if scope_error:
return SaveResult(
status="error",
message=scope_error["message"],
memory_md=old_memory,
warnings=warnings,
)
next_content = render_memory_document(parse_memory_document(next_content))
try:
_set_memory(target, normalized, next_content)
session.add(target)
await session.commit()
except Exception as e:
logger.exception("Failed to update %s memory: %s", normalized.value, e)
await session.rollback()
return SaveResult(
status="error",
message=f"Failed to update {normalized.value} memory: {e}",
memory_md=old_memory,
)
diff_warnings = validate_diff(old_memory, next_content)
format_warnings = validate_bullet_format(next_content)
warning = soft_limit_warning(next_content)
if warning:
warnings.append(warning)
return SaveResult(
status="saved",
message=(
"Memory updated."
if normalized is MemoryScope.USER
else "Team memory updated."
),
memory_md=next_content,
warnings=warnings,
diff_warnings=diff_warnings,
format_warnings=format_warnings,
notice=notice,
)
async def reset_memory(
*,
scope: MemoryScope | str,
target_id: str | int | UUID,
session: AsyncSession,
) -> SaveResult:
return await save_memory(
scope=scope,
target_id=target_id,
content="",
session=session,
llm=None,
)