diff --git a/surfsense_backend/alembic/versions/73_add_user_memories_table.py b/surfsense_backend/alembic/versions/73_add_user_memories_table.py new file mode 100644 index 000000000..40ecfd91b --- /dev/null +++ b/surfsense_backend/alembic/versions/73_add_user_memories_table.py @@ -0,0 +1,136 @@ +"""Add user_memories table for AI memory feature + +Revision ID: 73 +Revises: 72 +Create Date: 2026-01-20 + +This migration adds the user_memories table which enables Claude-like memory +functionality - allowing the AI to remember facts, preferences, and context +about users across conversations. +""" + +from collections.abc import Sequence + +from alembic import op + +from app.config import config + +# revision identifiers, used by Alembic. +revision: str = "73" +down_revision: str | None = "72" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +# Get embedding dimension from config +EMBEDDING_DIM = config.embedding_model_instance.dimension + + +def upgrade() -> None: + """Create user_memories table and MemoryCategory enum.""" + + # Create the MemoryCategory enum type + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'memorycategory') THEN + CREATE TYPE memorycategory AS ENUM ( + 'preference', + 'fact', + 'instruction', + 'context' + ); + END IF; + END$$; + """ + ) + + # Create user_memories table + op.execute( + f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'user_memories' + ) THEN + CREATE TABLE user_memories ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + search_space_id INTEGER REFERENCES searchspaces(id) ON DELETE CASCADE, + memory_text TEXT NOT NULL, + category memorycategory NOT NULL DEFAULT 'fact', + embedding vector({EMBEDDING_DIM}), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ); + END IF; + END$$; + """ + ) + + # Create indexes for efficient querying + op.execute( + """ + DO $$ + BEGIN + -- Index on user_id for filtering memories by user + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_user_id' + ) THEN + CREATE INDEX ix_user_memories_user_id ON user_memories(user_id); + END IF; + + -- Index on search_space_id for filtering memories by search space + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_search_space_id' + ) THEN + CREATE INDEX ix_user_memories_search_space_id ON user_memories(search_space_id); + END IF; + + -- Index on updated_at for ordering by recency + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_updated_at' + ) THEN + CREATE INDEX ix_user_memories_updated_at ON user_memories(updated_at); + END IF; + + -- Index on category for filtering by memory type + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_category' + ) THEN + CREATE INDEX ix_user_memories_category ON user_memories(category); + END IF; + + -- Composite index for common query pattern (user + search space) + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_user_search_space' + ) THEN + CREATE INDEX ix_user_memories_user_search_space ON user_memories(user_id, search_space_id); + END IF; + END$$; + """ + ) + + # Create vector index for semantic search + op.execute( + """ + CREATE INDEX IF NOT EXISTS user_memories_vector_index + ON user_memories USING hnsw (embedding public.vector_cosine_ops); + """ + ) + + +def downgrade() -> None: + """Drop user_memories table and MemoryCategory enum.""" + + # Drop the table + op.execute("DROP TABLE IF EXISTS user_memories CASCADE;") + + # Drop the enum type + op.execute("DROP TYPE IF EXISTS memorycategory;") diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 9675521f5..5bc6ac2e2 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -34,6 +34,7 @@ async def create_surfsense_deep_agent( db_session: AsyncSession, connector_service: ConnectorService, checkpointer: Checkpointer, + user_id: str | None = None, agent_config: AgentConfig | None = None, enabled_tools: list[str] | None = None, disabled_tools: list[str] | None = None, @@ -49,6 +50,8 @@ async def create_surfsense_deep_agent( - link_preview: Fetch rich previews for URLs - display_image: Display images in chat - scrape_webpage: Extract content from webpages + - save_memory: Store facts/preferences about the user + - recall_memory: Retrieve relevant user memories The agent also includes TodoListMiddleware by default (via create_deep_agent) which provides: - write_todos: Create and update planning/todo lists for complex tasks @@ -64,6 +67,7 @@ async def create_surfsense_deep_agent( connector_service: Initialized connector service for knowledge base search checkpointer: LangGraph checkpointer for conversation state persistence. Use AsyncPostgresSaver for production or MemorySaver for testing. + user_id: The current user's UUID string (required for memory tools) agent_config: Optional AgentConfig from NewLLMConfig for prompt configuration. If None, uses default system prompt with citations enabled. enabled_tools: Explicit list of tool names to enable. If None, all default tools @@ -118,6 +122,7 @@ async def create_surfsense_deep_agent( "db_session": db_session, "connector_service": connector_service, "firecrawl_api_key": firecrawl_api_key, + "user_id": user_id, # Required for memory tools } # Build tools using the async registry (includes MCP tools) diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index 76429a830..d8202a8b0 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -116,6 +116,45 @@ You have access to the following tools: * This makes your response more visual and engaging. * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. * Don't show every image - just the most relevant 1-3 images that enhance understanding. + +6. save_memory: Save facts, preferences, or context about the user for personalized responses. + - Use this when the user explicitly or implicitly shares information worth remembering. + - Trigger scenarios: + * User says "remember this", "keep this in mind", "note that", or similar + * User shares personal preferences (e.g., "I prefer Python over JavaScript") + * User shares facts about themselves (e.g., "I'm a senior developer at Company X") + * User gives standing instructions (e.g., "always respond in bullet points") + * User shares project context (e.g., "I'm working on migrating our codebase to TypeScript") + - Args: + - content: The fact/preference to remember. Phrase it clearly: + * "User prefers dark mode for all interfaces" + * "User is a senior Python developer" + * "User wants responses in bullet point format" + * "User is working on project called ProjectX" + - category: Type of memory: + * "preference": User preferences (coding style, tools, formats) + * "fact": Facts about the user (role, expertise, background) + * "instruction": Standing instructions (response format, communication style) + * "context": Current context (ongoing projects, goals, challenges) + - Returns: Confirmation of saved memory + - IMPORTANT: Only save information that would be genuinely useful for future conversations. + Don't save trivial or temporary information. + +7. recall_memory: Retrieve relevant memories about the user for personalized responses. + - Use this to access stored information about the user. + - Trigger scenarios: + * You need user context to give a better, more personalized answer + * User references something they mentioned before + * User asks "what do you know about me?" or similar + * Personalization would significantly improve response quality + * Before making recommendations that should consider user preferences + - Args: + - query: Optional search query to find specific memories (e.g., "programming preferences") + - category: Optional filter by category ("preference", "fact", "instruction", "context") + - top_k: Number of memories to retrieve (default: 5) + - Returns: Relevant memories formatted as context + - IMPORTANT: Use the recalled memories naturally in your response without explicitly + stating "Based on your memory..." - integrate the context seamlessly. - User: "How do I install SurfSense?" @@ -136,6 +175,23 @@ You have access to the following tools: - User: "What did I discuss on Slack last week about the React migration?" - Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")` +- User: "Remember that I prefer TypeScript over JavaScript" + - Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")` + +- User: "I'm a data scientist working on ML pipelines" + - Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")` + +- User: "Always give me code examples in Python" + - Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")` + +- User: "What programming language should I use for this project?" + - First recall: `recall_memory(query="programming language preferences")` + - Then provide a personalized recommendation based on their preferences + +- User: "What do you know about me?" + - Call: `recall_memory(top_k=10)` + - Then summarize the stored memories + - User: "Give me a podcast about AI trends based on what we discussed" - First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")` diff --git a/surfsense_backend/app/agents/new_chat/tools/__init__.py b/surfsense_backend/app/agents/new_chat/tools/__init__.py index b531d9b4d..acbdbcb3a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/__init__.py +++ b/surfsense_backend/app/agents/new_chat/tools/__init__.py @@ -11,6 +11,8 @@ Available tools: - link_preview: Fetch rich previews for URLs - display_image: Display images in chat - scrape_webpage: Extract content from webpages +- save_memory: Store facts/preferences about the user +- recall_memory: Retrieve relevant user memories """ # Registry exports @@ -33,6 +35,7 @@ from .registry import ( ) from .scrape_webpage import create_scrape_webpage_tool from .search_surfsense_docs import create_search_surfsense_docs_tool +from .user_memory import create_recall_memory_tool, create_save_memory_tool __all__ = [ # Registry @@ -43,6 +46,8 @@ __all__ = [ "create_display_image_tool", "create_generate_podcast_tool", "create_link_preview_tool", + "create_recall_memory_tool", + "create_save_memory_tool", "create_scrape_webpage_tool", "create_search_knowledge_base_tool", "create_search_surfsense_docs_tool", diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 6873f864c..e4ce7a6b7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -50,6 +50,7 @@ from .mcp_tool import load_mcp_tools from .podcast import create_generate_podcast_tool from .scrape_webpage import create_scrape_webpage_tool from .search_surfsense_docs import create_search_surfsense_docs_tool +from .user_memory import create_recall_memory_tool, create_save_memory_tool # ============================================================================= # Tool Definition @@ -138,6 +139,31 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ requires=["db_session"], ), # ========================================================================= + # USER MEMORY TOOLS - Claude-like memory feature + # ========================================================================= + # Save memory tool - stores facts/preferences about the user + ToolDefinition( + name="save_memory", + description="Save facts, preferences, or context about the user for personalized responses", + factory=lambda deps: create_save_memory_tool( + user_id=deps["user_id"], + search_space_id=deps["search_space_id"], + db_session=deps["db_session"], + ), + requires=["user_id", "search_space_id", "db_session"], + ), + # Recall memory tool - retrieves relevant user memories + ToolDefinition( + name="recall_memory", + description="Recall user memories for personalized and contextual responses", + factory=lambda deps: create_recall_memory_tool( + user_id=deps["user_id"], + search_space_id=deps["search_space_id"], + db_session=deps["db_session"], + ), + requires=["user_id", "search_space_id", "db_session"], + ), + # ========================================================================= # ADD YOUR CUSTOM TOOLS BELOW # ========================================================================= # Example: diff --git a/surfsense_backend/app/agents/new_chat/tools/user_memory.py b/surfsense_backend/app/agents/new_chat/tools/user_memory.py new file mode 100644 index 000000000..3cefa2b02 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/user_memory.py @@ -0,0 +1,339 @@ +""" +User memory tools for the SurfSense agent. + +This module provides tools for storing and retrieving user memories, +enabling personalized AI responses similar to Claude's memory feature. + +Features: +- save_memory: Store facts, preferences, and context about the user +- recall_memory: Retrieve relevant memories using semantic search +""" + +import logging +from datetime import UTC, datetime +from typing import Any +from uuid import UUID + +from langchain_core.tools import tool +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import MemoryCategory, UserMemory + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + +# Default number of memories to retrieve +DEFAULT_RECALL_TOP_K = 5 + +# Maximum number of memories per user (to prevent unbounded growth) +MAX_MEMORIES_PER_USER = 100 + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def _to_uuid(user_id: str) -> UUID: + """Convert a string user_id to a UUID object.""" + if isinstance(user_id, UUID): + return user_id + return UUID(user_id) + + +async def get_user_memory_count( + db_session: AsyncSession, + user_id: str, + search_space_id: int | None = None, +) -> int: + """Get the count of memories for a user.""" + uuid_user_id = _to_uuid(user_id) + query = select(UserMemory).where(UserMemory.user_id == uuid_user_id) + if search_space_id is not None: + query = query.where( + (UserMemory.search_space_id == search_space_id) + | (UserMemory.search_space_id.is_(None)) + ) + result = await db_session.execute(query) + return len(result.scalars().all()) + + +async def delete_oldest_memory( + db_session: AsyncSession, + user_id: str, + search_space_id: int | None = None, +) -> None: + """Delete the oldest memory for a user to make room for new ones.""" + uuid_user_id = _to_uuid(user_id) + query = ( + select(UserMemory) + .where(UserMemory.user_id == uuid_user_id) + .order_by(UserMemory.updated_at.asc()) + .limit(1) + ) + if search_space_id is not None: + query = query.where( + (UserMemory.search_space_id == search_space_id) + | (UserMemory.search_space_id.is_(None)) + ) + result = await db_session.execute(query) + oldest_memory = result.scalars().first() + if oldest_memory: + await db_session.delete(oldest_memory) + await db_session.commit() + + +def format_memories_for_context(memories: list[dict[str, Any]]) -> str: + """Format retrieved memories into a readable context string for the LLM.""" + if not memories: + return "No relevant memories found for this user." + + parts = [""] + for memory in memories: + category = memory.get("category", "unknown") + text = memory.get("memory_text", "") + updated = memory.get("updated_at", "") + parts.append(f" {text}") + parts.append("") + + return "\n".join(parts) + + +# ============================================================================= +# Tool Factory Functions +# ============================================================================= + + +def create_save_memory_tool( + user_id: str, + search_space_id: int, + db_session: AsyncSession, +): + """ + Factory function to create the save_memory tool. + + Args: + user_id: The user's UUID + search_space_id: The search space ID (for space-specific memories) + db_session: Database session for executing queries + + Returns: + A configured tool function for saving user memories + """ + + @tool + async def save_memory( + content: str, + category: str = "fact", + ) -> dict[str, Any]: + """ + Save a fact, preference, or context about the user for future reference. + + Use this tool when: + - User explicitly says "remember this", "keep this in mind", or similar + - User shares personal preferences (e.g., "I prefer Python over JavaScript") + - User shares important facts about themselves (name, role, interests, projects) + - User gives standing instructions (e.g., "always respond in bullet points") + - User shares relevant context (e.g., "I'm working on project X") + + The saved information will be available in future conversations to provide + more personalized and contextual responses. + + Args: + content: The fact/preference/context to remember. + Phrase it clearly, e.g., "User prefers dark mode", + "User is a senior Python developer", "User is working on an AI project" + category: Type of memory. One of: + - "preference": User preferences (e.g., coding style, tools, formats) + - "fact": Facts about the user (e.g., name, role, expertise) + - "instruction": Standing instructions (e.g., response format preferences) + - "context": Current context (e.g., ongoing projects, goals) + + Returns: + A dictionary with the save status and memory details + """ + # Validate category + valid_categories = ["preference", "fact", "instruction", "context"] + if category not in valid_categories: + category = "fact" + + try: + # Convert user_id to UUID + uuid_user_id = _to_uuid(user_id) + + # Check if we've hit the memory limit + memory_count = await get_user_memory_count( + db_session, user_id, search_space_id + ) + if memory_count >= MAX_MEMORIES_PER_USER: + # Delete oldest memory to make room + await delete_oldest_memory(db_session, user_id, search_space_id) + + # Generate embedding for the memory + embedding = config.embedding_model_instance.embed(content) + + # Map string category to enum + category_enum = MemoryCategory(category) + + # Create new memory + new_memory = UserMemory( + user_id=uuid_user_id, + search_space_id=search_space_id, + memory_text=content, + category=category_enum, + embedding=embedding, + updated_at=datetime.now(UTC), + ) + + db_session.add(new_memory) + await db_session.commit() + await db_session.refresh(new_memory) + + return { + "status": "saved", + "memory_id": new_memory.id, + "memory_text": content, + "category": category, + "message": f"I'll remember: {content}", + } + + except Exception as e: + logger.exception(f"Failed to save memory for user {user_id}: {e}") + return { + "status": "error", + "error": str(e), + "message": "Failed to save memory. Please try again.", + } + + return save_memory + + +def create_recall_memory_tool( + user_id: str, + search_space_id: int, + db_session: AsyncSession, +): + """ + Factory function to create the recall_memory tool. + + Args: + user_id: The user's UUID + search_space_id: The search space ID + db_session: Database session for executing queries + + Returns: + A configured tool function for recalling user memories + """ + + @tool + async def recall_memory( + query: str | None = None, + category: str | None = None, + top_k: int = DEFAULT_RECALL_TOP_K, + ) -> dict[str, Any]: + """ + Recall relevant memories about the user to provide personalized responses. + + Use this tool when: + - You need user context to give a better, more personalized answer + - User asks about their preferences or past information they shared + - User references something they told you before + - Personalization would significantly improve the response quality + - User asks "what do you know about me?" or similar + + Args: + query: Optional search query to find specific memories. + If not provided, returns the most recent memories. + Example: "programming preferences", "current projects" + category: Optional category filter. One of: + "preference", "fact", "instruction", "context" + If not provided, searches all categories. + top_k: Number of memories to retrieve (default: 5, max: 20) + + Returns: + A dictionary containing relevant memories and formatted context + """ + top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20 + + try: + # Convert user_id to UUID + uuid_user_id = _to_uuid(user_id) + + if query: + # Semantic search using embeddings + query_embedding = config.embedding_model_instance.embed(query) + + # Build query with vector similarity + stmt = ( + select(UserMemory) + .where(UserMemory.user_id == uuid_user_id) + .where( + (UserMemory.search_space_id == search_space_id) + | (UserMemory.search_space_id.is_(None)) + ) + ) + + # Add category filter if specified + if category and category in ["preference", "fact", "instruction", "context"]: + stmt = stmt.where(UserMemory.category == MemoryCategory(category)) + + # Order by vector similarity + stmt = stmt.order_by( + UserMemory.embedding.op("<=>")(query_embedding) + ).limit(top_k) + + else: + # No query - return most recent memories + stmt = ( + select(UserMemory) + .where(UserMemory.user_id == uuid_user_id) + .where( + (UserMemory.search_space_id == search_space_id) + | (UserMemory.search_space_id.is_(None)) + ) + ) + + # Add category filter if specified + if category and category in ["preference", "fact", "instruction", "context"]: + stmt = stmt.where(UserMemory.category == MemoryCategory(category)) + + stmt = stmt.order_by(UserMemory.updated_at.desc()).limit(top_k) + + result = await db_session.execute(stmt) + memories = result.scalars().all() + + # Format memories for response + memory_list = [ + { + "id": m.id, + "memory_text": m.memory_text, + "category": m.category.value if m.category else "unknown", + "updated_at": m.updated_at.isoformat() if m.updated_at else None, + } + for m in memories + ] + + formatted_context = format_memories_for_context(memory_list) + + return { + "status": "success", + "count": len(memory_list), + "memories": memory_list, + "formatted_context": formatted_context, + } + + except Exception as e: + return { + "status": "error", + "error": str(e), + "memories": [], + "formatted_context": "Failed to recall memories.", + } + + return recall_memory diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 38e27ecf2..c23b133e2 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -472,6 +472,63 @@ class ChatCommentMention(BaseModel, TimestampMixin): mentioned_user = relationship("User") +class MemoryCategory(str, Enum): + """Categories for user memories.""" + + PREFERENCE = "preference" # User preferences (e.g., "prefers dark mode") + FACT = "fact" # Facts about the user (e.g., "is a Python developer") + INSTRUCTION = "instruction" # Standing instructions (e.g., "always respond in bullet points") + CONTEXT = "context" # Contextual information (e.g., "working on project X") + + +class UserMemory(BaseModel, TimestampMixin): + """ + Stores facts, preferences, and context about users for personalized AI responses. + Similar to Claude's memory feature - enables the AI to remember user information + across conversations. + """ + + __tablename__ = "user_memories" + + user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + # Optional association with a search space (if memory is space-specific) + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=True, + index=True, + ) + + # The actual memory content + memory_text = Column(Text, nullable=False) + # Category for organization and filtering + category = Column( + SQLAlchemyEnum(MemoryCategory), + nullable=False, + default=MemoryCategory.FACT, + ) + # Vector embedding for semantic search + embedding = Column(Vector(config.embedding_model_instance.dimension)) + + # Track when memory was last updated + updated_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + index=True, + ) + + # Relationships + user = relationship("User", back_populates="memories") + search_space = relationship("SearchSpace", back_populates="user_memories") + + class Document(BaseModel, TimestampMixin): __tablename__ = "documents" @@ -659,6 +716,14 @@ class SearchSpace(BaseModel, TimestampMixin): cascade="all, delete-orphan", ) + # User memories associated with this search space + user_memories = relationship( + "UserMemory", + back_populates="search_space", + order_by="UserMemory.updated_at.desc()", + cascade="all, delete-orphan", + ) + class SearchSourceConnector(BaseModel, TimestampMixin): __tablename__ = "search_source_connectors" @@ -967,6 +1032,14 @@ if config.AUTH_TYPE == "GOOGLE": passive_deletes=True, ) + # User memories for personalized AI responses + memories = relationship( + "UserMemory", + back_populates="user", + order_by="UserMemory.updated_at.desc()", + cascade="all, delete-orphan", + ) + # Page usage tracking for ETL services pages_limit = Column( Integer, @@ -1010,6 +1083,14 @@ else: passive_deletes=True, ) + # User memories for personalized AI responses + memories = relationship( + "UserMemory", + back_populates="user", + order_by="UserMemory.updated_at.desc()", + cascade="all, delete-orphan", + ) + # Page usage tracking for ETL services pages_limit = Column( Integer, diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 8fddc55c4..4b8600fab 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -990,6 +990,7 @@ async def handle_new_chat( search_space_id=request.search_space_id, chat_id=request.chat_id, session=session, + user_id=str(user.id), # Pass user ID for memory tools llm_config_id=llm_config_id, attachments=request.attachments, mentioned_document_ids=request.mentioned_document_ids, diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 85a524108..7d2cf4172 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -149,6 +149,7 @@ async def stream_new_chat( search_space_id: int, chat_id: int, session: AsyncSession, + user_id: str | None = None, llm_config_id: int = -1, attachments: list[ChatAttachment] | None = None, mentioned_document_ids: list[int] | None = None, @@ -166,6 +167,7 @@ async def stream_new_chat( search_space_id: The search space ID chat_id: The chat ID (used as LangGraph thread_id for memory) session: The database session + user_id: The current user's UUID string (for memory tools) llm_config_id: The LLM configuration ID (default: -1 for first global config) messages: Optional chat history from frontend (list of ChatMessage) attachments: Optional attachments with extracted content @@ -243,6 +245,7 @@ async def stream_new_chat( db_session=session, connector_service=connector_service, checkpointer=checkpointer, + user_id=user_id, # Pass user ID for memory tools agent_config=agent_config, # Pass prompt configuration firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured ) diff --git a/surfsense_web/components/tool-ui/index.ts b/surfsense_web/components/tool-ui/index.ts index 68f790954..5b4ea0a34 100644 --- a/surfsense_web/components/tool-ui/index.ts +++ b/surfsense_web/components/tool-ui/index.ts @@ -77,4 +77,17 @@ export { ScrapeWebpageResultSchema, ScrapeWebpageToolUI, } from "./scrape-webpage"; +export { + type MemoryItem, + type RecallMemoryArgs, + RecallMemoryArgsSchema, + type RecallMemoryResult, + RecallMemoryResultSchema, + RecallMemoryToolUI, + type SaveMemoryArgs, + SaveMemoryArgsSchema, + type SaveMemoryResult, + SaveMemoryResultSchema, + SaveMemoryToolUI, +} from "./user-memory"; export { type WriteTodosData, WriteTodosSchema, WriteTodosToolUI } from "./write-todos"; diff --git a/surfsense_web/components/tool-ui/user-memory.tsx b/surfsense_web/components/tool-ui/user-memory.tsx new file mode 100644 index 000000000..6fd64d632 --- /dev/null +++ b/surfsense_web/components/tool-ui/user-memory.tsx @@ -0,0 +1,291 @@ +"use client"; + +import { makeAssistantToolUI } from "@assistant-ui/react"; +import { BrainIcon, CheckIcon, Loader2Icon, SearchIcon, XIcon } from "lucide-react"; +import { z } from "zod"; + +// ============================================================================ +// Zod Schemas for save_memory tool +// ============================================================================ + +const SaveMemoryArgsSchema = z.object({ + content: z.string(), + category: z.string().default("fact"), +}); + +const SaveMemoryResultSchema = z.object({ + status: z.enum(["saved", "error"]), + memory_id: z.number().nullish(), + memory_text: z.string().nullish(), + category: z.string().nullish(), + message: z.string().nullish(), + error: z.string().nullish(), +}); + +type SaveMemoryArgs = z.infer; +type SaveMemoryResult = z.infer; + +// ============================================================================ +// Zod Schemas for recall_memory tool +// ============================================================================ + +const RecallMemoryArgsSchema = z.object({ + query: z.string().nullish(), + category: z.string().nullish(), + top_k: z.number().default(5), +}); + +const MemoryItemSchema = z.object({ + id: z.number(), + memory_text: z.string(), + category: z.string(), + updated_at: z.string().nullish(), +}); + +const RecallMemoryResultSchema = z.object({ + status: z.enum(["success", "error"]), + count: z.number().nullish(), + memories: z.array(MemoryItemSchema).nullish(), + formatted_context: z.string().nullish(), + error: z.string().nullish(), +}); + +type RecallMemoryArgs = z.infer; +type RecallMemoryResult = z.infer; +type MemoryItem = z.infer; + +// ============================================================================ +// Category badge colors +// ============================================================================ + +const categoryColors: Record = { + preference: "bg-blue-500/10 text-blue-600 dark:text-blue-400", + fact: "bg-green-500/10 text-green-600 dark:text-green-400", + instruction: "bg-purple-500/10 text-purple-600 dark:text-purple-400", + context: "bg-orange-500/10 text-orange-600 dark:text-orange-400", +}; + +function CategoryBadge({ category }: { category: string }) { + const colorClass = categoryColors[category] || "bg-gray-500/10 text-gray-600 dark:text-gray-400"; + return ( + + {category} + + ); +} + +// ============================================================================ +// Save Memory Tool UI +// ============================================================================ + +export const SaveMemoryToolUI = makeAssistantToolUI({ + toolName: "save_memory", + render: function SaveMemoryUI({ args, result, status }) { + const isRunning = status.type === "running" || status.type === "requires-action"; + const isComplete = status.type === "complete"; + const isError = result?.status === "error"; + + // Parse args safely + const parsedArgs = SaveMemoryArgsSchema.safeParse(args); + const content = parsedArgs.success ? parsedArgs.data.content : ""; + const category = parsedArgs.success ? parsedArgs.data.category : "fact"; + + // Loading state + if (isRunning) { + return ( +
+
+ +
+
+ Saving to memory... +
+
+ ); + } + + // Error state + if (isError) { + return ( +
+
+ +
+
+ Failed to save memory + {result?.error && ( +

{result.error}

+ )} +
+
+ ); + } + + // Success state + if (isComplete && result?.status === "saved") { + return ( +
+
+ +
+
+
+ + Memory saved + +
+

{content}

+
+
+ ); + } + + // Default/incomplete state - show what's being saved + if (content) { + return ( +
+
+ +
+
+
+ Saving memory + +
+

{content}

+
+
+ ); + } + + return null; + }, +}); + +// ============================================================================ +// Recall Memory Tool UI +// ============================================================================ + +export const RecallMemoryToolUI = makeAssistantToolUI({ + toolName: "recall_memory", + render: function RecallMemoryUI({ args, result, status }) { + const isRunning = status.type === "running" || status.type === "requires-action"; + const isComplete = status.type === "complete"; + const isError = result?.status === "error"; + + // Parse args safely + const parsedArgs = RecallMemoryArgsSchema.safeParse(args); + const query = parsedArgs.success ? parsedArgs.data.query : null; + + // Loading state + if (isRunning) { + return ( +
+
+ +
+
+ + {query ? `Searching memories for "${query}"...` : "Recalling memories..."} + +
+
+ ); + } + + // Error state + if (isError) { + return ( +
+
+ +
+
+ Failed to recall memories + {result?.error && ( +

{result.error}

+ )} +
+
+ ); + } + + // Success state with memories + if (isComplete && result?.status === "success") { + const memories = result.memories || []; + const count = result.count || 0; + + if (count === 0) { + return ( +
+
+ +
+ No memories found +
+ ); + } + + return ( +
+
+ + + Recalled {count} {count === 1 ? "memory" : "memories"} + +
+
+ {memories.slice(0, 5).map((memory: MemoryItem) => ( +
+ + + {memory.memory_text} + +
+ ))} + {memories.length > 5 && ( +

+ ...and {memories.length - 5} more +

+ )} +
+
+ ); + } + + // Default/incomplete state + if (query) { + return ( +
+
+ +
+ + Searching memories for "{query}" + +
+ ); + } + + return null; + }, +}); + +// ============================================================================ +// Exports +// ============================================================================ + +export { + SaveMemoryArgsSchema, + SaveMemoryResultSchema, + RecallMemoryArgsSchema, + RecallMemoryResultSchema, + type SaveMemoryArgs, + type SaveMemoryResult, + type RecallMemoryArgs, + type RecallMemoryResult, + type MemoryItem, +};