From dfa6005af52ec2b1d7e13438167d600b64cd9b8e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 8 Apr 2026 23:54:29 +0530 Subject: [PATCH] feat: implement update_memory tool and routes for user memory management --- .../agents/new_chat/tools/update_memory.py | 157 ++++++++++++++++++ surfsense_backend/app/routes/memory_routes.py | 46 +++++ 2 files changed, 203 insertions(+) create mode 100644 surfsense_backend/app/agents/new_chat/tools/update_memory.py create mode 100644 surfsense_backend/app/routes/memory_routes.py diff --git a/surfsense_backend/app/agents/new_chat/tools/update_memory.py b/surfsense_backend/app/agents/new_chat/tools/update_memory.py new file mode 100644 index 000000000..1bb51b94f --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/update_memory.py @@ -0,0 +1,157 @@ +"""Markdown-document memory tool for the SurfSense agent. + +Replaces the old row-per-fact save_memory / recall_memory tools with a single +update_memory tool that overwrites a freeform markdown TEXT column. The LLM +always sees the current memory in / tags injected +by MemoryInjectionMiddleware, so it passes the FULL updated document each time. +""" + +from __future__ import annotations + +import logging +from typing import Any +from uuid import UUID + +from langchain_core.tools import tool +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import SearchSpace, User + +logger = logging.getLogger(__name__) + +MEMORY_SOFT_LIMIT = 20_000 +MEMORY_HARD_LIMIT = 25_000 + + +def _validate_memory_size(content: str) -> dict[str, Any] | None: + """Return an error/warning dict if *content* is too large, else None.""" + length = len(content) + if length > MEMORY_HARD_LIMIT: + return { + "status": "error", + "message": ( + f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit " + f"({length:,} chars). Consolidate by merging related items, " + "removing outdated entries, and shortening descriptions. " + "Then call update_memory again." + ), + } + return None + + +def _soft_warning(content: str) -> str | None: + """Return a warning string if content exceeds the soft limit.""" + length = len(content) + if length > MEMORY_SOFT_LIMIT: + return ( + f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. " + "Consolidate by merging related items and removing less important " + "entries on your next update." + ) + return None + + +def create_update_memory_tool( + user_id: str | UUID, + db_session: AsyncSession, +): + uid = UUID(user_id) if isinstance(user_id, str) else user_id + + @tool + async def update_memory(updated_memory: str) -> dict[str, Any]: + """Update the user's personal memory document. + + Your current memory is shown in in the system prompt. + When the user shares important long-term information (preferences, + facts, instructions, context), rewrite the memory document to include + the new information. Merge new facts with existing ones, update + contradictions, remove outdated entries, and keep it concise. + + Args: + updated_memory: The FULL updated markdown document (not a diff). + """ + error = _validate_memory_size(updated_memory) + if error: + return error + + try: + result = await db_session.execute( + select(User).where(User.id == uid) + ) + user = result.scalars().first() + if not user: + return {"status": "error", "message": "User not found."} + + user.memory_md = updated_memory + await db_session.commit() + + resp: dict[str, Any] = { + "status": "saved", + "message": "Memory updated.", + } + warning = _soft_warning(updated_memory) + if warning: + resp["warning"] = warning + return resp + except Exception as e: + logger.exception("Failed to update user memory: %s", e) + await db_session.rollback() + return { + "status": "error", + "message": f"Failed to update memory: {e}", + } + + return update_memory + + +def create_update_team_memory_tool( + search_space_id: int, + db_session: AsyncSession, +): + @tool + async def update_memory(updated_memory: str) -> dict[str, Any]: + """Update the team's shared memory document for this search space. + + Your current team memory is shown in in the system + prompt. When the team shares important long-term information + (decisions, conventions, key facts, priorities), rewrite the memory + document to include the new information. Merge new facts with + existing ones, update contradictions, remove outdated entries, and + keep it concise. + + Args: + updated_memory: The FULL updated markdown document (not a diff). + """ + error = _validate_memory_size(updated_memory) + if error: + return error + + try: + result = await db_session.execute( + select(SearchSpace).where(SearchSpace.id == search_space_id) + ) + space = result.scalars().first() + if not space: + return {"status": "error", "message": "Search space not found."} + + space.shared_memory_md = updated_memory + await db_session.commit() + + resp: dict[str, Any] = { + "status": "saved", + "message": "Team memory updated.", + } + warning = _soft_warning(updated_memory) + if warning: + resp["warning"] = warning + return resp + except Exception as e: + logger.exception("Failed to update team memory: %s", e) + await db_session.rollback() + return { + "status": "error", + "message": f"Failed to update team memory: {e}", + } + + return update_memory diff --git a/surfsense_backend/app/routes/memory_routes.py b/surfsense_backend/app/routes/memory_routes.py new file mode 100644 index 000000000..aa8b1be28 --- /dev/null +++ b/surfsense_backend/app/routes/memory_routes.py @@ -0,0 +1,46 @@ +"""Routes for user memory management (personal memory.md).""" + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT +from app.db import User, get_async_session +from app.users import current_active_user + +router = APIRouter() + + +class MemoryRead(BaseModel): + memory_md: str + + +class MemoryUpdate(BaseModel): + memory_md: str + + +@router.get("/users/me/memory", response_model=MemoryRead) +async def get_user_memory( + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +): + await session.refresh(user, ["memory_md"]) + return MemoryRead(memory_md=user.memory_md or "") + + +@router.put("/users/me/memory", response_model=MemoryRead) +async def update_user_memory( + body: MemoryUpdate, + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +): + if len(body.memory_md) > MEMORY_HARD_LIMIT: + raise HTTPException( + status_code=400, + detail=f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit ({len(body.memory_md):,} chars).", + ) + user.memory_md = body.memory_md + session.add(user) + await session.commit() + await session.refresh(user, ["memory_md"]) + return MemoryRead(memory_md=user.memory_md or "")