mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
feat: implement team memory extraction and validation mechanisms, enhancing memory management by enforcing scope restrictions and improving memory persistence for shared threads
This commit is contained in:
parent
33626d4f91
commit
a0883d2ab6
8 changed files with 322 additions and 49 deletions
|
|
@ -1,11 +1,8 @@
|
|||
"""Background memory extraction for the SurfSense agent.
|
||||
|
||||
After each agent response, if the agent did not call ``update_memory`` during
|
||||
the turn, this module runs a lightweight LLM call to decide whether the user's
|
||||
message contains any long-term information worth persisting.
|
||||
|
||||
Only user (personal) memory is handled here — team memory relies on explicit
|
||||
agent calls.
|
||||
the turn, this module can run a lightweight LLM call to decide whether the
|
||||
latest message contains long-term information worth persisting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -18,7 +15,7 @@ from langchain_core.messages import HumanMessage
|
|||
from sqlalchemy import select
|
||||
|
||||
from app.agents.new_chat.tools.update_memory import _save_memory
|
||||
from app.db import User, shielded_async_session
|
||||
from app.db import SearchSpace, User, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -55,6 +52,51 @@ If nothing is worth remembering, output exactly: NO_UPDATE
|
|||
{user_message}
|
||||
</user_message>"""
|
||||
|
||||
_TEAM_MEMORY_EXTRACT_PROMPT = """\
|
||||
You are a team-memory extraction assistant. Analyze the latest message and \
|
||||
decide if it contains durable TEAM-level information worth persisting.
|
||||
|
||||
High-precision rule: if uncertain, output NO_UPDATE.
|
||||
|
||||
Worth remembering (team-level only):
|
||||
- Explicit decisions (e.g. "we decided to use X")
|
||||
- Team conventions/standards (naming, review policy, coding norms)
|
||||
- Long-lived architecture/process facts
|
||||
- Stable project constraints, owners, recurring schedules
|
||||
|
||||
NOT worth remembering:
|
||||
- Personal preferences or biography of one person
|
||||
- Questions, brainstorming, tentative ideas, or speculation
|
||||
- One-off requests, status updates, TODOs, logistics for this session
|
||||
- Anything not clearly adopted by the team
|
||||
|
||||
If the message contains memorizable team information, output the FULL updated \
|
||||
team memory document with new facts merged into existing content. Follow rules:
|
||||
- Use the same ## section structure as the existing memory.
|
||||
- Keep entries as single concise bullet points (under 120 chars each).
|
||||
- Every bullet MUST start with a (YYYY-MM-DD) date prefix.
|
||||
- If a new fact contradicts an existing entry, update the existing entry.
|
||||
- Do not duplicate existing information.
|
||||
- NEVER use personal sections like "## About the user", "## Preferences", \
|
||||
or "## Instructions".
|
||||
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
|
||||
- Standard sections: "## Team decisions", "## Team conventions", \
|
||||
"## Key facts", "## Current priorities"
|
||||
|
||||
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||
|
||||
<current_team_memory>
|
||||
{current_memory}
|
||||
</current_team_memory>
|
||||
|
||||
<latest_message_author>
|
||||
{author}
|
||||
</latest_message_author>
|
||||
|
||||
<latest_message>
|
||||
{user_message}
|
||||
</latest_message>"""
|
||||
|
||||
|
||||
async def extract_and_save_memory(
|
||||
*,
|
||||
|
|
@ -105,6 +147,7 @@ async def extract_and_save_memory(
|
|||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
logger.info(
|
||||
"Background memory extraction for user %s: %s",
|
||||
|
|
@ -113,3 +156,69 @@ async def extract_and_save_memory(
|
|||
)
|
||||
except Exception:
|
||||
logger.exception("Background user memory extraction failed")
|
||||
|
||||
|
||||
async def extract_and_save_team_memory(
|
||||
*,
|
||||
user_message: str,
|
||||
search_space_id: int | None,
|
||||
llm: Any,
|
||||
author_display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Background task: extract team-level memory and persist it.
|
||||
|
||||
Runs only for shared threads. Designed to be fire-and-forget and catches
|
||||
exceptions internally.
|
||||
"""
|
||||
if not search_space_id:
|
||||
return
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
|
||||
current_memory=old_memory or "(empty)",
|
||||
author=author_display_name or "Unknown team member",
|
||||
user_message=user_message,
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
|
||||
)
|
||||
text = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
).strip()
|
||||
|
||||
if text == "NO_UPDATE" or not text:
|
||||
logger.debug(
|
||||
"Team memory extraction: no update needed (space %s)",
|
||||
search_space_id,
|
||||
)
|
||||
return
|
||||
|
||||
save_result = await _save_memory(
|
||||
updated_memory=text,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
logger.info(
|
||||
"Background team memory extraction for space %s: %s",
|
||||
search_space_id,
|
||||
save_result.get("status"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Background team memory extraction failed")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
"""Memory injection middleware for the SurfSense agent.
|
||||
|
||||
Loads the user's personal memory (User.memory_md) and, for shared threads,
|
||||
the team memory (SearchSpace.shared_memory_md) from the database and injects
|
||||
them into the system prompt as <user_memory> / <team_memory> XML blocks on
|
||||
every turn. This ensures the LLM always has the full memory context without
|
||||
requiring a tool call.
|
||||
Injects memory markdown into the system prompt on every turn:
|
||||
- Private threads: only personal memory (<user_memory>)
|
||||
- Shared threads: only team memory (<team_memory>)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -58,7 +56,25 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
memory_blocks: list[str] = []
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
if self.user_id is not None:
|
||||
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||
team_memory = await self._load_team_memory(session)
|
||||
if team_memory:
|
||||
chars = len(team_memory)
|
||||
memory_blocks.append(
|
||||
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{team_memory}\n"
|
||||
f"</team_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Team memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
elif self.user_id is not None:
|
||||
user_memory, display_name = await self._load_user_memory(session)
|
||||
if display_name:
|
||||
first_name = display_name.split()[0]
|
||||
|
|
@ -80,25 +96,6 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
f"</memory_warning>"
|
||||
)
|
||||
|
||||
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||
team_memory = await self._load_team_memory(session)
|
||||
if team_memory:
|
||||
chars = len(team_memory)
|
||||
memory_blocks.append(
|
||||
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{team_memory}\n"
|
||||
f"</team_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Team memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
|
||||
if not memory_blocks:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -284,9 +284,13 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
|||
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
|
||||
- Keep it concise and well under the character limit shown in <user_memory>.
|
||||
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
|
||||
## About the user — role, background, company
|
||||
## Preferences — languages, tools, frameworks, response style
|
||||
## Instructions — standing instructions, things to always/never do
|
||||
## About the user
|
||||
## Preferences
|
||||
## Instructions
|
||||
- Section guidance:
|
||||
* About the user: role, background, company, durable identity context
|
||||
* Preferences: languages, tools, frameworks, response style preferences
|
||||
* Instructions: standing instructions, things to always/never do
|
||||
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
|
||||
- During consolidation, prioritize keeping: identity/instructions > preferences.
|
||||
""",
|
||||
|
|
@ -295,6 +299,8 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
|||
- Your current team memory is already in <team_memory> in your context. The `chars`
|
||||
and `limit` attributes show current usage and the maximum allowed size.
|
||||
- This is the team's curated long-term memory — decisions, conventions, key facts.
|
||||
- NEVER store personal memory in team memory (e.g. personal bio, individual
|
||||
preferences, or user-only standing instructions).
|
||||
- Call update_memory when:
|
||||
* A team member explicitly asks to remember or forget something
|
||||
* The conversation surfaces durable team decisions, conventions, or facts
|
||||
|
|
@ -308,10 +314,15 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
|||
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
|
||||
- Keep it concise and well under the character limit shown in <team_memory>.
|
||||
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
|
||||
## Team decisions — agreed-upon choices with rationale
|
||||
## Conventions — coding standards, tools, processes, naming patterns
|
||||
## Key facts — where things are, how things work, team structure
|
||||
## Current priorities — active projects, deadlines, blockers
|
||||
## Team decisions
|
||||
## Conventions
|
||||
## Key facts
|
||||
## Current priorities
|
||||
- Section guidance:
|
||||
* Team decisions: agreed choices and durable technical/product decisions
|
||||
* Conventions: coding standards, tools, processes, naming patterns
|
||||
* Key facts: stable facts about org/team/system setup
|
||||
* Current priorities: active projects, near-term goals, important blockers
|
||||
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
|
||||
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
|
||||
""",
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -34,6 +34,15 @@ MEMORY_SOFT_LIMIT = 18_000
|
|||
MEMORY_HARD_LIMIT = 25_000
|
||||
|
||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
||||
|
||||
_USER_ONLY_HEADINGS = {"about the user", "preferences", "instructions"}
|
||||
_TEAM_ONLY_HEADINGS = {
|
||||
"team decisions",
|
||||
"conventions",
|
||||
"key facts",
|
||||
"current priorities",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -46,6 +55,45 @@ def _extract_headings(memory: str) -> set[str]:
|
|||
return set(_SECTION_HEADING_RE.findall(memory))
|
||||
|
||||
|
||||
def _normalize_heading(heading: str) -> str:
|
||||
"""Normalize heading text for robust scope checks."""
|
||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
||||
|
||||
|
||||
def _validate_memory_scope(
|
||||
content: str, scope: Literal["user", "team"]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Reject cross-scope headings (user sections in team memory and vice versa)."""
|
||||
headings = {_normalize_heading(h) for h in _extract_headings(content)}
|
||||
if not headings:
|
||||
return None
|
||||
|
||||
if scope == "team":
|
||||
leaked = sorted(headings & _USER_ONLY_HEADINGS)
|
||||
if leaked:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Team memory cannot include personal sections: "
|
||||
+ ", ".join(leaked)
|
||||
+ ". Use team sections only."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
leaked = sorted(headings & _TEAM_ONLY_HEADINGS)
|
||||
if leaked:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"User memory cannot include team sections: "
|
||||
+ ", ".join(leaked)
|
||||
+ ". Use personal sections only."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||
"""Return a list of warning strings about suspicious changes."""
|
||||
if not old_memory:
|
||||
|
|
@ -166,6 +214,7 @@ async def _save_memory(
|
|||
commit_fn,
|
||||
rollback_fn,
|
||||
label: str,
|
||||
scope: Literal["user", "team"],
|
||||
) -> dict[str, Any]:
|
||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||
return a response dict.
|
||||
|
|
@ -200,6 +249,10 @@ async def _save_memory(
|
|||
if size_err:
|
||||
return size_err
|
||||
|
||||
scope_err = _validate_memory_scope(content, scope)
|
||||
if scope_err:
|
||||
return scope_err
|
||||
|
||||
# --- persist ---
|
||||
try:
|
||||
apply_fn(content)
|
||||
|
|
@ -270,6 +323,7 @@ def create_update_memory_tool(
|
|||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update user memory: %s", e)
|
||||
|
|
@ -319,6 +373,7 @@ def create_update_team_memory_tool(
|
|||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update team memory: %s", e)
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ async def edit_user_memory(
|
|||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
|
||||
if result.get("status") == "error":
|
||||
|
|
|
|||
|
|
@ -56,7 +56,9 @@ RULES:
|
|||
2. If the instruction asks to remove something, remove the matching entry.
|
||||
3. If the instruction asks to change something, update the matching entry.
|
||||
4. Preserve the existing ## section structure and all other entries.
|
||||
5. Output ONLY the updated markdown — no explanations, no wrapping.
|
||||
5. NEVER use personal sections like "## About the user", "## Preferences", or
|
||||
"## Instructions". Team memory must stay team-scoped.
|
||||
6. Output ONLY the updated markdown — no explanations, no wrapping.
|
||||
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
|
|
@ -372,6 +374,7 @@ async def edit_team_memory(
|
|||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
|
||||
if save_result.get("status") == "error":
|
||||
|
|
|
|||
|
|
@ -37,7 +37,10 @@ from app.agents.new_chat.llm_config import (
|
|||
load_agent_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.agents.new_chat.memory_extraction import extract_and_save_memory
|
||||
from app.agents.new_chat.memory_extraction import (
|
||||
extract_and_save_memory,
|
||||
extract_and_save_team_memory,
|
||||
)
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
|
|
@ -1545,15 +1548,26 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Fire background memory extraction if the agent didn't handle it
|
||||
if not stream_result.agent_called_update_memory and user_id:
|
||||
asyncio.create_task(
|
||||
extract_and_save_memory(
|
||||
user_message=user_query,
|
||||
user_id=user_id,
|
||||
llm=llm,
|
||||
# Fire background memory extraction if the agent didn't handle it.
|
||||
# Shared threads write to team memory; private threads write to user memory.
|
||||
if not stream_result.agent_called_update_memory:
|
||||
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||
asyncio.create_task(
|
||||
extract_and_save_team_memory(
|
||||
user_message=user_query,
|
||||
search_space_id=search_space_id,
|
||||
llm=llm,
|
||||
author_display_name=current_user_display_name,
|
||||
)
|
||||
)
|
||||
elif user_id:
|
||||
asyncio.create_task(
|
||||
extract_and_save_memory(
|
||||
user_message=user_query,
|
||||
user_id=user_id,
|
||||
llm=llm,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_finish_step()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,83 @@
|
|||
"""Unit tests for memory scope validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.tools.update_memory import _save_memory, _validate_memory_scope
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _Recorder:
|
||||
def __init__(self) -> None:
|
||||
self.applied_content: str | None = None
|
||||
self.commit_calls = 0
|
||||
self.rollback_calls = 0
|
||||
|
||||
def apply(self, content: str) -> None:
|
||||
self.applied_content = content
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commit_calls += 1
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollback_calls += 1
|
||||
|
||||
|
||||
def test_validate_memory_scope_rejects_user_sections_in_team_scope() -> None:
|
||||
content = "## About the user\n- (2026-04-10) Student studying DSA\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is not None
|
||||
assert result["status"] == "error"
|
||||
assert "personal sections" in result["message"]
|
||||
|
||||
|
||||
def test_validate_memory_scope_rejects_team_sections_in_user_scope() -> None:
|
||||
content = "## Team decisions\n- (2026-04-10) Python-first backend policy\n"
|
||||
result = _validate_memory_scope(content, "user")
|
||||
assert result is not None
|
||||
assert result["status"] == "error"
|
||||
assert "team sections" in result["message"]
|
||||
|
||||
|
||||
def test_validate_memory_scope_normalizes_heading_case_and_spacing() -> None:
|
||||
content = "## About The User \n- (2026-04-10) Student\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is not None
|
||||
assert result["status"] == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_blocks_cross_scope_write_before_commit() -> None:
|
||||
recorder = _Recorder()
|
||||
result = await _save_memory(
|
||||
updated_memory="## About the user\n- (2026-04-10) Student\n",
|
||||
old_memory=None,
|
||||
llm=None,
|
||||
apply_fn=recorder.apply,
|
||||
commit_fn=recorder.commit,
|
||||
rollback_fn=recorder.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert recorder.commit_calls == 0
|
||||
assert recorder.applied_content is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_allows_valid_scope_and_commits() -> None:
|
||||
recorder = _Recorder()
|
||||
content = "## Team decisions\n- (2026-04-10) Python-first backend policy\n"
|
||||
result = await _save_memory(
|
||||
updated_memory=content,
|
||||
old_memory=None,
|
||||
llm=None,
|
||||
apply_fn=recorder.apply,
|
||||
commit_fn=recorder.commit,
|
||||
rollback_fn=recorder.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
assert result["status"] == "saved"
|
||||
assert recorder.commit_calls == 1
|
||||
assert recorder.applied_content == content
|
||||
Loading…
Add table
Add a link
Reference in a new issue