mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-24 21:38:09 +02:00
feat: implement team memory extraction and validation mechanisms, enhancing memory management by enforcing scope restrictions and improving memory persistence for shared threads
This commit is contained in:
parent
33626d4f91
commit
a0883d2ab6
8 changed files with 322 additions and 49 deletions
|
|
@ -1,11 +1,8 @@
|
||||||
"""Background memory extraction for the SurfSense agent.
|
"""Background memory extraction for the SurfSense agent.
|
||||||
|
|
||||||
After each agent response, if the agent did not call ``update_memory`` during
|
After each agent response, if the agent did not call ``update_memory`` during
|
||||||
the turn, this module runs a lightweight LLM call to decide whether the user's
|
the turn, this module can run a lightweight LLM call to decide whether the
|
||||||
message contains any long-term information worth persisting.
|
latest message contains long-term information worth persisting.
|
||||||
|
|
||||||
Only user (personal) memory is handled here — team memory relies on explicit
|
|
||||||
agent calls.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -18,7 +15,7 @@ from langchain_core.messages import HumanMessage
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.agents.new_chat.tools.update_memory import _save_memory
|
from app.agents.new_chat.tools.update_memory import _save_memory
|
||||||
from app.db import User, shielded_async_session
|
from app.db import SearchSpace, User, shielded_async_session
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -55,6 +52,51 @@ If nothing is worth remembering, output exactly: NO_UPDATE
|
||||||
{user_message}
|
{user_message}
|
||||||
</user_message>"""
|
</user_message>"""
|
||||||
|
|
||||||
|
_TEAM_MEMORY_EXTRACT_PROMPT = """\
|
||||||
|
You are a team-memory extraction assistant. Analyze the latest message and \
|
||||||
|
decide if it contains durable TEAM-level information worth persisting.
|
||||||
|
|
||||||
|
High-precision rule: if uncertain, output NO_UPDATE.
|
||||||
|
|
||||||
|
Worth remembering (team-level only):
|
||||||
|
- Explicit decisions (e.g. "we decided to use X")
|
||||||
|
- Team conventions/standards (naming, review policy, coding norms)
|
||||||
|
- Long-lived architecture/process facts
|
||||||
|
- Stable project constraints, owners, recurring schedules
|
||||||
|
|
||||||
|
NOT worth remembering:
|
||||||
|
- Personal preferences or biography of one person
|
||||||
|
- Questions, brainstorming, tentative ideas, or speculation
|
||||||
|
- One-off requests, status updates, TODOs, logistics for this session
|
||||||
|
- Anything not clearly adopted by the team
|
||||||
|
|
||||||
|
If the message contains memorizable team information, output the FULL updated \
|
||||||
|
team memory document with new facts merged into existing content. Follow rules:
|
||||||
|
- Use the same ## section structure as the existing memory.
|
||||||
|
- Keep entries as single concise bullet points (under 120 chars each).
|
||||||
|
- Every bullet MUST start with a (YYYY-MM-DD) date prefix.
|
||||||
|
- If a new fact contradicts an existing entry, update the existing entry.
|
||||||
|
- Do not duplicate existing information.
|
||||||
|
- NEVER use personal sections like "## About the user", "## Preferences", \
|
||||||
|
or "## Instructions".
|
||||||
|
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
|
||||||
|
- Standard sections: "## Team decisions", "## Team conventions", \
|
||||||
|
"## Key facts", "## Current priorities"
|
||||||
|
|
||||||
|
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||||
|
|
||||||
|
<current_team_memory>
|
||||||
|
{current_memory}
|
||||||
|
</current_team_memory>
|
||||||
|
|
||||||
|
<latest_message_author>
|
||||||
|
{author}
|
||||||
|
</latest_message_author>
|
||||||
|
|
||||||
|
<latest_message>
|
||||||
|
{user_message}
|
||||||
|
</latest_message>"""
|
||||||
|
|
||||||
|
|
||||||
async def extract_and_save_memory(
|
async def extract_and_save_memory(
|
||||||
*,
|
*,
|
||||||
|
|
@ -105,6 +147,7 @@ async def extract_and_save_memory(
|
||||||
commit_fn=session.commit,
|
commit_fn=session.commit,
|
||||||
rollback_fn=session.rollback,
|
rollback_fn=session.rollback,
|
||||||
label="memory",
|
label="memory",
|
||||||
|
scope="user",
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Background memory extraction for user %s: %s",
|
"Background memory extraction for user %s: %s",
|
||||||
|
|
@ -113,3 +156,69 @@ async def extract_and_save_memory(
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Background user memory extraction failed")
|
logger.exception("Background user memory extraction failed")
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_and_save_team_memory(
|
||||||
|
*,
|
||||||
|
user_message: str,
|
||||||
|
search_space_id: int | None,
|
||||||
|
llm: Any,
|
||||||
|
author_display_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Background task: extract team-level memory and persist it.
|
||||||
|
|
||||||
|
Runs only for shared threads. Designed to be fire-and-forget and catches
|
||||||
|
exceptions internally.
|
||||||
|
"""
|
||||||
|
if not search_space_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||||
|
)
|
||||||
|
space = result.scalars().first()
|
||||||
|
if not space:
|
||||||
|
return
|
||||||
|
|
||||||
|
old_memory = space.shared_memory_md
|
||||||
|
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
|
||||||
|
current_memory=old_memory or "(empty)",
|
||||||
|
author=author_display_name or "Unknown team member",
|
||||||
|
user_message=user_message,
|
||||||
|
)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
|
||||||
|
)
|
||||||
|
text = (
|
||||||
|
response.content
|
||||||
|
if isinstance(response.content, str)
|
||||||
|
else str(response.content)
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
if text == "NO_UPDATE" or not text:
|
||||||
|
logger.debug(
|
||||||
|
"Team memory extraction: no update needed (space %s)",
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
save_result = await _save_memory(
|
||||||
|
updated_memory=text,
|
||||||
|
old_memory=old_memory,
|
||||||
|
llm=llm,
|
||||||
|
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||||
|
commit_fn=session.commit,
|
||||||
|
rollback_fn=session.rollback,
|
||||||
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Background team memory extraction for space %s: %s",
|
||||||
|
search_space_id,
|
||||||
|
save_result.get("status"),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Background team memory extraction failed")
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
"""Memory injection middleware for the SurfSense agent.
|
"""Memory injection middleware for the SurfSense agent.
|
||||||
|
|
||||||
Loads the user's personal memory (User.memory_md) and, for shared threads,
|
Injects memory markdown into the system prompt on every turn:
|
||||||
the team memory (SearchSpace.shared_memory_md) from the database and injects
|
- Private threads: only personal memory (<user_memory>)
|
||||||
them into the system prompt as <user_memory> / <team_memory> XML blocks on
|
- Shared threads: only team memory (<team_memory>)
|
||||||
every turn. This ensures the LLM always has the full memory context without
|
|
||||||
requiring a tool call.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -58,7 +56,25 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
memory_blocks: list[str] = []
|
memory_blocks: list[str] = []
|
||||||
|
|
||||||
async with shielded_async_session() as session:
|
async with shielded_async_session() as session:
|
||||||
if self.user_id is not None:
|
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||||
|
team_memory = await self._load_team_memory(session)
|
||||||
|
if team_memory:
|
||||||
|
chars = len(team_memory)
|
||||||
|
memory_blocks.append(
|
||||||
|
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||||
|
f"{team_memory}\n"
|
||||||
|
f"</team_memory>"
|
||||||
|
)
|
||||||
|
if chars > MEMORY_SOFT_LIMIT:
|
||||||
|
memory_blocks.append(
|
||||||
|
f"<memory_warning>Team memory is at "
|
||||||
|
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||||
|
f"the hard limit. On your next update_memory call, consolidate "
|
||||||
|
f"by merging duplicates, removing outdated entries, and "
|
||||||
|
f"shortening descriptions before adding anything new."
|
||||||
|
f"</memory_warning>"
|
||||||
|
)
|
||||||
|
elif self.user_id is not None:
|
||||||
user_memory, display_name = await self._load_user_memory(session)
|
user_memory, display_name = await self._load_user_memory(session)
|
||||||
if display_name:
|
if display_name:
|
||||||
first_name = display_name.split()[0]
|
first_name = display_name.split()[0]
|
||||||
|
|
@ -80,25 +96,6 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
f"</memory_warning>"
|
f"</memory_warning>"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
|
||||||
team_memory = await self._load_team_memory(session)
|
|
||||||
if team_memory:
|
|
||||||
chars = len(team_memory)
|
|
||||||
memory_blocks.append(
|
|
||||||
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
|
||||||
f"{team_memory}\n"
|
|
||||||
f"</team_memory>"
|
|
||||||
)
|
|
||||||
if chars > MEMORY_SOFT_LIMIT:
|
|
||||||
memory_blocks.append(
|
|
||||||
f"<memory_warning>Team memory is at "
|
|
||||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
|
||||||
f"the hard limit. On your next update_memory call, consolidate "
|
|
||||||
f"by merging duplicates, removing outdated entries, and "
|
|
||||||
f"shortening descriptions before adding anything new."
|
|
||||||
f"</memory_warning>"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not memory_blocks:
|
if not memory_blocks:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -284,9 +284,13 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
||||||
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
|
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
|
||||||
- Keep it concise and well under the character limit shown in <user_memory>.
|
- Keep it concise and well under the character limit shown in <user_memory>.
|
||||||
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
|
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
|
||||||
## About the user — role, background, company
|
## About the user
|
||||||
## Preferences — languages, tools, frameworks, response style
|
## Preferences
|
||||||
## Instructions — standing instructions, things to always/never do
|
## Instructions
|
||||||
|
- Section guidance:
|
||||||
|
* About the user: role, background, company, durable identity context
|
||||||
|
* Preferences: languages, tools, frameworks, response style preferences
|
||||||
|
* Instructions: standing instructions, things to always/never do
|
||||||
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
|
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
|
||||||
- During consolidation, prioritize keeping: identity/instructions > preferences.
|
- During consolidation, prioritize keeping: identity/instructions > preferences.
|
||||||
""",
|
""",
|
||||||
|
|
@ -295,6 +299,8 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
||||||
- Your current team memory is already in <team_memory> in your context. The `chars`
|
- Your current team memory is already in <team_memory> in your context. The `chars`
|
||||||
and `limit` attributes show current usage and the maximum allowed size.
|
and `limit` attributes show current usage and the maximum allowed size.
|
||||||
- This is the team's curated long-term memory — decisions, conventions, key facts.
|
- This is the team's curated long-term memory — decisions, conventions, key facts.
|
||||||
|
- NEVER store personal memory in team memory (e.g. personal bio, individual
|
||||||
|
preferences, or user-only standing instructions).
|
||||||
- Call update_memory when:
|
- Call update_memory when:
|
||||||
* A team member explicitly asks to remember or forget something
|
* A team member explicitly asks to remember or forget something
|
||||||
* The conversation surfaces durable team decisions, conventions, or facts
|
* The conversation surfaces durable team decisions, conventions, or facts
|
||||||
|
|
@ -308,10 +314,15 @@ _MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
||||||
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
|
- Every bullet MUST start with a (YYYY-MM-DD) date prefix indicating when it was recorded or last updated.
|
||||||
- Keep it concise and well under the character limit shown in <team_memory>.
|
- Keep it concise and well under the character limit shown in <team_memory>.
|
||||||
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
|
- You MUST organize memory using these standard sections (add new `##` sections only if none of the standard ones fit):
|
||||||
## Team decisions — agreed-upon choices with rationale
|
## Team decisions
|
||||||
## Conventions — coding standards, tools, processes, naming patterns
|
## Conventions
|
||||||
## Key facts — where things are, how things work, team structure
|
## Key facts
|
||||||
## Current priorities — active projects, deadlines, blockers
|
## Current priorities
|
||||||
|
- Section guidance:
|
||||||
|
* Team decisions: agreed choices and durable technical/product decisions
|
||||||
|
* Conventions: coding standards, tools, processes, naming patterns
|
||||||
|
* Key facts: stable facts about org/team/system setup
|
||||||
|
* Current priorities: active projects, near-term goals, important blockers
|
||||||
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
|
- Each entry MUST be a single bullet point. Keep entries concise (aim for under 120 chars each).
|
||||||
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
|
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
|
||||||
""",
|
""",
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
@ -34,6 +34,15 @@ MEMORY_SOFT_LIMIT = 18_000
|
||||||
MEMORY_HARD_LIMIT = 25_000
|
MEMORY_HARD_LIMIT = 25_000
|
||||||
|
|
||||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||||
|
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
||||||
|
|
||||||
|
_USER_ONLY_HEADINGS = {"about the user", "preferences", "instructions"}
|
||||||
|
_TEAM_ONLY_HEADINGS = {
|
||||||
|
"team decisions",
|
||||||
|
"conventions",
|
||||||
|
"key facts",
|
||||||
|
"current priorities",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -46,6 +55,45 @@ def _extract_headings(memory: str) -> set[str]:
|
||||||
return set(_SECTION_HEADING_RE.findall(memory))
|
return set(_SECTION_HEADING_RE.findall(memory))
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_heading(heading: str) -> str:
|
||||||
|
"""Normalize heading text for robust scope checks."""
|
||||||
|
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_memory_scope(
|
||||||
|
content: str, scope: Literal["user", "team"]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Reject cross-scope headings (user sections in team memory and vice versa)."""
|
||||||
|
headings = {_normalize_heading(h) for h in _extract_headings(content)}
|
||||||
|
if not headings:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if scope == "team":
|
||||||
|
leaked = sorted(headings & _USER_ONLY_HEADINGS)
|
||||||
|
if leaked:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": (
|
||||||
|
"Team memory cannot include personal sections: "
|
||||||
|
+ ", ".join(leaked)
|
||||||
|
+ ". Use team sections only."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
leaked = sorted(headings & _TEAM_ONLY_HEADINGS)
|
||||||
|
if leaked:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": (
|
||||||
|
"User memory cannot include team sections: "
|
||||||
|
+ ", ".join(leaked)
|
||||||
|
+ ". Use personal sections only."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||||
"""Return a list of warning strings about suspicious changes."""
|
"""Return a list of warning strings about suspicious changes."""
|
||||||
if not old_memory:
|
if not old_memory:
|
||||||
|
|
@ -166,6 +214,7 @@ async def _save_memory(
|
||||||
commit_fn,
|
commit_fn,
|
||||||
rollback_fn,
|
rollback_fn,
|
||||||
label: str,
|
label: str,
|
||||||
|
scope: Literal["user", "team"],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||||
return a response dict.
|
return a response dict.
|
||||||
|
|
@ -200,6 +249,10 @@ async def _save_memory(
|
||||||
if size_err:
|
if size_err:
|
||||||
return size_err
|
return size_err
|
||||||
|
|
||||||
|
scope_err = _validate_memory_scope(content, scope)
|
||||||
|
if scope_err:
|
||||||
|
return scope_err
|
||||||
|
|
||||||
# --- persist ---
|
# --- persist ---
|
||||||
try:
|
try:
|
||||||
apply_fn(content)
|
apply_fn(content)
|
||||||
|
|
@ -270,6 +323,7 @@ def create_update_memory_tool(
|
||||||
commit_fn=db_session.commit,
|
commit_fn=db_session.commit,
|
||||||
rollback_fn=db_session.rollback,
|
rollback_fn=db_session.rollback,
|
||||||
label="memory",
|
label="memory",
|
||||||
|
scope="user",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update user memory: %s", e)
|
logger.exception("Failed to update user memory: %s", e)
|
||||||
|
|
@ -319,6 +373,7 @@ def create_update_team_memory_tool(
|
||||||
commit_fn=db_session.commit,
|
commit_fn=db_session.commit,
|
||||||
rollback_fn=db_session.rollback,
|
rollback_fn=db_session.rollback,
|
||||||
label="team memory",
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update team memory: %s", e)
|
logger.exception("Failed to update team memory: %s", e)
|
||||||
|
|
|
||||||
|
|
@ -132,6 +132,7 @@ async def edit_user_memory(
|
||||||
commit_fn=session.commit,
|
commit_fn=session.commit,
|
||||||
rollback_fn=session.rollback,
|
rollback_fn=session.rollback,
|
||||||
label="memory",
|
label="memory",
|
||||||
|
scope="user",
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.get("status") == "error":
|
if result.get("status") == "error":
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,9 @@ RULES:
|
||||||
2. If the instruction asks to remove something, remove the matching entry.
|
2. If the instruction asks to remove something, remove the matching entry.
|
||||||
3. If the instruction asks to change something, update the matching entry.
|
3. If the instruction asks to change something, update the matching entry.
|
||||||
4. Preserve the existing ## section structure and all other entries.
|
4. Preserve the existing ## section structure and all other entries.
|
||||||
5. Output ONLY the updated markdown — no explanations, no wrapping.
|
5. NEVER use personal sections like "## About the user", "## Preferences", or
|
||||||
|
"## Instructions". Team memory must stay team-scoped.
|
||||||
|
6. Output ONLY the updated markdown — no explanations, no wrapping.
|
||||||
|
|
||||||
<current_memory>
|
<current_memory>
|
||||||
{current_memory}
|
{current_memory}
|
||||||
|
|
@ -372,6 +374,7 @@ async def edit_team_memory(
|
||||||
commit_fn=session.commit,
|
commit_fn=session.commit,
|
||||||
rollback_fn=session.rollback,
|
rollback_fn=session.rollback,
|
||||||
label="team memory",
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
)
|
)
|
||||||
|
|
||||||
if save_result.get("status") == "error":
|
if save_result.get("status") == "error":
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,10 @@ from app.agents.new_chat.llm_config import (
|
||||||
load_agent_config,
|
load_agent_config,
|
||||||
load_llm_config_from_yaml,
|
load_llm_config_from_yaml,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.memory_extraction import extract_and_save_memory
|
from app.agents.new_chat.memory_extraction import (
|
||||||
|
extract_and_save_memory,
|
||||||
|
extract_and_save_team_memory,
|
||||||
|
)
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ChatVisibility,
|
ChatVisibility,
|
||||||
NewChatMessage,
|
NewChatMessage,
|
||||||
|
|
@ -1545,15 +1548,26 @@ async def stream_new_chat(
|
||||||
chat_id, generated_title
|
chat_id, generated_title
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fire background memory extraction if the agent didn't handle it
|
# Fire background memory extraction if the agent didn't handle it.
|
||||||
if not stream_result.agent_called_update_memory and user_id:
|
# Shared threads write to team memory; private threads write to user memory.
|
||||||
asyncio.create_task(
|
if not stream_result.agent_called_update_memory:
|
||||||
extract_and_save_memory(
|
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||||
user_message=user_query,
|
asyncio.create_task(
|
||||||
user_id=user_id,
|
extract_and_save_team_memory(
|
||||||
llm=llm,
|
user_message=user_query,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
llm=llm,
|
||||||
|
author_display_name=current_user_display_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif user_id:
|
||||||
|
asyncio.create_task(
|
||||||
|
extract_and_save_memory(
|
||||||
|
user_message=user_query,
|
||||||
|
user_id=user_id,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Finish the step and message
|
# Finish the step and message
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""Unit tests for memory scope validation."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.update_memory import _save_memory, _validate_memory_scope
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _Recorder:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.applied_content: str | None = None
|
||||||
|
self.commit_calls = 0
|
||||||
|
self.rollback_calls = 0
|
||||||
|
|
||||||
|
def apply(self, content: str) -> None:
|
||||||
|
self.applied_content = content
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
self.commit_calls += 1
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
self.rollback_calls += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_rejects_user_sections_in_team_scope() -> None:
|
||||||
|
content = "## About the user\n- (2026-04-10) Student studying DSA\n"
|
||||||
|
result = _validate_memory_scope(content, "team")
|
||||||
|
assert result is not None
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "personal sections" in result["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_rejects_team_sections_in_user_scope() -> None:
|
||||||
|
content = "## Team decisions\n- (2026-04-10) Python-first backend policy\n"
|
||||||
|
result = _validate_memory_scope(content, "user")
|
||||||
|
assert result is not None
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "team sections" in result["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_normalizes_heading_case_and_spacing() -> None:
|
||||||
|
content = "## About The User \n- (2026-04-10) Student\n"
|
||||||
|
result = _validate_memory_scope(content, "team")
|
||||||
|
assert result is not None
|
||||||
|
assert result["status"] == "error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_blocks_cross_scope_write_before_commit() -> None:
|
||||||
|
recorder = _Recorder()
|
||||||
|
result = await _save_memory(
|
||||||
|
updated_memory="## About the user\n- (2026-04-10) Student\n",
|
||||||
|
old_memory=None,
|
||||||
|
llm=None,
|
||||||
|
apply_fn=recorder.apply,
|
||||||
|
commit_fn=recorder.commit,
|
||||||
|
rollback_fn=recorder.rollback,
|
||||||
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
|
)
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert recorder.commit_calls == 0
|
||||||
|
assert recorder.applied_content is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_allows_valid_scope_and_commits() -> None:
|
||||||
|
recorder = _Recorder()
|
||||||
|
content = "## Team decisions\n- (2026-04-10) Python-first backend policy\n"
|
||||||
|
result = await _save_memory(
|
||||||
|
updated_memory=content,
|
||||||
|
old_memory=None,
|
||||||
|
llm=None,
|
||||||
|
apply_fn=recorder.apply,
|
||||||
|
commit_fn=recorder.commit,
|
||||||
|
rollback_fn=recorder.rollback,
|
||||||
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
|
)
|
||||||
|
assert result["status"] == "saved"
|
||||||
|
assert recorder.commit_calls == 1
|
||||||
|
assert recorder.applied_content == content
|
||||||
Loading…
Add table
Add a link
Reference in a new issue