mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-15 18:25:18 +02:00
feat: implement team memory extraction and validation mechanisms, enhancing memory management by enforcing scope restrictions and improving memory persistence for shared threads
This commit is contained in:
parent
33626d4f91
commit
a0883d2ab6
8 changed files with 322 additions and 49 deletions
|
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -34,6 +34,15 @@ MEMORY_SOFT_LIMIT = 18_000
|
|||
MEMORY_HARD_LIMIT = 25_000
|
||||
|
||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
||||
|
||||
_USER_ONLY_HEADINGS = {"about the user", "preferences", "instructions"}
|
||||
_TEAM_ONLY_HEADINGS = {
|
||||
"team decisions",
|
||||
"conventions",
|
||||
"key facts",
|
||||
"current priorities",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -46,6 +55,45 @@ def _extract_headings(memory: str) -> set[str]:
|
|||
return set(_SECTION_HEADING_RE.findall(memory))
|
||||
|
||||
|
||||
def _normalize_heading(heading: str) -> str:
|
||||
"""Normalize heading text for robust scope checks."""
|
||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
||||
|
||||
|
||||
def _validate_memory_scope(
|
||||
content: str, scope: Literal["user", "team"]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Reject cross-scope headings (user sections in team memory and vice versa)."""
|
||||
headings = {_normalize_heading(h) for h in _extract_headings(content)}
|
||||
if not headings:
|
||||
return None
|
||||
|
||||
if scope == "team":
|
||||
leaked = sorted(headings & _USER_ONLY_HEADINGS)
|
||||
if leaked:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Team memory cannot include personal sections: "
|
||||
+ ", ".join(leaked)
|
||||
+ ". Use team sections only."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
leaked = sorted(headings & _TEAM_ONLY_HEADINGS)
|
||||
if leaked:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"User memory cannot include team sections: "
|
||||
+ ", ".join(leaked)
|
||||
+ ". Use personal sections only."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||
"""Return a list of warning strings about suspicious changes."""
|
||||
if not old_memory:
|
||||
|
|
@ -166,6 +214,7 @@ async def _save_memory(
|
|||
commit_fn,
|
||||
rollback_fn,
|
||||
label: str,
|
||||
scope: Literal["user", "team"],
|
||||
) -> dict[str, Any]:
|
||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||
return a response dict.
|
||||
|
|
@ -200,6 +249,10 @@ async def _save_memory(
|
|||
if size_err:
|
||||
return size_err
|
||||
|
||||
scope_err = _validate_memory_scope(content, scope)
|
||||
if scope_err:
|
||||
return scope_err
|
||||
|
||||
# --- persist ---
|
||||
try:
|
||||
apply_fn(content)
|
||||
|
|
@ -270,6 +323,7 @@ def create_update_memory_tool(
|
|||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update user memory: %s", e)
|
||||
|
|
@ -319,6 +373,7 @@ def create_update_team_memory_tool(
|
|||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update team memory: %s", e)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue