Merge pull request #799 from CREDO23/sur-152-impr-split-private-and-shared-memory

[Feat] Split private vs shared chat memory and add team prompt/attribution
This commit is contained in:
Rohan Verma 2026-02-09 15:03:54 -08:00 committed by GitHub
commit 3f0c9c35f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 664 additions and 86 deletions

View file

@ -0,0 +1,77 @@
"""Add shared_memories table (SUR-152)."""
from collections.abc import Sequence
from alembic import op
from app.config import config
revision: str = "96"
down_revision: str | None = "95"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
EMBEDDING_DIM = config.embedding_model_instance.dimension
def upgrade() -> None:
op.execute(
f"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'shared_memories'
) THEN
CREATE TABLE shared_memories (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
created_by_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
memory_text TEXT NOT NULL,
category memorycategory NOT NULL DEFAULT 'fact',
embedding vector({EMBEDDING_DIM})
);
END IF;
END$$;
"""
)
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_search_space_id'
) THEN
CREATE INDEX ix_shared_memories_search_space_id ON shared_memories(search_space_id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_updated_at'
) THEN
CREATE INDEX ix_shared_memories_updated_at ON shared_memories(updated_at);
END IF;
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_created_by_id'
) THEN
CREATE INDEX ix_shared_memories_created_by_id ON shared_memories(created_by_id);
END IF;
END$$;
"""
)
op.execute(
"""
CREATE INDEX IF NOT EXISTS shared_memories_vector_index
ON shared_memories USING hnsw (embedding public.vector_cosine_ops);
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS shared_memories_vector_index;")
op.execute("DROP INDEX IF EXISTS ix_shared_memories_created_by_id;")
op.execute("DROP INDEX IF EXISTS ix_shared_memories_updated_at;")
op.execute("DROP INDEX IF EXISTS ix_shared_memories_search_space_id;")
op.execute("DROP TABLE IF EXISTS shared_memories CASCADE;")

View file

@ -22,6 +22,7 @@ from app.agents.new_chat.system_prompt import (
build_surfsense_system_prompt, build_surfsense_system_prompt,
) )
from app.agents.new_chat.tools.registry import build_tools_async from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
# ============================================================================= # =============================================================================
@ -126,6 +127,7 @@ async def create_surfsense_deep_agent(
disabled_tools: list[str] | None = None, disabled_tools: list[str] | None = None,
additional_tools: Sequence[BaseTool] | None = None, additional_tools: Sequence[BaseTool] | None = None,
firecrawl_api_key: str | None = None, firecrawl_api_key: str | None = None,
thread_visibility: ChatVisibility | None = None,
): ):
""" """
Create a SurfSense deep agent with configurable tools and prompts. Create a SurfSense deep agent with configurable tools and prompts.
@ -228,14 +230,15 @@ async def create_surfsense_deep_agent(
logging.warning(f"Failed to discover available connectors/document types: {e}") logging.warning(f"Failed to discover available connectors/document types: {e}")
# Build dependencies dict for the tools registry # Build dependencies dict for the tools registry
visibility = thread_visibility or ChatVisibility.PRIVATE
dependencies = { dependencies = {
"search_space_id": search_space_id, "search_space_id": search_space_id,
"db_session": db_session, "db_session": db_session,
"connector_service": connector_service, "connector_service": connector_service,
"firecrawl_api_key": firecrawl_api_key, "firecrawl_api_key": firecrawl_api_key,
"user_id": user_id, # Required for memory tools "user_id": user_id,
"thread_id": thread_id, # For podcast tool "thread_id": thread_id,
# Dynamic connector/document type discovery for knowledge base tool "thread_visibility": visibility,
"available_connectors": available_connectors, "available_connectors": available_connectors,
"available_document_types": available_document_types, "available_document_types": available_document_types,
} }
@ -255,10 +258,12 @@ async def create_surfsense_deep_agent(
custom_system_instructions=agent_config.system_instructions, custom_system_instructions=agent_config.system_instructions,
use_default_system_instructions=agent_config.use_default_system_instructions, use_default_system_instructions=agent_config.use_default_system_instructions,
citations_enabled=agent_config.citations_enabled, citations_enabled=agent_config.citations_enabled,
thread_visibility=thread_visibility,
) )
else: else:
# Use default prompt (with citations enabled) system_prompt = build_surfsense_system_prompt(
system_prompt = build_surfsense_system_prompt() thread_visibility=thread_visibility,
)
# Create the deep agent with system prompt and checkpointer # Create the deep agent with system prompt and checkpointer
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent # Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent

View file

@ -12,6 +12,8 @@ The prompt is composed of three parts:
from datetime import UTC, datetime from datetime import UTC, datetime
from app.db import ChatVisibility
# Default system instructions - can be overridden via NewLLMConfig.system_instructions # Default system instructions - can be overridden via NewLLMConfig.system_instructions
SURFSENSE_SYSTEM_INSTRUCTIONS = """ SURFSENSE_SYSTEM_INSTRUCTIONS = """
<system_instruction> <system_instruction>
@ -22,7 +24,34 @@ Today's date (UTC): {resolved_today}
</system_instruction> </system_instruction>
""" """
SURFSENSE_TOOLS_INSTRUCTIONS = """ # Default system instructions for shared (team) threads: team context + message format for attribution
_SYSTEM_INSTRUCTIONS_SHARED = """
<system_instruction>
You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base.
In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers.
Today's date (UTC): {resolved_today}
</system_instruction>
"""
def _get_system_instructions(
thread_visibility: ChatVisibility | None = None, today: datetime | None = None
) -> str:
"""Build system instructions based on thread visibility (private vs shared)."""
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
visibility = thread_visibility or ChatVisibility.PRIVATE
if visibility == ChatVisibility.SEARCH_SPACE:
return _SYSTEM_INSTRUCTIONS_SHARED.format(resolved_today=resolved_today)
else:
return SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
# Tools 0-6 (common to both private and shared prompts)
_TOOLS_INSTRUCTIONS_COMMON = """
<tools> <tools>
You have access to the following tools: You have access to the following tools:
@ -138,7 +167,11 @@ You have access to the following tools:
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
* Don't show every image - just the most relevant 1-3 images that enhance understanding. * Don't show every image - just the most relevant 1-3 images that enhance understanding.
7. save_memory: Save facts, preferences, or context about the user for personalized responses. """
# Private (user) memory: tools 7-8 + memory-specific examples
_TOOLS_INSTRUCTIONS_MEMORY_PRIVATE = """
7. save_memory: Save facts, preferences, or context for personalized responses.
- Use this when the user explicitly or implicitly shares information worth remembering. - Use this when the user explicitly or implicitly shares information worth remembering.
- Trigger scenarios: - Trigger scenarios:
* User says "remember this", "keep this in mind", "note that", or similar * User says "remember this", "keep this in mind", "note that", or similar
@ -178,6 +211,75 @@ You have access to the following tools:
stating "Based on your memory..." - integrate the context seamlessly. stating "Based on your memory..." - integrate the context seamlessly.
</tools> </tools>
<tool_call_examples> <tool_call_examples>
- User: "Remember that I prefer TypeScript over JavaScript"
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
- User: "I'm a data scientist working on ML pipelines"
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
- User: "Always give me code examples in Python"
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
- User: "What programming language should I use for this project?"
- First recall: `recall_memory(query="programming language preferences")`
- Then provide a personalized recommendation based on their preferences
- User: "What do you know about me?"
- Call: `recall_memory(top_k=10)`
- Then summarize the stored memories
"""
# Shared (team) memory: tools 7-8 + team memory examples
_TOOLS_INSTRUCTIONS_MEMORY_SHARED = """
7. save_memory: Save a fact, preference, or context to the team's shared memory for future reference.
- Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat.
- Use when the team agrees on something to remember (e.g., decisions, conventions).
- Someone shares a preference or fact that should be visible to the whole team.
- The saved information will be available in future shared conversations in this space.
- Args:
- content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays"
- category: Type of memory. One of:
* "preference": Team or workspace preferences
* "fact": Facts the team agreed on (e.g., processes, locations)
* "instruction": Standing instructions for the team
* "context": Current context (e.g., ongoing projects, goals)
- Returns: Confirmation of saved memory; returned context may include who added it (added_by).
- IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space.
8. recall_memory: Recall relevant team memories for this space to provide contextual responses.
- Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?").
- Use when someone asks about something the team agreed to remember.
- Use when team preferences or conventions would improve the response.
- Args:
- query: Optional search query to find specific memories. If not provided, returns the most recent memories.
- category: Optional filter by category ("preference", "fact", "instruction", "context")
- top_k: Number of memories to retrieve (default: 5, max: 20)
- Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...".
</tools>
<tool_call_examples>
- User: "Remember that API keys are stored in Vault"
- Call: `save_memory(content="API keys are stored in Vault", category="fact")`
- User: "Let's remember that the team prefers weekly demos on Fridays"
- Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")`
- User: "What did we decide about the release date?"
- First recall: `recall_memory(query="release date decision")`
- Then answer based on the team memories
- User: "Where do we document onboarding?"
- Call: `recall_memory(query="onboarding documentation")`
- Then answer using the recalled team context
- User: "What have we agreed to remember about our deployment process?"
- Call: `recall_memory(query="deployment process", top_k=10)`
- Then summarize the relevant team memories
"""
# Examples shared by both private and shared prompts (knowledge base, docs, podcast, links, images, etc.)
_TOOLS_INSTRUCTIONS_EXAMPLES_COMMON = """
- User: "What time is the team meeting today?" - User: "What time is the team meeting today?"
- Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.) - Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.)
- DO NOT limit to just calendar - the info might be in notes! - DO NOT limit to just calendar - the info might be in notes!
@ -209,23 +311,6 @@ You have access to the following tools:
- User: "What's in my Obsidian vault about project ideas?" - User: "What's in my Obsidian vault about project ideas?"
- Call: `search_knowledge_base(query="project ideas", connectors_to_search=["OBSIDIAN_CONNECTOR"])` - Call: `search_knowledge_base(query="project ideas", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
- User: "Remember that I prefer TypeScript over JavaScript"
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
- User: "I'm a data scientist working on ML pipelines"
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
- User: "Always give me code examples in Python"
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
- User: "What programming language should I use for this project?"
- First recall: `recall_memory(query="programming language preferences")`
- Then provide a personalized recommendation based on their preferences
- User: "What do you know about me?"
- Call: `recall_memory(top_k=10)`
- Then summarize the stored memories
- User: "Give me a podcast about AI trends based on what we discussed" - User: "Give me a podcast about AI trends based on what we discussed"
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")` - First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
@ -315,6 +400,31 @@ You have access to the following tools:
</tool_call_examples> </tool_call_examples>
""" """
# Reassemble so existing callers see no change (same full prompt)
SURFSENSE_TOOLS_INSTRUCTIONS = (
_TOOLS_INSTRUCTIONS_COMMON
+ _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
+ _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
)
def _get_tools_instructions(thread_visibility: ChatVisibility | None = None) -> str:
"""Build tools instructions based on thread visibility (private vs shared).
For private chats: use user-focused memory wording and examples.
For shared chats: use team memory wording and examples.
"""
visibility = thread_visibility or ChatVisibility.PRIVATE
memory_block = (
_TOOLS_INSTRUCTIONS_MEMORY_SHARED
if visibility == ChatVisibility.SEARCH_SPACE
else _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
)
return (
_TOOLS_INSTRUCTIONS_COMMON + memory_block + _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
)
SURFSENSE_CITATION_INSTRUCTIONS = """ SURFSENSE_CITATION_INSTRUCTIONS = """
<citation_instructions> <citation_instructions>
CRITICAL CITATION REQUIREMENTS: CRITICAL CITATION REQUIREMENTS:
@ -413,6 +523,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
def build_surfsense_system_prompt( def build_surfsense_system_prompt(
today: datetime | None = None, today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
) -> str: ) -> str:
""" """
Build the SurfSense system prompt with default settings. Build the SurfSense system prompt with default settings.
@ -424,17 +535,17 @@ def build_surfsense_system_prompt(
Args: Args:
today: Optional datetime for today's date (defaults to current UTC date) today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
Returns: Returns:
Complete system prompt string Complete system prompt string
""" """
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
return ( visibility = thread_visibility or ChatVisibility.PRIVATE
SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today) system_instructions = _get_system_instructions(visibility, today)
+ SURFSENSE_TOOLS_INSTRUCTIONS tools_instructions = _get_tools_instructions(visibility)
+ SURFSENSE_CITATION_INSTRUCTIONS citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
) return system_instructions + tools_instructions + citation_instructions
def build_configurable_system_prompt( def build_configurable_system_prompt(
@ -442,6 +553,7 @@ def build_configurable_system_prompt(
use_default_system_instructions: bool = True, use_default_system_instructions: bool = True,
citations_enabled: bool = True, citations_enabled: bool = True,
today: datetime | None = None, today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
) -> str: ) -> str:
""" """
Build a configurable SurfSense system prompt based on NewLLMConfig settings. Build a configurable SurfSense system prompt based on NewLLMConfig settings.
@ -460,6 +572,7 @@ def build_configurable_system_prompt(
citations_enabled: Whether to include citation instructions (True) or citations_enabled: Whether to include citation instructions (True) or
anti-citation instructions (False). anti-citation instructions (False).
today: Optional datetime for today's date (defaults to current UTC date) today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
Returns: Returns:
Complete system prompt string Complete system prompt string
@ -473,16 +586,14 @@ def build_configurable_system_prompt(
resolved_today=resolved_today resolved_today=resolved_today
) )
elif use_default_system_instructions: elif use_default_system_instructions:
# Use default instructions visibility = thread_visibility or ChatVisibility.PRIVATE
system_instructions = SURFSENSE_SYSTEM_INSTRUCTIONS.format( system_instructions = _get_system_instructions(visibility, today)
resolved_today=resolved_today
)
else: else:
# No system instructions (edge case) # No system instructions (edge case)
system_instructions = "" system_instructions = ""
# Tools instructions are always included # Tools instructions: conditional on thread_visibility (private vs shared memory wording)
tools_instructions = SURFSENSE_TOOLS_INSTRUCTIONS tools_instructions = _get_tools_instructions(thread_visibility)
# Citation instructions based on toggle # Citation instructions based on toggle
citation_instructions = ( citation_instructions = (

View file

@ -51,8 +51,14 @@ from .mcp_tool import load_mcp_tools
from .podcast import create_generate_podcast_tool from .podcast import create_generate_podcast_tool
from .scrape_webpage import create_scrape_webpage_tool from .scrape_webpage import create_scrape_webpage_tool
from .search_surfsense_docs import create_search_surfsense_docs_tool from .search_surfsense_docs import create_search_surfsense_docs_tool
from .shared_memory import (
create_recall_shared_memory_tool,
create_save_shared_memory_tool,
)
from .user_memory import create_recall_memory_tool, create_save_memory_tool from .user_memory import create_recall_memory_tool, create_save_memory_tool
from app.db import ChatVisibility
# ============================================================================= # =============================================================================
# Tool Definition # Tool Definition
# ============================================================================= # =============================================================================
@ -156,29 +162,42 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
requires=["db_session"], requires=["db_session"],
), ),
# ========================================================================= # =========================================================================
# USER MEMORY TOOLS - Claude-like memory feature # USER MEMORY TOOLS - private or team store by thread_visibility
# ========================================================================= # =========================================================================
# Save memory tool - stores facts/preferences about the user
ToolDefinition( ToolDefinition(
name="save_memory", name="save_memory",
description="Save facts, preferences, or context about the user for personalized responses", description="Save facts, preferences, or context for personalized or team responses",
factory=lambda deps: create_save_memory_tool( factory=lambda deps: (
user_id=deps["user_id"], create_save_shared_memory_tool(
search_space_id=deps["search_space_id"], search_space_id=deps["search_space_id"],
db_session=deps["db_session"], created_by_id=deps["user_id"],
db_session=deps["db_session"],
)
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
else create_save_memory_tool(
user_id=deps["user_id"],
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
)
), ),
requires=["user_id", "search_space_id", "db_session"], requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
), ),
# Recall memory tool - retrieves relevant user memories
ToolDefinition( ToolDefinition(
name="recall_memory", name="recall_memory",
description="Recall user memories for personalized and contextual responses", description="Recall relevant memories (personal or team) for context",
factory=lambda deps: create_recall_memory_tool( factory=lambda deps: (
user_id=deps["user_id"], create_recall_shared_memory_tool(
search_space_id=deps["search_space_id"], search_space_id=deps["search_space_id"],
db_session=deps["db_session"], db_session=deps["db_session"],
)
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
else create_recall_memory_tool(
user_id=deps["user_id"],
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
)
), ),
requires=["user_id", "search_space_id", "db_session"], requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
), ),
# ========================================================================= # =========================================================================
# ADD YOUR CUSTOM TOOLS BELOW # ADD YOUR CUSTOM TOOLS BELOW

View file

@ -0,0 +1,278 @@
"""Shared (team) memory backend for search-space-scoped AI context."""
import logging
from typing import Any
from uuid import UUID
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import MemoryCategory, SharedMemory, User
logger = logging.getLogger(__name__)
DEFAULT_RECALL_TOP_K = 5
MAX_MEMORIES_PER_SEARCH_SPACE = 250
async def get_shared_memory_count(
db_session: AsyncSession,
search_space_id: int,
) -> int:
result = await db_session.execute(
select(SharedMemory).where(SharedMemory.search_space_id == search_space_id)
)
return len(result.scalars().all())
async def delete_oldest_shared_memory(
db_session: AsyncSession,
search_space_id: int,
) -> None:
result = await db_session.execute(
select(SharedMemory)
.where(SharedMemory.search_space_id == search_space_id)
.order_by(SharedMemory.updated_at.asc())
.limit(1)
)
oldest = result.scalars().first()
if oldest:
await db_session.delete(oldest)
await db_session.commit()
def _to_uuid(value: str | UUID) -> UUID:
if isinstance(value, UUID):
return value
return UUID(value)
async def save_shared_memory(
db_session: AsyncSession,
search_space_id: int,
created_by_id: str | UUID,
content: str,
category: str = "fact",
) -> dict[str, Any]:
category = category.lower() if category else "fact"
valid = ["preference", "fact", "instruction", "context"]
if category not in valid:
category = "fact"
try:
count = await get_shared_memory_count(db_session, search_space_id)
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
await delete_oldest_shared_memory(db_session, search_space_id)
embedding = config.embedding_model_instance.embed(content)
row = SharedMemory(
search_space_id=search_space_id,
created_by_id=_to_uuid(created_by_id),
memory_text=content,
category=MemoryCategory(category),
embedding=embedding,
)
db_session.add(row)
await db_session.commit()
await db_session.refresh(row)
return {
"status": "saved",
"memory_id": row.id,
"memory_text": content,
"category": category,
"message": f"I'll remember: {content}",
}
except Exception as e:
logger.exception("Failed to save shared memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"error": str(e),
"message": "Failed to save memory. Please try again.",
}
async def recall_shared_memory(
db_session: AsyncSession,
search_space_id: int,
query: str | None = None,
category: str | None = None,
top_k: int = DEFAULT_RECALL_TOP_K,
) -> dict[str, Any]:
top_k = min(max(top_k, 1), 20)
try:
valid_categories = ["preference", "fact", "instruction", "context"]
stmt = select(SharedMemory).where(
SharedMemory.search_space_id == search_space_id
)
if category and category in valid_categories:
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
if query:
query_embedding = config.embedding_model_instance.embed(query)
stmt = stmt.order_by(
SharedMemory.embedding.op("<=>")(query_embedding)
).limit(top_k)
else:
stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k)
result = await db_session.execute(stmt)
rows = result.scalars().all()
memory_list = [
{
"id": m.id,
"memory_text": m.memory_text,
"category": m.category.value if m.category else "unknown",
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
"created_by_id": str(m.created_by_id) if m.created_by_id else None,
}
for m in rows
]
created_by_ids = list({m["created_by_id"] for m in memory_list if m["created_by_id"]})
created_by_map: dict[str, str] = {}
if created_by_ids:
uuids = [UUID(uid) for uid in created_by_ids]
users_result = await db_session.execute(
select(User).where(User.id.in_(uuids))
)
for u in users_result.scalars().all():
created_by_map[str(u.id)] = u.display_name or "A team member"
formatted_context = format_shared_memories_for_context(
memory_list, created_by_map
)
return {
"status": "success",
"count": len(memory_list),
"memories": memory_list,
"formatted_context": formatted_context,
}
except Exception as e:
logger.exception("Failed to recall shared memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"error": str(e),
"memories": [],
"formatted_context": "Failed to recall memories.",
}
def format_shared_memories_for_context(
memories: list[dict[str, Any]],
created_by_map: dict[str, str] | None = None,
) -> str:
if not memories:
return "No relevant team memories found."
created_by_map = created_by_map or {}
parts = ["<team_memories>"]
for memory in memories:
category = memory.get("category", "unknown")
text = memory.get("memory_text", "")
updated = memory.get("updated_at", "")
created_by_id = memory.get("created_by_id")
added_by = (
created_by_map.get(str(created_by_id), "A team member")
if created_by_id is not None
else "A team member"
)
parts.append(
f" <memory category='{category}' updated='{updated}' added_by='{added_by}'>{text}</memory>"
)
parts.append("</team_memories>")
return "\n".join(parts)
def create_save_shared_memory_tool(
search_space_id: int,
created_by_id: str | UUID,
db_session: AsyncSession,
):
"""
Factory function to create the save_memory tool for shared (team) chats.
Args:
search_space_id: The search space ID
created_by_id: The user ID of the person adding the memory
db_session: Database session for executing queries
Returns:
A configured tool function for saving team memories
"""
@tool
async def save_memory(
content: str,
category: str = "fact",
) -> dict[str, Any]:
"""
Save a fact, preference, or context to the team's shared memory for future reference.
Use this tool when:
- User or a team member says "remember this", "keep this in mind", or similar in this shared chat
- The team agrees on something to remember (e.g., decisions, conventions, where things live)
- Someone shares a preference or fact that should be visible to the whole team
The saved information will be available in future shared conversations in this space.
Args:
content: The fact/preference/context to remember.
Phrase it clearly, e.g., "API keys are stored in Vault",
"The team prefers weekly demos on Fridays"
category: Type of memory. One of:
- "preference": Team or workspace preferences
- "fact": Facts the team agreed on (e.g., processes, locations)
- "instruction": Standing instructions for the team
- "context": Current context (e.g., ongoing projects, goals)
Returns:
A dictionary with the save status and memory details
"""
return await save_shared_memory(
db_session, search_space_id, created_by_id, content, category
)
return save_memory
def create_recall_shared_memory_tool(
search_space_id: int,
db_session: AsyncSession,
):
"""
Factory function to create the recall_memory tool for shared (team) chats.
Args:
search_space_id: The search space ID
db_session: Database session for executing queries
Returns:
A configured tool function for recalling team memories
"""
@tool
async def recall_memory(
query: str | None = None,
category: str | None = None,
top_k: int = DEFAULT_RECALL_TOP_K,
) -> dict[str, Any]:
"""
Recall relevant team memories for this space to provide contextual responses.
Use this tool when:
- You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?")
- Someone asks about something the team agreed to remember
- Team preferences or conventions would improve the response
Args:
query: Optional search query to find specific memories.
If not provided, returns the most recent memories.
category: Optional category filter. One of:
"preference", "fact", "instruction", "context"
top_k: Number of memories to retrieve (default: 5, max: 20)
Returns:
A dictionary containing relevant memories and formatted context
"""
return await recall_shared_memory(
db_session, search_space_id, query, category, top_k
)
return recall_memory

View file

@ -801,9 +801,8 @@ class MemoryCategory(str, Enum):
class UserMemory(BaseModel, TimestampMixin): class UserMemory(BaseModel, TimestampMixin):
""" """
Stores facts, preferences, and context about users for personalized AI responses. Private memory: facts, preferences, context per user per search space.
Similar to Claude's memory feature - enables the AI to remember user information Used only for private chats (not shared/team chats).
across conversations.
""" """
__tablename__ = "user_memories" __tablename__ = "user_memories"
@ -847,6 +846,40 @@ class UserMemory(BaseModel, TimestampMixin):
search_space = relationship("SearchSpace", back_populates="user_memories") search_space = relationship("SearchSpace", back_populates="user_memories")
class SharedMemory(BaseModel, TimestampMixin):
__tablename__ = "shared_memories"
search_space_id = Column(
Integer,
ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
created_by_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
memory_text = Column(Text, nullable=False)
category = Column(
SQLAlchemyEnum(MemoryCategory),
nullable=False,
default=MemoryCategory.fact,
)
embedding = Column(Vector(config.embedding_model_instance.dimension))
updated_at = Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
index=True,
)
search_space = relationship("SearchSpace", back_populates="shared_memories")
created_by = relationship("User")
class Document(BaseModel, TimestampMixin): class Document(BaseModel, TimestampMixin):
__tablename__ = "documents" __tablename__ = "documents"
@ -1209,6 +1242,12 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="UserMemory.updated_at.desc()", order_by="UserMemory.updated_at.desc()",
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
shared_memories = relationship(
"SharedMemory",
back_populates="search_space",
order_by="SharedMemory.updated_at.desc()",
cascade="all, delete-orphan",
)
class SearchSourceConnector(BaseModel, TimestampMixin): class SearchSourceConnector(BaseModel, TimestampMixin):
@ -1258,7 +1297,7 @@ class NewLLMConfig(BaseModel, TimestampMixin):
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) - Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
- Citation toggle (enable/disable citation instructions) - Citation toggle (enable/disable citation instructions)
Note: SURFSENSE_TOOLS_INSTRUCTIONS is always used and not configurable. Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
""" """
__tablename__ = "new_llm_configs" __tablename__ = "new_llm_configs"

View file

@ -1045,12 +1045,14 @@ async def handle_new_chat(
search_space_id=request.search_space_id, search_space_id=request.search_space_id,
chat_id=request.chat_id, chat_id=request.chat_id,
session=session, session=session,
user_id=str(user.id), # Pass user ID for memory tools and session state user_id=str(user.id),
llm_config_id=llm_config_id, llm_config_id=llm_config_id,
attachments=request.attachments, attachments=request.attachments,
mentioned_document_ids=request.mentioned_document_ids, mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
needs_history_bootstrap=thread.needs_history_bootstrap, needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={
@ -1281,6 +1283,8 @@ async def regenerate_response(
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
checkpoint_id=target_checkpoint_id, checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap, needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
): ):
yield chunk yield chunk
# If we get here, streaming completed successfully # If we get here, streaming completed successfully

View file

@ -26,7 +26,7 @@ from app.agents.new_chat.llm_config import (
load_agent_config, load_agent_config,
load_llm_config_from_yaml, load_llm_config_from_yaml,
) )
from app.db import Document, SurfsenseDocsDocument from app.db import ChatVisibility, Document, SurfsenseDocsDocument
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.schemas.new_chat import ChatAttachment from app.schemas.new_chat import ChatAttachment
from app.services.chat_session_state_service import ( from app.services.chat_session_state_service import (
@ -208,6 +208,8 @@ async def stream_new_chat(
mentioned_surfsense_doc_ids: list[int] | None = None, mentioned_surfsense_doc_ids: list[int] | None = None,
checkpoint_id: str | None = None, checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False, needs_history_bootstrap: bool = False,
thread_visibility: ChatVisibility | None = None,
current_user_display_name: str | None = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
Stream chat responses from the new SurfSense deep agent. Stream chat responses from the new SurfSense deep agent.
@ -295,17 +297,18 @@ async def stream_new_chat(
# Get the PostgreSQL checkpointer for persistent conversation memory # Get the PostgreSQL checkpointer for persistent conversation memory
checkpointer = await get_checkpointer() checkpointer = await get_checkpointer()
# Create the deep agent with checkpointer and configurable prompts visibility = thread_visibility or ChatVisibility.PRIVATE
agent = await create_surfsense_deep_agent( agent = await create_surfsense_deep_agent(
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
db_session=session, db_session=session,
connector_service=connector_service, connector_service=connector_service,
checkpointer=checkpointer, checkpointer=checkpointer,
user_id=user_id, # Pass user ID for memory tools user_id=user_id,
thread_id=chat_id, # Pass chat ID for podcast association thread_id=chat_id,
agent_config=agent_config, # Pass prompt configuration agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
) )
# Build input with message history # Build input with message history
@ -313,7 +316,9 @@ async def stream_new_chat(
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet) # Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
if needs_history_bootstrap: if needs_history_bootstrap:
langchain_messages = await bootstrap_history_from_db(session, chat_id) langchain_messages = await bootstrap_history_from_db(
session, chat_id, thread_visibility=visibility
)
# Clear the flag so we don't bootstrap again on next message # Clear the flag so we don't bootstrap again on next message
from app.db import NewChatThread from app.db import NewChatThread
@ -376,6 +381,9 @@ async def stream_new_chat(
context = "\n\n".join(context_parts) context = "\n\n".join(context_parts)
final_query = f"{context}\n\n<user_query>{user_query}</user_query>" final_query = f"{context}\n\n<user_query>{user_query}</user_query>"
if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
final_query = f"**[{current_user_display_name}]:** {final_query}"
# if messages: # if messages:
# # Convert frontend messages to LangChain format # # Convert frontend messages to LangChain format
# for msg in messages: # for msg in messages:

View file

@ -12,6 +12,7 @@ These utilities help extract and transform content for different use cases.
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
def extract_text_content(content: str | dict | list) -> str: def extract_text_content(content: str | dict | list) -> str:
@ -38,6 +39,7 @@ def extract_text_content(content: str | dict | list) -> str:
async def bootstrap_history_from_db( async def bootstrap_history_from_db(
session: AsyncSession, session: AsyncSession,
thread_id: int, thread_id: int,
thread_visibility: "ChatVisibility | None" = None,
) -> list[HumanMessage | AIMessage]: ) -> list[HumanMessage | AIMessage]:
""" """
Load message history from database and convert to LangChain format. Load message history from database and convert to LangChain format.
@ -45,20 +47,28 @@ async def bootstrap_history_from_db(
Used for cloned chats where the LangGraph checkpointer has no state, Used for cloned chats where the LangGraph checkpointer has no state,
but we have messages in the database that should be used as context. but we have messages in the database that should be used as context.
When thread_visibility is SEARCH_SPACE, user messages are prefixed with
the author's display name so the LLM sees who said what.
Args: Args:
session: Database session session: Database session
thread_id: The chat thread ID thread_id: The chat thread ID
thread_visibility: When SEARCH_SPACE, user messages get author prefix
Returns: Returns:
List of LangChain messages (HumanMessage/AIMessage) List of LangChain messages (HumanMessage/AIMessage)
""" """
from app.db import NewChatMessage from app.db import ChatVisibility, NewChatMessage
result = await session.execute( is_shared = thread_visibility == ChatVisibility.SEARCH_SPACE
stmt = (
select(NewChatMessage) select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id) .filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at) .order_by(NewChatMessage.created_at)
) )
if is_shared:
stmt = stmt.options(selectinload(NewChatMessage.author))
result = await session.execute(stmt)
db_messages = result.scalars().all() db_messages = result.scalars().all()
langchain_messages: list[HumanMessage | AIMessage] = [] langchain_messages: list[HumanMessage | AIMessage] = []
@ -68,6 +78,11 @@ async def bootstrap_history_from_db(
if not text_content: if not text_content:
continue continue
if msg.role == "user": if msg.role == "user":
if is_shared:
author_name = (
(msg.author.display_name if msg.author else None) or "A team member"
)
text_content = f"**[{author_name}]:** {text_content}"
langchain_messages.append(HumanMessage(content=text_content)) langchain_messages.append(HumanMessage(content=text_content))
elif msg.role == "assistant": elif msg.role == "assistant":
langchain_messages.append(AIMessage(content=text_content)) langchain_messages.append(AIMessage(content=text_content))

View file

@ -1,16 +1,14 @@
import { atomWithQuery } from "jotai-tanstack-query"; import { atomWithQuery } from "jotai-tanstack-query";
import { userApiService } from "@/lib/apis/user-api.service"; import { userApiService } from "@/lib/apis/user-api.service";
import { getBearerToken } from "@/lib/auth-utils"; import { getBearerToken, isPublicRoute } from "@/lib/auth-utils";
import { cacheKeys } from "@/lib/query-client/cache-keys"; import { cacheKeys } from "@/lib/query-client/cache-keys";
export const currentUserAtom = atomWithQuery(() => { export const currentUserAtom = atomWithQuery(() => {
const pathname = typeof window !== "undefined" ? window.location.pathname : null;
return { return {
queryKey: cacheKeys.user.current(), queryKey: cacheKeys.user.current(),
staleTime: 5 * 60 * 1000, // 5 minutes staleTime: 5 * 60 * 1000, // 5 minutes
// Only fetch user data when a bearer token is present enabled: !!getBearerToken() && pathname !== null && !isPublicRoute(pathname),
enabled: !!getBearerToken(), queryFn: async () => userApiService.getMe(),
queryFn: async () => {
return userApiService.getMe();
},
}; };
}); });

View file

@ -10,28 +10,53 @@ const REFRESH_TOKEN_KEY = "surfsense_refresh_token";
let isRefreshing = false; let isRefreshing = false;
let refreshPromise: Promise<string | null> | null = null; let refreshPromise: Promise<string | null> | null = null;
/** Path prefixes for routes that do not require auth (no current-user fetch, no redirect on 401) */
const PUBLIC_ROUTE_PREFIXES = [
"/login",
"/register",
"/auth",
"/docs",
"/public",
"/invite",
"/contact",
"/pricing",
"/privacy",
"/terms",
"/changelog",
];
/** /**
* Saves the current path and redirects to login page * Returns true if the pathname is a public route where we should not run auth checks
* Call this when a 401 response is received * or redirect to login on 401.
*/
export function isPublicRoute(pathname: string): boolean {
if (pathname === "/" || pathname === "") return true;
return PUBLIC_ROUTE_PREFIXES.some((prefix) => pathname.startsWith(prefix));
}
/**
* Clears tokens and optionally redirects to login.
* Call this when a 401 response is received.
* Only redirects when the current route is protected; on public routes we just clear tokens.
*/ */
export function handleUnauthorized(): void { export function handleUnauthorized(): void {
if (typeof window === "undefined") return; if (typeof window === "undefined") return;
// Save the current path (including search params and hash) for redirect after login const pathname = window.location.pathname;
const currentPath = window.location.pathname + window.location.search + window.location.hash;
// Don't save auth-related paths // Always clear tokens
const excludedPaths = ["/auth", "/auth/callback", "/"];
if (!excludedPaths.includes(window.location.pathname)) {
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
}
// Clear both tokens
localStorage.removeItem(BEARER_TOKEN_KEY); localStorage.removeItem(BEARER_TOKEN_KEY);
localStorage.removeItem(REFRESH_TOKEN_KEY); localStorage.removeItem(REFRESH_TOKEN_KEY);
// Redirect to home page (which has login options) // Only redirect on protected routes; stay on public pages (e.g. /docs)
window.location.href = "/login"; if (!isPublicRoute(pathname)) {
const currentPath = pathname + window.location.search + window.location.hash;
const excludedPaths = ["/auth", "/auth/callback", "/"];
if (!excludedPaths.includes(pathname)) {
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
}
window.location.href = "/login";
}
} }
/** /**
@ -179,7 +204,6 @@ export function getAuthHeaders(additionalHeaders?: Record<string, string>): Reco
/** /**
* Attempts to refresh the access token using the stored refresh token. * Attempts to refresh the access token using the stored refresh token.
* Returns the new access token if successful, null otherwise. * Returns the new access token if successful, null otherwise.
* Exported for use by API services.
*/ */
export async function refreshAccessToken(): Promise<string | null> { export async function refreshAccessToken(): Promise<string | null> {
// If already refreshing, wait for that request to complete // If already refreshing, wait for that request to complete