2026-05-20 02:01:36 +05:30
|
|
|
"""Canonical read/write/reset/extract service for markdown memory."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from enum import StrEnum
|
|
|
|
|
from typing import Any, Literal
|
|
|
|
|
from uuid import UUID
|
|
|
|
|
|
|
|
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
|
|
from app.db import SearchSpace, User
|
2026-05-20 13:20:05 +05:30
|
|
|
from app.services.memory.document import parse_memory_document, render_memory_document
|
2026-05-20 02:01:36 +05:30
|
|
|
from app.services.memory.prompts import (
|
|
|
|
|
TEAM_MEMORY_EXTRACT_PROMPT,
|
|
|
|
|
USER_MEMORY_EXTRACT_PROMPT,
|
|
|
|
|
)
|
|
|
|
|
from app.services.memory.rewrite import forced_rewrite
|
2026-05-20 03:17:05 +05:30
|
|
|
from app.services.memory.schemas import MemoryExtractionDecision, MemoryLimits
|
2026-05-20 02:01:36 +05:30
|
|
|
from app.services.memory.validation import (
|
|
|
|
|
MEMORY_HARD_LIMIT,
|
2026-05-20 03:17:05 +05:30
|
|
|
MEMORY_SOFT_LIMIT,
|
2026-05-20 02:01:36 +05:30
|
|
|
soft_limit_warning,
|
|
|
|
|
strip_preamble_to_first_heading,
|
|
|
|
|
validate_bullet_format,
|
|
|
|
|
validate_diff,
|
|
|
|
|
validate_heading_sanity,
|
|
|
|
|
validate_memory_scope,
|
|
|
|
|
validate_memory_size,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryScope(StrEnum):
|
|
|
|
|
USER = "user"
|
|
|
|
|
TEAM = "team"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class SaveResult:
|
|
|
|
|
status: Literal["saved", "error", "no_op"]
|
|
|
|
|
message: str
|
|
|
|
|
memory_md: str = ""
|
|
|
|
|
warnings: list[str] = field(default_factory=list)
|
|
|
|
|
diff_warnings: list[str] = field(default_factory=list)
|
|
|
|
|
format_warnings: list[str] = field(default_factory=list)
|
|
|
|
|
notice: str | None = None
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
|
|
|
data: dict[str, Any] = {
|
|
|
|
|
"status": self.status,
|
|
|
|
|
"message": self.message,
|
|
|
|
|
"memory_md": self.memory_md,
|
|
|
|
|
}
|
|
|
|
|
if self.notice:
|
|
|
|
|
data["notice"] = self.notice
|
|
|
|
|
if self.warnings:
|
|
|
|
|
data["warnings"] = self.warnings
|
|
|
|
|
if len(self.warnings) == 1:
|
|
|
|
|
data["warning"] = self.warnings[0]
|
|
|
|
|
if self.diff_warnings:
|
|
|
|
|
data["diff_warnings"] = self.diff_warnings
|
|
|
|
|
if self.format_warnings:
|
|
|
|
|
data["format_warnings"] = self.format_warnings
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
2026-05-20 03:17:05 +05:30
|
|
|
def memory_limits() -> MemoryLimits:
|
|
|
|
|
return MemoryLimits(soft=MEMORY_SOFT_LIMIT, hard=MEMORY_HARD_LIMIT)
|
2026-05-20 02:01:36 +05:30
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_scope(scope: MemoryScope | str) -> MemoryScope:
|
|
|
|
|
return scope if isinstance(scope, MemoryScope) else MemoryScope(scope)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_user_id(target_id: str | UUID) -> UUID:
|
|
|
|
|
return UUID(target_id) if isinstance(target_id, str) else target_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _load_target(
|
|
|
|
|
*,
|
|
|
|
|
scope: MemoryScope | str,
|
|
|
|
|
target_id: str | int | UUID,
|
|
|
|
|
session: AsyncSession,
|
|
|
|
|
) -> User | SearchSpace | None:
|
|
|
|
|
normalized = _normalize_scope(scope)
|
|
|
|
|
if normalized is MemoryScope.USER:
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(User).where(User.id == _normalize_user_id(target_id)) # type: ignore[arg-type]
|
|
|
|
|
)
|
|
|
|
|
return result.scalars().first()
|
2026-05-20 12:55:10 +05:30
|
|
|
result = await session.execute(
|
|
|
|
|
select(SearchSpace).where(SearchSpace.id == int(target_id))
|
|
|
|
|
)
|
2026-05-20 02:01:36 +05:30
|
|
|
return result.scalars().first()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_memory(target: User | SearchSpace, scope: MemoryScope) -> str:
|
|
|
|
|
if scope is MemoryScope.USER:
|
|
|
|
|
return getattr(target, "memory_md", None) or ""
|
|
|
|
|
return getattr(target, "shared_memory_md", None) or ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_memory(target: User | SearchSpace, scope: MemoryScope, content: str) -> None:
|
|
|
|
|
if scope is MemoryScope.USER:
|
|
|
|
|
target.memory_md = content
|
|
|
|
|
else:
|
|
|
|
|
target.shared_memory_md = content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def read_memory(
|
|
|
|
|
*,
|
|
|
|
|
scope: MemoryScope | str,
|
|
|
|
|
target_id: str | int | UUID,
|
|
|
|
|
session: AsyncSession,
|
|
|
|
|
) -> str:
|
|
|
|
|
normalized = _normalize_scope(scope)
|
|
|
|
|
target = await _load_target(scope=normalized, target_id=target_id, session=session)
|
|
|
|
|
if target is None:
|
|
|
|
|
return ""
|
|
|
|
|
return _get_memory(target, normalized)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def save_memory(
|
|
|
|
|
*,
|
|
|
|
|
scope: MemoryScope | str,
|
|
|
|
|
target_id: str | int | UUID,
|
|
|
|
|
content: str,
|
|
|
|
|
session: AsyncSession,
|
|
|
|
|
llm: Any | None = None,
|
|
|
|
|
) -> SaveResult:
|
|
|
|
|
normalized = _normalize_scope(scope)
|
|
|
|
|
if not isinstance(content, str):
|
|
|
|
|
return SaveResult(
|
|
|
|
|
status="error",
|
|
|
|
|
message="Internal error: memory payload must be a string.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
target = await _load_target(scope=normalized, target_id=target_id, session=session)
|
|
|
|
|
if target is None:
|
|
|
|
|
return SaveResult(
|
|
|
|
|
status="error",
|
2026-05-20 12:55:10 +05:30
|
|
|
message="User not found."
|
|
|
|
|
if normalized is MemoryScope.USER
|
|
|
|
|
else "Search space not found.",
|
2026-05-20 02:01:36 +05:30
|
|
|
)
|
|
|
|
|
|
|
|
|
|
old_memory = _get_memory(target, normalized)
|
|
|
|
|
next_content = strip_preamble_to_first_heading(content.strip())
|
|
|
|
|
notice: str | None = None
|
|
|
|
|
warnings: list[str] = []
|
|
|
|
|
|
|
|
|
|
if len(next_content) > MEMORY_HARD_LIMIT and llm is not None:
|
|
|
|
|
rewritten = await forced_rewrite(next_content, llm)
|
|
|
|
|
if rewritten is not None and len(rewritten) < len(next_content):
|
|
|
|
|
next_content = strip_preamble_to_first_heading(rewritten)
|
|
|
|
|
notice = "Memory was automatically rewritten to fit within limits."
|
|
|
|
|
|
|
|
|
|
for validation in (
|
|
|
|
|
validate_memory_size(next_content),
|
|
|
|
|
validate_heading_sanity(next_content),
|
|
|
|
|
):
|
|
|
|
|
if validation:
|
|
|
|
|
return SaveResult(
|
|
|
|
|
status="error",
|
|
|
|
|
message=validation["message"],
|
|
|
|
|
memory_md=old_memory,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
scope_error, scope_warnings = validate_memory_scope(
|
|
|
|
|
next_content,
|
|
|
|
|
normalized.value,
|
|
|
|
|
old_memory=old_memory,
|
|
|
|
|
)
|
|
|
|
|
warnings.extend(scope_warnings)
|
|
|
|
|
if scope_error:
|
|
|
|
|
return SaveResult(
|
|
|
|
|
status="error",
|
|
|
|
|
message=scope_error["message"],
|
|
|
|
|
memory_md=old_memory,
|
|
|
|
|
warnings=warnings,
|
|
|
|
|
)
|
|
|
|
|
|
2026-05-20 13:20:05 +05:30
|
|
|
next_content = render_memory_document(parse_memory_document(next_content))
|
|
|
|
|
|
2026-05-20 02:01:36 +05:30
|
|
|
try:
|
|
|
|
|
_set_memory(target, normalized, next_content)
|
|
|
|
|
session.add(target)
|
|
|
|
|
await session.commit()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception("Failed to update %s memory: %s", normalized.value, e)
|
|
|
|
|
await session.rollback()
|
|
|
|
|
return SaveResult(
|
|
|
|
|
status="error",
|
|
|
|
|
message=f"Failed to update {normalized.value} memory: {e}",
|
|
|
|
|
memory_md=old_memory,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
diff_warnings = validate_diff(old_memory, next_content)
|
|
|
|
|
format_warnings = validate_bullet_format(next_content)
|
|
|
|
|
warning = soft_limit_warning(next_content)
|
|
|
|
|
if warning:
|
|
|
|
|
warnings.append(warning)
|
|
|
|
|
|
|
|
|
|
return SaveResult(
|
|
|
|
|
status="saved",
|
|
|
|
|
message=(
|
|
|
|
|
"Memory updated."
|
|
|
|
|
if normalized is MemoryScope.USER
|
|
|
|
|
else "Team memory updated."
|
|
|
|
|
),
|
|
|
|
|
memory_md=next_content,
|
|
|
|
|
warnings=warnings,
|
|
|
|
|
diff_warnings=diff_warnings,
|
|
|
|
|
format_warnings=format_warnings,
|
|
|
|
|
notice=notice,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def reset_memory(
|
|
|
|
|
*,
|
|
|
|
|
scope: MemoryScope | str,
|
|
|
|
|
target_id: str | int | UUID,
|
|
|
|
|
session: AsyncSession,
|
|
|
|
|
) -> SaveResult:
|
|
|
|
|
return await save_memory(
|
|
|
|
|
scope=scope,
|
|
|
|
|
target_id=target_id,
|
|
|
|
|
content="",
|
|
|
|
|
session=session,
|
|
|
|
|
llm=None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def extract_and_save(
|
|
|
|
|
*,
|
|
|
|
|
scope: MemoryScope | str,
|
|
|
|
|
target_id: str | int | UUID,
|
|
|
|
|
user_message: str,
|
|
|
|
|
actor_display_name: str | None,
|
|
|
|
|
session: AsyncSession,
|
|
|
|
|
llm: Any,
|
|
|
|
|
) -> SaveResult:
|
|
|
|
|
normalized = _normalize_scope(scope)
|
|
|
|
|
current_memory = await read_memory(
|
|
|
|
|
scope=normalized,
|
|
|
|
|
target_id=target_id,
|
|
|
|
|
session=session,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if normalized is MemoryScope.USER:
|
|
|
|
|
first_name = (
|
|
|
|
|
actor_display_name.strip().split()[0]
|
|
|
|
|
if actor_display_name and actor_display_name.strip()
|
|
|
|
|
else "The user"
|
|
|
|
|
)
|
|
|
|
|
prompt = USER_MEMORY_EXTRACT_PROMPT.format(
|
|
|
|
|
current_memory=current_memory or "(empty)",
|
|
|
|
|
user_message=user_message,
|
|
|
|
|
user_name=first_name,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
prompt = TEAM_MEMORY_EXTRACT_PROMPT.format(
|
|
|
|
|
current_memory=current_memory or "(empty)",
|
|
|
|
|
author=actor_display_name or "Unknown team member",
|
|
|
|
|
user_message=user_message,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
structured = llm.with_structured_output(MemoryExtractionDecision)
|
|
|
|
|
decision = await structured.ainvoke(
|
|
|
|
|
[HumanMessage(content=prompt)],
|
|
|
|
|
config={"tags": ["surfsense:internal", "memory-extraction"]},
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception("Structured memory extraction failed")
|
|
|
|
|
return SaveResult(
|
|
|
|
|
status="error",
|
|
|
|
|
message="Structured memory extraction failed.",
|
|
|
|
|
memory_md=current_memory,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if decision.action == "no_update":
|
|
|
|
|
return SaveResult(
|
|
|
|
|
status="no_op",
|
|
|
|
|
message=decision.reason or "No durable memory to persist.",
|
|
|
|
|
memory_md=current_memory,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not decision.updated_memory:
|
|
|
|
|
return SaveResult(
|
|
|
|
|
status="error",
|
|
|
|
|
message="Structured memory extraction chose save without updated_memory.",
|
|
|
|
|
memory_md=current_memory,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return await save_memory(
|
|
|
|
|
scope=normalized,
|
|
|
|
|
target_id=target_id,
|
|
|
|
|
content=decision.updated_memory,
|
|
|
|
|
session=session,
|
|
|
|
|
llm=llm,
|
|
|
|
|
)
|