mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
Merge pull request #799 from CREDO23/sur-152-impr-split-private-and-shared-memory
[Feat] Split private vs shared chat memory and add team prompt/attribution
This commit is contained in:
commit
3f0c9c35f7
11 changed files with 664 additions and 86 deletions
|
|
@ -0,0 +1,77 @@
|
||||||
|
"""Add shared_memories table (SUR-152)."""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
revision: str = "96"
|
||||||
|
down_revision: str | None = "95"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
EMBEDDING_DIM = config.embedding_model_instance.dimension
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
f"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT FROM information_schema.tables
|
||||||
|
WHERE table_name = 'shared_memories'
|
||||||
|
) THEN
|
||||||
|
CREATE TABLE shared_memories (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||||
|
created_by_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
memory_text TEXT NOT NULL,
|
||||||
|
category memorycategory NOT NULL DEFAULT 'fact',
|
||||||
|
embedding vector({EMBEDDING_DIM})
|
||||||
|
);
|
||||||
|
END IF;
|
||||||
|
END$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_indexes
|
||||||
|
WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_search_space_id'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX ix_shared_memories_search_space_id ON shared_memories(search_space_id);
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_indexes
|
||||||
|
WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_updated_at'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX ix_shared_memories_updated_at ON shared_memories(updated_at);
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_indexes
|
||||||
|
WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_created_by_id'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX ix_shared_memories_created_by_id ON shared_memories(created_by_id);
|
||||||
|
END IF;
|
||||||
|
END$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS shared_memories_vector_index
|
||||||
|
ON shared_memories USING hnsw (embedding public.vector_cosine_ops);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP INDEX IF EXISTS shared_memories_vector_index;")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_shared_memories_created_by_id;")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_shared_memories_updated_at;")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_shared_memories_search_space_id;")
|
||||||
|
op.execute("DROP TABLE IF EXISTS shared_memories CASCADE;")
|
||||||
|
|
@ -22,6 +22,7 @@ from app.agents.new_chat.system_prompt import (
|
||||||
build_surfsense_system_prompt,
|
build_surfsense_system_prompt,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.tools.registry import build_tools_async
|
from app.agents.new_chat.tools.registry import build_tools_async
|
||||||
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -126,6 +127,7 @@ async def create_surfsense_deep_agent(
|
||||||
disabled_tools: list[str] | None = None,
|
disabled_tools: list[str] | None = None,
|
||||||
additional_tools: Sequence[BaseTool] | None = None,
|
additional_tools: Sequence[BaseTool] | None = None,
|
||||||
firecrawl_api_key: str | None = None,
|
firecrawl_api_key: str | None = None,
|
||||||
|
thread_visibility: ChatVisibility | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a SurfSense deep agent with configurable tools and prompts.
|
Create a SurfSense deep agent with configurable tools and prompts.
|
||||||
|
|
@ -228,14 +230,15 @@ async def create_surfsense_deep_agent(
|
||||||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||||
|
|
||||||
# Build dependencies dict for the tools registry
|
# Build dependencies dict for the tools registry
|
||||||
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
dependencies = {
|
dependencies = {
|
||||||
"search_space_id": search_space_id,
|
"search_space_id": search_space_id,
|
||||||
"db_session": db_session,
|
"db_session": db_session,
|
||||||
"connector_service": connector_service,
|
"connector_service": connector_service,
|
||||||
"firecrawl_api_key": firecrawl_api_key,
|
"firecrawl_api_key": firecrawl_api_key,
|
||||||
"user_id": user_id, # Required for memory tools
|
"user_id": user_id,
|
||||||
"thread_id": thread_id, # For podcast tool
|
"thread_id": thread_id,
|
||||||
# Dynamic connector/document type discovery for knowledge base tool
|
"thread_visibility": visibility,
|
||||||
"available_connectors": available_connectors,
|
"available_connectors": available_connectors,
|
||||||
"available_document_types": available_document_types,
|
"available_document_types": available_document_types,
|
||||||
}
|
}
|
||||||
|
|
@ -255,10 +258,12 @@ async def create_surfsense_deep_agent(
|
||||||
custom_system_instructions=agent_config.system_instructions,
|
custom_system_instructions=agent_config.system_instructions,
|
||||||
use_default_system_instructions=agent_config.use_default_system_instructions,
|
use_default_system_instructions=agent_config.use_default_system_instructions,
|
||||||
citations_enabled=agent_config.citations_enabled,
|
citations_enabled=agent_config.citations_enabled,
|
||||||
|
thread_visibility=thread_visibility,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use default prompt (with citations enabled)
|
system_prompt = build_surfsense_system_prompt(
|
||||||
system_prompt = build_surfsense_system_prompt()
|
thread_visibility=thread_visibility,
|
||||||
|
)
|
||||||
|
|
||||||
# Create the deep agent with system prompt and checkpointer
|
# Create the deep agent with system prompt and checkpointer
|
||||||
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
|
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ The prompt is composed of three parts:
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
# Default system instructions - can be overridden via NewLLMConfig.system_instructions
|
# Default system instructions - can be overridden via NewLLMConfig.system_instructions
|
||||||
SURFSENSE_SYSTEM_INSTRUCTIONS = """
|
SURFSENSE_SYSTEM_INSTRUCTIONS = """
|
||||||
<system_instruction>
|
<system_instruction>
|
||||||
|
|
@ -22,7 +24,34 @@ Today's date (UTC): {resolved_today}
|
||||||
</system_instruction>
|
</system_instruction>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SURFSENSE_TOOLS_INSTRUCTIONS = """
|
# Default system instructions for shared (team) threads: team context + message format for attribution
|
||||||
|
_SYSTEM_INSTRUCTIONS_SHARED = """
|
||||||
|
<system_instruction>
|
||||||
|
You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base.
|
||||||
|
|
||||||
|
In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers.
|
||||||
|
|
||||||
|
Today's date (UTC): {resolved_today}
|
||||||
|
|
||||||
|
</system_instruction>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_system_instructions(
|
||||||
|
thread_visibility: ChatVisibility | None = None, today: datetime | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Build system instructions based on thread visibility (private vs shared)."""
|
||||||
|
|
||||||
|
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
||||||
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||||
|
return _SYSTEM_INSTRUCTIONS_SHARED.format(resolved_today=resolved_today)
|
||||||
|
else:
|
||||||
|
return SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
|
||||||
|
|
||||||
|
|
||||||
|
# Tools 0-6 (common to both private and shared prompts)
|
||||||
|
_TOOLS_INSTRUCTIONS_COMMON = """
|
||||||
<tools>
|
<tools>
|
||||||
You have access to the following tools:
|
You have access to the following tools:
|
||||||
|
|
||||||
|
|
@ -138,7 +167,11 @@ You have access to the following tools:
|
||||||
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
||||||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||||
|
|
||||||
7. save_memory: Save facts, preferences, or context about the user for personalized responses.
|
"""
|
||||||
|
|
||||||
|
# Private (user) memory: tools 7-8 + memory-specific examples
|
||||||
|
_TOOLS_INSTRUCTIONS_MEMORY_PRIVATE = """
|
||||||
|
7. save_memory: Save facts, preferences, or context for personalized responses.
|
||||||
- Use this when the user explicitly or implicitly shares information worth remembering.
|
- Use this when the user explicitly or implicitly shares information worth remembering.
|
||||||
- Trigger scenarios:
|
- Trigger scenarios:
|
||||||
* User says "remember this", "keep this in mind", "note that", or similar
|
* User says "remember this", "keep this in mind", "note that", or similar
|
||||||
|
|
@ -178,6 +211,75 @@ You have access to the following tools:
|
||||||
stating "Based on your memory..." - integrate the context seamlessly.
|
stating "Based on your memory..." - integrate the context seamlessly.
|
||||||
</tools>
|
</tools>
|
||||||
<tool_call_examples>
|
<tool_call_examples>
|
||||||
|
- User: "Remember that I prefer TypeScript over JavaScript"
|
||||||
|
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
||||||
|
|
||||||
|
- User: "I'm a data scientist working on ML pipelines"
|
||||||
|
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
||||||
|
|
||||||
|
- User: "Always give me code examples in Python"
|
||||||
|
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
||||||
|
|
||||||
|
- User: "What programming language should I use for this project?"
|
||||||
|
- First recall: `recall_memory(query="programming language preferences")`
|
||||||
|
- Then provide a personalized recommendation based on their preferences
|
||||||
|
|
||||||
|
- User: "What do you know about me?"
|
||||||
|
- Call: `recall_memory(top_k=10)`
|
||||||
|
- Then summarize the stored memories
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Shared (team) memory: tools 7-8 + team memory examples
|
||||||
|
_TOOLS_INSTRUCTIONS_MEMORY_SHARED = """
|
||||||
|
7. save_memory: Save a fact, preference, or context to the team's shared memory for future reference.
|
||||||
|
- Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat.
|
||||||
|
- Use when the team agrees on something to remember (e.g., decisions, conventions).
|
||||||
|
- Someone shares a preference or fact that should be visible to the whole team.
|
||||||
|
- The saved information will be available in future shared conversations in this space.
|
||||||
|
- Args:
|
||||||
|
- content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays"
|
||||||
|
- category: Type of memory. One of:
|
||||||
|
* "preference": Team or workspace preferences
|
||||||
|
* "fact": Facts the team agreed on (e.g., processes, locations)
|
||||||
|
* "instruction": Standing instructions for the team
|
||||||
|
* "context": Current context (e.g., ongoing projects, goals)
|
||||||
|
- Returns: Confirmation of saved memory; returned context may include who added it (added_by).
|
||||||
|
- IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space.
|
||||||
|
|
||||||
|
8. recall_memory: Recall relevant team memories for this space to provide contextual responses.
|
||||||
|
- Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?").
|
||||||
|
- Use when someone asks about something the team agreed to remember.
|
||||||
|
- Use when team preferences or conventions would improve the response.
|
||||||
|
- Args:
|
||||||
|
- query: Optional search query to find specific memories. If not provided, returns the most recent memories.
|
||||||
|
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
||||||
|
- top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||||
|
- Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...".
|
||||||
|
</tools>
|
||||||
|
<tool_call_examples>
|
||||||
|
- User: "Remember that API keys are stored in Vault"
|
||||||
|
- Call: `save_memory(content="API keys are stored in Vault", category="fact")`
|
||||||
|
|
||||||
|
- User: "Let's remember that the team prefers weekly demos on Fridays"
|
||||||
|
- Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")`
|
||||||
|
|
||||||
|
- User: "What did we decide about the release date?"
|
||||||
|
- First recall: `recall_memory(query="release date decision")`
|
||||||
|
- Then answer based on the team memories
|
||||||
|
|
||||||
|
- User: "Where do we document onboarding?"
|
||||||
|
- Call: `recall_memory(query="onboarding documentation")`
|
||||||
|
- Then answer using the recalled team context
|
||||||
|
|
||||||
|
- User: "What have we agreed to remember about our deployment process?"
|
||||||
|
- Call: `recall_memory(query="deployment process", top_k=10)`
|
||||||
|
- Then summarize the relevant team memories
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Examples shared by both private and shared prompts (knowledge base, docs, podcast, links, images, etc.)
|
||||||
|
_TOOLS_INSTRUCTIONS_EXAMPLES_COMMON = """
|
||||||
- User: "What time is the team meeting today?"
|
- User: "What time is the team meeting today?"
|
||||||
- Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.)
|
- Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.)
|
||||||
- DO NOT limit to just calendar - the info might be in notes!
|
- DO NOT limit to just calendar - the info might be in notes!
|
||||||
|
|
@ -209,23 +311,6 @@ You have access to the following tools:
|
||||||
- User: "What's in my Obsidian vault about project ideas?"
|
- User: "What's in my Obsidian vault about project ideas?"
|
||||||
- Call: `search_knowledge_base(query="project ideas", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
|
- Call: `search_knowledge_base(query="project ideas", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
|
||||||
|
|
||||||
- User: "Remember that I prefer TypeScript over JavaScript"
|
|
||||||
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
|
||||||
|
|
||||||
- User: "I'm a data scientist working on ML pipelines"
|
|
||||||
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
|
||||||
|
|
||||||
- User: "Always give me code examples in Python"
|
|
||||||
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
|
||||||
|
|
||||||
- User: "What programming language should I use for this project?"
|
|
||||||
- First recall: `recall_memory(query="programming language preferences")`
|
|
||||||
- Then provide a personalized recommendation based on their preferences
|
|
||||||
|
|
||||||
- User: "What do you know about me?"
|
|
||||||
- Call: `recall_memory(top_k=10)`
|
|
||||||
- Then summarize the stored memories
|
|
||||||
|
|
||||||
- User: "Give me a podcast about AI trends based on what we discussed"
|
- User: "Give me a podcast about AI trends based on what we discussed"
|
||||||
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
|
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
|
||||||
|
|
||||||
|
|
@ -315,6 +400,31 @@ You have access to the following tools:
|
||||||
</tool_call_examples>
|
</tool_call_examples>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Reassemble so existing callers see no change (same full prompt)
|
||||||
|
SURFSENSE_TOOLS_INSTRUCTIONS = (
|
||||||
|
_TOOLS_INSTRUCTIONS_COMMON
|
||||||
|
+ _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
|
||||||
|
+ _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tools_instructions(thread_visibility: ChatVisibility | None = None) -> str:
|
||||||
|
"""Build tools instructions based on thread visibility (private vs shared).
|
||||||
|
|
||||||
|
For private chats: use user-focused memory wording and examples.
|
||||||
|
For shared chats: use team memory wording and examples.
|
||||||
|
"""
|
||||||
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
memory_block = (
|
||||||
|
_TOOLS_INSTRUCTIONS_MEMORY_SHARED
|
||||||
|
if visibility == ChatVisibility.SEARCH_SPACE
|
||||||
|
else _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
_TOOLS_INSTRUCTIONS_COMMON + memory_block + _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SURFSENSE_CITATION_INSTRUCTIONS = """
|
SURFSENSE_CITATION_INSTRUCTIONS = """
|
||||||
<citation_instructions>
|
<citation_instructions>
|
||||||
CRITICAL CITATION REQUIREMENTS:
|
CRITICAL CITATION REQUIREMENTS:
|
||||||
|
|
@ -413,6 +523,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
||||||
|
|
||||||
def build_surfsense_system_prompt(
|
def build_surfsense_system_prompt(
|
||||||
today: datetime | None = None,
|
today: datetime | None = None,
|
||||||
|
thread_visibility: ChatVisibility | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build the SurfSense system prompt with default settings.
|
Build the SurfSense system prompt with default settings.
|
||||||
|
|
@ -424,17 +535,17 @@ def build_surfsense_system_prompt(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
today: Optional datetime for today's date (defaults to current UTC date)
|
today: Optional datetime for today's date (defaults to current UTC date)
|
||||||
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Complete system prompt string
|
Complete system prompt string
|
||||||
"""
|
"""
|
||||||
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
|
||||||
|
|
||||||
return (
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
|
system_instructions = _get_system_instructions(visibility, today)
|
||||||
+ SURFSENSE_TOOLS_INSTRUCTIONS
|
tools_instructions = _get_tools_instructions(visibility)
|
||||||
+ SURFSENSE_CITATION_INSTRUCTIONS
|
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
|
||||||
)
|
return system_instructions + tools_instructions + citation_instructions
|
||||||
|
|
||||||
|
|
||||||
def build_configurable_system_prompt(
|
def build_configurable_system_prompt(
|
||||||
|
|
@ -442,6 +553,7 @@ def build_configurable_system_prompt(
|
||||||
use_default_system_instructions: bool = True,
|
use_default_system_instructions: bool = True,
|
||||||
citations_enabled: bool = True,
|
citations_enabled: bool = True,
|
||||||
today: datetime | None = None,
|
today: datetime | None = None,
|
||||||
|
thread_visibility: ChatVisibility | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||||
|
|
@ -460,6 +572,7 @@ def build_configurable_system_prompt(
|
||||||
citations_enabled: Whether to include citation instructions (True) or
|
citations_enabled: Whether to include citation instructions (True) or
|
||||||
anti-citation instructions (False).
|
anti-citation instructions (False).
|
||||||
today: Optional datetime for today's date (defaults to current UTC date)
|
today: Optional datetime for today's date (defaults to current UTC date)
|
||||||
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Complete system prompt string
|
Complete system prompt string
|
||||||
|
|
@ -473,16 +586,14 @@ def build_configurable_system_prompt(
|
||||||
resolved_today=resolved_today
|
resolved_today=resolved_today
|
||||||
)
|
)
|
||||||
elif use_default_system_instructions:
|
elif use_default_system_instructions:
|
||||||
# Use default instructions
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
system_instructions = SURFSENSE_SYSTEM_INSTRUCTIONS.format(
|
system_instructions = _get_system_instructions(visibility, today)
|
||||||
resolved_today=resolved_today
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# No system instructions (edge case)
|
# No system instructions (edge case)
|
||||||
system_instructions = ""
|
system_instructions = ""
|
||||||
|
|
||||||
# Tools instructions are always included
|
# Tools instructions: conditional on thread_visibility (private vs shared memory wording)
|
||||||
tools_instructions = SURFSENSE_TOOLS_INSTRUCTIONS
|
tools_instructions = _get_tools_instructions(thread_visibility)
|
||||||
|
|
||||||
# Citation instructions based on toggle
|
# Citation instructions based on toggle
|
||||||
citation_instructions = (
|
citation_instructions = (
|
||||||
|
|
|
||||||
|
|
@ -51,8 +51,14 @@ from .mcp_tool import load_mcp_tools
|
||||||
from .podcast import create_generate_podcast_tool
|
from .podcast import create_generate_podcast_tool
|
||||||
from .scrape_webpage import create_scrape_webpage_tool
|
from .scrape_webpage import create_scrape_webpage_tool
|
||||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||||
|
from .shared_memory import (
|
||||||
|
create_recall_shared_memory_tool,
|
||||||
|
create_save_shared_memory_tool,
|
||||||
|
)
|
||||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||||
|
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Tool Definition
|
# Tool Definition
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -156,29 +162,42 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
requires=["db_session"],
|
requires=["db_session"],
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# USER MEMORY TOOLS - Claude-like memory feature
|
# USER MEMORY TOOLS - private or team store by thread_visibility
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Save memory tool - stores facts/preferences about the user
|
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="save_memory",
|
name="save_memory",
|
||||||
description="Save facts, preferences, or context about the user for personalized responses",
|
description="Save facts, preferences, or context for personalized or team responses",
|
||||||
factory=lambda deps: create_save_memory_tool(
|
factory=lambda deps: (
|
||||||
|
create_save_shared_memory_tool(
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
created_by_id=deps["user_id"],
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
)
|
||||||
|
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||||
|
else create_save_memory_tool(
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
search_space_id=deps["search_space_id"],
|
search_space_id=deps["search_space_id"],
|
||||||
db_session=deps["db_session"],
|
db_session=deps["db_session"],
|
||||||
|
)
|
||||||
),
|
),
|
||||||
requires=["user_id", "search_space_id", "db_session"],
|
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||||
),
|
),
|
||||||
# Recall memory tool - retrieves relevant user memories
|
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="recall_memory",
|
name="recall_memory",
|
||||||
description="Recall user memories for personalized and contextual responses",
|
description="Recall relevant memories (personal or team) for context",
|
||||||
factory=lambda deps: create_recall_memory_tool(
|
factory=lambda deps: (
|
||||||
|
create_recall_shared_memory_tool(
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
)
|
||||||
|
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||||
|
else create_recall_memory_tool(
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
search_space_id=deps["search_space_id"],
|
search_space_id=deps["search_space_id"],
|
||||||
db_session=deps["db_session"],
|
db_session=deps["db_session"],
|
||||||
|
)
|
||||||
),
|
),
|
||||||
requires=["user_id", "search_space_id", "db_session"],
|
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# ADD YOUR CUSTOM TOOLS BELOW
|
# ADD YOUR CUSTOM TOOLS BELOW
|
||||||
|
|
|
||||||
278
surfsense_backend/app/agents/new_chat/tools/shared_memory.py
Normal file
278
surfsense_backend/app/agents/new_chat/tools/shared_memory.py
Normal file
|
|
@ -0,0 +1,278 @@
|
||||||
|
"""Shared (team) memory backend for search-space-scoped AI context."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import MemoryCategory, SharedMemory, User
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_RECALL_TOP_K = 5
|
||||||
|
MAX_MEMORIES_PER_SEARCH_SPACE = 250
|
||||||
|
|
||||||
|
|
||||||
|
async def get_shared_memory_count(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> int:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SharedMemory).where(SharedMemory.search_space_id == search_space_id)
|
||||||
|
)
|
||||||
|
return len(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_oldest_shared_memory(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SharedMemory)
|
||||||
|
.where(SharedMemory.search_space_id == search_space_id)
|
||||||
|
.order_by(SharedMemory.updated_at.asc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
oldest = result.scalars().first()
|
||||||
|
if oldest:
|
||||||
|
await db_session.delete(oldest)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def _to_uuid(value: str | UUID) -> UUID:
|
||||||
|
if isinstance(value, UUID):
|
||||||
|
return value
|
||||||
|
return UUID(value)
|
||||||
|
|
||||||
|
|
||||||
|
async def save_shared_memory(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
created_by_id: str | UUID,
|
||||||
|
content: str,
|
||||||
|
category: str = "fact",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
category = category.lower() if category else "fact"
|
||||||
|
valid = ["preference", "fact", "instruction", "context"]
|
||||||
|
if category not in valid:
|
||||||
|
category = "fact"
|
||||||
|
try:
|
||||||
|
count = await get_shared_memory_count(db_session, search_space_id)
|
||||||
|
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
|
||||||
|
await delete_oldest_shared_memory(db_session, search_space_id)
|
||||||
|
embedding = config.embedding_model_instance.embed(content)
|
||||||
|
row = SharedMemory(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=_to_uuid(created_by_id),
|
||||||
|
memory_text=content,
|
||||||
|
category=MemoryCategory(category),
|
||||||
|
embedding=embedding,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(row)
|
||||||
|
return {
|
||||||
|
"status": "saved",
|
||||||
|
"memory_id": row.id,
|
||||||
|
"memory_text": content,
|
||||||
|
"category": category,
|
||||||
|
"message": f"I'll remember: {content}",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to save shared memory: %s", e)
|
||||||
|
await db_session.rollback()
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"message": "Failed to save memory. Please try again.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def recall_shared_memory(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
query: str | None = None,
|
||||||
|
category: str | None = None,
|
||||||
|
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
top_k = min(max(top_k, 1), 20)
|
||||||
|
try:
|
||||||
|
valid_categories = ["preference", "fact", "instruction", "context"]
|
||||||
|
stmt = select(SharedMemory).where(
|
||||||
|
SharedMemory.search_space_id == search_space_id
|
||||||
|
)
|
||||||
|
if category and category in valid_categories:
|
||||||
|
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
|
||||||
|
if query:
|
||||||
|
query_embedding = config.embedding_model_instance.embed(query)
|
||||||
|
stmt = stmt.order_by(
|
||||||
|
SharedMemory.embedding.op("<=>")(query_embedding)
|
||||||
|
).limit(top_k)
|
||||||
|
else:
|
||||||
|
stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k)
|
||||||
|
result = await db_session.execute(stmt)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
memory_list = [
|
||||||
|
{
|
||||||
|
"id": m.id,
|
||||||
|
"memory_text": m.memory_text,
|
||||||
|
"category": m.category.value if m.category else "unknown",
|
||||||
|
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
||||||
|
"created_by_id": str(m.created_by_id) if m.created_by_id else None,
|
||||||
|
}
|
||||||
|
for m in rows
|
||||||
|
]
|
||||||
|
created_by_ids = list({m["created_by_id"] for m in memory_list if m["created_by_id"]})
|
||||||
|
created_by_map: dict[str, str] = {}
|
||||||
|
if created_by_ids:
|
||||||
|
uuids = [UUID(uid) for uid in created_by_ids]
|
||||||
|
users_result = await db_session.execute(
|
||||||
|
select(User).where(User.id.in_(uuids))
|
||||||
|
)
|
||||||
|
for u in users_result.scalars().all():
|
||||||
|
created_by_map[str(u.id)] = u.display_name or "A team member"
|
||||||
|
formatted_context = format_shared_memories_for_context(
|
||||||
|
memory_list, created_by_map
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"count": len(memory_list),
|
||||||
|
"memories": memory_list,
|
||||||
|
"formatted_context": formatted_context,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to recall shared memory: %s", e)
|
||||||
|
await db_session.rollback()
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"memories": [],
|
||||||
|
"formatted_context": "Failed to recall memories.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def format_shared_memories_for_context(
|
||||||
|
memories: list[dict[str, Any]],
|
||||||
|
created_by_map: dict[str, str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
if not memories:
|
||||||
|
return "No relevant team memories found."
|
||||||
|
created_by_map = created_by_map or {}
|
||||||
|
parts = ["<team_memories>"]
|
||||||
|
for memory in memories:
|
||||||
|
category = memory.get("category", "unknown")
|
||||||
|
text = memory.get("memory_text", "")
|
||||||
|
updated = memory.get("updated_at", "")
|
||||||
|
created_by_id = memory.get("created_by_id")
|
||||||
|
added_by = (
|
||||||
|
created_by_map.get(str(created_by_id), "A team member")
|
||||||
|
if created_by_id is not None
|
||||||
|
else "A team member"
|
||||||
|
)
|
||||||
|
parts.append(
|
||||||
|
f" <memory category='{category}' updated='{updated}' added_by='{added_by}'>{text}</memory>"
|
||||||
|
)
|
||||||
|
parts.append("</team_memories>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def create_save_shared_memory_tool(
|
||||||
|
search_space_id: int,
|
||||||
|
created_by_id: str | UUID,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the save_memory tool for shared (team) chats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_space_id: The search space ID
|
||||||
|
created_by_id: The user ID of the person adding the memory
|
||||||
|
db_session: Database session for executing queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured tool function for saving team memories
|
||||||
|
"""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def save_memory(
|
||||||
|
content: str,
|
||||||
|
category: str = "fact",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Save a fact, preference, or context to the team's shared memory for future reference.
|
||||||
|
|
||||||
|
Use this tool when:
|
||||||
|
- User or a team member says "remember this", "keep this in mind", or similar in this shared chat
|
||||||
|
- The team agrees on something to remember (e.g., decisions, conventions, where things live)
|
||||||
|
- Someone shares a preference or fact that should be visible to the whole team
|
||||||
|
|
||||||
|
The saved information will be available in future shared conversations in this space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The fact/preference/context to remember.
|
||||||
|
Phrase it clearly, e.g., "API keys are stored in Vault",
|
||||||
|
"The team prefers weekly demos on Fridays"
|
||||||
|
category: Type of memory. One of:
|
||||||
|
- "preference": Team or workspace preferences
|
||||||
|
- "fact": Facts the team agreed on (e.g., processes, locations)
|
||||||
|
- "instruction": Standing instructions for the team
|
||||||
|
- "context": Current context (e.g., ongoing projects, goals)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the save status and memory details
|
||||||
|
"""
|
||||||
|
return await save_shared_memory(
|
||||||
|
db_session, search_space_id, created_by_id, content, category
|
||||||
|
)
|
||||||
|
|
||||||
|
return save_memory
|
||||||
|
|
||||||
|
|
||||||
|
def create_recall_shared_memory_tool(
|
||||||
|
search_space_id: int,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the recall_memory tool for shared (team) chats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_space_id: The search space ID
|
||||||
|
db_session: Database session for executing queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured tool function for recalling team memories
|
||||||
|
"""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def recall_memory(
|
||||||
|
query: str | None = None,
|
||||||
|
category: str | None = None,
|
||||||
|
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Recall relevant team memories for this space to provide contextual responses.
|
||||||
|
|
||||||
|
Use this tool when:
|
||||||
|
- You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?")
|
||||||
|
- Someone asks about something the team agreed to remember
|
||||||
|
- Team preferences or conventions would improve the response
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Optional search query to find specific memories.
|
||||||
|
If not provided, returns the most recent memories.
|
||||||
|
category: Optional category filter. One of:
|
||||||
|
"preference", "fact", "instruction", "context"
|
||||||
|
top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing relevant memories and formatted context
|
||||||
|
"""
|
||||||
|
return await recall_shared_memory(
|
||||||
|
db_session, search_space_id, query, category, top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
return recall_memory
|
||||||
|
|
@ -801,9 +801,8 @@ class MemoryCategory(str, Enum):
|
||||||
|
|
||||||
class UserMemory(BaseModel, TimestampMixin):
|
class UserMemory(BaseModel, TimestampMixin):
|
||||||
"""
|
"""
|
||||||
Stores facts, preferences, and context about users for personalized AI responses.
|
Private memory: facts, preferences, context per user per search space.
|
||||||
Similar to Claude's memory feature - enables the AI to remember user information
|
Used only for private chats (not shared/team chats).
|
||||||
across conversations.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "user_memories"
|
__tablename__ = "user_memories"
|
||||||
|
|
@ -847,6 +846,40 @@ class UserMemory(BaseModel, TimestampMixin):
|
||||||
search_space = relationship("SearchSpace", back_populates="user_memories")
|
search_space = relationship("SearchSpace", back_populates="user_memories")
|
||||||
|
|
||||||
|
|
||||||
|
class SharedMemory(BaseModel, TimestampMixin):
|
||||||
|
__tablename__ = "shared_memories"
|
||||||
|
|
||||||
|
search_space_id = Column(
|
||||||
|
Integer,
|
||||||
|
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
created_by_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("user.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
memory_text = Column(Text, nullable=False)
|
||||||
|
category = Column(
|
||||||
|
SQLAlchemyEnum(MemoryCategory),
|
||||||
|
nullable=False,
|
||||||
|
default=MemoryCategory.fact,
|
||||||
|
)
|
||||||
|
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||||
|
updated_at = Column(
|
||||||
|
TIMESTAMP(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
default=lambda: datetime.now(UTC),
|
||||||
|
onupdate=lambda: datetime.now(UTC),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
search_space = relationship("SearchSpace", back_populates="shared_memories")
|
||||||
|
created_by = relationship("User")
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel, TimestampMixin):
|
class Document(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "documents"
|
__tablename__ = "documents"
|
||||||
|
|
||||||
|
|
@ -1209,6 +1242,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
||||||
order_by="UserMemory.updated_at.desc()",
|
order_by="UserMemory.updated_at.desc()",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
shared_memories = relationship(
|
||||||
|
"SharedMemory",
|
||||||
|
back_populates="search_space",
|
||||||
|
order_by="SharedMemory.updated_at.desc()",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||||
|
|
@ -1258,7 +1297,7 @@ class NewLLMConfig(BaseModel, TimestampMixin):
|
||||||
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||||
- Citation toggle (enable/disable citation instructions)
|
- Citation toggle (enable/disable citation instructions)
|
||||||
|
|
||||||
Note: SURFSENSE_TOOLS_INSTRUCTIONS is always used and not configurable.
|
Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "new_llm_configs"
|
__tablename__ = "new_llm_configs"
|
||||||
|
|
|
||||||
|
|
@ -1045,12 +1045,14 @@ async def handle_new_chat(
|
||||||
search_space_id=request.search_space_id,
|
search_space_id=request.search_space_id,
|
||||||
chat_id=request.chat_id,
|
chat_id=request.chat_id,
|
||||||
session=session,
|
session=session,
|
||||||
user_id=str(user.id), # Pass user ID for memory tools and session state
|
user_id=str(user.id),
|
||||||
llm_config_id=llm_config_id,
|
llm_config_id=llm_config_id,
|
||||||
attachments=request.attachments,
|
attachments=request.attachments,
|
||||||
mentioned_document_ids=request.mentioned_document_ids,
|
mentioned_document_ids=request.mentioned_document_ids,
|
||||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||||
|
thread_visibility=thread.visibility,
|
||||||
|
current_user_display_name=user.display_name or "A team member",
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
|
|
@ -1281,6 +1283,8 @@ async def regenerate_response(
|
||||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||||
checkpoint_id=target_checkpoint_id,
|
checkpoint_id=target_checkpoint_id,
|
||||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||||
|
thread_visibility=thread.visibility,
|
||||||
|
current_user_display_name=user.display_name or "A team member",
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
# If we get here, streaming completed successfully
|
# If we get here, streaming completed successfully
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from app.agents.new_chat.llm_config import (
|
||||||
load_agent_config,
|
load_agent_config,
|
||||||
load_llm_config_from_yaml,
|
load_llm_config_from_yaml,
|
||||||
)
|
)
|
||||||
from app.db import Document, SurfsenseDocsDocument
|
from app.db import ChatVisibility, Document, SurfsenseDocsDocument
|
||||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
from app.schemas.new_chat import ChatAttachment
|
from app.schemas.new_chat import ChatAttachment
|
||||||
from app.services.chat_session_state_service import (
|
from app.services.chat_session_state_service import (
|
||||||
|
|
@ -208,6 +208,8 @@ async def stream_new_chat(
|
||||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||||
checkpoint_id: str | None = None,
|
checkpoint_id: str | None = None,
|
||||||
needs_history_bootstrap: bool = False,
|
needs_history_bootstrap: bool = False,
|
||||||
|
thread_visibility: ChatVisibility | None = None,
|
||||||
|
current_user_display_name: str | None = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Stream chat responses from the new SurfSense deep agent.
|
Stream chat responses from the new SurfSense deep agent.
|
||||||
|
|
@ -295,17 +297,18 @@ async def stream_new_chat(
|
||||||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||||
checkpointer = await get_checkpointer()
|
checkpointer = await get_checkpointer()
|
||||||
|
|
||||||
# Create the deep agent with checkpointer and configurable prompts
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
agent = await create_surfsense_deep_agent(
|
agent = await create_surfsense_deep_agent(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
db_session=session,
|
db_session=session,
|
||||||
connector_service=connector_service,
|
connector_service=connector_service,
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
user_id=user_id, # Pass user ID for memory tools
|
user_id=user_id,
|
||||||
thread_id=chat_id, # Pass chat ID for podcast association
|
thread_id=chat_id,
|
||||||
agent_config=agent_config, # Pass prompt configuration
|
agent_config=agent_config,
|
||||||
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
|
thread_visibility=visibility,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build input with message history
|
# Build input with message history
|
||||||
|
|
@ -313,7 +316,9 @@ async def stream_new_chat(
|
||||||
|
|
||||||
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
|
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
|
||||||
if needs_history_bootstrap:
|
if needs_history_bootstrap:
|
||||||
langchain_messages = await bootstrap_history_from_db(session, chat_id)
|
langchain_messages = await bootstrap_history_from_db(
|
||||||
|
session, chat_id, thread_visibility=visibility
|
||||||
|
)
|
||||||
|
|
||||||
# Clear the flag so we don't bootstrap again on next message
|
# Clear the flag so we don't bootstrap again on next message
|
||||||
from app.db import NewChatThread
|
from app.db import NewChatThread
|
||||||
|
|
@ -376,6 +381,9 @@ async def stream_new_chat(
|
||||||
context = "\n\n".join(context_parts)
|
context = "\n\n".join(context_parts)
|
||||||
final_query = f"{context}\n\n<user_query>{user_query}</user_query>"
|
final_query = f"{context}\n\n<user_query>{user_query}</user_query>"
|
||||||
|
|
||||||
|
if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
|
||||||
|
final_query = f"**[{current_user_display_name}]:** {final_query}"
|
||||||
|
|
||||||
# if messages:
|
# if messages:
|
||||||
# # Convert frontend messages to LangChain format
|
# # Convert frontend messages to LangChain format
|
||||||
# for msg in messages:
|
# for msg in messages:
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ These utilities help extract and transform content for different use cases.
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
|
||||||
def extract_text_content(content: str | dict | list) -> str:
|
def extract_text_content(content: str | dict | list) -> str:
|
||||||
|
|
@ -38,6 +39,7 @@ def extract_text_content(content: str | dict | list) -> str:
|
||||||
async def bootstrap_history_from_db(
|
async def bootstrap_history_from_db(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
thread_id: int,
|
thread_id: int,
|
||||||
|
thread_visibility: "ChatVisibility | None" = None,
|
||||||
) -> list[HumanMessage | AIMessage]:
|
) -> list[HumanMessage | AIMessage]:
|
||||||
"""
|
"""
|
||||||
Load message history from database and convert to LangChain format.
|
Load message history from database and convert to LangChain format.
|
||||||
|
|
@ -45,20 +47,28 @@ async def bootstrap_history_from_db(
|
||||||
Used for cloned chats where the LangGraph checkpointer has no state,
|
Used for cloned chats where the LangGraph checkpointer has no state,
|
||||||
but we have messages in the database that should be used as context.
|
but we have messages in the database that should be used as context.
|
||||||
|
|
||||||
|
When thread_visibility is SEARCH_SPACE, user messages are prefixed with
|
||||||
|
the author's display name so the LLM sees who said what.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: Database session
|
session: Database session
|
||||||
thread_id: The chat thread ID
|
thread_id: The chat thread ID
|
||||||
|
thread_visibility: When SEARCH_SPACE, user messages get author prefix
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of LangChain messages (HumanMessage/AIMessage)
|
List of LangChain messages (HumanMessage/AIMessage)
|
||||||
"""
|
"""
|
||||||
from app.db import NewChatMessage
|
from app.db import ChatVisibility, NewChatMessage
|
||||||
|
|
||||||
result = await session.execute(
|
is_shared = thread_visibility == ChatVisibility.SEARCH_SPACE
|
||||||
|
stmt = (
|
||||||
select(NewChatMessage)
|
select(NewChatMessage)
|
||||||
.filter(NewChatMessage.thread_id == thread_id)
|
.filter(NewChatMessage.thread_id == thread_id)
|
||||||
.order_by(NewChatMessage.created_at)
|
.order_by(NewChatMessage.created_at)
|
||||||
)
|
)
|
||||||
|
if is_shared:
|
||||||
|
stmt = stmt.options(selectinload(NewChatMessage.author))
|
||||||
|
result = await session.execute(stmt)
|
||||||
db_messages = result.scalars().all()
|
db_messages = result.scalars().all()
|
||||||
|
|
||||||
langchain_messages: list[HumanMessage | AIMessage] = []
|
langchain_messages: list[HumanMessage | AIMessage] = []
|
||||||
|
|
@ -68,6 +78,11 @@ async def bootstrap_history_from_db(
|
||||||
if not text_content:
|
if not text_content:
|
||||||
continue
|
continue
|
||||||
if msg.role == "user":
|
if msg.role == "user":
|
||||||
|
if is_shared:
|
||||||
|
author_name = (
|
||||||
|
(msg.author.display_name if msg.author else None) or "A team member"
|
||||||
|
)
|
||||||
|
text_content = f"**[{author_name}]:** {text_content}"
|
||||||
langchain_messages.append(HumanMessage(content=text_content))
|
langchain_messages.append(HumanMessage(content=text_content))
|
||||||
elif msg.role == "assistant":
|
elif msg.role == "assistant":
|
||||||
langchain_messages.append(AIMessage(content=text_content))
|
langchain_messages.append(AIMessage(content=text_content))
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,14 @@
|
||||||
import { atomWithQuery } from "jotai-tanstack-query";
|
import { atomWithQuery } from "jotai-tanstack-query";
|
||||||
import { userApiService } from "@/lib/apis/user-api.service";
|
import { userApiService } from "@/lib/apis/user-api.service";
|
||||||
import { getBearerToken } from "@/lib/auth-utils";
|
import { getBearerToken, isPublicRoute } from "@/lib/auth-utils";
|
||||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||||
|
|
||||||
export const currentUserAtom = atomWithQuery(() => {
|
export const currentUserAtom = atomWithQuery(() => {
|
||||||
|
const pathname = typeof window !== "undefined" ? window.location.pathname : null;
|
||||||
return {
|
return {
|
||||||
queryKey: cacheKeys.user.current(),
|
queryKey: cacheKeys.user.current(),
|
||||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||||
// Only fetch user data when a bearer token is present
|
enabled: !!getBearerToken() && pathname !== null && !isPublicRoute(pathname),
|
||||||
enabled: !!getBearerToken(),
|
queryFn: async () => userApiService.getMe(),
|
||||||
queryFn: async () => {
|
|
||||||
return userApiService.getMe();
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -10,28 +10,53 @@ const REFRESH_TOKEN_KEY = "surfsense_refresh_token";
|
||||||
let isRefreshing = false;
|
let isRefreshing = false;
|
||||||
let refreshPromise: Promise<string | null> | null = null;
|
let refreshPromise: Promise<string | null> | null = null;
|
||||||
|
|
||||||
|
/** Path prefixes for routes that do not require auth (no current-user fetch, no redirect on 401) */
|
||||||
|
const PUBLIC_ROUTE_PREFIXES = [
|
||||||
|
"/login",
|
||||||
|
"/register",
|
||||||
|
"/auth",
|
||||||
|
"/docs",
|
||||||
|
"/public",
|
||||||
|
"/invite",
|
||||||
|
"/contact",
|
||||||
|
"/pricing",
|
||||||
|
"/privacy",
|
||||||
|
"/terms",
|
||||||
|
"/changelog",
|
||||||
|
];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Saves the current path and redirects to login page
|
* Returns true if the pathname is a public route where we should not run auth checks
|
||||||
* Call this when a 401 response is received
|
* or redirect to login on 401.
|
||||||
|
*/
|
||||||
|
export function isPublicRoute(pathname: string): boolean {
|
||||||
|
if (pathname === "/" || pathname === "") return true;
|
||||||
|
return PUBLIC_ROUTE_PREFIXES.some((prefix) => pathname.startsWith(prefix));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clears tokens and optionally redirects to login.
|
||||||
|
* Call this when a 401 response is received.
|
||||||
|
* Only redirects when the current route is protected; on public routes we just clear tokens.
|
||||||
*/
|
*/
|
||||||
export function handleUnauthorized(): void {
|
export function handleUnauthorized(): void {
|
||||||
if (typeof window === "undefined") return;
|
if (typeof window === "undefined") return;
|
||||||
|
|
||||||
// Save the current path (including search params and hash) for redirect after login
|
const pathname = window.location.pathname;
|
||||||
const currentPath = window.location.pathname + window.location.search + window.location.hash;
|
|
||||||
|
|
||||||
// Don't save auth-related paths
|
// Always clear tokens
|
||||||
const excludedPaths = ["/auth", "/auth/callback", "/"];
|
|
||||||
if (!excludedPaths.includes(window.location.pathname)) {
|
|
||||||
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear both tokens
|
|
||||||
localStorage.removeItem(BEARER_TOKEN_KEY);
|
localStorage.removeItem(BEARER_TOKEN_KEY);
|
||||||
localStorage.removeItem(REFRESH_TOKEN_KEY);
|
localStorage.removeItem(REFRESH_TOKEN_KEY);
|
||||||
|
|
||||||
// Redirect to home page (which has login options)
|
// Only redirect on protected routes; stay on public pages (e.g. /docs)
|
||||||
|
if (!isPublicRoute(pathname)) {
|
||||||
|
const currentPath = pathname + window.location.search + window.location.hash;
|
||||||
|
const excludedPaths = ["/auth", "/auth/callback", "/"];
|
||||||
|
if (!excludedPaths.includes(pathname)) {
|
||||||
|
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
|
||||||
|
}
|
||||||
window.location.href = "/login";
|
window.location.href = "/login";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -179,7 +204,6 @@ export function getAuthHeaders(additionalHeaders?: Record<string, string>): Reco
|
||||||
/**
|
/**
|
||||||
* Attempts to refresh the access token using the stored refresh token.
|
* Attempts to refresh the access token using the stored refresh token.
|
||||||
* Returns the new access token if successful, null otherwise.
|
* Returns the new access token if successful, null otherwise.
|
||||||
* Exported for use by API services.
|
|
||||||
*/
|
*/
|
||||||
export async function refreshAccessToken(): Promise<string | null> {
|
export async function refreshAccessToken(): Promise<string | null> {
|
||||||
// If already refreshing, wait for that request to complete
|
// If already refreshing, wait for that request to complete
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue