feat: implement team memory extraction and validation mechanisms, enhancing memory management by enforcing scope restrictions and improving memory persistence for shared threads

This commit is contained in:
Anish Sarkar 2026-04-10 01:54:00 +05:30
parent 33626d4f91
commit a0883d2ab6
8 changed files with 322 additions and 49 deletions

View file

@ -1,11 +1,8 @@
"""Background memory extraction for the SurfSense agent.
After each agent response, if the agent did not call ``update_memory`` during
the turn, this module runs a lightweight LLM call to decide whether the user's
message contains any long-term information worth persisting.
Only user (personal) memory is handled here team memory relies on explicit
agent calls.
the turn, this module can run a lightweight LLM call to decide whether the
latest message contains long-term information worth persisting.
"""
from __future__ import annotations
@ -18,7 +15,7 @@ from langchain_core.messages import HumanMessage
from sqlalchemy import select
from app.agents.new_chat.tools.update_memory import _save_memory
from app.db import User, shielded_async_session
from app.db import SearchSpace, User, shielded_async_session
logger = logging.getLogger(__name__)
@ -55,6 +52,51 @@ If nothing is worth remembering, output exactly: NO_UPDATE
{user_message}
</user_message>"""
_TEAM_MEMORY_EXTRACT_PROMPT = """\
You are a team-memory extraction assistant. Analyze the latest message and \
decide if it contains durable TEAM-level information worth persisting.
High-precision rule: if uncertain, output NO_UPDATE.
Worth remembering (team-level only):
- Explicit decisions (e.g. "we decided to use X")
- Team conventions/standards (naming, review policy, coding norms)
- Long-lived architecture/process facts
- Stable project constraints, owners, recurring schedules
NOT worth remembering:
- Personal preferences or biography of one person
- Questions, brainstorming, tentative ideas, or speculation
- One-off requests, status updates, TODOs, logistics for this session
- Anything not clearly adopted by the team
If the message contains memorizable team information, output the FULL updated \
team memory document with new facts merged into existing content. Follow rules:
- Use the same ## section structure as the existing memory.
- Keep entries as single concise bullet points (under 120 chars each).
- Every bullet MUST start with a (YYYY-MM-DD) date prefix.
- If a new fact contradicts an existing entry, update the existing entry.
- Do not duplicate existing information.
- NEVER use personal sections like "## About the user", "## Preferences", \
or "## Instructions".
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
- Standard sections: "## Team decisions", "## Team conventions", \
"## Key facts", "## Current priorities"
If nothing is worth remembering, output exactly: NO_UPDATE
<current_team_memory>
{current_memory}
</current_team_memory>
<latest_message_author>
{author}
</latest_message_author>
<latest_message>
{user_message}
</latest_message>"""
async def extract_and_save_memory(
*,
@ -105,6 +147,7 @@ async def extract_and_save_memory(
commit_fn=session.commit,
rollback_fn=session.rollback,
label="memory",
scope="user",
)
logger.info(
"Background memory extraction for user %s: %s",
@ -113,3 +156,69 @@ async def extract_and_save_memory(
)
except Exception:
logger.exception("Background user memory extraction failed")
async def extract_and_save_team_memory(
*,
user_message: str,
search_space_id: int | None,
llm: Any,
author_display_name: str | None = None,
) -> None:
"""Background task: extract team-level memory and persist it.
Runs only for shared threads. Designed to be fire-and-forget and catches
exceptions internally.
"""
if not search_space_id:
return
try:
async with shielded_async_session() as session:
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return
old_memory = space.shared_memory_md
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
current_memory=old_memory or "(empty)",
author=author_display_name or "Unknown team member",
user_message=user_message,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
if text == "NO_UPDATE" or not text:
logger.debug(
"Team memory extraction: no update needed (space %s)",
search_space_id,
)
return
save_result = await _save_memory(
updated_memory=text,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="team memory",
scope="team",
)
logger.info(
"Background team memory extraction for space %s: %s",
search_space_id,
save_result.get("status"),
)
except Exception:
logger.exception("Background team memory extraction failed")

View file

@ -1,10 +1,8 @@
"""Memory injection middleware for the SurfSense agent.
Loads the user's personal memory (User.memory_md) and, for shared threads,
the team memory (SearchSpace.shared_memory_md) from the database and injects
them into the system prompt as <user_memory> / <team_memory> XML blocks on
every turn. This ensures the LLM always has the full memory context without
requiring a tool call.
Injects memory markdown into the system prompt on every turn:
- Private threads: only personal memory (<user_memory>)
- Shared threads: only team memory (<team_memory>)
"""
from __future__ import annotations
@ -58,7 +56,25 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
memory_blocks: list[str] = []
async with shielded_async_session() as session:
if self.user_id is not None:
if self.visibility == ChatVisibility.SEARCH_SPACE:
team_memory = await self._load_team_memory(session)
if team_memory:
chars = len(team_memory)
memory_blocks.append(
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
f"{team_memory}\n"
f"</team_memory>"
)
if chars > MEMORY_SOFT_LIMIT:
memory_blocks.append(
f"<memory_warning>Team memory is at "
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
f"the hard limit. On your next update_memory call, consolidate "
f"by merging duplicates, removing outdated entries, and "
f"shortening descriptions before adding anything new."
f"</memory_warning>"
)
elif self.user_id is not None:
user_memory, display_name = await self._load_user_memory(session)
if display_name:
first_name = display_name.split()[0]
@ -80,25 +96,6 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
f"</memory_warning>"
)
if self.visibility == ChatVisibility.SEARCH_SPACE:
team_memory = await self._load_team_memory(session)
if team_memory:
chars = len(team_memory)
memory_blocks.append(
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
f"{team_memory}\n"
f"</team_memory>"
)
if chars > MEMORY_SOFT_LIMIT:
memory_blocks.append(
f"<memory_warning>Team memory is at "
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
f"the hard limit. On your next update_memory call, consolidate "
f"by merging duplicates, removing outdated entries, and "
f"shortening descriptions before adding anything new."
f"</memory_warning>"
)
if not memory_blocks:
return None

View file

@ -284,9 +284,13 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
- Keep it concise and well under the character limit shown in <user_memory>.
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
## About the user — role, background, company
## Preferences — languages, tools, frameworks, response style
## Instructions — standing instructions, things to always/never do
## About the user
## Preferences
## Instructions
- Section guidance:
* About the user: role, background, company, durable identity context
* Preferences: languages, tools, frameworks, response style preferences
* Instructions: standing instructions, things to always/never do
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
- During consolidation, prioritize keeping: identity/instructions > preferences.
""",
@ -295,6 +299,8 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
- Your current team memory is already in <team_memory> in your context. The `chars`
and `limit` attributes show current usage and the maximum allowed size.
- This is the team's curated long-term memory — decisions, conventions, key facts.
- NEVER store personal memory in team memory (e.g. personal bio, individual
preferences, or user-only standing instructions).
- Call update_memory when:
* A team member explicitly asks to remember or forget something
* The conversation surfaces durable team decisions, conventions, or facts
@ -308,10 +314,15 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
- Keep it concise and well under the character limit shown in <team_memory>.
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
## Team decisions — agreed-upon choices with rationale
## Conventions — coding standards, tools, processes, naming patterns
## Key facts — where things are, how things work, team structure
## Current priorities — active projects, deadlines, blockers
## Team decisions
## Conventions
## Key facts
## Current priorities
- Section guidance:
* Team decisions: agreed choices and durable technical/product decisions
* Conventions: coding standards, tools, processes, naming patterns
* Key facts: stable facts about org/team/system setup
* Current priorities: active projects, near-term goals, important blockers
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
""",

View file

@ -18,7 +18,7 @@ from __future__ import annotations
import logging
import re
from typing import Any
from typing import Any, Literal
from uuid import UUID
from langchain_core.messages import HumanMessage
@ -34,6 +34,15 @@ MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
_USER_ONLY_HEADINGS = {"about the user", "preferences", "instructions"}
_TEAM_ONLY_HEADINGS = {
"team decisions",
"conventions",
"key facts",
"current priorities",
}
# ---------------------------------------------------------------------------
@ -46,6 +55,45 @@ def _extract_headings(memory: str) -> set[str]:
return set(_SECTION_HEADING_RE.findall(memory))
def _normalize_heading(heading: str) -> str:
"""Normalize heading text for robust scope checks."""
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
def _validate_memory_scope(
content: str, scope: Literal["user", "team"]
) -> dict[str, Any] | None:
"""Reject cross-scope headings (user sections in team memory and vice versa)."""
headings = {_normalize_heading(h) for h in _extract_headings(content)}
if not headings:
return None
if scope == "team":
leaked = sorted(headings & _USER_ONLY_HEADINGS)
if leaked:
return {
"status": "error",
"message": (
"Team memory cannot include personal sections: "
+ ", ".join(leaked)
+ ". Use team sections only."
),
}
return None
leaked = sorted(headings & _TEAM_ONLY_HEADINGS)
if leaked:
return {
"status": "error",
"message": (
"User memory cannot include team sections: "
+ ", ".join(leaked)
+ ". Use personal sections only."
),
}
return None
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
"""Return a list of warning strings about suspicious changes."""
if not old_memory:
@ -166,6 +214,7 @@ async def _save_memory(
commit_fn,
rollback_fn,
label: str,
scope: Literal["user", "team"],
) -> dict[str, Any]:
"""Validate, optionally force-rewrite if over the hard limit, save, and
return a response dict.
@ -200,6 +249,10 @@ async def _save_memory(
if size_err:
return size_err
scope_err = _validate_memory_scope(content, scope)
if scope_err:
return scope_err
# --- persist ---
try:
apply_fn(content)
@ -270,6 +323,7 @@ def create_update_memory_tool(
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
scope="user",
)
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
@ -319,6 +373,7 @@ def create_update_team_memory_tool(
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
scope="team",
)
except Exception as e:
logger.exception("Failed to update team memory: %s", e)

View file

@ -132,6 +132,7 @@ async def edit_user_memory(
commit_fn=session.commit,
rollback_fn=session.rollback,
label="memory",
scope="user",
)
if result.get("status") == "error":

View file

@ -56,7 +56,9 @@ RULES:
2. If the instruction asks to remove something, remove the matching entry.
3. If the instruction asks to change something, update the matching entry.
4. Preserve the existing ## section structure and all other entries.
5. Output ONLY the updated markdown no explanations, no wrapping.
5. NEVER use personal sections like "## About the user", "## Preferences", or
"## Instructions". Team memory must stay team-scoped.
6. Output ONLY the updated markdown no explanations, no wrapping.
<current_memory>
{current_memory}
@ -372,6 +374,7 @@ async def edit_team_memory(
commit_fn=session.commit,
rollback_fn=session.rollback,
label="team memory",
scope="team",
)
if save_result.get("status") == "error":

View file

@ -37,7 +37,10 @@ from app.agents.new_chat.llm_config import (
load_agent_config,
load_llm_config_from_yaml,
)
from app.agents.new_chat.memory_extraction import extract_and_save_memory
from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
from app.db import (
ChatVisibility,
NewChatMessage,
@ -1545,15 +1548,26 @@ async def stream_new_chat(
chat_id, generated_title
)
# Fire background memory extraction if the agent didn't handle it
if not stream_result.agent_called_update_memory and user_id:
asyncio.create_task(
extract_and_save_memory(
user_message=user_query,
user_id=user_id,
llm=llm,
# Fire background memory extraction if the agent didn't handle it.
# Shared threads write to team memory; private threads write to user memory.
if not stream_result.agent_called_update_memory:
if visibility == ChatVisibility.SEARCH_SPACE:
asyncio.create_task(
extract_and_save_team_memory(
user_message=user_query,
search_space_id=search_space_id,
llm=llm,
author_display_name=current_user_display_name,
)
)
elif user_id:
asyncio.create_task(
extract_and_save_memory(
user_message=user_query,
user_id=user_id,
llm=llm,
)
)
)
# Finish the step and message
yield streaming_service.format_finish_step()

View file

@ -0,0 +1,83 @@
"""Unit tests for memory scope validation."""
import pytest
from app.agents.new_chat.tools.update_memory import _save_memory, _validate_memory_scope
pytestmark = pytest.mark.unit
class _Recorder:
def __init__(self) -> None:
self.applied_content: str | None = None
self.commit_calls = 0
self.rollback_calls = 0
def apply(self, content: str) -> None:
self.applied_content = content
async def commit(self) -> None:
self.commit_calls += 1
async def rollback(self) -> None:
self.rollback_calls += 1
def test_validate_memory_scope_rejects_user_sections_in_team_scope() -> None:
content = "## About the user\n- (2026-04-10) Student studying DSA\n"
result = _validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
assert "personal sections" in result["message"]
def test_validate_memory_scope_rejects_team_sections_in_user_scope() -> None:
content = "## Team decisions\n- (2026-04-10) Python-first backend policy\n"
result = _validate_memory_scope(content, "user")
assert result is not None
assert result["status"] == "error"
assert "team sections" in result["message"]
def test_validate_memory_scope_normalizes_heading_case_and_spacing() -> None:
content = "## About The User \n- (2026-04-10) Student\n"
result = _validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
@pytest.mark.asyncio
async def test_save_memory_blocks_cross_scope_write_before_commit() -> None:
recorder = _Recorder()
result = await _save_memory(
updated_memory="## About the user\n- (2026-04-10) Student\n",
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="team memory",
scope="team",
)
assert result["status"] == "error"
assert recorder.commit_calls == 0
assert recorder.applied_content is None
@pytest.mark.asyncio
async def test_save_memory_allows_valid_scope_and_commits() -> None:
recorder = _Recorder()
content = "## Team decisions\n- (2026-04-10) Python-first backend policy\n"
result = await _save_memory(
updated_memory=content,
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="team memory",
scope="team",
)
assert result["status"] == "saved"
assert recorder.commit_calls == 1
assert recorder.applied_content == content