mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
Merge pull request #799 from CREDO23/sur-152-impr-split-private-and-shared-memory
[Feat] Split private vs shared chat memory and add team prompt/attribution
This commit is contained in:
commit
3f0c9c35f7
11 changed files with 664 additions and 86 deletions
|
|
@ -0,0 +1,77 @@
|
|||
"""Add shared_memories table (SUR-152)."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
from app.config import config
|
||||
|
||||
revision: str = "96"
|
||||
down_revision: str | None = "95"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
EMBEDDING_DIM = config.embedding_model_instance.dimension
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'shared_memories'
|
||||
) THEN
|
||||
CREATE TABLE shared_memories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
created_by_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
memory_text TEXT NOT NULL,
|
||||
category memorycategory NOT NULL DEFAULT 'fact',
|
||||
embedding vector({EMBEDDING_DIM})
|
||||
);
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_search_space_id'
|
||||
) THEN
|
||||
CREATE INDEX ix_shared_memories_search_space_id ON shared_memories(search_space_id);
|
||||
END IF;
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_updated_at'
|
||||
) THEN
|
||||
CREATE INDEX ix_shared_memories_updated_at ON shared_memories(updated_at);
|
||||
END IF;
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE tablename = 'shared_memories' AND indexname = 'ix_shared_memories_created_by_id'
|
||||
) THEN
|
||||
CREATE INDEX ix_shared_memories_created_by_id ON shared_memories(created_by_id);
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS shared_memories_vector_index
|
||||
ON shared_memories USING hnsw (embedding public.vector_cosine_ops);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS shared_memories_vector_index;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_shared_memories_created_by_id;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_shared_memories_updated_at;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_shared_memories_search_space_id;")
|
||||
op.execute("DROP TABLE IF EXISTS shared_memories CASCADE;")
|
||||
|
|
@ -22,6 +22,7 @@ from app.agents.new_chat.system_prompt import (
|
|||
build_surfsense_system_prompt,
|
||||
)
|
||||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -126,6 +127,7 @@ async def create_surfsense_deep_agent(
|
|||
disabled_tools: list[str] | None = None,
|
||||
additional_tools: Sequence[BaseTool] | None = None,
|
||||
firecrawl_api_key: str | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with configurable tools and prompts.
|
||||
|
|
@ -228,14 +230,15 @@ async def create_surfsense_deep_agent(
|
|||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||
|
||||
# Build dependencies dict for the tools registry
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
dependencies = {
|
||||
"search_space_id": search_space_id,
|
||||
"db_session": db_session,
|
||||
"connector_service": connector_service,
|
||||
"firecrawl_api_key": firecrawl_api_key,
|
||||
"user_id": user_id, # Required for memory tools
|
||||
"thread_id": thread_id, # For podcast tool
|
||||
# Dynamic connector/document type discovery for knowledge base tool
|
||||
"user_id": user_id,
|
||||
"thread_id": thread_id,
|
||||
"thread_visibility": visibility,
|
||||
"available_connectors": available_connectors,
|
||||
"available_document_types": available_document_types,
|
||||
}
|
||||
|
|
@ -255,10 +258,12 @@ async def create_surfsense_deep_agent(
|
|||
custom_system_instructions=agent_config.system_instructions,
|
||||
use_default_system_instructions=agent_config.use_default_system_instructions,
|
||||
citations_enabled=agent_config.citations_enabled,
|
||||
thread_visibility=thread_visibility,
|
||||
)
|
||||
else:
|
||||
# Use default prompt (with citations enabled)
|
||||
system_prompt = build_surfsense_system_prompt()
|
||||
system_prompt = build_surfsense_system_prompt(
|
||||
thread_visibility=thread_visibility,
|
||||
)
|
||||
|
||||
# Create the deep agent with system prompt and checkpointer
|
||||
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ The prompt is composed of three parts:
|
|||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
# Default system instructions - can be overridden via NewLLMConfig.system_instructions
|
||||
SURFSENSE_SYSTEM_INSTRUCTIONS = """
|
||||
<system_instruction>
|
||||
|
|
@ -22,7 +24,34 @@ Today's date (UTC): {resolved_today}
|
|||
</system_instruction>
|
||||
"""
|
||||
|
||||
SURFSENSE_TOOLS_INSTRUCTIONS = """
|
||||
# Default system instructions for shared (team) threads: team context + message format for attribution
|
||||
_SYSTEM_INSTRUCTIONS_SHARED = """
|
||||
<system_instruction>
|
||||
You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base.
|
||||
|
||||
In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers.
|
||||
|
||||
Today's date (UTC): {resolved_today}
|
||||
|
||||
</system_instruction>
|
||||
"""
|
||||
|
||||
|
||||
def _get_system_instructions(
|
||||
thread_visibility: ChatVisibility | None = None, today: datetime | None = None
|
||||
) -> str:
|
||||
"""Build system instructions based on thread visibility (private vs shared)."""
|
||||
|
||||
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||
return _SYSTEM_INSTRUCTIONS_SHARED.format(resolved_today=resolved_today)
|
||||
else:
|
||||
return SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
|
||||
|
||||
|
||||
# Tools 0-6 (common to both private and shared prompts)
|
||||
_TOOLS_INSTRUCTIONS_COMMON = """
|
||||
<tools>
|
||||
You have access to the following tools:
|
||||
|
||||
|
|
@ -138,7 +167,11 @@ You have access to the following tools:
|
|||
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
||||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||
|
||||
7. save_memory: Save facts, preferences, or context about the user for personalized responses.
|
||||
"""
|
||||
|
||||
# Private (user) memory: tools 7-8 + memory-specific examples
|
||||
_TOOLS_INSTRUCTIONS_MEMORY_PRIVATE = """
|
||||
7. save_memory: Save facts, preferences, or context for personalized responses.
|
||||
- Use this when the user explicitly or implicitly shares information worth remembering.
|
||||
- Trigger scenarios:
|
||||
* User says "remember this", "keep this in mind", "note that", or similar
|
||||
|
|
@ -178,6 +211,75 @@ You have access to the following tools:
|
|||
stating "Based on your memory..." - integrate the context seamlessly.
|
||||
</tools>
|
||||
<tool_call_examples>
|
||||
- User: "Remember that I prefer TypeScript over JavaScript"
|
||||
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
||||
|
||||
- User: "I'm a data scientist working on ML pipelines"
|
||||
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
||||
|
||||
- User: "Always give me code examples in Python"
|
||||
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
||||
|
||||
- User: "What programming language should I use for this project?"
|
||||
- First recall: `recall_memory(query="programming language preferences")`
|
||||
- Then provide a personalized recommendation based on their preferences
|
||||
|
||||
- User: "What do you know about me?"
|
||||
- Call: `recall_memory(top_k=10)`
|
||||
- Then summarize the stored memories
|
||||
|
||||
"""
|
||||
|
||||
# Shared (team) memory: tools 7-8 + team memory examples
|
||||
_TOOLS_INSTRUCTIONS_MEMORY_SHARED = """
|
||||
7. save_memory: Save a fact, preference, or context to the team's shared memory for future reference.
|
||||
- Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat.
|
||||
- Use when the team agrees on something to remember (e.g., decisions, conventions).
|
||||
- Someone shares a preference or fact that should be visible to the whole team.
|
||||
- The saved information will be available in future shared conversations in this space.
|
||||
- Args:
|
||||
- content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays"
|
||||
- category: Type of memory. One of:
|
||||
* "preference": Team or workspace preferences
|
||||
* "fact": Facts the team agreed on (e.g., processes, locations)
|
||||
* "instruction": Standing instructions for the team
|
||||
* "context": Current context (e.g., ongoing projects, goals)
|
||||
- Returns: Confirmation of saved memory; returned context may include who added it (added_by).
|
||||
- IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space.
|
||||
|
||||
8. recall_memory: Recall relevant team memories for this space to provide contextual responses.
|
||||
- Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?").
|
||||
- Use when someone asks about something the team agreed to remember.
|
||||
- Use when team preferences or conventions would improve the response.
|
||||
- Args:
|
||||
- query: Optional search query to find specific memories. If not provided, returns the most recent memories.
|
||||
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
||||
- top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
- Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...".
|
||||
</tools>
|
||||
<tool_call_examples>
|
||||
- User: "Remember that API keys are stored in Vault"
|
||||
- Call: `save_memory(content="API keys are stored in Vault", category="fact")`
|
||||
|
||||
- User: "Let's remember that the team prefers weekly demos on Fridays"
|
||||
- Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")`
|
||||
|
||||
- User: "What did we decide about the release date?"
|
||||
- First recall: `recall_memory(query="release date decision")`
|
||||
- Then answer based on the team memories
|
||||
|
||||
- User: "Where do we document onboarding?"
|
||||
- Call: `recall_memory(query="onboarding documentation")`
|
||||
- Then answer using the recalled team context
|
||||
|
||||
- User: "What have we agreed to remember about our deployment process?"
|
||||
- Call: `recall_memory(query="deployment process", top_k=10)`
|
||||
- Then summarize the relevant team memories
|
||||
|
||||
"""
|
||||
|
||||
# Examples shared by both private and shared prompts (knowledge base, docs, podcast, links, images, etc.)
|
||||
_TOOLS_INSTRUCTIONS_EXAMPLES_COMMON = """
|
||||
- User: "What time is the team meeting today?"
|
||||
- Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.)
|
||||
- DO NOT limit to just calendar - the info might be in notes!
|
||||
|
|
@ -209,23 +311,6 @@ You have access to the following tools:
|
|||
- User: "What's in my Obsidian vault about project ideas?"
|
||||
- Call: `search_knowledge_base(query="project ideas", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
|
||||
|
||||
- User: "Remember that I prefer TypeScript over JavaScript"
|
||||
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
||||
|
||||
- User: "I'm a data scientist working on ML pipelines"
|
||||
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
||||
|
||||
- User: "Always give me code examples in Python"
|
||||
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
||||
|
||||
- User: "What programming language should I use for this project?"
|
||||
- First recall: `recall_memory(query="programming language preferences")`
|
||||
- Then provide a personalized recommendation based on their preferences
|
||||
|
||||
- User: "What do you know about me?"
|
||||
- Call: `recall_memory(top_k=10)`
|
||||
- Then summarize the stored memories
|
||||
|
||||
- User: "Give me a podcast about AI trends based on what we discussed"
|
||||
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
|
||||
|
||||
|
|
@ -315,6 +400,31 @@ You have access to the following tools:
|
|||
</tool_call_examples>
|
||||
"""
|
||||
|
||||
# Reassemble so existing callers see no change (same full prompt)
|
||||
SURFSENSE_TOOLS_INSTRUCTIONS = (
|
||||
_TOOLS_INSTRUCTIONS_COMMON
|
||||
+ _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
|
||||
+ _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
|
||||
)
|
||||
|
||||
|
||||
def _get_tools_instructions(thread_visibility: ChatVisibility | None = None) -> str:
|
||||
"""Build tools instructions based on thread visibility (private vs shared).
|
||||
|
||||
For private chats: use user-focused memory wording and examples.
|
||||
For shared chats: use team memory wording and examples.
|
||||
"""
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
memory_block = (
|
||||
_TOOLS_INSTRUCTIONS_MEMORY_SHARED
|
||||
if visibility == ChatVisibility.SEARCH_SPACE
|
||||
else _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
|
||||
)
|
||||
return (
|
||||
_TOOLS_INSTRUCTIONS_COMMON + memory_block + _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
|
||||
)
|
||||
|
||||
|
||||
SURFSENSE_CITATION_INSTRUCTIONS = """
|
||||
<citation_instructions>
|
||||
CRITICAL CITATION REQUIREMENTS:
|
||||
|
|
@ -413,6 +523,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
|||
|
||||
def build_surfsense_system_prompt(
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build the SurfSense system prompt with default settings.
|
||||
|
|
@ -424,17 +535,17 @@ def build_surfsense_system_prompt(
|
|||
|
||||
Args:
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
"""
|
||||
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
||||
|
||||
return (
|
||||
SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
|
||||
+ SURFSENSE_TOOLS_INSTRUCTIONS
|
||||
+ SURFSENSE_CITATION_INSTRUCTIONS
|
||||
)
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
system_instructions = _get_system_instructions(visibility, today)
|
||||
tools_instructions = _get_tools_instructions(visibility)
|
||||
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
|
||||
return system_instructions + tools_instructions + citation_instructions
|
||||
|
||||
|
||||
def build_configurable_system_prompt(
|
||||
|
|
@ -442,6 +553,7 @@ def build_configurable_system_prompt(
|
|||
use_default_system_instructions: bool = True,
|
||||
citations_enabled: bool = True,
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||
|
|
@ -460,6 +572,7 @@ def build_configurable_system_prompt(
|
|||
citations_enabled: Whether to include citation instructions (True) or
|
||||
anti-citation instructions (False).
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -473,16 +586,14 @@ def build_configurable_system_prompt(
|
|||
resolved_today=resolved_today
|
||||
)
|
||||
elif use_default_system_instructions:
|
||||
# Use default instructions
|
||||
system_instructions = SURFSENSE_SYSTEM_INSTRUCTIONS.format(
|
||||
resolved_today=resolved_today
|
||||
)
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
system_instructions = _get_system_instructions(visibility, today)
|
||||
else:
|
||||
# No system instructions (edge case)
|
||||
system_instructions = ""
|
||||
|
||||
# Tools instructions are always included
|
||||
tools_instructions = SURFSENSE_TOOLS_INSTRUCTIONS
|
||||
# Tools instructions: conditional on thread_visibility (private vs shared memory wording)
|
||||
tools_instructions = _get_tools_instructions(thread_visibility)
|
||||
|
||||
# Citation instructions based on toggle
|
||||
citation_instructions = (
|
||||
|
|
|
|||
|
|
@ -51,8 +51,14 @@ from .mcp_tool import load_mcp_tools
|
|||
from .podcast import create_generate_podcast_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .shared_memory import (
|
||||
create_recall_shared_memory_tool,
|
||||
create_save_shared_memory_tool,
|
||||
)
|
||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
# =============================================================================
|
||||
# Tool Definition
|
||||
# =============================================================================
|
||||
|
|
@ -156,29 +162,42 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["db_session"],
|
||||
),
|
||||
# =========================================================================
|
||||
# USER MEMORY TOOLS - Claude-like memory feature
|
||||
# USER MEMORY TOOLS - private or team store by thread_visibility
|
||||
# =========================================================================
|
||||
# Save memory tool - stores facts/preferences about the user
|
||||
ToolDefinition(
|
||||
name="save_memory",
|
||||
description="Save facts, preferences, or context about the user for personalized responses",
|
||||
factory=lambda deps: create_save_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
description="Save facts, preferences, or context for personalized or team responses",
|
||||
factory=lambda deps: (
|
||||
create_save_shared_memory_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
created_by_id=deps["user_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||
else create_save_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
),
|
||||
requires=["user_id", "search_space_id", "db_session"],
|
||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
),
|
||||
# Recall memory tool - retrieves relevant user memories
|
||||
ToolDefinition(
|
||||
name="recall_memory",
|
||||
description="Recall user memories for personalized and contextual responses",
|
||||
factory=lambda deps: create_recall_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
description="Recall relevant memories (personal or team) for context",
|
||||
factory=lambda deps: (
|
||||
create_recall_shared_memory_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||
else create_recall_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
),
|
||||
requires=["user_id", "search_space_id", "db_session"],
|
||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
),
|
||||
# =========================================================================
|
||||
# ADD YOUR CUSTOM TOOLS BELOW
|
||||
|
|
|
|||
278
surfsense_backend/app/agents/new_chat/tools/shared_memory.py
Normal file
278
surfsense_backend/app/agents/new_chat/tools/shared_memory.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""Shared (team) memory backend for search-space-scoped AI context."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import MemoryCategory, SharedMemory, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RECALL_TOP_K = 5
|
||||
MAX_MEMORIES_PER_SEARCH_SPACE = 250
|
||||
|
||||
|
||||
async def get_shared_memory_count(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> int:
|
||||
result = await db_session.execute(
|
||||
select(SharedMemory).where(SharedMemory.search_space_id == search_space_id)
|
||||
)
|
||||
return len(result.scalars().all())
|
||||
|
||||
|
||||
async def delete_oldest_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> None:
|
||||
result = await db_session.execute(
|
||||
select(SharedMemory)
|
||||
.where(SharedMemory.search_space_id == search_space_id)
|
||||
.order_by(SharedMemory.updated_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
oldest = result.scalars().first()
|
||||
if oldest:
|
||||
await db_session.delete(oldest)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
def _to_uuid(value: str | UUID) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
return UUID(value)
|
||||
|
||||
|
||||
async def save_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
created_by_id: str | UUID,
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
category = category.lower() if category else "fact"
|
||||
valid = ["preference", "fact", "instruction", "context"]
|
||||
if category not in valid:
|
||||
category = "fact"
|
||||
try:
|
||||
count = await get_shared_memory_count(db_session, search_space_id)
|
||||
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
|
||||
await delete_oldest_shared_memory(db_session, search_space_id)
|
||||
embedding = config.embedding_model_instance.embed(content)
|
||||
row = SharedMemory(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=_to_uuid(created_by_id),
|
||||
memory_text=content,
|
||||
category=MemoryCategory(category),
|
||||
embedding=embedding,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(row)
|
||||
return {
|
||||
"status": "saved",
|
||||
"memory_id": row.id,
|
||||
"memory_text": content,
|
||||
"category": category,
|
||||
"message": f"I'll remember: {content}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save shared memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": "Failed to save memory. Please try again.",
|
||||
}
|
||||
|
||||
|
||||
async def recall_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
top_k = min(max(top_k, 1), 20)
|
||||
try:
|
||||
valid_categories = ["preference", "fact", "instruction", "context"]
|
||||
stmt = select(SharedMemory).where(
|
||||
SharedMemory.search_space_id == search_space_id
|
||||
)
|
||||
if category and category in valid_categories:
|
||||
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
|
||||
if query:
|
||||
query_embedding = config.embedding_model_instance.embed(query)
|
||||
stmt = stmt.order_by(
|
||||
SharedMemory.embedding.op("<=>")(query_embedding)
|
||||
).limit(top_k)
|
||||
else:
|
||||
stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k)
|
||||
result = await db_session.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
memory_list = [
|
||||
{
|
||||
"id": m.id,
|
||||
"memory_text": m.memory_text,
|
||||
"category": m.category.value if m.category else "unknown",
|
||||
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
||||
"created_by_id": str(m.created_by_id) if m.created_by_id else None,
|
||||
}
|
||||
for m in rows
|
||||
]
|
||||
created_by_ids = list({m["created_by_id"] for m in memory_list if m["created_by_id"]})
|
||||
created_by_map: dict[str, str] = {}
|
||||
if created_by_ids:
|
||||
uuids = [UUID(uid) for uid in created_by_ids]
|
||||
users_result = await db_session.execute(
|
||||
select(User).where(User.id.in_(uuids))
|
||||
)
|
||||
for u in users_result.scalars().all():
|
||||
created_by_map[str(u.id)] = u.display_name or "A team member"
|
||||
formatted_context = format_shared_memories_for_context(
|
||||
memory_list, created_by_map
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(memory_list),
|
||||
"memories": memory_list,
|
||||
"formatted_context": formatted_context,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to recall shared memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"memories": [],
|
||||
"formatted_context": "Failed to recall memories.",
|
||||
}
|
||||
|
||||
|
||||
def format_shared_memories_for_context(
|
||||
memories: list[dict[str, Any]],
|
||||
created_by_map: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
if not memories:
|
||||
return "No relevant team memories found."
|
||||
created_by_map = created_by_map or {}
|
||||
parts = ["<team_memories>"]
|
||||
for memory in memories:
|
||||
category = memory.get("category", "unknown")
|
||||
text = memory.get("memory_text", "")
|
||||
updated = memory.get("updated_at", "")
|
||||
created_by_id = memory.get("created_by_id")
|
||||
added_by = (
|
||||
created_by_map.get(str(created_by_id), "A team member")
|
||||
if created_by_id is not None
|
||||
else "A team member"
|
||||
)
|
||||
parts.append(
|
||||
f" <memory category='{category}' updated='{updated}' added_by='{added_by}'>{text}</memory>"
|
||||
)
|
||||
parts.append("</team_memories>")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def create_save_shared_memory_tool(
|
||||
search_space_id: int,
|
||||
created_by_id: str | UUID,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the save_memory tool for shared (team) chats.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
created_by_id: The user ID of the person adding the memory
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for saving team memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def save_memory(
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Save a fact, preference, or context to the team's shared memory for future reference.
|
||||
|
||||
Use this tool when:
|
||||
- User or a team member says "remember this", "keep this in mind", or similar in this shared chat
|
||||
- The team agrees on something to remember (e.g., decisions, conventions, where things live)
|
||||
- Someone shares a preference or fact that should be visible to the whole team
|
||||
|
||||
The saved information will be available in future shared conversations in this space.
|
||||
|
||||
Args:
|
||||
content: The fact/preference/context to remember.
|
||||
Phrase it clearly, e.g., "API keys are stored in Vault",
|
||||
"The team prefers weekly demos on Fridays"
|
||||
category: Type of memory. One of:
|
||||
- "preference": Team or workspace preferences
|
||||
- "fact": Facts the team agreed on (e.g., processes, locations)
|
||||
- "instruction": Standing instructions for the team
|
||||
- "context": Current context (e.g., ongoing projects, goals)
|
||||
|
||||
Returns:
|
||||
A dictionary with the save status and memory details
|
||||
"""
|
||||
return await save_shared_memory(
|
||||
db_session, search_space_id, created_by_id, content, category
|
||||
)
|
||||
|
||||
return save_memory
|
||||
|
||||
|
||||
def create_recall_shared_memory_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the recall_memory tool for shared (team) chats.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for recalling team memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def recall_memory(
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Recall relevant team memories for this space to provide contextual responses.
|
||||
|
||||
Use this tool when:
|
||||
- You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?")
|
||||
- Someone asks about something the team agreed to remember
|
||||
- Team preferences or conventions would improve the response
|
||||
|
||||
Args:
|
||||
query: Optional search query to find specific memories.
|
||||
If not provided, returns the most recent memories.
|
||||
category: Optional category filter. One of:
|
||||
"preference", "fact", "instruction", "context"
|
||||
top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
|
||||
Returns:
|
||||
A dictionary containing relevant memories and formatted context
|
||||
"""
|
||||
return await recall_shared_memory(
|
||||
db_session, search_space_id, query, category, top_k
|
||||
)
|
||||
|
||||
return recall_memory
|
||||
|
|
@ -801,9 +801,8 @@ class MemoryCategory(str, Enum):
|
|||
|
||||
class UserMemory(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Stores facts, preferences, and context about users for personalized AI responses.
|
||||
Similar to Claude's memory feature - enables the AI to remember user information
|
||||
across conversations.
|
||||
Private memory: facts, preferences, context per user per search space.
|
||||
Used only for private chats (not shared/team chats).
|
||||
"""
|
||||
|
||||
__tablename__ = "user_memories"
|
||||
|
|
@ -847,6 +846,40 @@ class UserMemory(BaseModel, TimestampMixin):
|
|||
search_space = relationship("SearchSpace", back_populates="user_memories")
|
||||
|
||||
|
||||
class SharedMemory(BaseModel, TimestampMixin):
|
||||
__tablename__ = "shared_memories"
|
||||
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
created_by_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
memory_text = Column(Text, nullable=False)
|
||||
category = Column(
|
||||
SQLAlchemyEnum(MemoryCategory),
|
||||
nullable=False,
|
||||
default=MemoryCategory.fact,
|
||||
)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
search_space = relationship("SearchSpace", back_populates="shared_memories")
|
||||
created_by = relationship("User")
|
||||
|
||||
|
||||
class Document(BaseModel, TimestampMixin):
|
||||
__tablename__ = "documents"
|
||||
|
||||
|
|
@ -1209,6 +1242,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="UserMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
shared_memories = relationship(
|
||||
"SharedMemory",
|
||||
back_populates="search_space",
|
||||
order_by="SharedMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||
|
|
@ -1258,7 +1297,7 @@ class NewLLMConfig(BaseModel, TimestampMixin):
|
|||
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
- Citation toggle (enable/disable citation instructions)
|
||||
|
||||
Note: SURFSENSE_TOOLS_INSTRUCTIONS is always used and not configurable.
|
||||
Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
|
||||
"""
|
||||
|
||||
__tablename__ = "new_llm_configs"
|
||||
|
|
|
|||
|
|
@ -1045,12 +1045,14 @@ async def handle_new_chat(
|
|||
search_space_id=request.search_space_id,
|
||||
chat_id=request.chat_id,
|
||||
session=session,
|
||||
user_id=str(user.id), # Pass user ID for memory tools and session state
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
attachments=request.attachments,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -1281,6 +1283,8 @@ async def regenerate_response(
|
|||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
checkpoint_id=target_checkpoint_id,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
):
|
||||
yield chunk
|
||||
# If we get here, streaming completed successfully
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from app.agents.new_chat.llm_config import (
|
|||
load_agent_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.db import Document, SurfsenseDocsDocument
|
||||
from app.db import ChatVisibility, Document, SurfsenseDocsDocument
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.schemas.new_chat import ChatAttachment
|
||||
from app.services.chat_session_state_service import (
|
||||
|
|
@ -208,6 +208,8 @@ async def stream_new_chat(
|
|||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
needs_history_bootstrap: bool = False,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
current_user_display_name: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream chat responses from the new SurfSense deep agent.
|
||||
|
|
@ -295,17 +297,18 @@ async def stream_new_chat(
|
|||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||
checkpointer = await get_checkpointer()
|
||||
|
||||
# Create the deep agent with checkpointer and configurable prompts
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id, # Pass user ID for memory tools
|
||||
thread_id=chat_id, # Pass chat ID for podcast association
|
||||
agent_config=agent_config, # Pass prompt configuration
|
||||
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
||||
# Build input with message history
|
||||
|
|
@ -313,7 +316,9 @@ async def stream_new_chat(
|
|||
|
||||
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
|
||||
if needs_history_bootstrap:
|
||||
langchain_messages = await bootstrap_history_from_db(session, chat_id)
|
||||
langchain_messages = await bootstrap_history_from_db(
|
||||
session, chat_id, thread_visibility=visibility
|
||||
)
|
||||
|
||||
# Clear the flag so we don't bootstrap again on next message
|
||||
from app.db import NewChatThread
|
||||
|
|
@ -376,6 +381,9 @@ async def stream_new_chat(
|
|||
context = "\n\n".join(context_parts)
|
||||
final_query = f"{context}\n\n<user_query>{user_query}</user_query>"
|
||||
|
||||
if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
|
||||
final_query = f"**[{current_user_display_name}]:** {final_query}"
|
||||
|
||||
# if messages:
|
||||
# # Convert frontend messages to LangChain format
|
||||
# for msg in messages:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ These utilities help extract and transform content for different use cases.
|
|||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
def extract_text_content(content: str | dict | list) -> str:
|
||||
|
|
@ -38,6 +39,7 @@ def extract_text_content(content: str | dict | list) -> str:
|
|||
async def bootstrap_history_from_db(
|
||||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
thread_visibility: "ChatVisibility | None" = None,
|
||||
) -> list[HumanMessage | AIMessage]:
|
||||
"""
|
||||
Load message history from database and convert to LangChain format.
|
||||
|
|
@ -45,20 +47,28 @@ async def bootstrap_history_from_db(
|
|||
Used for cloned chats where the LangGraph checkpointer has no state,
|
||||
but we have messages in the database that should be used as context.
|
||||
|
||||
When thread_visibility is SEARCH_SPACE, user messages are prefixed with
|
||||
the author's display name so the LLM sees who said what.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
thread_id: The chat thread ID
|
||||
thread_visibility: When SEARCH_SPACE, user messages get author prefix
|
||||
|
||||
Returns:
|
||||
List of LangChain messages (HumanMessage/AIMessage)
|
||||
"""
|
||||
from app.db import NewChatMessage
|
||||
from app.db import ChatVisibility, NewChatMessage
|
||||
|
||||
result = await session.execute(
|
||||
is_shared = thread_visibility == ChatVisibility.SEARCH_SPACE
|
||||
stmt = (
|
||||
select(NewChatMessage)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at)
|
||||
)
|
||||
if is_shared:
|
||||
stmt = stmt.options(selectinload(NewChatMessage.author))
|
||||
result = await session.execute(stmt)
|
||||
db_messages = result.scalars().all()
|
||||
|
||||
langchain_messages: list[HumanMessage | AIMessage] = []
|
||||
|
|
@ -68,6 +78,11 @@ async def bootstrap_history_from_db(
|
|||
if not text_content:
|
||||
continue
|
||||
if msg.role == "user":
|
||||
if is_shared:
|
||||
author_name = (
|
||||
(msg.author.display_name if msg.author else None) or "A team member"
|
||||
)
|
||||
text_content = f"**[{author_name}]:** {text_content}"
|
||||
langchain_messages.append(HumanMessage(content=text_content))
|
||||
elif msg.role == "assistant":
|
||||
langchain_messages.append(AIMessage(content=text_content))
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
import { atomWithQuery } from "jotai-tanstack-query";
|
||||
import { userApiService } from "@/lib/apis/user-api.service";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import { getBearerToken, isPublicRoute } from "@/lib/auth-utils";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
|
||||
export const currentUserAtom = atomWithQuery(() => {
|
||||
const pathname = typeof window !== "undefined" ? window.location.pathname : null;
|
||||
return {
|
||||
queryKey: cacheKeys.user.current(),
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
// Only fetch user data when a bearer token is present
|
||||
enabled: !!getBearerToken(),
|
||||
queryFn: async () => {
|
||||
return userApiService.getMe();
|
||||
},
|
||||
enabled: !!getBearerToken() && pathname !== null && !isPublicRoute(pathname),
|
||||
queryFn: async () => userApiService.getMe(),
|
||||
};
|
||||
});
|
||||
|
|
|
|||
|
|
@ -10,28 +10,53 @@ const REFRESH_TOKEN_KEY = "surfsense_refresh_token";
|
|||
let isRefreshing = false;
|
||||
let refreshPromise: Promise<string | null> | null = null;
|
||||
|
||||
/** Path prefixes for routes that do not require auth (no current-user fetch, no redirect on 401) */
|
||||
const PUBLIC_ROUTE_PREFIXES = [
|
||||
"/login",
|
||||
"/register",
|
||||
"/auth",
|
||||
"/docs",
|
||||
"/public",
|
||||
"/invite",
|
||||
"/contact",
|
||||
"/pricing",
|
||||
"/privacy",
|
||||
"/terms",
|
||||
"/changelog",
|
||||
];
|
||||
|
||||
/**
|
||||
* Saves the current path and redirects to login page
|
||||
* Call this when a 401 response is received
|
||||
* Returns true if the pathname is a public route where we should not run auth checks
|
||||
* or redirect to login on 401.
|
||||
*/
|
||||
export function isPublicRoute(pathname: string): boolean {
|
||||
if (pathname === "/" || pathname === "") return true;
|
||||
return PUBLIC_ROUTE_PREFIXES.some((prefix) => pathname.startsWith(prefix));
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears tokens and optionally redirects to login.
|
||||
* Call this when a 401 response is received.
|
||||
* Only redirects when the current route is protected; on public routes we just clear tokens.
|
||||
*/
|
||||
export function handleUnauthorized(): void {
|
||||
if (typeof window === "undefined") return;
|
||||
|
||||
// Save the current path (including search params and hash) for redirect after login
|
||||
const currentPath = window.location.pathname + window.location.search + window.location.hash;
|
||||
const pathname = window.location.pathname;
|
||||
|
||||
// Don't save auth-related paths
|
||||
const excludedPaths = ["/auth", "/auth/callback", "/"];
|
||||
if (!excludedPaths.includes(window.location.pathname)) {
|
||||
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
|
||||
}
|
||||
|
||||
// Clear both tokens
|
||||
// Always clear tokens
|
||||
localStorage.removeItem(BEARER_TOKEN_KEY);
|
||||
localStorage.removeItem(REFRESH_TOKEN_KEY);
|
||||
|
||||
// Redirect to home page (which has login options)
|
||||
window.location.href = "/login";
|
||||
// Only redirect on protected routes; stay on public pages (e.g. /docs)
|
||||
if (!isPublicRoute(pathname)) {
|
||||
const currentPath = pathname + window.location.search + window.location.hash;
|
||||
const excludedPaths = ["/auth", "/auth/callback", "/"];
|
||||
if (!excludedPaths.includes(pathname)) {
|
||||
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
|
||||
}
|
||||
window.location.href = "/login";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -179,7 +204,6 @@ export function getAuthHeaders(additionalHeaders?: Record<string, string>): Reco
|
|||
/**
|
||||
* Attempts to refresh the access token using the stored refresh token.
|
||||
* Returns the new access token if successful, null otherwise.
|
||||
* Exported for use by API services.
|
||||
*/
|
||||
export async function refreshAccessToken(): Promise<string | null> {
|
||||
// If already refreshing, wait for that request to complete
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue