feat: implement team memory extraction and validation mechanisms, enhancing memory management by enforcing scope restrictions and improving memory persistence for shared threads

This commit is contained in:
Anish Sarkar 2026-04-10 01:54:00 +05:30
parent 33626d4f91
commit a0883d2ab6
8 changed files with 322 additions and 49 deletions

View file

@ -18,7 +18,7 @@ from __future__ import annotations
import logging
import re
from typing import Any
from typing import Any, Literal
from uuid import UUID
from langchain_core.messages import HumanMessage
@ -34,6 +34,15 @@ MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
_USER_ONLY_HEADINGS = {"about the user", "preferences", "instructions"}
_TEAM_ONLY_HEADINGS = {
"team decisions",
"conventions",
"key facts",
"current priorities",
}
# ---------------------------------------------------------------------------
@ -46,6 +55,45 @@ def _extract_headings(memory: str) -> set[str]:
return set(_SECTION_HEADING_RE.findall(memory))
def _normalize_heading(heading: str) -> str:
"""Normalize heading text for robust scope checks."""
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
def _validate_memory_scope(
content: str, scope: Literal["user", "team"]
) -> dict[str, Any] | None:
"""Reject cross-scope headings (user sections in team memory and vice versa)."""
headings = {_normalize_heading(h) for h in _extract_headings(content)}
if not headings:
return None
if scope == "team":
leaked = sorted(headings & _USER_ONLY_HEADINGS)
if leaked:
return {
"status": "error",
"message": (
"Team memory cannot include personal sections: "
+ ", ".join(leaked)
+ ". Use team sections only."
),
}
return None
leaked = sorted(headings & _TEAM_ONLY_HEADINGS)
if leaked:
return {
"status": "error",
"message": (
"User memory cannot include team sections: "
+ ", ".join(leaked)
+ ". Use personal sections only."
),
}
return None
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
"""Return a list of warning strings about suspicious changes."""
if not old_memory:
@ -166,6 +214,7 @@ async def _save_memory(
commit_fn,
rollback_fn,
label: str,
scope: Literal["user", "team"],
) -> dict[str, Any]:
"""Validate, optionally force-rewrite if over the hard limit, save, and
return a response dict.
@ -200,6 +249,10 @@ async def _save_memory(
if size_err:
return size_err
scope_err = _validate_memory_scope(content, scope)
if scope_err:
return scope_err
# --- persist ---
try:
apply_fn(content)
@ -270,6 +323,7 @@ def create_update_memory_tool(
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
scope="user",
)
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
@ -319,6 +373,7 @@ def create_update_team_memory_tool(
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
scope="team",
)
except Exception as e:
logger.exception("Failed to update team memory: %s", e)