feat: implement team memory extraction and validation mechanisms, enhancing memory management by enforcing scope restrictions and improving memory persistence for shared threads

This commit is contained in:
Anish Sarkar 2026-04-10 01:54:00 +05:30
parent 33626d4f91
commit a0883d2ab6
8 changed files with 322 additions and 49 deletions

View file

@ -1,11 +1,8 @@
"""Background memory extraction for the SurfSense agent. """Background memory extraction for the SurfSense agent.
After each agent response, if the agent did not call ``update_memory`` during After each agent response, if the agent did not call ``update_memory`` during
the turn, this module runs a lightweight LLM call to decide whether the user's the turn, this module can run a lightweight LLM call to decide whether the
message contains any long-term information worth persisting. latest message contains long-term information worth persisting.
Only user (personal) memory is handled here team memory relies on explicit
agent calls.
""" """
from __future__ import annotations from __future__ import annotations
@ -18,7 +15,7 @@ from langchain_core.messages import HumanMessage
from sqlalchemy import select from sqlalchemy import select
from app.agents.new_chat.tools.update_memory import _save_memory from app.agents.new_chat.tools.update_memory import _save_memory
from app.db import User, shielded_async_session from app.db import SearchSpace, User, shielded_async_session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,6 +52,51 @@ If nothing is worth remembering, output exactly: NO_UPDATE
{user_message} {user_message}
</user_message>""" </user_message>"""
_TEAM_MEMORY_EXTRACT_PROMPT = """\
You are a team-memory extraction assistant. Analyze the latest message and \
decide if it contains durable TEAM-level information worth persisting.
High-precision rule: if uncertain, output NO_UPDATE.
Worth remembering (team-level only):
- Explicit decisions (e.g. "we decided to use X")
- Team conventions/standards (naming, review policy, coding norms)
- Long-lived architecture/process facts
- Stable project constraints, owners, recurring schedules
NOT worth remembering:
- Personal preferences or biography of one person
- Questions, brainstorming, tentative ideas, or speculation
- One-off requests, status updates, TODOs, logistics for this session
- Anything not clearly adopted by the team
If the message contains memorizable team information, output the FULL updated \
team memory document with new facts merged into existing content. Follow rules:
- Use the same ## section structure as the existing memory.
- Keep entries as single concise bullet points (under 120 chars each).
- Every bullet MUST start with a (YYYY-MM-DD) date prefix.
- If a new fact contradicts an existing entry, update the existing entry.
- Do not duplicate existing information.
- NEVER use personal sections like "## About the user", "## Preferences", \
or "## Instructions".
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
- Standard sections: "## Team decisions", "## Team conventions", \
"## Key facts", "## Current priorities"
If nothing is worth remembering, output exactly: NO_UPDATE
<current_team_memory>
{current_memory}
</current_team_memory>
<latest_message_author>
{author}
</latest_message_author>
<latest_message>
{user_message}
</latest_message>"""
async def extract_and_save_memory( async def extract_and_save_memory(
*, *,
@ -105,6 +147,7 @@ async def extract_and_save_memory(
commit_fn=session.commit, commit_fn=session.commit,
rollback_fn=session.rollback, rollback_fn=session.rollback,
label="memory", label="memory",
scope="user",
) )
logger.info( logger.info(
"Background memory extraction for user %s: %s", "Background memory extraction for user %s: %s",
@ -113,3 +156,69 @@ async def extract_and_save_memory(
) )
except Exception: except Exception:
logger.exception("Background user memory extraction failed") logger.exception("Background user memory extraction failed")
async def extract_and_save_team_memory(
*,
user_message: str,
search_space_id: int | None,
llm: Any,
author_display_name: str | None = None,
) -> None:
"""Background task: extract team-level memory and persist it.
Runs only for shared threads. Designed to be fire-and-forget and catches
exceptions internally.
"""
if not search_space_id:
return
try:
async with shielded_async_session() as session:
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return
old_memory = space.shared_memory_md
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
current_memory=old_memory or "(empty)",
author=author_display_name or "Unknown team member",
user_message=user_message,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
if text == "NO_UPDATE" or not text:
logger.debug(
"Team memory extraction: no update needed (space %s)",
search_space_id,
)
return
save_result = await _save_memory(
updated_memory=text,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="team memory",
scope="team",
)
logger.info(
"Background team memory extraction for space %s: %s",
search_space_id,
save_result.get("status"),
)
except Exception:
logger.exception("Background team memory extraction failed")

View file

@ -1,10 +1,8 @@
"""Memory injection middleware for the SurfSense agent. """Memory injection middleware for the SurfSense agent.
Loads the user's personal memory (User.memory_md) and, for shared threads, Injects memory markdown into the system prompt on every turn:
the team memory (SearchSpace.shared_memory_md) from the database and injects - Private threads: only personal memory (<user_memory>)
them into the system prompt as <user_memory> / <team_memory> XML blocks on - Shared threads: only team memory (<team_memory>)
every turn. This ensures the LLM always has the full memory context without
requiring a tool call.
""" """
from __future__ import annotations from __future__ import annotations
@ -58,7 +56,25 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
memory_blocks: list[str] = [] memory_blocks: list[str] = []
async with shielded_async_session() as session: async with shielded_async_session() as session:
if self.user_id is not None: if self.visibility == ChatVisibility.SEARCH_SPACE:
team_memory = await self._load_team_memory(session)
if team_memory:
chars = len(team_memory)
memory_blocks.append(
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
f"{team_memory}\n"
f"</team_memory>"
)
if chars > MEMORY_SOFT_LIMIT:
memory_blocks.append(
f"<memory_warning>Team memory is at "
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
f"the hard limit. On your next update_memory call, consolidate "
f"by merging duplicates, removing outdated entries, and "
f"shortening descriptions before adding anything new."
f"</memory_warning>"
)
elif self.user_id is not None:
user_memory, display_name = await self._load_user_memory(session) user_memory, display_name = await self._load_user_memory(session)
if display_name: if display_name:
first_name = display_name.split()[0] first_name = display_name.split()[0]
@ -80,25 +96,6 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
f"</memory_warning>" f"</memory_warning>"
) )
if self.visibility == ChatVisibility.SEARCH_SPACE:
team_memory = await self._load_team_memory(session)
if team_memory:
chars = len(team_memory)
memory_blocks.append(
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
f"{team_memory}\n"
f"</team_memory>"
)
if chars > MEMORY_SOFT_LIMIT:
memory_blocks.append(
f"<memory_warning>Team memory is at "
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
f"the hard limit. On your next update_memory call, consolidate "
f"by merging duplicates, removing outdated entries, and "
f"shortening descriptions before adding anything new."
f"</memory_warning>"
)
if not memory_blocks: if not memory_blocks:
return None return None

View file

@ -284,9 +284,13 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated. - Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
- Keep it concise and well under the character limit shown in <user_memory>. - Keep it concise and well under the character limit shown in <user_memory>.
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit): - You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
## About the user — role, background, company ## About the user
## Preferences — languages, tools, frameworks, response style ## Preferences
## Instructions — standing instructions, things to always/never do ## Instructions
- Section guidance:
* About the user: role, background, company, durable identity context
* Preferences: languages, tools, frameworks, response style preferences
* Instructions: standing instructions, things to always/never do
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each). - Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
- During consolidation, prioritize keeping: identity/instructions > preferences. - During consolidation, prioritize keeping: identity/instructions > preferences.
""", """,
@ -295,6 +299,8 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
- Your current team memory is already in <team_memory> in your context. The `chars` - Your current team memory is already in <team_memory> in your context. The `chars`
and `limit` attributes show current usage and the maximum allowed size. and `limit` attributes show current usage and the maximum allowed size.
- This is the team's curated long-term memory — decisions, conventions, key facts. - This is the team's curated long-term memory — decisions, conventions, key facts.
- NEVER store personal memory in team memory (e.g. personal bio, individual
preferences, or user-only standing instructions).
- Call update_memory when: - Call update_memory when:
* A team member explicitly asks to remember or forget something * A team member explicitly asks to remember or forget something
* The conversation surfaces durable team decisions, conventions, or facts * The conversation surfaces durable team decisions, conventions, or facts
@ -308,10 +314,15 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated. - Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
- Keep it concise and well under the character limit shown in <team_memory>. - Keep it concise and well under the character limit shown in <team_memory>.
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit): - You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
## Team decisions — agreed-upon choices with rationale ## Team decisions
## Conventions — coding standards, tools, processes, naming patterns ## Conventions
## Key facts — where things are, how things work, team structure ## Key facts
## Current priorities — active projects, deadlines, blockers ## Current priorities
- Section guidance:
* Team decisions: agreed choices and durable technical/product decisions
* Conventions: coding standards, tools, processes, naming patterns
* Key facts: stable facts about org/team/system setup
* Current priorities: active projects, near-term goals, important blockers
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each). - Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities. - During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
""", """,

View file

@ -18,7 +18,7 @@ from __future__ import annotations
import logging import logging
import re import re
from typing import Any from typing import Any, Literal
from uuid import UUID from uuid import UUID
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
@ -34,6 +34,15 @@ MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000 MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE) _SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
_USER_ONLY_HEADINGS = {"about the user", "preferences", "instructions"}
_TEAM_ONLY_HEADINGS = {
"team decisions",
"conventions",
"key facts",
"current priorities",
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -46,6 +55,45 @@ def _extract_headings(memory: str) -> set[str]:
return set(_SECTION_HEADING_RE.findall(memory)) return set(_SECTION_HEADING_RE.findall(memory))
def _normalize_heading(heading: str) -> str:
"""Normalize heading text for robust scope checks."""
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
def _validate_memory_scope(
content: str, scope: Literal["user", "team"]
) -> dict[str, Any] | None:
"""Reject cross-scope headings (user sections in team memory and vice versa)."""
headings = {_normalize_heading(h) for h in _extract_headings(content)}
if not headings:
return None
if scope == "team":
leaked = sorted(headings & _USER_ONLY_HEADINGS)
if leaked:
return {
"status": "error",
"message": (
"Team memory cannot include personal sections: "
+ ", ".join(leaked)
+ ". Use team sections only."
),
}
return None
leaked = sorted(headings & _TEAM_ONLY_HEADINGS)
if leaked:
return {
"status": "error",
"message": (
"User memory cannot include team sections: "
+ ", ".join(leaked)
+ ". Use personal sections only."
),
}
return None
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]: def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
"""Return a list of warning strings about suspicious changes.""" """Return a list of warning strings about suspicious changes."""
if not old_memory: if not old_memory:
@ -166,6 +214,7 @@ async def _save_memory(
commit_fn, commit_fn,
rollback_fn, rollback_fn,
label: str, label: str,
scope: Literal["user", "team"],
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Validate, optionally force-rewrite if over the hard limit, save, and """Validate, optionally force-rewrite if over the hard limit, save, and
return a response dict. return a response dict.
@ -200,6 +249,10 @@ async def _save_memory(
if size_err: if size_err:
return size_err return size_err
scope_err = _validate_memory_scope(content, scope)
if scope_err:
return scope_err
# --- persist --- # --- persist ---
try: try:
apply_fn(content) apply_fn(content)
@ -270,6 +323,7 @@ def create_update_memory_tool(
commit_fn=db_session.commit, commit_fn=db_session.commit,
rollback_fn=db_session.rollback, rollback_fn=db_session.rollback,
label="memory", label="memory",
scope="user",
) )
except Exception as e: except Exception as e:
logger.exception("Failed to update user memory: %s", e) logger.exception("Failed to update user memory: %s", e)
@ -319,6 +373,7 @@ def create_update_team_memory_tool(
commit_fn=db_session.commit, commit_fn=db_session.commit,
rollback_fn=db_session.rollback, rollback_fn=db_session.rollback,
label="team memory", label="team memory",
scope="team",
) )
except Exception as e: except Exception as e:
logger.exception("Failed to update team memory: %s", e) logger.exception("Failed to update team memory: %s", e)

View file

@ -132,6 +132,7 @@ async def edit_user_memory(
commit_fn=session.commit, commit_fn=session.commit,
rollback_fn=session.rollback, rollback_fn=session.rollback,
label="memory", label="memory",
scope="user",
) )
if result.get("status") == "error": if result.get("status") == "error":

View file

@ -56,7 +56,9 @@ RULES:
2. If the instruction asks to remove something, remove the matching entry. 2. If the instruction asks to remove something, remove the matching entry.
3. If the instruction asks to change something, update the matching entry. 3. If the instruction asks to change something, update the matching entry.
4. Preserve the existing ## section structure and all other entries. 4. Preserve the existing ## section structure and all other entries.
5. Output ONLY the updated markdown no explanations, no wrapping. 5. NEVER use personal sections like "## About the user", "## Preferences", or
"## Instructions". Team memory must stay team-scoped.
6. Output ONLY the updated markdown no explanations, no wrapping.
<current_memory> <current_memory>
{current_memory} {current_memory}
@ -372,6 +374,7 @@ async def edit_team_memory(
commit_fn=session.commit, commit_fn=session.commit,
rollback_fn=session.rollback, rollback_fn=session.rollback,
label="team memory", label="team memory",
scope="team",
) )
if save_result.get("status") == "error": if save_result.get("status") == "error":

View file

@ -37,7 +37,10 @@ from app.agents.new_chat.llm_config import (
load_agent_config, load_agent_config,
load_llm_config_from_yaml, load_llm_config_from_yaml,
) )
from app.agents.new_chat.memory_extraction import extract_and_save_memory from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
from app.db import ( from app.db import (
ChatVisibility, ChatVisibility,
NewChatMessage, NewChatMessage,
@ -1545,15 +1548,26 @@ async def stream_new_chat(
chat_id, generated_title chat_id, generated_title
) )
# Fire background memory extraction if the agent didn't handle it # Fire background memory extraction if the agent didn't handle it.
if not stream_result.agent_called_update_memory and user_id: # Shared threads write to team memory; private threads write to user memory.
asyncio.create_task( if not stream_result.agent_called_update_memory:
extract_and_save_memory( if visibility == ChatVisibility.SEARCH_SPACE:
user_message=user_query, asyncio.create_task(
user_id=user_id, extract_and_save_team_memory(
llm=llm, user_message=user_query,
search_space_id=search_space_id,
llm=llm,
author_display_name=current_user_display_name,
)
)
elif user_id:
asyncio.create_task(
extract_and_save_memory(
user_message=user_query,
user_id=user_id,
llm=llm,
)
) )
)
# Finish the step and message # Finish the step and message
yield streaming_service.format_finish_step() yield streaming_service.format_finish_step()

View file

@ -0,0 +1,83 @@
"""Unit tests for memory scope validation."""
import pytest
from app.agents.new_chat.tools.update_memory import _save_memory, _validate_memory_scope
pytestmark = pytest.mark.unit
class _Recorder:
def __init__(self) -> None:
self.applied_content: str | None = None
self.commit_calls = 0
self.rollback_calls = 0
def apply(self, content: str) -> None:
self.applied_content = content
async def commit(self) -> None:
self.commit_calls += 1
async def rollback(self) -> None:
self.rollback_calls += 1
def test_validate_memory_scope_rejects_user_sections_in_team_scope() -> None:
content = "## About the user\n- (2026-04-10) Student studying DSA\n"
result = _validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
assert "personal sections" in result["message"]
def test_validate_memory_scope_rejects_team_sections_in_user_scope() -> None:
content = "## Team decisions\n- (2026-04-10) Python-first backend policy\n"
result = _validate_memory_scope(content, "user")
assert result is not None
assert result["status"] == "error"
assert "team sections" in result["message"]
def test_validate_memory_scope_normalizes_heading_case_and_spacing() -> None:
content = "## About The User \n- (2026-04-10) Student\n"
result = _validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
@pytest.mark.asyncio
async def test_save_memory_blocks_cross_scope_write_before_commit() -> None:
recorder = _Recorder()
result = await _save_memory(
updated_memory="## About the user\n- (2026-04-10) Student\n",
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="team memory",
scope="team",
)
assert result["status"] == "error"
assert recorder.commit_calls == 0
assert recorder.applied_content is None
@pytest.mark.asyncio
async def test_save_memory_allows_valid_scope_and_commits() -> None:
recorder = _Recorder()
content = "## Team decisions\n- (2026-04-10) Python-first backend policy\n"
result = await _save_memory(
updated_memory=content,
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="team memory",
scope="team",
)
assert result["status"] == "saved"
assert recorder.commit_calls == 1
assert recorder.applied_content == content