diff --git a/surfsense_backend/alembic/versions/96_add_shared_memories_table.py b/surfsense_backend/alembic/versions/96_add_shared_memories_table.py new file mode 100644 index 000000000..4455d4ead --- /dev/null +++ b/surfsense_backend/alembic/versions/96_add_shared_memories_table.py @@ -0,0 +1,77 @@ +"""Add shared_memories table (SUR-152).""" + +from collections.abc import Sequence + +from alembic import op +from app.config import config + +revision: str = "96" +down_revision: str | None = "95" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +EMBEDDING_DIM = config.embedding_model_instance.dimension + + +def upgrade() -> None: + op.execute( + f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'shared_memories' + ) THEN + CREATE TABLE shared_memories ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE, + created_by_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + memory_text TEXT NOT NULL, + category memorycategory NOT NULL DEFAULT 'fact', + embedding vector({EMBEDDING_DIM}) + ); + END IF; + END$$; + """ + ) + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_search_space_id' + ) THEN + CREATE INDEX ix_shared_memories_search_space_id ON shared_memories(search_space_id); + END IF; + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_updated_at' + ) THEN + CREATE INDEX ix_shared_memories_updated_at ON shared_memories(updated_at); + END IF; + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_created_by_id' + ) THEN + CREATE INDEX ix_shared_memories_created_by_id ON shared_memories(created_by_id); + END IF; + END$$; + """ + ) + op.execute( + """ + CREATE INDEX IF NOT EXISTS shared_memories_vector_index + ON shared_memories USING hnsw (embedding public.vector_cosine_ops); + """ + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS shared_memories_vector_index;") + op.execute("DROP INDEX IF EXISTS ix_shared_memories_created_by_id;") + op.execute("DROP INDEX IF EXISTS ix_shared_memories_updated_at;") + op.execute("DROP INDEX IF EXISTS ix_shared_memories_search_space_id;") + op.execute("DROP TABLE IF EXISTS shared_memories CASCADE;") diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 9da6ea3c2..97fe33f0c 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -22,6 +22,7 @@ from app.agents.new_chat.system_prompt import ( build_surfsense_system_prompt, ) from app.agents.new_chat.tools.registry import build_tools_async +from app.db import ChatVisibility from app.services.connector_service import ConnectorService # ============================================================================= @@ -126,6 +127,7 @@ async def create_surfsense_deep_agent( disabled_tools: list[str] | None = None, additional_tools: Sequence[BaseTool] | None = None, firecrawl_api_key: str | None = None, + thread_visibility: ChatVisibility | None = None, ): """ Create a SurfSense deep agent with configurable tools and prompts. @@ -226,16 +228,17 @@ async def create_surfsense_deep_agent( import logging logging.warning(f"Failed to discover available connectors/document types: {e}") - + # Build dependencies dict for the tools registry + visibility = thread_visibility or ChatVisibility.PRIVATE dependencies = { "search_space_id": search_space_id, "db_session": db_session, "connector_service": connector_service, "firecrawl_api_key": firecrawl_api_key, - "user_id": user_id, # Required for memory tools - "thread_id": thread_id, # For podcast tool - # Dynamic connector/document type discovery for knowledge base tool + "user_id": user_id, + "thread_id": thread_id, + "thread_visibility": visibility, "available_connectors": available_connectors, "available_document_types": available_document_types, } @@ -255,10 +258,12 @@ async def create_surfsense_deep_agent( custom_system_instructions=agent_config.system_instructions, use_default_system_instructions=agent_config.use_default_system_instructions, citations_enabled=agent_config.citations_enabled, + thread_visibility=thread_visibility, ) else: - # Use default prompt (with citations enabled) - system_prompt = build_surfsense_system_prompt() + system_prompt = build_surfsense_system_prompt( + thread_visibility=thread_visibility, + ) # Create the deep agent with system prompt and checkpointer # Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index 01c762197..a3520dad6 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -12,6 +12,8 @@ The prompt is composed of three parts: from datetime import UTC, datetime +from app.db import ChatVisibility + # Default system instructions - can be overridden via NewLLMConfig.system_instructions SURFSENSE_SYSTEM_INSTRUCTIONS = """ @@ -22,7 +24,34 @@ Today's date (UTC): {resolved_today} """ -SURFSENSE_TOOLS_INSTRUCTIONS = """ +# Default system instructions for shared (team) threads: team context + message format for attribution +_SYSTEM_INSTRUCTIONS_SHARED = """ + +You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base. + +In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers. + +Today's date (UTC): {resolved_today} + + +""" + + +def _get_system_instructions( + thread_visibility: ChatVisibility | None = None, today: datetime | None = None +) -> str: + """Build system instructions based on thread visibility (private vs shared).""" + + resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat() + visibility = thread_visibility or ChatVisibility.PRIVATE + if visibility == ChatVisibility.SEARCH_SPACE: + return _SYSTEM_INSTRUCTIONS_SHARED.format(resolved_today=resolved_today) + else: + return SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today) + + +# Tools 0-6 (common to both private and shared prompts) +_TOOLS_INSTRUCTIONS_COMMON = """ You have access to the following tools: @@ -138,7 +167,11 @@ You have access to the following tools: * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. * Don't show every image - just the most relevant 1-3 images that enhance understanding. -7. save_memory: Save facts, preferences, or context about the user for personalized responses. +""" + +# Private (user) memory: tools 7-8 + memory-specific examples +_TOOLS_INSTRUCTIONS_MEMORY_PRIVATE = """ +7. save_memory: Save facts, preferences, or context for personalized responses. - Use this when the user explicitly or implicitly shares information worth remembering. - Trigger scenarios: * User says "remember this", "keep this in mind", "note that", or similar @@ -178,6 +211,75 @@ You have access to the following tools: stating "Based on your memory..." - integrate the context seamlessly. +- User: "Remember that I prefer TypeScript over JavaScript" + - Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")` + +- User: "I'm a data scientist working on ML pipelines" + - Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")` + +- User: "Always give me code examples in Python" + - Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")` + +- User: "What programming language should I use for this project?" + - First recall: `recall_memory(query="programming language preferences")` + - Then provide a personalized recommendation based on their preferences + +- User: "What do you know about me?" + - Call: `recall_memory(top_k=10)` + - Then summarize the stored memories + +""" + +# Shared (team) memory: tools 7-8 + team memory examples +_TOOLS_INSTRUCTIONS_MEMORY_SHARED = """ +7. save_memory: Save a fact, preference, or context to the team's shared memory for future reference. + - Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat. + - Use when the team agrees on something to remember (e.g., decisions, conventions). + - Someone shares a preference or fact that should be visible to the whole team. + - The saved information will be available in future shared conversations in this space. + - Args: + - content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays" + - category: Type of memory. One of: + * "preference": Team or workspace preferences + * "fact": Facts the team agreed on (e.g., processes, locations) + * "instruction": Standing instructions for the team + * "context": Current context (e.g., ongoing projects, goals) + - Returns: Confirmation of saved memory; returned context may include who added it (added_by). + - IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space. + +8. recall_memory: Recall relevant team memories for this space to provide contextual responses. + - Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?"). + - Use when someone asks about something the team agreed to remember. + - Use when team preferences or conventions would improve the response. + - Args: + - query: Optional search query to find specific memories. If not provided, returns the most recent memories. + - category: Optional filter by category ("preference", "fact", "instruction", "context") + - top_k: Number of memories to retrieve (default: 5, max: 20) + - Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...". + + +- User: "Remember that API keys are stored in Vault" + - Call: `save_memory(content="API keys are stored in Vault", category="fact")` + +- User: "Let's remember that the team prefers weekly demos on Fridays" + - Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")` + +- User: "What did we decide about the release date?" + - First recall: `recall_memory(query="release date decision")` + - Then answer based on the team memories + +- User: "Where do we document onboarding?" + - Call: `recall_memory(query="onboarding documentation")` + - Then answer using the recalled team context + +- User: "What have we agreed to remember about our deployment process?" + - Call: `recall_memory(query="deployment process", top_k=10)` + - Then summarize the relevant team memories + +""" + +# Examples shared by both private and shared prompts (knowledge base, docs, podcast, links, images, etc.) +_TOOLS_INSTRUCTIONS_EXAMPLES_COMMON = """ - User: "What time is the team meeting today?" - Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.) - DO NOT limit to just calendar - the info might be in notes! @@ -209,23 +311,6 @@ You have access to the following tools: - User: "What's in my Obsidian vault about project ideas?" - Call: `search_knowledge_base(query="project ideas", connectors_to_search=["OBSIDIAN_CONNECTOR"])` -- User: "Remember that I prefer TypeScript over JavaScript" - - Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")` - -- User: "I'm a data scientist working on ML pipelines" - - Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")` - -- User: "Always give me code examples in Python" - - Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")` - -- User: "What programming language should I use for this project?" - - First recall: `recall_memory(query="programming language preferences")` - - Then provide a personalized recommendation based on their preferences - -- User: "What do you know about me?" - - Call: `recall_memory(top_k=10)` - - Then summarize the stored memories - - User: "Give me a podcast about AI trends based on what we discussed" - First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")` @@ -315,6 +400,31 @@ You have access to the following tools: """ +# Reassemble so existing callers see no change (same full prompt) +SURFSENSE_TOOLS_INSTRUCTIONS = ( + _TOOLS_INSTRUCTIONS_COMMON + + _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE + + _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON +) + + +def _get_tools_instructions(thread_visibility: ChatVisibility | None = None) -> str: + """Build tools instructions based on thread visibility (private vs shared). + + For private chats: use user-focused memory wording and examples. + For shared chats: use team memory wording and examples. + """ + visibility = thread_visibility or ChatVisibility.PRIVATE + memory_block = ( + _TOOLS_INSTRUCTIONS_MEMORY_SHARED + if visibility == ChatVisibility.SEARCH_SPACE + else _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE + ) + return ( + _TOOLS_INSTRUCTIONS_COMMON + memory_block + _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON + ) + + SURFSENSE_CITATION_INSTRUCTIONS = """ CRITICAL CITATION REQUIREMENTS: @@ -413,6 +523,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format def build_surfsense_system_prompt( today: datetime | None = None, + thread_visibility: ChatVisibility | None = None, ) -> str: """ Build the SurfSense system prompt with default settings. @@ -424,17 +535,17 @@ def build_surfsense_system_prompt( Args: today: Optional datetime for today's date (defaults to current UTC date) + thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. Returns: Complete system prompt string """ - resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat() - return ( - SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today) - + SURFSENSE_TOOLS_INSTRUCTIONS - + SURFSENSE_CITATION_INSTRUCTIONS - ) + visibility = thread_visibility or ChatVisibility.PRIVATE + system_instructions = _get_system_instructions(visibility, today) + tools_instructions = _get_tools_instructions(visibility) + citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS + return system_instructions + tools_instructions + citation_instructions def build_configurable_system_prompt( @@ -442,6 +553,7 @@ def build_configurable_system_prompt( use_default_system_instructions: bool = True, citations_enabled: bool = True, today: datetime | None = None, + thread_visibility: ChatVisibility | None = None, ) -> str: """ Build a configurable SurfSense system prompt based on NewLLMConfig settings. @@ -460,6 +572,7 @@ def build_configurable_system_prompt( citations_enabled: Whether to include citation instructions (True) or anti-citation instructions (False). today: Optional datetime for today's date (defaults to current UTC date) + thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. Returns: Complete system prompt string @@ -473,16 +586,14 @@ def build_configurable_system_prompt( resolved_today=resolved_today ) elif use_default_system_instructions: - # Use default instructions - system_instructions = SURFSENSE_SYSTEM_INSTRUCTIONS.format( - resolved_today=resolved_today - ) + visibility = thread_visibility or ChatVisibility.PRIVATE + system_instructions = _get_system_instructions(visibility, today) else: # No system instructions (edge case) system_instructions = "" - # Tools instructions are always included - tools_instructions = SURFSENSE_TOOLS_INSTRUCTIONS + # Tools instructions: conditional on thread_visibility (private vs shared memory wording) + tools_instructions = _get_tools_instructions(thread_visibility) # Citation instructions based on toggle citation_instructions = ( diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 2cf43c973..30201e8df 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -51,8 +51,14 @@ from .mcp_tool import load_mcp_tools from .podcast import create_generate_podcast_tool from .scrape_webpage import create_scrape_webpage_tool from .search_surfsense_docs import create_search_surfsense_docs_tool +from .shared_memory import ( + create_recall_shared_memory_tool, + create_save_shared_memory_tool, +) from .user_memory import create_recall_memory_tool, create_save_memory_tool +from app.db import ChatVisibility + # ============================================================================= # Tool Definition # ============================================================================= @@ -156,29 +162,42 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ requires=["db_session"], ), # ========================================================================= - # USER MEMORY TOOLS - Claude-like memory feature + # USER MEMORY TOOLS - private or team store by thread_visibility # ========================================================================= - # Save memory tool - stores facts/preferences about the user ToolDefinition( name="save_memory", - description="Save facts, preferences, or context about the user for personalized responses", - factory=lambda deps: create_save_memory_tool( - user_id=deps["user_id"], - search_space_id=deps["search_space_id"], - db_session=deps["db_session"], + description="Save facts, preferences, or context for personalized or team responses", + factory=lambda deps: ( + create_save_shared_memory_tool( + search_space_id=deps["search_space_id"], + created_by_id=deps["user_id"], + db_session=deps["db_session"], + ) + if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE + else create_save_memory_tool( + user_id=deps["user_id"], + search_space_id=deps["search_space_id"], + db_session=deps["db_session"], + ) ), - requires=["user_id", "search_space_id", "db_session"], + requires=["user_id", "search_space_id", "db_session", "thread_visibility"], ), - # Recall memory tool - retrieves relevant user memories ToolDefinition( name="recall_memory", - description="Recall user memories for personalized and contextual responses", - factory=lambda deps: create_recall_memory_tool( - user_id=deps["user_id"], - search_space_id=deps["search_space_id"], - db_session=deps["db_session"], + description="Recall relevant memories (personal or team) for context", + factory=lambda deps: ( + create_recall_shared_memory_tool( + search_space_id=deps["search_space_id"], + db_session=deps["db_session"], + ) + if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE + else create_recall_memory_tool( + user_id=deps["user_id"], + search_space_id=deps["search_space_id"], + db_session=deps["db_session"], + ) ), - requires=["user_id", "search_space_id", "db_session"], + requires=["user_id", "search_space_id", "db_session", "thread_visibility"], ), # ========================================================================= # ADD YOUR CUSTOM TOOLS BELOW diff --git a/surfsense_backend/app/agents/new_chat/tools/shared_memory.py b/surfsense_backend/app/agents/new_chat/tools/shared_memory.py new file mode 100644 index 000000000..aa4a738ce --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/shared_memory.py @@ -0,0 +1,278 @@ +"""Shared (team) memory backend for search-space-scoped AI context.""" + +import logging +from typing import Any +from uuid import UUID + +from langchain_core.tools import tool +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import MemoryCategory, SharedMemory, User + +logger = logging.getLogger(__name__) + +DEFAULT_RECALL_TOP_K = 5 +MAX_MEMORIES_PER_SEARCH_SPACE = 250 + + +async def get_shared_memory_count( + db_session: AsyncSession, + search_space_id: int, +) -> int: + result = await db_session.execute( + select(SharedMemory).where(SharedMemory.search_space_id == search_space_id) + ) + return len(result.scalars().all()) + + +async def delete_oldest_shared_memory( + db_session: AsyncSession, + search_space_id: int, +) -> None: + result = await db_session.execute( + select(SharedMemory) + .where(SharedMemory.search_space_id == search_space_id) + .order_by(SharedMemory.updated_at.asc()) + .limit(1) + ) + oldest = result.scalars().first() + if oldest: + await db_session.delete(oldest) + await db_session.commit() + + +def _to_uuid(value: str | UUID) -> UUID: + if isinstance(value, UUID): + return value + return UUID(value) + + +async def save_shared_memory( + db_session: AsyncSession, + search_space_id: int, + created_by_id: str | UUID, + content: str, + category: str = "fact", +) -> dict[str, Any]: + category = category.lower() if category else "fact" + valid = ["preference", "fact", "instruction", "context"] + if category not in valid: + category = "fact" + try: + count = await get_shared_memory_count(db_session, search_space_id) + if count >= MAX_MEMORIES_PER_SEARCH_SPACE: + await delete_oldest_shared_memory(db_session, search_space_id) + embedding = config.embedding_model_instance.embed(content) + row = SharedMemory( + search_space_id=search_space_id, + created_by_id=_to_uuid(created_by_id), + memory_text=content, + category=MemoryCategory(category), + embedding=embedding, + ) + db_session.add(row) + await db_session.commit() + await db_session.refresh(row) + return { + "status": "saved", + "memory_id": row.id, + "memory_text": content, + "category": category, + "message": f"I'll remember: {content}", + } + except Exception as e: + logger.exception("Failed to save shared memory: %s", e) + await db_session.rollback() + return { + "status": "error", + "error": str(e), + "message": "Failed to save memory. Please try again.", + } + + +async def recall_shared_memory( + db_session: AsyncSession, + search_space_id: int, + query: str | None = None, + category: str | None = None, + top_k: int = DEFAULT_RECALL_TOP_K, +) -> dict[str, Any]: + top_k = min(max(top_k, 1), 20) + try: + valid_categories = ["preference", "fact", "instruction", "context"] + stmt = select(SharedMemory).where( + SharedMemory.search_space_id == search_space_id + ) + if category and category in valid_categories: + stmt = stmt.where(SharedMemory.category == MemoryCategory(category)) + if query: + query_embedding = config.embedding_model_instance.embed(query) + stmt = stmt.order_by( + SharedMemory.embedding.op("<=>")(query_embedding) + ).limit(top_k) + else: + stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k) + result = await db_session.execute(stmt) + rows = result.scalars().all() + memory_list = [ + { + "id": m.id, + "memory_text": m.memory_text, + "category": m.category.value if m.category else "unknown", + "updated_at": m.updated_at.isoformat() if m.updated_at else None, + "created_by_id": str(m.created_by_id) if m.created_by_id else None, + } + for m in rows + ] + created_by_ids = list({m["created_by_id"] for m in memory_list if m["created_by_id"]}) + created_by_map: dict[str, str] = {} + if created_by_ids: + uuids = [UUID(uid) for uid in created_by_ids] + users_result = await db_session.execute( + select(User).where(User.id.in_(uuids)) + ) + for u in users_result.scalars().all(): + created_by_map[str(u.id)] = u.display_name or "A team member" + formatted_context = format_shared_memories_for_context( + memory_list, created_by_map + ) + return { + "status": "success", + "count": len(memory_list), + "memories": memory_list, + "formatted_context": formatted_context, + } + except Exception as e: + logger.exception("Failed to recall shared memory: %s", e) + await db_session.rollback() + return { + "status": "error", + "error": str(e), + "memories": [], + "formatted_context": "Failed to recall memories.", + } + + +def format_shared_memories_for_context( + memories: list[dict[str, Any]], + created_by_map: dict[str, str] | None = None, +) -> str: + if not memories: + return "No relevant team memories found." + created_by_map = created_by_map or {} + parts = [""] + for memory in memories: + category = memory.get("category", "unknown") + text = memory.get("memory_text", "") + updated = memory.get("updated_at", "") + created_by_id = memory.get("created_by_id") + added_by = ( + created_by_map.get(str(created_by_id), "A team member") + if created_by_id is not None + else "A team member" + ) + parts.append( + f" {text}" + ) + parts.append("") + return "\n".join(parts) + + +def create_save_shared_memory_tool( + search_space_id: int, + created_by_id: str | UUID, + db_session: AsyncSession, +): + """ + Factory function to create the save_memory tool for shared (team) chats. + + Args: + search_space_id: The search space ID + created_by_id: The user ID of the person adding the memory + db_session: Database session for executing queries + + Returns: + A configured tool function for saving team memories + """ + + @tool + async def save_memory( + content: str, + category: str = "fact", + ) -> dict[str, Any]: + """ + Save a fact, preference, or context to the team's shared memory for future reference. + + Use this tool when: + - User or a team member says "remember this", "keep this in mind", or similar in this shared chat + - The team agrees on something to remember (e.g., decisions, conventions, where things live) + - Someone shares a preference or fact that should be visible to the whole team + + The saved information will be available in future shared conversations in this space. + + Args: + content: The fact/preference/context to remember. + Phrase it clearly, e.g., "API keys are stored in Vault", + "The team prefers weekly demos on Fridays" + category: Type of memory. One of: + - "preference": Team or workspace preferences + - "fact": Facts the team agreed on (e.g., processes, locations) + - "instruction": Standing instructions for the team + - "context": Current context (e.g., ongoing projects, goals) + + Returns: + A dictionary with the save status and memory details + """ + return await save_shared_memory( + db_session, search_space_id, created_by_id, content, category + ) + + return save_memory + + +def create_recall_shared_memory_tool( + search_space_id: int, + db_session: AsyncSession, +): + """ + Factory function to create the recall_memory tool for shared (team) chats. + + Args: + search_space_id: The search space ID + db_session: Database session for executing queries + + Returns: + A configured tool function for recalling team memories + """ + + @tool + async def recall_memory( + query: str | None = None, + category: str | None = None, + top_k: int = DEFAULT_RECALL_TOP_K, + ) -> dict[str, Any]: + """ + Recall relevant team memories for this space to provide contextual responses. + + Use this tool when: + - You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?") + - Someone asks about something the team agreed to remember + - Team preferences or conventions would improve the response + + Args: + query: Optional search query to find specific memories. + If not provided, returns the most recent memories. + category: Optional category filter. One of: + "preference", "fact", "instruction", "context" + top_k: Number of memories to retrieve (default: 5, max: 20) + + Returns: + A dictionary containing relevant memories and formatted context + """ + return await recall_shared_memory( + db_session, search_space_id, query, category, top_k + ) + + return recall_memory diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 0cab2820b..1de3d7163 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -801,9 +801,8 @@ class MemoryCategory(str, Enum): class UserMemory(BaseModel, TimestampMixin): """ - Stores facts, preferences, and context about users for personalized AI responses. - Similar to Claude's memory feature - enables the AI to remember user information - across conversations. + Private memory: facts, preferences, context per user per search space. + Used only for private chats (not shared/team chats). """ __tablename__ = "user_memories" @@ -847,6 +846,40 @@ class UserMemory(BaseModel, TimestampMixin): search_space = relationship("SearchSpace", back_populates="user_memories") +class SharedMemory(BaseModel, TimestampMixin): + __tablename__ = "shared_memories" + + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + created_by_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + memory_text = Column(Text, nullable=False) + category = Column( + SQLAlchemyEnum(MemoryCategory), + nullable=False, + default=MemoryCategory.fact, + ) + embedding = Column(Vector(config.embedding_model_instance.dimension)) + updated_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + index=True, + ) + + search_space = relationship("SearchSpace", back_populates="shared_memories") + created_by = relationship("User") + + class Document(BaseModel, TimestampMixin): __tablename__ = "documents" @@ -1209,6 +1242,12 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="UserMemory.updated_at.desc()", cascade="all, delete-orphan", ) + shared_memories = relationship( + "SharedMemory", + back_populates="search_space", + order_by="SharedMemory.updated_at.desc()", + cascade="all, delete-orphan", + ) class SearchSourceConnector(BaseModel, TimestampMixin): @@ -1258,7 +1297,7 @@ class NewLLMConfig(BaseModel, TimestampMixin): - Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) - Citation toggle (enable/disable citation instructions) - Note: SURFSENSE_TOOLS_INSTRUCTIONS is always used and not configurable. + Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory). """ __tablename__ = "new_llm_configs" diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 06e929997..6d5268a8d 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1045,12 +1045,14 @@ async def handle_new_chat( search_space_id=request.search_space_id, chat_id=request.chat_id, session=session, - user_id=str(user.id), # Pass user ID for memory tools and session state + user_id=str(user.id), llm_config_id=llm_config_id, attachments=request.attachments, mentioned_document_ids=request.mentioned_document_ids, mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, needs_history_bootstrap=thread.needs_history_bootstrap, + thread_visibility=thread.visibility, + current_user_display_name=user.display_name or "A team member", ), media_type="text/event-stream", headers={ @@ -1281,6 +1283,8 @@ async def regenerate_response( mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, checkpoint_id=target_checkpoint_id, needs_history_bootstrap=thread.needs_history_bootstrap, + thread_visibility=thread.visibility, + current_user_display_name=user.display_name or "A team member", ): yield chunk # If we get here, streaming completed successfully diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 685f77e39..af5a2b0df 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -26,7 +26,7 @@ from app.agents.new_chat.llm_config import ( load_agent_config, load_llm_config_from_yaml, ) -from app.db import Document, SurfsenseDocsDocument +from app.db import ChatVisibility, Document, SurfsenseDocsDocument from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.schemas.new_chat import ChatAttachment from app.services.chat_session_state_service import ( @@ -208,6 +208,8 @@ async def stream_new_chat( mentioned_surfsense_doc_ids: list[int] | None = None, checkpoint_id: str | None = None, needs_history_bootstrap: bool = False, + thread_visibility: ChatVisibility | None = None, + current_user_display_name: str | None = None, ) -> AsyncGenerator[str, None]: """ Stream chat responses from the new SurfSense deep agent. @@ -295,17 +297,18 @@ async def stream_new_chat( # Get the PostgreSQL checkpointer for persistent conversation memory checkpointer = await get_checkpointer() - # Create the deep agent with checkpointer and configurable prompts + visibility = thread_visibility or ChatVisibility.PRIVATE agent = await create_surfsense_deep_agent( llm=llm, search_space_id=search_space_id, db_session=session, connector_service=connector_service, checkpointer=checkpointer, - user_id=user_id, # Pass user ID for memory tools - thread_id=chat_id, # Pass chat ID for podcast association - agent_config=agent_config, # Pass prompt configuration - firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, ) # Build input with message history @@ -313,7 +316,9 @@ async def stream_new_chat( # Bootstrap history for cloned chats (no LangGraph checkpoint exists yet) if needs_history_bootstrap: - langchain_messages = await bootstrap_history_from_db(session, chat_id) + langchain_messages = await bootstrap_history_from_db( + session, chat_id, thread_visibility=visibility + ) # Clear the flag so we don't bootstrap again on next message from app.db import NewChatThread @@ -376,6 +381,9 @@ async def stream_new_chat( context = "\n\n".join(context_parts) final_query = f"{context}\n\n{user_query}" + if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name: + final_query = f"**[{current_user_display_name}]:** {final_query}" + # if messages: # # Convert frontend messages to LangChain format # for msg in messages: diff --git a/surfsense_backend/app/utils/content_utils.py b/surfsense_backend/app/utils/content_utils.py index d2342b79e..9a417075d 100644 --- a/surfsense_backend/app/utils/content_utils.py +++ b/surfsense_backend/app/utils/content_utils.py @@ -12,6 +12,7 @@ These utilities help extract and transform content for different use cases. from langchain_core.messages import AIMessage, HumanMessage from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload def extract_text_content(content: str | dict | list) -> str: @@ -38,6 +39,7 @@ def extract_text_content(content: str | dict | list) -> str: async def bootstrap_history_from_db( session: AsyncSession, thread_id: int, + thread_visibility: "ChatVisibility | None" = None, ) -> list[HumanMessage | AIMessage]: """ Load message history from database and convert to LangChain format. @@ -45,20 +47,28 @@ async def bootstrap_history_from_db( Used for cloned chats where the LangGraph checkpointer has no state, but we have messages in the database that should be used as context. + When thread_visibility is SEARCH_SPACE, user messages are prefixed with + the author's display name so the LLM sees who said what. + Args: session: Database session thread_id: The chat thread ID + thread_visibility: When SEARCH_SPACE, user messages get author prefix Returns: List of LangChain messages (HumanMessage/AIMessage) """ - from app.db import NewChatMessage + from app.db import ChatVisibility, NewChatMessage - result = await session.execute( + is_shared = thread_visibility == ChatVisibility.SEARCH_SPACE + stmt = ( select(NewChatMessage) .filter(NewChatMessage.thread_id == thread_id) .order_by(NewChatMessage.created_at) ) + if is_shared: + stmt = stmt.options(selectinload(NewChatMessage.author)) + result = await session.execute(stmt) db_messages = result.scalars().all() langchain_messages: list[HumanMessage | AIMessage] = [] @@ -68,6 +78,11 @@ async def bootstrap_history_from_db( if not text_content: continue if msg.role == "user": + if is_shared: + author_name = ( + (msg.author.display_name if msg.author else None) or "A team member" + ) + text_content = f"**[{author_name}]:** {text_content}" langchain_messages.append(HumanMessage(content=text_content)) elif msg.role == "assistant": langchain_messages.append(AIMessage(content=text_content)) diff --git a/surfsense_web/atoms/user/user-query.atoms.ts b/surfsense_web/atoms/user/user-query.atoms.ts index cd9ec6c87..6b436d7a0 100644 --- a/surfsense_web/atoms/user/user-query.atoms.ts +++ b/surfsense_web/atoms/user/user-query.atoms.ts @@ -1,16 +1,14 @@ import { atomWithQuery } from "jotai-tanstack-query"; import { userApiService } from "@/lib/apis/user-api.service"; -import { getBearerToken } from "@/lib/auth-utils"; +import { getBearerToken, isPublicRoute } from "@/lib/auth-utils"; import { cacheKeys } from "@/lib/query-client/cache-keys"; export const currentUserAtom = atomWithQuery(() => { + const pathname = typeof window !== "undefined" ? window.location.pathname : null; return { queryKey: cacheKeys.user.current(), staleTime: 5 * 60 * 1000, // 5 minutes - // Only fetch user data when a bearer token is present - enabled: !!getBearerToken(), - queryFn: async () => { - return userApiService.getMe(); - }, + enabled: !!getBearerToken() && pathname !== null && !isPublicRoute(pathname), + queryFn: async () => userApiService.getMe(), }; }); diff --git a/surfsense_web/lib/auth-utils.ts b/surfsense_web/lib/auth-utils.ts index 8c067a4b7..c2a0d58a5 100644 --- a/surfsense_web/lib/auth-utils.ts +++ b/surfsense_web/lib/auth-utils.ts @@ -10,28 +10,53 @@ const REFRESH_TOKEN_KEY = "surfsense_refresh_token"; let isRefreshing = false; let refreshPromise: Promise | null = null; +/** Path prefixes for routes that do not require auth (no current-user fetch, no redirect on 401) */ +const PUBLIC_ROUTE_PREFIXES = [ + "/login", + "/register", + "/auth", + "/docs", + "/public", + "/invite", + "/contact", + "/pricing", + "/privacy", + "/terms", + "/changelog", +]; + /** - * Saves the current path and redirects to login page - * Call this when a 401 response is received + * Returns true if the pathname is a public route where we should not run auth checks + * or redirect to login on 401. + */ +export function isPublicRoute(pathname: string): boolean { + if (pathname === "/" || pathname === "") return true; + return PUBLIC_ROUTE_PREFIXES.some((prefix) => pathname.startsWith(prefix)); +} + +/** + * Clears tokens and optionally redirects to login. + * Call this when a 401 response is received. + * Only redirects when the current route is protected; on public routes we just clear tokens. */ export function handleUnauthorized(): void { if (typeof window === "undefined") return; - // Save the current path (including search params and hash) for redirect after login - const currentPath = window.location.pathname + window.location.search + window.location.hash; + const pathname = window.location.pathname; - // Don't save auth-related paths - const excludedPaths = ["/auth", "/auth/callback", "/"]; - if (!excludedPaths.includes(window.location.pathname)) { - localStorage.setItem(REDIRECT_PATH_KEY, currentPath); - } - - // Clear both tokens + // Always clear tokens localStorage.removeItem(BEARER_TOKEN_KEY); localStorage.removeItem(REFRESH_TOKEN_KEY); - // Redirect to home page (which has login options) - window.location.href = "/login"; + // Only redirect on protected routes; stay on public pages (e.g. /docs) + if (!isPublicRoute(pathname)) { + const currentPath = pathname + window.location.search + window.location.hash; + const excludedPaths = ["/auth", "/auth/callback", "/"]; + if (!excludedPaths.includes(pathname)) { + localStorage.setItem(REDIRECT_PATH_KEY, currentPath); + } + window.location.href = "/login"; + } } /** @@ -179,7 +204,6 @@ export function getAuthHeaders(additionalHeaders?: Record): Reco /** * Attempts to refresh the access token using the stored refresh token. * Returns the new access token if successful, null otherwise. - * Exported for use by API services. */ export async function refreshAccessToken(): Promise { // If already refreshing, wait for that request to complete