Merge upstream/dev into feat/kb-export-and-folder-upload

This commit is contained in:
CREDO23 2026-04-11 10:28:40 +02:00
commit c30cc08771
61 changed files with 2670 additions and 1474 deletions

View file

@ -0,0 +1,38 @@
"""Add memory_md columns to user and searchspaces tables
Revision ID: 121
Revises: 120
Changes:
1. Add memory_md TEXT column to user table (personal memory)
2. Add shared_memory_md TEXT column to searchspaces table (team memory)
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "121"
down_revision: str | None = "120"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("memory_md", sa.Text(), nullable=True, server_default=""),
)
op.add_column(
"searchspaces",
sa.Column("shared_memory_md", sa.Text(), nullable=True, server_default=""),
)
def downgrade() -> None:
op.drop_column("searchspaces", "shared_memory_md")
op.drop_column("user", "memory_md")

View file

@ -0,0 +1,247 @@
"""Migrate row-per-fact memories to markdown, then drop legacy tables
Revision ID: 122
Revises: 121
Converts user_memories rows into per-user markdown documents stored in
user.memory_md, and shared_memories rows into per-search-space markdown
stored in searchspaces.shared_memory_md. Then drops the old tables and
the memorycategory enum.
The markdown format matches the new memory system:
## Heading
- (YYYY-MM-DD) [fact|pref|instr] memory text
"""
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Sequence
from uuid import UUID
import sqlalchemy as sa
from sqlalchemy import inspect as sa_inspect
from alembic import op
from app.config import config
logger = logging.getLogger(__name__)
revision: str = "122"
down_revision: str | None = "121"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
EMBEDDING_DIM = config.embedding_model_instance.dimension
_CATEGORY_TO_MARKER = {
"fact": "fact",
"context": "fact",
"preference": "pref",
"instruction": "instr",
}
_CATEGORY_HEADING = {
"fact": "Facts",
"preference": "Preferences",
"instruction": "Instructions",
"context": "Context",
}
_HEADING_ORDER = ["fact", "preference", "instruction", "context"]
def _build_markdown(rows: list[tuple]) -> str:
"""Build a markdown document from (memory_text, category, created_at) rows."""
by_category: dict[str, list[str]] = defaultdict(list)
for memory_text, category, created_at in rows:
cat = str(category)
marker = _CATEGORY_TO_MARKER.get(cat, "fact")
date_str = created_at.strftime("%Y-%m-%d")
clean_text = str(memory_text).replace("\n", " ").strip()
bullet = f"- ({date_str}) [{marker}] {clean_text}"
by_category[cat].append(bullet)
sections: list[str] = []
for cat in _HEADING_ORDER:
if cat in by_category:
heading = _CATEGORY_HEADING[cat]
sections.append(f"## {heading}")
sections.extend(by_category[cat])
sections.append("")
return "\n".join(sections).strip() + "\n"
def _migrate_user_memories(conn: sa.engine.Connection) -> None:
"""Convert user_memories rows → user.memory_md grouped by user_id."""
rows = conn.execute(
sa.text(
"SELECT user_id, memory_text, category::text, created_at "
"FROM user_memories ORDER BY created_at"
)
).fetchall()
if not rows:
logger.info("user_memories is empty, skipping data migration.")
return
by_user: dict[UUID, list[tuple]] = defaultdict(list)
for user_id, memory_text, category, created_at in rows:
by_user[user_id].append((memory_text, category, created_at))
migrated = 0
for uid, user_rows in by_user.items():
existing = conn.execute(
sa.text('SELECT memory_md FROM "user" WHERE id = :uid'),
{"uid": uid},
).scalar()
if existing and existing.strip():
logger.info("User %s already has memory_md, skipping.", uid)
continue
markdown = _build_markdown(user_rows)
conn.execute(
sa.text('UPDATE "user" SET memory_md = :md WHERE id = :uid'),
{"md": markdown, "uid": uid},
)
migrated += 1
logger.info("Migrated user_memories for %d user(s).", migrated)
def _migrate_shared_memories(conn: sa.engine.Connection) -> None:
"""Convert shared_memories rows → searchspaces.shared_memory_md."""
rows = conn.execute(
sa.text(
"SELECT search_space_id, memory_text, category::text, created_at "
"FROM shared_memories ORDER BY created_at"
)
).fetchall()
if not rows:
logger.info("shared_memories is empty, skipping data migration.")
return
by_space: dict[int, list[tuple]] = defaultdict(list)
for search_space_id, memory_text, category, created_at in rows:
by_space[search_space_id].append((memory_text, category, created_at))
migrated = 0
for space_id, space_rows in by_space.items():
existing = conn.execute(
sa.text("SELECT shared_memory_md FROM searchspaces WHERE id = :sid"),
{"sid": space_id},
).scalar()
if existing and existing.strip():
logger.info(
"Search space %s already has shared_memory_md, skipping.", space_id
)
continue
markdown = _build_markdown(space_rows)
conn.execute(
sa.text("UPDATE searchspaces SET shared_memory_md = :md WHERE id = :sid"),
{"md": markdown, "sid": space_id},
)
migrated += 1
logger.info("Migrated shared_memories for %d search space(s).", migrated)
def upgrade() -> None:
conn = op.get_bind()
inspector = sa_inspect(conn)
tables = inspector.get_table_names()
if "user_memories" in tables:
_migrate_user_memories(conn)
if "shared_memories" in tables:
_migrate_shared_memories(conn)
op.execute("DROP TABLE IF EXISTS shared_memories CASCADE;")
op.execute("DROP TABLE IF EXISTS user_memories CASCADE;")
op.execute("DROP TYPE IF EXISTS memorycategory;")
def downgrade() -> None:
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'memorycategory') THEN
CREATE TYPE memorycategory AS ENUM (
'preference',
'fact',
'instruction',
'context'
);
END IF;
END$$;
"""
)
op.execute(
f"""
CREATE TABLE IF NOT EXISTS user_memories (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
search_space_id INTEGER REFERENCES searchspaces(id) ON DELETE CASCADE,
memory_text TEXT NOT NULL,
category memorycategory NOT NULL DEFAULT 'fact',
embedding vector({EMBEDDING_DIM}),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
"""
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_user_memories_user_id ON user_memories(user_id);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_user_memories_search_space_id ON user_memories(search_space_id);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_user_memories_updated_at ON user_memories(updated_at);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_user_memories_category ON user_memories(category);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_user_memories_user_search_space ON user_memories(user_id, search_space_id);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS user_memories_vector_index ON user_memories USING hnsw (embedding public.vector_cosine_ops);"
)
op.execute(
f"""
CREATE TABLE IF NOT EXISTS shared_memories (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
created_by_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
memory_text TEXT NOT NULL,
category memorycategory NOT NULL DEFAULT 'fact',
embedding vector({EMBEDDING_DIM})
);
"""
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_shared_memories_search_space_id ON shared_memories(search_space_id);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_shared_memories_updated_at ON shared_memories(updated_at);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_shared_memories_created_by_id ON shared_memories(created_by_id);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS shared_memories_vector_index ON shared_memories USING hnsw (embedding public.vector_cosine_ops);"
)

View file

@ -38,6 +38,7 @@ from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware,
KnowledgeBaseSearchMiddleware,
MemoryInjectionMiddleware,
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.system_prompt import (
@ -168,8 +169,7 @@ async def create_surfsense_deep_agent(
- generate_podcast: Generate audio podcasts from content
- generate_image: Generate images from text descriptions using AI models
- scrape_webpage: Extract content from webpages
- save_memory: Store facts/preferences about the user
- recall_memory: Retrieve relevant user memories
- update_memory: Update the user's personal or team memory document
The agent also includes TodoListMiddleware by default (via create_deep_agent) which provides:
- write_todos: Create and update planning/todo lists for complex tasks
@ -281,6 +281,7 @@ async def create_surfsense_deep_agent(
"available_connectors": available_connectors,
"available_document_types": available_document_types,
"max_input_tokens": _max_input_tokens,
"llm": llm,
}
# Disable Notion action tools if no Notion connector is configured
@ -425,9 +426,16 @@ async def create_surfsense_deep_agent(
)
# -- Build the middleware stack (mirrors create_deep_agent internals) ------
_memory_middleware = MemoryInjectionMiddleware(
user_id=user_id,
search_space_id=search_space_id,
thread_visibility=visibility,
)
# General-purpose subagent middleware
gp_middleware = [
TodoListMiddleware(),
_memory_middleware,
SurfSenseFilesystemMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
@ -447,6 +455,7 @@ async def create_surfsense_deep_agent(
# Main agent middleware
deepagent_middleware = [
TodoListMiddleware(),
_memory_middleware,
KnowledgeBaseSearchMiddleware(
llm=llm,
search_space_id=search_space_id,

View file

@ -0,0 +1,239 @@
"""Background memory extraction for the SurfSense agent.
After each agent response, if the agent did not call ``update_memory`` during
the turn, this module can run a lightweight LLM call to decide whether the
latest message contains long-term information worth persisting.
"""
from __future__ import annotations
import logging
from typing import Any
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy import select
from app.agents.new_chat.tools.update_memory import _save_memory
from app.db import SearchSpace, User, shielded_async_session
logger = logging.getLogger(__name__)
_MEMORY_EXTRACT_PROMPT = """\
You are a memory extraction assistant. Analyze the user's message and decide \
if it contains any long-term information worth persisting to memory.
Worth remembering: preferences, background/identity, goals, projects, \
instructions, tools/languages they use, decisions, expertise, workplace \
durable facts that will matter in future conversations.
NOT worth remembering: greetings, one-off factual questions, session \
logistics, ephemeral requests, follow-up clarifications with no new personal \
info, things that only matter for the current task.
If the message contains memorizable information, output the FULL updated \
memory document with the new facts merged into the existing content. Follow \
these rules:
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's
name in headings.
- Keep entries as single bullet points. Be descriptive but concise include relevant
details and context rather than just a few words.
- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text
[fact] = durable facts, [pref] = preferences, [instr] = standing instructions.
- Use the user's first name (from <user_name>) in entry text, not "the user".
- If a new fact contradicts an existing entry, update the existing entry.
- Do not duplicate information that is already present.
If nothing is worth remembering, output exactly: NO_UPDATE
<user_name>{user_name}</user_name>
<current_memory>
{current_memory}
</current_memory>
<user_message>
{user_message}
</user_message>"""
_TEAM_MEMORY_EXTRACT_PROMPT = """\
You are a team-memory extraction assistant. Analyze the latest message and \
decide if it contains durable TEAM-level information worth persisting.
Decision policy:
- Prioritize recall for durable team context, while avoiding personal-only facts.
- Do NOT require explicit consensus language. A direct team-level statement can
be stored if it is stable and broadly useful for future team chats.
- If evidence is weak or clearly tentative, output NO_UPDATE.
Worth remembering (team-level only):
- Decisions and defaults that guide future team work
- Team conventions/standards (naming, review policy, coding norms)
- Stable org/project facts (locations, ownership, constraints)
- Long-lived architecture/process facts
- Ongoing priorities that are likely relevant beyond this turn
NOT worth remembering:
- Personal preferences or biography of one person
- Questions, brainstorming, tentative ideas, or speculation
- One-off requests, status updates, TODOs, logistics for this session
- Information scoped only to a single ephemeral task
If the message contains memorizable team information, output the FULL updated \
team memory document with new facts merged into existing content. Follow rules:
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
freely. Keep heading names short (2-3 words) and natural.
- Keep entries as single bullet points. Be descriptive but concise include relevant
details and context rather than just a few words.
- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr].
- If a new fact contradicts an existing entry, update the existing entry.
- Do not duplicate existing information.
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
If nothing is worth remembering, output exactly: NO_UPDATE
<current_team_memory>
{current_memory}
</current_team_memory>
<latest_message_author>
{author}
</latest_message_author>
<latest_message>
{user_message}
</latest_message>"""
async def extract_and_save_memory(
*,
user_message: str,
user_id: str | None,
llm: Any,
) -> None:
"""Background task: extract memorizable info and persist it.
Designed to be fire-and-forget catches all exceptions internally.
"""
if not user_id:
return
try:
uid = UUID(user_id) if isinstance(user_id, str) else user_id
async with shielded_async_session() as session:
result = await session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return
old_memory = user.memory_md
first_name = (
user.display_name.strip().split()[0]
if user.display_name and user.display_name.strip()
else "The user"
)
prompt = _MEMORY_EXTRACT_PROMPT.format(
current_memory=old_memory or "(empty)",
user_message=user_message,
user_name=first_name,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-extraction"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
if text == "NO_UPDATE" or not text:
logger.debug("Memory extraction: no update needed (user %s)", uid)
return
save_result = await _save_memory(
updated_memory=text,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="memory",
scope="user",
)
logger.info(
"Background memory extraction for user %s: %s",
uid,
save_result.get("status"),
)
except Exception:
logger.exception("Background user memory extraction failed")
async def extract_and_save_team_memory(
*,
user_message: str,
search_space_id: int | None,
llm: Any,
author_display_name: str | None = None,
) -> None:
"""Background task: extract team-level memory and persist it.
Runs only for shared threads. Designed to be fire-and-forget and catches
exceptions internally.
"""
if not search_space_id:
return
try:
async with shielded_async_session() as session:
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return
old_memory = space.shared_memory_md
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
current_memory=old_memory or "(empty)",
author=author_display_name or "Unknown team member",
user_message=user_message,
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
if text == "NO_UPDATE" or not text:
logger.debug(
"Team memory extraction: no update needed (space %s)",
search_space_id,
)
return
save_result = await _save_memory(
updated_memory=text,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="team memory",
scope="team",
)
logger.info(
"Background team memory extraction for space %s: %s",
search_space_id,
save_result.get("status"),
)
except Exception:
logger.exception("Background team memory extraction failed")

View file

@ -9,9 +9,13 @@ from app.agents.new_chat.middleware.filesystem import (
from app.agents.new_chat.middleware.knowledge_search import (
KnowledgeBaseSearchMiddleware,
)
from app.agents.new_chat.middleware.memory_injection import (
MemoryInjectionMiddleware,
)
__all__ = [
"DedupHITLToolCallsMiddleware",
"KnowledgeBaseSearchMiddleware",
"MemoryInjectionMiddleware",
"SurfSenseFilesystemMiddleware",
]

View file

@ -0,0 +1,138 @@
"""Memory injection middleware for the SurfSense agent.
Injects memory markdown into the system prompt on every turn:
- Private threads: only personal memory (<user_memory>)
- Shared threads: only team memory (<team_memory>)
"""
from __future__ import annotations
import logging
from typing import Any
from uuid import UUID
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.runtime import Runtime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
logger = logging.getLogger(__name__)
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Injects memory markdown into the conversation on every turn."""
tools = ()
def __init__(
self,
*,
user_id: str | UUID | None,
search_space_id: int,
thread_visibility: ChatVisibility | None = None,
) -> None:
self.user_id = UUID(user_id) if isinstance(user_id, str) else user_id
self.search_space_id = search_space_id
self.visibility = thread_visibility or ChatVisibility.PRIVATE
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
messages = state.get("messages") or []
if not messages:
return None
last_message = messages[-1]
if not isinstance(last_message, HumanMessage):
return None
memory_blocks: list[str] = []
async with shielded_async_session() as session:
if self.visibility == ChatVisibility.SEARCH_SPACE:
team_memory = await self._load_team_memory(session)
if team_memory:
chars = len(team_memory)
memory_blocks.append(
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
f"{team_memory}\n"
f"</team_memory>"
)
if chars > MEMORY_SOFT_LIMIT:
memory_blocks.append(
f"<memory_warning>Team memory is at "
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
f"the hard limit. On your next update_memory call, consolidate "
f"by merging duplicates, removing outdated entries, and "
f"shortening descriptions before adding anything new."
f"</memory_warning>"
)
elif self.user_id is not None:
user_memory, display_name = await self._load_user_memory(session)
if display_name and display_name.strip():
first_name = display_name.strip().split()[0]
memory_blocks.append(f"<user_name>{first_name}</user_name>")
if user_memory:
chars = len(user_memory)
memory_blocks.append(
f'<user_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
f"{user_memory}\n"
f"</user_memory>"
)
if chars > MEMORY_SOFT_LIMIT:
memory_blocks.append(
f"<memory_warning>Your personal memory is at "
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
f"the hard limit. On your next update_memory call, consolidate "
f"by merging duplicates, removing outdated entries, and "
f"shortening descriptions before adding anything new."
f"</memory_warning>"
)
if not memory_blocks:
return None
memory_text = "\n\n".join(memory_blocks)
memory_msg = SystemMessage(content=memory_text)
new_messages = list(messages)
insert_idx = 1 if len(new_messages) > 1 else 0
new_messages.insert(insert_idx, memory_msg)
return {"messages": new_messages}
async def _load_user_memory(
self, session: AsyncSession
) -> tuple[str | None, str | None]:
"""Return (memory_content, display_name)."""
try:
result = await session.execute(
select(User.memory_md, User.display_name).where(User.id == self.user_id)
)
row = result.one_or_none()
if row is None:
return None, None
return row.memory_md or None, row.display_name
except Exception:
logger.exception("Failed to load user memory")
return None, None
async def _load_team_memory(self, session: AsyncSession) -> str | None:
try:
result = await session.execute(
select(SearchSpace.shared_memory_md).where(
SearchSpace.id == self.search_space_id
)
)
row = result.scalar_one_or_none()
return row if row else None
except Exception:
logger.exception("Failed to load team memory")
return None

View file

@ -40,6 +40,13 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
</knowledge_base_only_policy>
<memory_protocol>
IMPORTANT After understanding each user message, ALWAYS check: does this message
reveal durable facts about the user (role, interests, preferences, projects,
background, or standing instructions)? If yes, you MUST call update_memory
alongside your normal response do not defer this to a later turn.
</memory_protocol>
</system_instruction>
"""
@ -71,6 +78,13 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
</knowledge_base_only_policy>
<memory_protocol>
IMPORTANT After understanding each user message, ALWAYS check: does this message
reveal durable facts about the team (decisions, conventions, architecture, processes,
or key facts)? If yes, you MUST call update_memory alongside your normal response
do not defer this to a later turn.
</memory_protocol>
</system_instruction>
"""
@ -248,115 +262,97 @@ _TOOL_INSTRUCTIONS["web_search"] = """
"""
# Memory tool instructions have private and shared variants.
# We store them keyed as "save_memory" / "recall_memory" with sub-keys.
# We store them keyed as "update_memory" with sub-keys.
_MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
"save_memory": {
"update_memory": {
"private": """
- save_memory: Save facts, preferences, or context for personalized responses.
- Use this when the user explicitly or implicitly shares information worth remembering.
- Trigger scenarios:
* User says "remember this", "keep this in mind", "note that", or similar
* User shares personal preferences (e.g., "I prefer Python over JavaScript")
* User shares facts about themselves (e.g., "I'm a senior developer at Company X")
* User gives standing instructions (e.g., "always respond in bullet points")
* User shares project context (e.g., "I'm working on migrating our codebase to TypeScript")
- update_memory: Update your personal memory document about the user.
- Your current memory is already in <user_memory> in your context. The `chars` and
`limit` attributes show your current usage and the maximum allowed size.
- This is your curated long-term memory the distilled essence of what you know about
the user, not raw conversation logs.
- Call update_memory when:
* The user explicitly asks to remember or forget something
* The user shares durable facts or preferences that will matter in future conversations
- The user's first name is provided in <user_name>. Use it in memory entries
instead of "the user" (e.g. "{name} works at..." not "The user works at...").
Do not store the name itself as a separate memory entry.
- Do not store short-lived or ephemeral info: one-off questions, greetings,
session logistics, or things that only matter for the current task.
- Args:
- content: The fact/preference to remember. Phrase it clearly:
* "User prefers dark mode for all interfaces"
* "User is a senior Python developer"
* "User wants responses in bullet point format"
* "User is working on project called ProjectX"
- category: Type of memory:
* "preference": User preferences (coding style, tools, formats)
* "fact": Facts about the user (role, expertise, background)
* "instruction": Standing instructions (response format, communication style)
* "context": Current context (ongoing projects, goals, challenges)
- Returns: Confirmation of saved memory
- IMPORTANT: Only save information that would be genuinely useful for future conversations.
Don't save trivial or temporary information.
- updated_memory: The FULL updated markdown document (not a diff).
Merge new facts with existing ones, update contradictions, remove outdated entries.
Treat every update as a curation pass consolidate, don't just append.
- Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text
Markers:
[fact] durable facts (role, background, projects, tools, expertise)
[pref] preferences (response style, languages, formats, tools)
[instr] standing instructions (always/never do, response rules)
- Keep it concise and well under the character limit shown in <user_memory>.
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
natural. Do NOT include the user's name in headings. Organize by context — e.g.
who they are, what they're focused on, how they prefer things. Create, split, or
merge headings freely as the memory grows.
- Each entry MUST be a single bullet point. Be descriptive but concise include relevant
details and context rather than just a few words.
- During consolidation, prioritize keeping: [instr] > [pref] > [fact].
""",
"shared": """
- save_memory: Save a fact, preference, or context to the team's shared memory for future reference.
- Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat.
- Use when the team agrees on something to remember (e.g., decisions, conventions).
- Someone shares a preference or fact that should be visible to the whole team.
- The saved information will be available in future shared conversations in this space.
- update_memory: Update the team's shared memory document for this search space.
- Your current team memory is already in <team_memory> in your context. The `chars`
and `limit` attributes show current usage and the maximum allowed size.
- This is the team's curated long-term memory — decisions, conventions, key facts.
- NEVER store personal memory in team memory (e.g. personal bio, individual
preferences, or user-only standing instructions).
- Call update_memory when:
* A team member explicitly asks to remember or forget something
* The conversation surfaces durable team decisions, conventions, or facts
that will matter in future conversations
- Do not store short-lived or ephemeral info: one-off questions, greetings,
session logistics, or things that only matter for the current task.
- Args:
- content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays"
- category: Type of memory. One of:
* "preference": Team or workspace preferences
* "fact": Facts the team agreed on (e.g., processes, locations)
* "instruction": Standing instructions for the team
* "context": Current context (e.g., ongoing projects, goals)
- Returns: Confirmation of saved memory; returned context may include who added it (added_by).
- IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space.
""",
},
"recall_memory": {
"private": """
- recall_memory: Retrieve relevant memories about the user for personalized responses.
- Use this to access stored information about the user.
- Trigger scenarios:
* You need user context to give a better, more personalized answer
* User references something they mentioned before
* User asks "what do you know about me?" or similar
* Personalization would significantly improve response quality
* Before making recommendations that should consider user preferences
- Args:
- query: Optional search query to find specific memories (e.g., "programming preferences")
- category: Optional filter by category ("preference", "fact", "instruction", "context")
- top_k: Number of memories to retrieve (default: 5)
- Returns: Relevant memories formatted as context
- IMPORTANT: Use the recalled memories naturally in your response without explicitly
stating "Based on your memory..." - integrate the context seamlessly.
""",
"shared": """
- recall_memory: Recall relevant team memories for this space to provide contextual responses.
- Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?").
- Use when someone asks about something the team agreed to remember.
- Use when team preferences or conventions would improve the response.
- Args:
- query: Optional search query to find specific memories. If not provided, returns the most recent memories.
- category: Optional filter by category ("preference", "fact", "instruction", "context")
- top_k: Number of memories to retrieve (default: 5, max: 20)
- Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...".
- updated_memory: The FULL updated markdown document (not a diff).
Merge new facts with existing ones, update contradictions, remove outdated entries.
Treat every update as a curation pass consolidate, don't just append.
- Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory.
- Keep it concise and well under the character limit shown in <team_memory>.
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
natural. Organize by context e.g. what the team decided, current architecture,
active processes. Create, split, or merge headings freely as the memory grows.
- Each entry MUST be a single bullet point. Be descriptive but concise include relevant
details and context rather than just a few words.
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
""",
},
}
_MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = {
"save_memory": {
"update_memory": {
"private": """
- User: "Remember that I prefer TypeScript over JavaScript"
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
- User: "I'm a data scientist working on ML pipelines"
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
- User: "Always give me code examples in Python"
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
- <user_name>Alex</user_name>, <user_memory> is empty. User: "I'm a space enthusiast, explain astrophage to me"
- The user casually shared a durable fact. Use their first name in the entry, short neutral heading:
update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n")
- User: "Remember that I prefer concise answers over detailed explanations"
- Durable preference. Merge with existing memory, add a new heading:
update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n\\n## Response style\\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\\n")
- User: "I actually moved to Tokyo last month"
- Updated fact, date prefix reflects when recorded:
update_memory(updated_memory="## Interests & background\\n...\\n\\n## Personal context\\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\\n...")
- User: "I'm a freelance photographer working on a nature documentary"
- Durable background info under a fitting heading:
update_memory(updated_memory="...\\n\\n## Current focus\\n- (2025-03-15) [fact] Alex is a freelance photographer\\n- (2025-03-15) [fact] Alex is working on a nature documentary\\n")
- User: "Always respond in bullet points"
- Standing instruction:
update_memory(updated_memory="...\\n\\n## Response style\\n- (2025-03-15) [instr] Always respond to Alex in bullet points\\n")
""",
"shared": """
- User: "Remember that API keys are stored in Vault"
- Call: `save_memory(content="API keys are stored in Vault", category="fact")`
- User: "Let's remember that the team prefers weekly demos on Fridays"
- Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")`
""",
},
"recall_memory": {
"private": """
- User: "What programming language should I use for this project?"
- First recall: `recall_memory(query="programming language preferences")`
- Then provide a personalized recommendation based on their preferences
- User: "What do you know about me?"
- Call: `recall_memory(top_k=10)`
- Then summarize the stored memories
""",
"shared": """
- User: "What did we decide about the release date?"
- First recall: `recall_memory(query="release date decision")`
- Then answer based on the team memories
- User: "Where do we document onboarding?"
- Call: `recall_memory(query="onboarding documentation")`
- Then answer using the recalled team context
- User: "Let's remember that we decided to do weekly standup meetings on Mondays"
- Durable team decision:
update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\\n...")
- User: "Our office is in downtown Seattle, 5th floor"
- Durable team fact:
update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\\n...")
""",
},
}
@ -456,8 +452,7 @@ _ALL_TOOL_NAMES_ORDERED = [
"generate_report",
"generate_image",
"scrape_webpage",
"save_memory",
"recall_memory",
"update_memory",
]

View file

@ -10,8 +10,7 @@ Available tools:
- generate_video_presentation: Generate video presentations with slides and narration
- generate_image: Generate images from text descriptions using AI models
- scrape_webpage: Extract content from webpages
- save_memory: Store facts/preferences about the user
- recall_memory: Retrieve relevant user memories
- update_memory: Update the user's / team's memory document
"""
# Registry exports
@ -33,7 +32,7 @@ from .registry import (
)
from .scrape_webpage import create_scrape_webpage_tool
from .search_surfsense_docs import create_search_surfsense_docs_tool
from .user_memory import create_recall_memory_tool, create_save_memory_tool
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
from .video_presentation import create_generate_video_presentation_tool
__all__ = [
@ -47,10 +46,10 @@ __all__ = [
"create_generate_image_tool",
"create_generate_podcast_tool",
"create_generate_video_presentation_tool",
"create_recall_memory_tool",
"create_save_memory_tool",
"create_scrape_webpage_tool",
"create_search_surfsense_docs_tool",
"create_update_memory_tool",
"create_update_team_memory_tool",
"format_documents_for_context",
"get_all_tool_names",
"get_default_enabled_tools",

View file

@ -94,11 +94,7 @@ from .podcast import create_generate_podcast_tool
from .report import create_generate_report_tool
from .scrape_webpage import create_scrape_webpage_tool
from .search_surfsense_docs import create_search_surfsense_docs_tool
from .shared_memory import (
create_recall_shared_memory_tool,
create_save_shared_memory_tool,
)
from .user_memory import create_recall_memory_tool, create_save_memory_tool
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
from .video_presentation import create_generate_video_presentation_tool
from .web_search import create_web_search_tool
@ -214,42 +210,31 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
requires=["db_session"],
),
# =========================================================================
# USER MEMORY TOOLS - private or team store by thread_visibility
# MEMORY TOOL - single update_memory, private or team by thread_visibility
# =========================================================================
ToolDefinition(
name="save_memory",
description="Save facts, preferences, or context for personalized or team responses",
name="update_memory",
description="Save important long-term facts, preferences, and instructions to the (personal or team) memory",
factory=lambda deps: (
create_save_shared_memory_tool(
create_update_team_memory_tool(
search_space_id=deps["search_space_id"],
created_by_id=deps["user_id"],
db_session=deps["db_session"],
llm=deps.get("llm"),
)
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
else create_save_memory_tool(
else create_update_memory_tool(
user_id=deps["user_id"],
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
llm=deps.get("llm"),
)
),
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
),
ToolDefinition(
name="recall_memory",
description="Recall relevant memories (personal or team) for context",
factory=lambda deps: (
create_recall_shared_memory_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
)
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
else create_recall_memory_tool(
user_id=deps["user_id"],
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
)
),
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
requires=[
"user_id",
"search_space_id",
"db_session",
"thread_visibility",
"llm",
],
),
# =========================================================================
# LINEAR TOOLS - create, update, delete issues

View file

@ -1,281 +0,0 @@
"""Shared (team) memory backend for search-space-scoped AI context."""
import asyncio
import logging
from typing import Any
from uuid import UUID
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import MemoryCategory, SharedMemory, User
from app.utils.document_converters import embed_text
logger = logging.getLogger(__name__)
DEFAULT_RECALL_TOP_K = 5
MAX_MEMORIES_PER_SEARCH_SPACE = 250
async def get_shared_memory_count(
db_session: AsyncSession,
search_space_id: int,
) -> int:
result = await db_session.execute(
select(SharedMemory).where(SharedMemory.search_space_id == search_space_id)
)
return len(result.scalars().all())
async def delete_oldest_shared_memory(
db_session: AsyncSession,
search_space_id: int,
) -> None:
result = await db_session.execute(
select(SharedMemory)
.where(SharedMemory.search_space_id == search_space_id)
.order_by(SharedMemory.updated_at.asc())
.limit(1)
)
oldest = result.scalars().first()
if oldest:
await db_session.delete(oldest)
await db_session.commit()
def _to_uuid(value: str | UUID) -> UUID:
if isinstance(value, UUID):
return value
return UUID(value)
async def save_shared_memory(
db_session: AsyncSession,
search_space_id: int,
created_by_id: str | UUID,
content: str,
category: str = "fact",
) -> dict[str, Any]:
category = category.lower() if category else "fact"
valid = ["preference", "fact", "instruction", "context"]
if category not in valid:
category = "fact"
try:
count = await get_shared_memory_count(db_session, search_space_id)
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
await delete_oldest_shared_memory(db_session, search_space_id)
embedding = await asyncio.to_thread(embed_text, content)
row = SharedMemory(
search_space_id=search_space_id,
created_by_id=_to_uuid(created_by_id),
memory_text=content,
category=MemoryCategory(category),
embedding=embedding,
)
db_session.add(row)
await db_session.commit()
await db_session.refresh(row)
return {
"status": "saved",
"memory_id": row.id,
"memory_text": content,
"category": category,
"message": f"I'll remember: {content}",
}
except Exception as e:
logger.exception("Failed to save shared memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"error": str(e),
"message": "Failed to save memory. Please try again.",
}
async def recall_shared_memory(
db_session: AsyncSession,
search_space_id: int,
query: str | None = None,
category: str | None = None,
top_k: int = DEFAULT_RECALL_TOP_K,
) -> dict[str, Any]:
top_k = min(max(top_k, 1), 20)
try:
valid_categories = ["preference", "fact", "instruction", "context"]
stmt = select(SharedMemory).where(
SharedMemory.search_space_id == search_space_id
)
if category and category in valid_categories:
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
if query:
query_embedding = await asyncio.to_thread(embed_text, query)
stmt = stmt.order_by(
SharedMemory.embedding.op("<=>")(query_embedding)
).limit(top_k)
else:
stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k)
result = await db_session.execute(stmt)
rows = result.scalars().all()
memory_list = [
{
"id": m.id,
"memory_text": m.memory_text,
"category": m.category.value if m.category else "unknown",
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
"created_by_id": str(m.created_by_id) if m.created_by_id else None,
}
for m in rows
]
created_by_ids = list(
{m["created_by_id"] for m in memory_list if m["created_by_id"]}
)
created_by_map: dict[str, str] = {}
if created_by_ids:
uuids = [UUID(uid) for uid in created_by_ids]
users_result = await db_session.execute(
select(User).where(User.id.in_(uuids))
)
for u in users_result.scalars().all():
created_by_map[str(u.id)] = u.display_name or "A team member"
formatted_context = format_shared_memories_for_context(
memory_list, created_by_map
)
return {
"status": "success",
"count": len(memory_list),
"memories": memory_list,
"formatted_context": formatted_context,
}
except Exception as e:
logger.exception("Failed to recall shared memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"error": str(e),
"memories": [],
"formatted_context": "Failed to recall memories.",
}
def format_shared_memories_for_context(
memories: list[dict[str, Any]],
created_by_map: dict[str, str] | None = None,
) -> str:
if not memories:
return "No relevant team memories found."
created_by_map = created_by_map or {}
parts = ["<team_memories>"]
for memory in memories:
category = memory.get("category", "unknown")
text = memory.get("memory_text", "")
updated = memory.get("updated_at", "")
created_by_id = memory.get("created_by_id")
added_by = (
created_by_map.get(str(created_by_id), "A team member")
if created_by_id is not None
else "A team member"
)
parts.append(
f" <memory category='{category}' updated='{updated}' added_by='{added_by}'>{text}</memory>"
)
parts.append("</team_memories>")
return "\n".join(parts)
def create_save_shared_memory_tool(
search_space_id: int,
created_by_id: str | UUID,
db_session: AsyncSession,
):
"""
Factory function to create the save_memory tool for shared (team) chats.
Args:
search_space_id: The search space ID
created_by_id: The user ID of the person adding the memory
db_session: Database session for executing queries
Returns:
A configured tool function for saving team memories
"""
@tool
async def save_memory(
content: str,
category: str = "fact",
) -> dict[str, Any]:
"""
Save a fact, preference, or context to the team's shared memory for future reference.
Use this tool when:
- User or a team member says "remember this", "keep this in mind", or similar in this shared chat
- The team agrees on something to remember (e.g., decisions, conventions, where things live)
- Someone shares a preference or fact that should be visible to the whole team
The saved information will be available in future shared conversations in this space.
Args:
content: The fact/preference/context to remember.
Phrase it clearly, e.g., "API keys are stored in Vault",
"The team prefers weekly demos on Fridays"
category: Type of memory. One of:
- "preference": Team or workspace preferences
- "fact": Facts the team agreed on (e.g., processes, locations)
- "instruction": Standing instructions for the team
- "context": Current context (e.g., ongoing projects, goals)
Returns:
A dictionary with the save status and memory details
"""
return await save_shared_memory(
db_session, search_space_id, created_by_id, content, category
)
return save_memory
def create_recall_shared_memory_tool(
search_space_id: int,
db_session: AsyncSession,
):
"""
Factory function to create the recall_memory tool for shared (team) chats.
Args:
search_space_id: The search space ID
db_session: Database session for executing queries
Returns:
A configured tool function for recalling team memories
"""
@tool
async def recall_memory(
query: str | None = None,
category: str | None = None,
top_k: int = DEFAULT_RECALL_TOP_K,
) -> dict[str, Any]:
"""
Recall relevant team memories for this space to provide contextual responses.
Use this tool when:
- You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?")
- Someone asks about something the team agreed to remember
- Team preferences or conventions would improve the response
Args:
query: Optional search query to find specific memories.
If not provided, returns the most recent memories.
category: Optional category filter. One of:
"preference", "fact", "instruction", "context"
top_k: Number of memories to retrieve (default: 5, max: 20)
Returns:
A dictionary containing relevant memories and formatted context
"""
return await recall_shared_memory(
db_session, search_space_id, query, category, top_k
)
return recall_memory

View file

@ -0,0 +1,389 @@
"""Markdown-document memory tool for the SurfSense agent.
Replaces the old row-per-fact save_memory / recall_memory tools with a single
update_memory tool that overwrites a freeform markdown TEXT column. The LLM
always sees the current memory in <user_memory> / <team_memory> tags injected
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
Overflow handling:
- Soft limit (18K chars): a warning is returned telling the agent to
consolidate on the next update.
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
If it still exceeds the limit after rewriting, the save is rejected.
- Diff validation: warns when entire ``##`` sections are dropped or when the
document shrinks by more than 60%.
"""
from __future__ import annotations
import logging
import re
from typing import Any, Literal
from uuid import UUID
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User
logger = logging.getLogger(__name__)
MEMORY_SOFT_LIMIT = 18_000
MEMORY_HARD_LIMIT = 25_000
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
# ---------------------------------------------------------------------------
# Diff validation
# ---------------------------------------------------------------------------
def _extract_headings(memory: str) -> set[str]:
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
return set(_SECTION_HEADING_RE.findall(memory))
def _normalize_heading(heading: str) -> str:
"""Normalize heading text for robust scope checks."""
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
def _validate_memory_scope(
content: str, scope: Literal["user", "team"]
) -> dict[str, Any] | None:
"""Reject personal-only markers ([pref], [instr]) in team memory."""
if scope != "team":
return None
markers = set(_MARKER_RE.findall(content))
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
if leaked:
tags = ", ".join(f"[{m}]" for m in leaked)
return {
"status": "error",
"message": (
f"Team memory cannot include personal markers: {tags}. "
"Use [fact] only in team memory."
),
}
return None
def _validate_bullet_format(content: str) -> list[str]:
"""Return warnings for bullet lines that don't match the required format.
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
"""
warnings: list[str] = []
for line in content.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
if not _BULLET_FORMAT_RE.match(stripped):
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
warnings.append(f"Malformed bullet: {short}")
return warnings
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
"""Return a list of warning strings about suspicious changes."""
if not old_memory:
return []
warnings: list[str] = []
old_headings = _extract_headings(old_memory)
new_headings = _extract_headings(new_memory)
dropped = old_headings - new_headings
if dropped:
names = ", ".join(sorted(dropped))
warnings.append(
f"Sections removed: {names}. "
"If unintentional, the user can restore from the settings page."
)
old_len = len(old_memory)
new_len = len(new_memory)
if old_len > 0 and new_len < old_len * 0.4:
warnings.append(
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
"Possible data loss."
)
return warnings
# ---------------------------------------------------------------------------
# Size validation & soft warning
# ---------------------------------------------------------------------------
def _validate_memory_size(content: str) -> dict[str, Any] | None:
"""Return an error/warning dict if *content* is too large, else None."""
length = len(content)
if length > MEMORY_HARD_LIMIT:
return {
"status": "error",
"message": (
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
f"({length:,} chars). Consolidate by merging related items, "
"removing outdated entries, and shortening descriptions. "
"Then call update_memory again."
),
}
return None
def _soft_warning(content: str) -> str | None:
"""Return a warning string if content exceeds the soft limit."""
length = len(content)
if length > MEMORY_SOFT_LIMIT:
return (
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
"Consolidate by merging related items and removing less important "
"entries on your next update."
)
return None
# ---------------------------------------------------------------------------
# Forced rewrite when memory exceeds the hard limit
# ---------------------------------------------------------------------------
_FORCED_REWRITE_PROMPT = """\
You are a memory curator. The following memory document exceeds the character \
limit and must be shortened.
RULES:
1. Rewrite the document to be under {target} characters.
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
or rename headings to consolidate, but keep names personal and descriptive.
3. Priority for keeping content: [instr] > [pref] > [fact].
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
6. Preserve the user's first name in entries — do not replace it with "the user".
7. Output ONLY the consolidated markdown no explanations, no wrapping.
<memory_document>
{content}
</memory_document>"""
async def _forced_rewrite(content: str, llm: Any) -> str | None:
"""Use a focused LLM call to compress *content* under the hard limit.
Returns the rewritten string, or ``None`` if the call fails.
"""
try:
prompt = _FORCED_REWRITE_PROMPT.format(
target=MEMORY_HARD_LIMIT, content=content
)
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
text = (
response.content
if isinstance(response.content, str)
else str(response.content)
)
return text.strip()
except Exception:
logger.exception("Forced rewrite LLM call failed")
return None
# ---------------------------------------------------------------------------
# Shared save-and-respond logic
# ---------------------------------------------------------------------------
async def _save_memory(
*,
updated_memory: str,
old_memory: str | None,
llm: Any | None,
apply_fn,
commit_fn,
rollback_fn,
label: str,
scope: Literal["user", "team"],
) -> dict[str, Any]:
"""Validate, optionally force-rewrite if over the hard limit, save, and
return a response dict.
Parameters
----------
updated_memory : str
The new document the agent submitted.
old_memory : str | None
The previously persisted document (for diff checks).
llm : Any | None
LLM instance for forced rewrite (may be ``None``).
apply_fn : callable(str) -> None
Callback that sets the new memory on the ORM object.
commit_fn : coroutine
``session.commit``.
rollback_fn : coroutine
``session.rollback``.
label : str
Human label for log messages (e.g. "user memory", "team memory").
"""
content = updated_memory
# --- forced rewrite if over the hard limit ---
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
rewritten = await _forced_rewrite(content, llm)
if rewritten is not None and len(rewritten) < len(content):
content = rewritten
# --- hard-limit gate (reject if still too large after rewrite) ---
size_err = _validate_memory_size(content)
if size_err:
return size_err
scope_err = _validate_memory_scope(content, scope)
if scope_err:
return scope_err
# --- persist ---
try:
apply_fn(content)
await commit_fn()
except Exception as e:
logger.exception("Failed to update %s: %s", label, e)
await rollback_fn()
return {"status": "error", "message": f"Failed to update {label}: {e}"}
# --- build response ---
resp: dict[str, Any] = {
"status": "saved",
"message": f"{label.capitalize()} updated.",
}
if content is not updated_memory:
resp["notice"] = "Memory was automatically rewritten to fit within limits."
diff_warnings = _validate_diff(old_memory, content)
if diff_warnings:
resp["diff_warnings"] = diff_warnings
format_warnings = _validate_bullet_format(content)
if format_warnings:
resp["format_warnings"] = format_warnings
warning = _soft_warning(content)
if warning:
resp["warning"] = warning
return resp
# ---------------------------------------------------------------------------
# Tool factories
# ---------------------------------------------------------------------------
def create_update_memory_tool(
user_id: str | UUID,
db_session: AsyncSession,
llm: Any | None = None,
):
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the user's personal memory document.
Your current memory is shown in <user_memory> in the system prompt.
When the user shares important long-term information (preferences,
facts, instructions, context), rewrite the memory document to include
the new information. Merge new facts with existing ones, update
contradictions, remove outdated entries, and keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
"""
try:
result = await db_session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return {"status": "error", "message": "User not found."}
old_memory = user.memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
scope="user",
)
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update memory: {e}",
}
return update_memory
def create_update_team_memory_tool(
search_space_id: int,
db_session: AsyncSession,
llm: Any | None = None,
):
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
Your current team memory is shown in <team_memory> in the system
prompt. When the team shares important long-term information
(decisions, conventions, key facts, priorities), rewrite the memory
document to include the new information. Merge new facts with
existing ones, update contradictions, remove outdated entries, and
keep it concise.
Args:
updated_memory: The FULL updated markdown document (not a diff).
"""
try:
result = await db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return {"status": "error", "message": "Search space not found."}
old_memory = space.shared_memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
scope="team",
)
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update team memory: {e}",
}
return update_memory

View file

@ -1,351 +0,0 @@
"""
User memory tools for the SurfSense agent.
This module provides tools for storing and retrieving user memories,
enabling personalized AI responses similar to Claude's memory feature.
Features:
- save_memory: Store facts, preferences, and context about the user
- recall_memory: Retrieve relevant memories using semantic search
"""
import asyncio
import logging
from typing import Any
from uuid import UUID
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import MemoryCategory, UserMemory
from app.utils.document_converters import embed_text
logger = logging.getLogger(__name__)
# =============================================================================
# Constants
# =============================================================================
# Default number of memories to retrieve
DEFAULT_RECALL_TOP_K = 5
# Maximum number of memories per user (to prevent unbounded growth)
MAX_MEMORIES_PER_USER = 100
# =============================================================================
# Helper Functions
# =============================================================================
def _to_uuid(user_id: str) -> UUID:
"""Convert a string user_id to a UUID object."""
if isinstance(user_id, UUID):
return user_id
return UUID(user_id)
async def get_user_memory_count(
db_session: AsyncSession,
user_id: str,
search_space_id: int | None = None,
) -> int:
"""Get the count of memories for a user."""
uuid_user_id = _to_uuid(user_id)
query = select(UserMemory).where(UserMemory.user_id == uuid_user_id)
if search_space_id is not None:
query = query.where(
(UserMemory.search_space_id == search_space_id)
| (UserMemory.search_space_id.is_(None))
)
result = await db_session.execute(query)
return len(result.scalars().all())
async def delete_oldest_memory(
db_session: AsyncSession,
user_id: str,
search_space_id: int | None = None,
) -> None:
"""Delete the oldest memory for a user to make room for new ones."""
uuid_user_id = _to_uuid(user_id)
query = (
select(UserMemory)
.where(UserMemory.user_id == uuid_user_id)
.order_by(UserMemory.updated_at.asc())
.limit(1)
)
if search_space_id is not None:
query = query.where(
(UserMemory.search_space_id == search_space_id)
| (UserMemory.search_space_id.is_(None))
)
result = await db_session.execute(query)
oldest_memory = result.scalars().first()
if oldest_memory:
await db_session.delete(oldest_memory)
await db_session.commit()
def format_memories_for_context(memories: list[dict[str, Any]]) -> str:
"""Format retrieved memories into a readable context string for the LLM."""
if not memories:
return "No relevant memories found for this user."
parts = ["<user_memories>"]
for memory in memories:
category = memory.get("category", "unknown")
text = memory.get("memory_text", "")
updated = memory.get("updated_at", "")
parts.append(
f" <memory category='{category}' updated='{updated}'>{text}</memory>"
)
parts.append("</user_memories>")
return "\n".join(parts)
# =============================================================================
# Tool Factory Functions
# =============================================================================
def create_save_memory_tool(
user_id: str,
search_space_id: int,
db_session: AsyncSession,
):
"""
Factory function to create the save_memory tool.
Args:
user_id: The user's UUID
search_space_id: The search space ID (for space-specific memories)
db_session: Database session for executing queries
Returns:
A configured tool function for saving user memories
"""
@tool
async def save_memory(
content: str,
category: str = "fact",
) -> dict[str, Any]:
"""
Save a fact, preference, or context about the user for future reference.
Use this tool when:
- User explicitly says "remember this", "keep this in mind", or similar
- User shares personal preferences (e.g., "I prefer Python over JavaScript")
- User shares important facts about themselves (name, role, interests, projects)
- User gives standing instructions (e.g., "always respond in bullet points")
- User shares relevant context (e.g., "I'm working on project X")
The saved information will be available in future conversations to provide
more personalized and contextual responses.
Args:
content: The fact/preference/context to remember.
Phrase it clearly, e.g., "User prefers dark mode",
"User is a senior Python developer", "User is working on an AI project"
category: Type of memory. One of:
- "preference": User preferences (e.g., coding style, tools, formats)
- "fact": Facts about the user (e.g., name, role, expertise)
- "instruction": Standing instructions (e.g., response format preferences)
- "context": Current context (e.g., ongoing projects, goals)
Returns:
A dictionary with the save status and memory details
"""
# Normalize and validate category (LLMs may send uppercase)
category = category.lower() if category else "fact"
valid_categories = ["preference", "fact", "instruction", "context"]
if category not in valid_categories:
category = "fact"
try:
# Convert user_id to UUID
uuid_user_id = _to_uuid(user_id)
# Check if we've hit the memory limit
memory_count = await get_user_memory_count(
db_session, user_id, search_space_id
)
if memory_count >= MAX_MEMORIES_PER_USER:
# Delete oldest memory to make room
await delete_oldest_memory(db_session, user_id, search_space_id)
embedding = await asyncio.to_thread(embed_text, content)
# Create new memory using ORM
# The pgvector Vector column type handles embedding conversion automatically
new_memory = UserMemory(
user_id=uuid_user_id,
search_space_id=search_space_id,
memory_text=content,
category=MemoryCategory(category), # Convert string to enum
embedding=embedding, # Pass embedding directly (list or numpy array)
)
db_session.add(new_memory)
await db_session.commit()
await db_session.refresh(new_memory)
return {
"status": "saved",
"memory_id": new_memory.id,
"memory_text": content,
"category": category,
"message": f"I'll remember: {content}",
}
except Exception as e:
logger.exception(f"Failed to save memory for user {user_id}: {e}")
# Rollback the session to clear any failed transaction state
await db_session.rollback()
return {
"status": "error",
"error": str(e),
"message": "Failed to save memory. Please try again.",
}
return save_memory
def create_recall_memory_tool(
user_id: str,
search_space_id: int,
db_session: AsyncSession,
):
"""
Factory function to create the recall_memory tool.
Args:
user_id: The user's UUID
search_space_id: The search space ID
db_session: Database session for executing queries
Returns:
A configured tool function for recalling user memories
"""
@tool
async def recall_memory(
query: str | None = None,
category: str | None = None,
top_k: int = DEFAULT_RECALL_TOP_K,
) -> dict[str, Any]:
"""
Recall relevant memories about the user to provide personalized responses.
Use this tool when:
- You need user context to give a better, more personalized answer
- User asks about their preferences or past information they shared
- User references something they told you before
- Personalization would significantly improve the response quality
- User asks "what do you know about me?" or similar
Args:
query: Optional search query to find specific memories.
If not provided, returns the most recent memories.
Example: "programming preferences", "current projects"
category: Optional category filter. One of:
"preference", "fact", "instruction", "context"
If not provided, searches all categories.
top_k: Number of memories to retrieve (default: 5, max: 20)
Returns:
A dictionary containing relevant memories and formatted context
"""
top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20
try:
# Convert user_id to UUID
uuid_user_id = _to_uuid(user_id)
if query:
query_embedding = await asyncio.to_thread(embed_text, query)
# Build query with vector similarity
stmt = (
select(UserMemory)
.where(UserMemory.user_id == uuid_user_id)
.where(
(UserMemory.search_space_id == search_space_id)
| (UserMemory.search_space_id.is_(None))
)
)
# Add category filter if specified
if category and category in [
"preference",
"fact",
"instruction",
"context",
]:
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
# Order by vector similarity
stmt = stmt.order_by(
UserMemory.embedding.op("<=>")(query_embedding)
).limit(top_k)
else:
# No query - return most recent memories
stmt = (
select(UserMemory)
.where(UserMemory.user_id == uuid_user_id)
.where(
(UserMemory.search_space_id == search_space_id)
| (UserMemory.search_space_id.is_(None))
)
)
# Add category filter if specified
if category and category in [
"preference",
"fact",
"instruction",
"context",
]:
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
stmt = stmt.order_by(UserMemory.updated_at.desc()).limit(top_k)
result = await db_session.execute(stmt)
memories = result.scalars().all()
# Format memories for response
memory_list = [
{
"id": m.id,
"memory_text": m.memory_text,
"category": m.category.value if m.category else "unknown",
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
}
for m in memories
]
formatted_context = format_memories_for_context(memory_list)
return {
"status": "success",
"count": len(memory_list),
"memories": memory_list,
"formatted_context": formatted_context,
}
except Exception as e:
logger.exception(f"Failed to recall memories for user {user_id}: {e}")
await db_session.rollback()
return {
"status": "error",
"error": str(e),
"memories": [],
"formatted_context": "Failed to recall memories.",
}
return recall_memory

View file

@ -861,99 +861,6 @@ class ChatSessionState(BaseModel):
ai_responding_to_user = relationship("User")
class MemoryCategory(StrEnum):
"""Categories for user memories."""
# Using lowercase keys to match PostgreSQL enum values
preference = "preference" # User preferences (e.g., "prefers dark mode")
fact = "fact" # Facts about the user (e.g., "is a Python developer")
instruction = (
"instruction" # Standing instructions (e.g., "always respond in bullet points")
)
context = "context" # Contextual information (e.g., "working on project X")
class UserMemory(BaseModel, TimestampMixin):
"""
Private memory: facts, preferences, context per user per search space.
Used only for private chats (not shared/team chats).
"""
__tablename__ = "user_memories"
user_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Optional association with a search space (if memory is space-specific)
search_space_id = Column(
Integer,
ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=True,
index=True,
)
# The actual memory content
memory_text = Column(Text, nullable=False)
# Category for organization and filtering
category = Column(
SQLAlchemyEnum(MemoryCategory),
nullable=False,
default=MemoryCategory.fact,
)
# Vector embedding for semantic search
embedding = Column(Vector(config.embedding_model_instance.dimension))
# Track when memory was last updated
updated_at = Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
index=True,
)
# Relationships
user = relationship("User", back_populates="memories")
search_space = relationship("SearchSpace", back_populates="user_memories")
class SharedMemory(BaseModel, TimestampMixin):
__tablename__ = "shared_memories"
search_space_id = Column(
Integer,
ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
created_by_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
memory_text = Column(Text, nullable=False)
category = Column(
SQLAlchemyEnum(MemoryCategory),
nullable=False,
default=MemoryCategory.fact,
)
embedding = Column(Vector(config.embedding_model_instance.dimension))
updated_at = Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
index=True,
)
search_space = relationship("SearchSpace", back_populates="shared_memories")
created_by = relationship("User")
class Folder(BaseModel, TimestampMixin):
__tablename__ = "folders"
@ -1394,6 +1301,8 @@ class SearchSpace(BaseModel, TimestampMixin):
Text, nullable=True, default=""
) # User's custom instructions
shared_memory_md = Column(Text, nullable=True, server_default="")
# Search space-level LLM preferences (shared by all members)
# Note: ID values:
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
@ -1516,20 +1425,6 @@ class SearchSpace(BaseModel, TimestampMixin):
cascade="all, delete-orphan",
)
# User memories associated with this search space
user_memories = relationship(
"UserMemory",
back_populates="search_space",
order_by="UserMemory.updated_at.desc()",
cascade="all, delete-orphan",
)
shared_memories = relationship(
"SharedMemory",
back_populates="search_space",
order_by="SharedMemory.updated_at.desc()",
cascade="all, delete-orphan",
)
class SearchSourceConnector(BaseModel, TimestampMixin):
__tablename__ = "search_source_connectors"
@ -2037,14 +1932,6 @@ if config.AUTH_TYPE == "GOOGLE":
passive_deletes=True,
)
# User memories for personalized AI responses
memories = relationship(
"UserMemory",
back_populates="user",
order_by="UserMemory.updated_at.desc()",
cascade="all, delete-orphan",
)
# Incentive tasks completed by this user
incentive_tasks = relationship(
"UserIncentiveTask",
@ -2072,6 +1959,8 @@ if config.AUTH_TYPE == "GOOGLE":
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
memory_md = Column(Text, nullable=True, server_default="")
# Refresh tokens for this user
refresh_tokens = relationship(
"RefreshToken",
@ -2157,14 +2046,6 @@ else:
passive_deletes=True,
)
# User memories for personalized AI responses
memories = relationship(
"UserMemory",
back_populates="user",
order_by="UserMemory.updated_at.desc()",
cascade="all, delete-orphan",
)
# Incentive tasks completed by this user
incentive_tasks = relationship(
"UserIncentiveTask",
@ -2192,6 +2073,8 @@ else:
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
memory_md = Column(Text, nullable=True, server_default="")
# Refresh tokens for this user
refresh_tokens = relationship(
"RefreshToken",

View file

@ -36,7 +36,7 @@ async def parse_with_azure_doc_intelligence(file_path: str) -> str:
async with client:
with open(file_path, "rb") as f:
poller = await client.begin_analyze_document(
"prebuilt-read",
"prebuilt-layout",
body=f,
output_content_format=DocumentContentFormat.MARKDOWN,
)

View file

@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
from .linear_add_connector_route import router as linear_add_connector_router
from .logs_routes import router as logs_router
from .luma_add_connector_route import router as luma_add_connector_router
from .memory_routes import router as memory_router
from .model_list_routes import router as model_list_router
from .new_chat_routes import router as new_chat_router
from .new_llm_config_routes import router as new_llm_config_router
@ -100,4 +101,5 @@ router.include_router(incentive_tasks_router) # Incentive tasks for earning fre
router.include_router(stripe_router) # Stripe checkout for additional page packs
router.include_router(youtube_router) # YouTube playlist resolution
router.include_router(prompts_router)
router.include_router(memory_router) # User personal memory (memory.md style)
router.include_router(autocomplete_router) # Lightweight autocomplete with KB context

View file

@ -0,0 +1,153 @@
"""Routes for user memory management (personal memory.md)."""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.llm_config import (
create_chat_litellm_from_agent_config,
load_agent_llm_config_for_search_space,
)
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
from app.db import User, get_async_session
from app.users import current_active_user
logger = logging.getLogger(__name__)
router = APIRouter()
class MemoryRead(BaseModel):
memory_md: str
class MemoryUpdate(BaseModel):
memory_md: str
class MemoryEditRequest(BaseModel):
query: str
search_space_id: int
_MEMORY_EDIT_PROMPT = """\
You are a memory editor. The user wants to modify their memory document. \
Apply the user's instruction to the existing memory document and output the \
FULL updated document.
RULES:
1. If the instruction asks to add something, add it with format: \
- (YYYY-MM-DD) [fact|pref|instr] text, under an existing or new ## heading. \
Heading names should be personal and descriptive, not generic categories.
2. If the instruction asks to remove something, remove the matching entry.
3. If the instruction asks to change something, update the matching entry.
4. Preserve existing ## headings and all other entries.
5. Every bullet must include a marker: [fact], [pref], or [instr].
6. Use the user's first name (from <user_name>) in entries instead of "the user".
7. Output ONLY the updated markdown no explanations, no wrapping.
<user_name>{user_name}</user_name>
<current_memory>
{current_memory}
</current_memory>
<user_instruction>
{instruction}
</user_instruction>"""
@router.get("/users/me/memory", response_model=MemoryRead)
async def get_user_memory(
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
await session.refresh(user, ["memory_md"])
return MemoryRead(memory_md=user.memory_md or "")
@router.put("/users/me/memory", response_model=MemoryRead)
async def update_user_memory(
body: MemoryUpdate,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
if len(body.memory_md) > MEMORY_HARD_LIMIT:
raise HTTPException(
status_code=400,
detail=f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit ({len(body.memory_md):,} chars).",
)
user.memory_md = body.memory_md
session.add(user)
await session.commit()
await session.refresh(user, ["memory_md"])
return MemoryRead(memory_md=user.memory_md or "")
@router.post("/users/me/memory/edit", response_model=MemoryRead)
async def edit_user_memory(
body: MemoryEditRequest,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""Apply a natural language edit to the user's personal memory via LLM."""
agent_config = await load_agent_llm_config_for_search_space(
session, body.search_space_id
)
if not agent_config:
raise HTTPException(status_code=500, detail="No LLM configuration available.")
llm = create_chat_litellm_from_agent_config(agent_config)
if not llm:
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
await session.refresh(user, ["memory_md", "display_name"])
current_memory = user.memory_md or ""
first_name = (
user.display_name.strip().split()[0]
if user.display_name and user.display_name.strip()
else "The user"
)
prompt = _MEMORY_EDIT_PROMPT.format(
current_memory=current_memory or "(empty)",
instruction=body.query,
user_name=first_name,
)
try:
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-edit"]},
)
updated = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
except Exception as e:
logger.exception("Memory edit LLM call failed: %s", e)
raise HTTPException(status_code=500, detail="Memory edit failed.") from e
if not updated:
raise HTTPException(status_code=400, detail="LLM returned empty result.")
result = await _save_memory(
updated_memory=updated,
old_memory=current_memory,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="memory",
scope="user",
)
if result.get("status") == "error":
raise HTTPException(status_code=400, detail=result["message"])
await session.refresh(user, ["memory_md"])
return MemoryRead(memory_md=user.memory_md or "")

View file

@ -1,10 +1,17 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage
from pydantic import BaseModel as PydanticBaseModel
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.llm_config import (
create_chat_litellm_from_agent_config,
load_agent_llm_config_for_search_space,
)
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
from app.config import config
from app.db import (
ImageGenerationConfig,
@ -34,6 +41,34 @@ logger = logging.getLogger(__name__)
router = APIRouter()
class _TeamMemoryEditRequest(PydanticBaseModel):
query: str
_TEAM_MEMORY_EDIT_PROMPT = """\
You are a memory editor for a team workspace. The user wants to modify the \
team's shared memory document. Apply the user's instruction to the existing \
memory document and output the FULL updated document.
RULES:
1. If the instruction asks to add something, add it with format: \
- (YYYY-MM-DD) [fact] text, under an existing or new ## heading. \
Heading names should be descriptive, not generic categories.
2. If the instruction asks to remove something, remove the matching entry.
3. If the instruction asks to change something, update the matching entry.
4. Preserve existing ## headings and all other entries.
5. NEVER use [pref] or [instr] markers. Team memory uses [fact] only.
6. Output ONLY the updated markdown no explanations, no wrapping.
<current_memory>
{current_memory}
</current_memory>
<user_instruction>
{instruction}
</user_instruction>"""
async def create_default_roles_and_membership(
session: AsyncSession,
search_space_id: int,
@ -255,6 +290,16 @@ async def update_search_space(
raise HTTPException(status_code=404, detail="Search space not found")
update_data = search_space_update.model_dump(exclude_unset=True)
if (
"shared_memory_md" in update_data
and len(update_data["shared_memory_md"] or "") > MEMORY_HARD_LIMIT
):
raise HTTPException(
status_code=400,
detail=f"Team memory exceeds {MEMORY_HARD_LIMIT:,} character limit.",
)
for key, value in update_data.items():
setattr(db_search_space, key, value)
await session.commit()
@ -269,6 +314,76 @@ async def update_search_space(
) from e
@router.post(
"/searchspaces/{search_space_id}/memory/edit",
response_model=SearchSpaceRead,
)
async def edit_team_memory(
search_space_id: int,
body: _TeamMemoryEditRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Apply a natural language edit to the team memory via LLM."""
await check_search_space_access(session, user, search_space_id)
agent_config = await load_agent_llm_config_for_search_space(
session, search_space_id
)
if not agent_config:
raise HTTPException(status_code=500, detail="No LLM configuration available.")
llm = create_chat_litellm_from_agent_config(agent_config)
if not llm:
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
db_search_space = result.scalars().first()
if not db_search_space:
raise HTTPException(status_code=404, detail="Search space not found")
current_memory = db_search_space.shared_memory_md or ""
prompt = _TEAM_MEMORY_EDIT_PROMPT.format(
current_memory=current_memory or "(empty)",
instruction=body.query,
)
try:
response = await llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal", "memory-edit"]},
)
updated = (
response.content
if isinstance(response.content, str)
else str(response.content)
).strip()
except Exception as e:
logger.exception("Team memory edit LLM call failed: %s", e)
raise HTTPException(status_code=500, detail="Team memory edit failed.") from e
if not updated:
raise HTTPException(status_code=400, detail="LLM returned empty result.")
save_result = await _save_memory(
updated_memory=updated,
old_memory=current_memory,
llm=llm,
apply_fn=lambda content: setattr(db_search_space, "shared_memory_md", content),
commit_fn=session.commit,
rollback_fn=session.rollback,
label="team memory",
scope="team",
)
if save_result.get("status") == "error":
raise HTTPException(status_code=400, detail=save_result["message"])
await session.refresh(db_search_space)
return db_search_space
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
async def delete_search_space(
search_space_id: int,

View file

@ -21,6 +21,7 @@ class SearchSpaceUpdate(BaseModel):
description: str | None = None
citations_enabled: bool | None = None
qna_custom_instructions: str | None = None
shared_memory_md: str | None = None
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
@ -29,6 +30,7 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
user_id: uuid.UUID
citations_enabled: bool
qna_custom_instructions: str | None = None
shared_memory_md: str | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -37,6 +37,10 @@ from app.agents.new_chat.llm_config import (
load_agent_config,
load_llm_config_from_yaml,
)
from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
from app.db import (
ChatVisibility,
NewChatMessage,
@ -59,8 +63,6 @@ from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_hea
_perf_log = get_perf_logger()
_background_tasks: set[asyncio.Task] = set()
def format_mentioned_surfsense_docs_as_context(
documents: list[SurfsenseDocsDocument],
@ -141,6 +143,7 @@ class StreamResult:
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
agent_called_update_memory: bool = False
async def _stream_agent_events(
@ -183,6 +186,7 @@ async def _stream_agent_events(
last_active_step_items: list[str] = initial_step_items or []
just_finished_tool: bool = False
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
called_update_memory: bool = False
def next_thinking_step_id() -> str:
nonlocal thinking_step_counter
@ -490,6 +494,9 @@ async def _stream_agent_events(
tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "")
if tool_name == "update_memory":
called_update_memory = True
if hasattr(raw_output, "content"):
content = raw_output.content
if isinstance(content, str):
@ -1111,6 +1118,7 @@ async def _stream_agent_events(
yield completion_event
result.accumulated_text = accumulated_text
result.agent_called_update_memory = called_update_memory
state = await agent.aget_state(config)
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
@ -1540,6 +1548,27 @@ async def stream_new_chat(
chat_id, generated_title
)
# Fire background memory extraction if the agent didn't handle it.
# Shared threads write to team memory; private threads write to user memory.
if not stream_result.agent_called_update_memory:
if visibility == ChatVisibility.SEARCH_SPACE:
asyncio.create_task(
extract_and_save_team_memory(
user_message=user_query,
search_space_id=search_space_id,
llm=llm,
author_display_name=current_user_display_name,
)
)
elif user_id:
asyncio.create_task(
extract_and_save_memory(
user_message=user_query,
user_id=user_id,
llm=llm,
)
)
# Finish the step and message
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()

View file

@ -0,0 +1,198 @@
"""Unit tests for memory scope validation and bullet format validation."""
import pytest
from app.agents.new_chat.tools.update_memory import (
_save_memory,
_validate_bullet_format,
_validate_memory_scope,
)
pytestmark = pytest.mark.unit
class _Recorder:
def __init__(self) -> None:
self.applied_content: str | None = None
self.commit_calls = 0
self.rollback_calls = 0
def apply(self, content: str) -> None:
self.applied_content = content
async def commit(self) -> None:
self.commit_calls += 1
async def rollback(self) -> None:
self.rollback_calls += 1
# ---------------------------------------------------------------------------
# _validate_memory_scope — marker-based
# ---------------------------------------------------------------------------
def test_validate_memory_scope_rejects_pref_marker_in_team_scope() -> None:
content = "- (2026-04-10) [pref] Prefers dark mode\n"
result = _validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
assert "[pref]" in result["message"]
def test_validate_memory_scope_rejects_instr_marker_in_team_scope() -> None:
content = "- (2026-04-10) [instr] Always respond in Spanish\n"
result = _validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
assert "[instr]" in result["message"]
def test_validate_memory_scope_rejects_both_personal_markers_in_team() -> None:
content = (
"- (2026-04-10) [pref] Prefers dark mode\n"
"- (2026-04-10) [instr] Always respond in Spanish\n"
)
result = _validate_memory_scope(content, "team")
assert result is not None
assert result["status"] == "error"
assert "[instr]" in result["message"]
assert "[pref]" in result["message"]
def test_validate_memory_scope_allows_fact_in_team_scope() -> None:
content = "- (2026-04-10) [fact] Office is in downtown Seattle\n"
result = _validate_memory_scope(content, "team")
assert result is None
def test_validate_memory_scope_allows_all_markers_in_user_scope() -> None:
content = (
"- (2026-04-10) [fact] Python developer\n"
"- (2026-04-10) [pref] Prefers concise answers\n"
"- (2026-04-10) [instr] Always use bullet points\n"
)
result = _validate_memory_scope(content, "user")
assert result is None
def test_validate_memory_scope_allows_any_heading_in_team() -> None:
content = "## Architecture\n- (2026-04-10) [fact] Uses PostgreSQL for persistence\n"
result = _validate_memory_scope(content, "team")
assert result is None
def test_validate_memory_scope_allows_any_heading_in_user() -> None:
content = "## My Projects\n- (2026-04-10) [fact] Working on SurfSense\n"
result = _validate_memory_scope(content, "user")
assert result is None
# ---------------------------------------------------------------------------
# _validate_bullet_format
# ---------------------------------------------------------------------------
def test_validate_bullet_format_passes_valid_bullets() -> None:
content = (
"## Work\n"
"- (2026-04-10) [fact] Senior Python developer\n"
"- (2026-04-10) [pref] Prefers dark mode\n"
"- (2026-04-10) [instr] Always respond in bullet points\n"
)
warnings = _validate_bullet_format(content)
assert warnings == []
def test_validate_bullet_format_warns_on_missing_marker() -> None:
content = "- (2026-04-10) Senior Python developer\n"
warnings = _validate_bullet_format(content)
assert len(warnings) == 1
assert "Malformed bullet" in warnings[0]
def test_validate_bullet_format_warns_on_missing_date() -> None:
content = "- [fact] Senior Python developer\n"
warnings = _validate_bullet_format(content)
assert len(warnings) == 1
assert "Malformed bullet" in warnings[0]
def test_validate_bullet_format_warns_on_unknown_marker() -> None:
content = "- (2026-04-10) [context] Working on project X\n"
warnings = _validate_bullet_format(content)
assert len(warnings) == 1
assert "Malformed bullet" in warnings[0]
def test_validate_bullet_format_ignores_non_bullet_lines() -> None:
content = "## Some Heading\nSome paragraph text\n"
warnings = _validate_bullet_format(content)
assert warnings == []
def test_validate_bullet_format_warns_on_old_format_without_marker() -> None:
content = "## About the user\n- (2026-04-10) Likes cats\n"
warnings = _validate_bullet_format(content)
assert len(warnings) == 1
# ---------------------------------------------------------------------------
# _save_memory — end-to-end with marker scope check
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_save_memory_blocks_pref_in_team_before_commit() -> None:
recorder = _Recorder()
result = await _save_memory(
updated_memory="- (2026-04-10) [pref] Prefers dark mode\n",
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="team memory",
scope="team",
)
assert result["status"] == "error"
assert recorder.commit_calls == 0
assert recorder.applied_content is None
@pytest.mark.asyncio
async def test_save_memory_allows_fact_in_team_and_commits() -> None:
recorder = _Recorder()
content = "- (2026-04-10) [fact] Weekly standup on Mondays\n"
result = await _save_memory(
updated_memory=content,
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="team memory",
scope="team",
)
assert result["status"] == "saved"
assert recorder.commit_calls == 1
assert recorder.applied_content == content
@pytest.mark.asyncio
async def test_save_memory_includes_format_warnings() -> None:
recorder = _Recorder()
content = "- (2026-04-10) Missing marker text\n"
result = await _save_memory(
updated_memory=content,
old_memory=None,
llm=None,
apply_fn=recorder.apply,
commit_fn=recorder.commit,
rollback_fn=recorder.rollback,
label="memory",
scope="user",
)
assert result["status"] == "saved"
assert "format_warnings" in result
assert len(result["format_warnings"]) == 1