mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-11 16:52:38 +02:00
Merge pull request #1200 from AnishSarkar22/refactor/persistent-memory
refactor: persistent memory
This commit is contained in:
commit
b96dc49c8a
52 changed files with 2622 additions and 1415 deletions
|
|
@ -0,0 +1,38 @@
|
||||||
|
"""Add memory_md columns to user and searchspaces tables
|
||||||
|
|
||||||
|
Revision ID: 121
|
||||||
|
Revises: 120
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
1. Add memory_md TEXT column to user table (personal memory)
|
||||||
|
2. Add shared_memory_md TEXT column to searchspaces table (team memory)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "121"
|
||||||
|
down_revision: str | None = "120"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"user",
|
||||||
|
sa.Column("memory_md", sa.Text(), nullable=True, server_default=""),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"searchspaces",
|
||||||
|
sa.Column("shared_memory_md", sa.Text(), nullable=True, server_default=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("searchspaces", "shared_memory_md")
|
||||||
|
op.drop_column("user", "memory_md")
|
||||||
|
|
@ -0,0 +1,247 @@
|
||||||
|
"""Migrate row-per-fact memories to markdown, then drop legacy tables
|
||||||
|
|
||||||
|
Revision ID: 122
|
||||||
|
Revises: 121
|
||||||
|
|
||||||
|
Converts user_memories rows into per-user markdown documents stored in
|
||||||
|
user.memory_md, and shared_memories rows into per-search-space markdown
|
||||||
|
stored in searchspaces.shared_memory_md. Then drops the old tables and
|
||||||
|
the memorycategory enum.
|
||||||
|
|
||||||
|
The markdown format matches the new memory system:
|
||||||
|
## Heading
|
||||||
|
- (YYYY-MM-DD) [fact|pref|instr] memory text
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import inspect as sa_inspect
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
revision: str = "122"
|
||||||
|
down_revision: str | None = "121"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
EMBEDDING_DIM = config.embedding_model_instance.dimension
|
||||||
|
|
||||||
|
_CATEGORY_TO_MARKER = {
|
||||||
|
"fact": "fact",
|
||||||
|
"context": "fact",
|
||||||
|
"preference": "pref",
|
||||||
|
"instruction": "instr",
|
||||||
|
}
|
||||||
|
|
||||||
|
_CATEGORY_HEADING = {
|
||||||
|
"fact": "Facts",
|
||||||
|
"preference": "Preferences",
|
||||||
|
"instruction": "Instructions",
|
||||||
|
"context": "Context",
|
||||||
|
}
|
||||||
|
|
||||||
|
_HEADING_ORDER = ["fact", "preference", "instruction", "context"]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_markdown(rows: list[tuple]) -> str:
|
||||||
|
"""Build a markdown document from (memory_text, category, created_at) rows."""
|
||||||
|
by_category: dict[str, list[str]] = defaultdict(list)
|
||||||
|
|
||||||
|
for memory_text, category, created_at in rows:
|
||||||
|
cat = str(category)
|
||||||
|
marker = _CATEGORY_TO_MARKER.get(cat, "fact")
|
||||||
|
date_str = created_at.strftime("%Y-%m-%d")
|
||||||
|
clean_text = str(memory_text).replace("\n", " ").strip()
|
||||||
|
bullet = f"- ({date_str}) [{marker}] {clean_text}"
|
||||||
|
by_category[cat].append(bullet)
|
||||||
|
|
||||||
|
sections: list[str] = []
|
||||||
|
for cat in _HEADING_ORDER:
|
||||||
|
if cat in by_category:
|
||||||
|
heading = _CATEGORY_HEADING[cat]
|
||||||
|
sections.append(f"## {heading}")
|
||||||
|
sections.extend(by_category[cat])
|
||||||
|
sections.append("")
|
||||||
|
|
||||||
|
return "\n".join(sections).strip() + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_user_memories(conn: sa.engine.Connection) -> None:
|
||||||
|
"""Convert user_memories rows → user.memory_md grouped by user_id."""
|
||||||
|
rows = conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT user_id, memory_text, category::text, created_at "
|
||||||
|
"FROM user_memories ORDER BY created_at"
|
||||||
|
)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
logger.info("user_memories is empty, skipping data migration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
by_user: dict[UUID, list[tuple]] = defaultdict(list)
|
||||||
|
for user_id, memory_text, category, created_at in rows:
|
||||||
|
by_user[user_id].append((memory_text, category, created_at))
|
||||||
|
|
||||||
|
migrated = 0
|
||||||
|
for uid, user_rows in by_user.items():
|
||||||
|
existing = conn.execute(
|
||||||
|
sa.text('SELECT memory_md FROM "user" WHERE id = :uid'),
|
||||||
|
{"uid": uid},
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
if existing and existing.strip():
|
||||||
|
logger.info("User %s already has memory_md, skipping.", uid)
|
||||||
|
continue
|
||||||
|
|
||||||
|
markdown = _build_markdown(user_rows)
|
||||||
|
conn.execute(
|
||||||
|
sa.text('UPDATE "user" SET memory_md = :md WHERE id = :uid'),
|
||||||
|
{"md": markdown, "uid": uid},
|
||||||
|
)
|
||||||
|
migrated += 1
|
||||||
|
|
||||||
|
logger.info("Migrated user_memories for %d user(s).", migrated)
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_shared_memories(conn: sa.engine.Connection) -> None:
|
||||||
|
"""Convert shared_memories rows → searchspaces.shared_memory_md."""
|
||||||
|
rows = conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT search_space_id, memory_text, category::text, created_at "
|
||||||
|
"FROM shared_memories ORDER BY created_at"
|
||||||
|
)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
logger.info("shared_memories is empty, skipping data migration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
by_space: dict[int, list[tuple]] = defaultdict(list)
|
||||||
|
for search_space_id, memory_text, category, created_at in rows:
|
||||||
|
by_space[search_space_id].append((memory_text, category, created_at))
|
||||||
|
|
||||||
|
migrated = 0
|
||||||
|
for space_id, space_rows in by_space.items():
|
||||||
|
existing = conn.execute(
|
||||||
|
sa.text("SELECT shared_memory_md FROM searchspaces WHERE id = :sid"),
|
||||||
|
{"sid": space_id},
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
if existing and existing.strip():
|
||||||
|
logger.info(
|
||||||
|
"Search space %s already has shared_memory_md, skipping.", space_id
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
markdown = _build_markdown(space_rows)
|
||||||
|
conn.execute(
|
||||||
|
sa.text("UPDATE searchspaces SET shared_memory_md = :md WHERE id = :sid"),
|
||||||
|
{"md": markdown, "sid": space_id},
|
||||||
|
)
|
||||||
|
migrated += 1
|
||||||
|
|
||||||
|
logger.info("Migrated shared_memories for %d search space(s).", migrated)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
inspector = sa_inspect(conn)
|
||||||
|
tables = inspector.get_table_names()
|
||||||
|
|
||||||
|
if "user_memories" in tables:
|
||||||
|
_migrate_user_memories(conn)
|
||||||
|
|
||||||
|
if "shared_memories" in tables:
|
||||||
|
_migrate_shared_memories(conn)
|
||||||
|
|
||||||
|
op.execute("DROP TABLE IF EXISTS shared_memories CASCADE;")
|
||||||
|
op.execute("DROP TABLE IF EXISTS user_memories CASCADE;")
|
||||||
|
op.execute("DROP TYPE IF EXISTS memorycategory;")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'memorycategory') THEN
|
||||||
|
CREATE TYPE memorycategory AS ENUM (
|
||||||
|
'preference',
|
||||||
|
'fact',
|
||||||
|
'instruction',
|
||||||
|
'context'
|
||||||
|
);
|
||||||
|
END IF;
|
||||||
|
END$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS user_memories (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
search_space_id INTEGER REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||||
|
memory_text TEXT NOT NULL,
|
||||||
|
category memorycategory NOT NULL DEFAULT 'fact',
|
||||||
|
embedding vector({EMBEDDING_DIM}),
|
||||||
|
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_user_memories_user_id ON user_memories(user_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_user_memories_search_space_id ON user_memories(search_space_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_user_memories_updated_at ON user_memories(updated_at);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_user_memories_category ON user_memories(category);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_user_memories_user_search_space ON user_memories(user_id, search_space_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS user_memories_vector_index ON user_memories USING hnsw (embedding public.vector_cosine_ops);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS shared_memories (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||||
|
created_by_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
memory_text TEXT NOT NULL,
|
||||||
|
category memorycategory NOT NULL DEFAULT 'fact',
|
||||||
|
embedding vector({EMBEDDING_DIM})
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_shared_memories_search_space_id ON shared_memories(search_space_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_shared_memories_updated_at ON shared_memories(updated_at);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_shared_memories_created_by_id ON shared_memories(created_by_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS shared_memories_vector_index ON shared_memories USING hnsw (embedding public.vector_cosine_ops);"
|
||||||
|
)
|
||||||
|
|
@ -38,6 +38,7 @@ from app.agents.new_chat.llm_config import AgentConfig
|
||||||
from app.agents.new_chat.middleware import (
|
from app.agents.new_chat.middleware import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
KnowledgeBaseSearchMiddleware,
|
KnowledgeBaseSearchMiddleware,
|
||||||
|
MemoryInjectionMiddleware,
|
||||||
SurfSenseFilesystemMiddleware,
|
SurfSenseFilesystemMiddleware,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.system_prompt import (
|
from app.agents.new_chat.system_prompt import (
|
||||||
|
|
@ -168,8 +169,7 @@ async def create_surfsense_deep_agent(
|
||||||
- generate_podcast: Generate audio podcasts from content
|
- generate_podcast: Generate audio podcasts from content
|
||||||
- generate_image: Generate images from text descriptions using AI models
|
- generate_image: Generate images from text descriptions using AI models
|
||||||
- scrape_webpage: Extract content from webpages
|
- scrape_webpage: Extract content from webpages
|
||||||
- save_memory: Store facts/preferences about the user
|
- update_memory: Update the user's personal or team memory document
|
||||||
- recall_memory: Retrieve relevant user memories
|
|
||||||
|
|
||||||
The agent also includes TodoListMiddleware by default (via create_deep_agent) which provides:
|
The agent also includes TodoListMiddleware by default (via create_deep_agent) which provides:
|
||||||
- write_todos: Create and update planning/todo lists for complex tasks
|
- write_todos: Create and update planning/todo lists for complex tasks
|
||||||
|
|
@ -281,6 +281,7 @@ async def create_surfsense_deep_agent(
|
||||||
"available_connectors": available_connectors,
|
"available_connectors": available_connectors,
|
||||||
"available_document_types": available_document_types,
|
"available_document_types": available_document_types,
|
||||||
"max_input_tokens": _max_input_tokens,
|
"max_input_tokens": _max_input_tokens,
|
||||||
|
"llm": llm,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Disable Notion action tools if no Notion connector is configured
|
# Disable Notion action tools if no Notion connector is configured
|
||||||
|
|
@ -425,9 +426,16 @@ async def create_surfsense_deep_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
# -- Build the middleware stack (mirrors create_deep_agent internals) ------
|
# -- Build the middleware stack (mirrors create_deep_agent internals) ------
|
||||||
|
_memory_middleware = MemoryInjectionMiddleware(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
thread_visibility=visibility,
|
||||||
|
)
|
||||||
|
|
||||||
# General-purpose subagent middleware
|
# General-purpose subagent middleware
|
||||||
gp_middleware = [
|
gp_middleware = [
|
||||||
TodoListMiddleware(),
|
TodoListMiddleware(),
|
||||||
|
_memory_middleware,
|
||||||
SurfSenseFilesystemMiddleware(
|
SurfSenseFilesystemMiddleware(
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
created_by_id=user_id,
|
created_by_id=user_id,
|
||||||
|
|
@ -447,6 +455,7 @@ async def create_surfsense_deep_agent(
|
||||||
# Main agent middleware
|
# Main agent middleware
|
||||||
deepagent_middleware = [
|
deepagent_middleware = [
|
||||||
TodoListMiddleware(),
|
TodoListMiddleware(),
|
||||||
|
_memory_middleware,
|
||||||
KnowledgeBaseSearchMiddleware(
|
KnowledgeBaseSearchMiddleware(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
|
||||||
239
surfsense_backend/app/agents/new_chat/memory_extraction.py
Normal file
239
surfsense_backend/app/agents/new_chat/memory_extraction.py
Normal file
|
|
@ -0,0 +1,239 @@
|
||||||
|
"""Background memory extraction for the SurfSense agent.
|
||||||
|
|
||||||
|
After each agent response, if the agent did not call ``update_memory`` during
|
||||||
|
the turn, this module can run a lightweight LLM call to decide whether the
|
||||||
|
latest message contains long-term information worth persisting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.update_memory import _save_memory
|
||||||
|
from app.db import SearchSpace, User, shielded_async_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MEMORY_EXTRACT_PROMPT = """\
|
||||||
|
You are a memory extraction assistant. Analyze the user's message and decide \
|
||||||
|
if it contains any long-term information worth persisting to memory.
|
||||||
|
|
||||||
|
Worth remembering: preferences, background/identity, goals, projects, \
|
||||||
|
instructions, tools/languages they use, decisions, expertise, workplace — \
|
||||||
|
durable facts that will matter in future conversations.
|
||||||
|
|
||||||
|
NOT worth remembering: greetings, one-off factual questions, session \
|
||||||
|
logistics, ephemeral requests, follow-up clarifications with no new personal \
|
||||||
|
info, things that only matter for the current task.
|
||||||
|
|
||||||
|
If the message contains memorizable information, output the FULL updated \
|
||||||
|
memory document with the new facts merged into the existing content. Follow \
|
||||||
|
these rules:
|
||||||
|
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
||||||
|
freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's
|
||||||
|
name in headings.
|
||||||
|
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
||||||
|
details and context rather than just a few words.
|
||||||
|
- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||||
|
[fact] = durable facts, [pref] = preferences, [instr] = standing instructions.
|
||||||
|
- Use the user's first name (from <user_name>) in entry text, not "the user".
|
||||||
|
- If a new fact contradicts an existing entry, update the existing entry.
|
||||||
|
- Do not duplicate information that is already present.
|
||||||
|
|
||||||
|
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||||
|
|
||||||
|
<user_name>{user_name}</user_name>
|
||||||
|
|
||||||
|
<current_memory>
|
||||||
|
{current_memory}
|
||||||
|
</current_memory>
|
||||||
|
|
||||||
|
<user_message>
|
||||||
|
{user_message}
|
||||||
|
</user_message>"""
|
||||||
|
|
||||||
|
_TEAM_MEMORY_EXTRACT_PROMPT = """\
|
||||||
|
You are a team-memory extraction assistant. Analyze the latest message and \
|
||||||
|
decide if it contains durable TEAM-level information worth persisting.
|
||||||
|
|
||||||
|
Decision policy:
|
||||||
|
- Prioritize recall for durable team context, while avoiding personal-only facts.
|
||||||
|
- Do NOT require explicit consensus language. A direct team-level statement can
|
||||||
|
be stored if it is stable and broadly useful for future team chats.
|
||||||
|
- If evidence is weak or clearly tentative, output NO_UPDATE.
|
||||||
|
|
||||||
|
Worth remembering (team-level only):
|
||||||
|
- Decisions and defaults that guide future team work
|
||||||
|
- Team conventions/standards (naming, review policy, coding norms)
|
||||||
|
- Stable org/project facts (locations, ownership, constraints)
|
||||||
|
- Long-lived architecture/process facts
|
||||||
|
- Ongoing priorities that are likely relevant beyond this turn
|
||||||
|
|
||||||
|
NOT worth remembering:
|
||||||
|
- Personal preferences or biography of one person
|
||||||
|
- Questions, brainstorming, tentative ideas, or speculation
|
||||||
|
- One-off requests, status updates, TODOs, logistics for this session
|
||||||
|
- Information scoped only to a single ephemeral task
|
||||||
|
|
||||||
|
If the message contains memorizable team information, output the FULL updated \
|
||||||
|
team memory document with new facts merged into existing content. Follow rules:
|
||||||
|
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
||||||
|
freely. Keep heading names short (2-3 words) and natural.
|
||||||
|
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
||||||
|
details and context rather than just a few words.
|
||||||
|
- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text
|
||||||
|
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr].
|
||||||
|
- If a new fact contradicts an existing entry, update the existing entry.
|
||||||
|
- Do not duplicate existing information.
|
||||||
|
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
|
||||||
|
|
||||||
|
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||||
|
|
||||||
|
<current_team_memory>
|
||||||
|
{current_memory}
|
||||||
|
</current_team_memory>
|
||||||
|
|
||||||
|
<latest_message_author>
|
||||||
|
{author}
|
||||||
|
</latest_message_author>
|
||||||
|
|
||||||
|
<latest_message>
|
||||||
|
{user_message}
|
||||||
|
</latest_message>"""
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_and_save_memory(
|
||||||
|
*,
|
||||||
|
user_message: str,
|
||||||
|
user_id: str | None,
|
||||||
|
llm: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Background task: extract memorizable info and persist it.
|
||||||
|
|
||||||
|
Designed to be fire-and-forget — catches all exceptions internally.
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
result = await session.execute(select(User).where(User.id == uid))
|
||||||
|
user = result.scalars().first()
|
||||||
|
if not user:
|
||||||
|
return
|
||||||
|
|
||||||
|
old_memory = user.memory_md
|
||||||
|
first_name = (
|
||||||
|
user.display_name.strip().split()[0]
|
||||||
|
if user.display_name and user.display_name.strip()
|
||||||
|
else "The user"
|
||||||
|
)
|
||||||
|
prompt = _MEMORY_EXTRACT_PROMPT.format(
|
||||||
|
current_memory=old_memory or "(empty)",
|
||||||
|
user_message=user_message,
|
||||||
|
user_name=first_name,
|
||||||
|
)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal", "memory-extraction"]},
|
||||||
|
)
|
||||||
|
text = (
|
||||||
|
response.content
|
||||||
|
if isinstance(response.content, str)
|
||||||
|
else str(response.content)
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
if text == "NO_UPDATE" or not text:
|
||||||
|
logger.debug("Memory extraction: no update needed (user %s)", uid)
|
||||||
|
return
|
||||||
|
|
||||||
|
save_result = await _save_memory(
|
||||||
|
updated_memory=text,
|
||||||
|
old_memory=old_memory,
|
||||||
|
llm=llm,
|
||||||
|
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||||
|
commit_fn=session.commit,
|
||||||
|
rollback_fn=session.rollback,
|
||||||
|
label="memory",
|
||||||
|
scope="user",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Background memory extraction for user %s: %s",
|
||||||
|
uid,
|
||||||
|
save_result.get("status"),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Background user memory extraction failed")
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_and_save_team_memory(
|
||||||
|
*,
|
||||||
|
user_message: str,
|
||||||
|
search_space_id: int | None,
|
||||||
|
llm: Any,
|
||||||
|
author_display_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Background task: extract team-level memory and persist it.
|
||||||
|
|
||||||
|
Runs only for shared threads. Designed to be fire-and-forget and catches
|
||||||
|
exceptions internally.
|
||||||
|
"""
|
||||||
|
if not search_space_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||||
|
)
|
||||||
|
space = result.scalars().first()
|
||||||
|
if not space:
|
||||||
|
return
|
||||||
|
|
||||||
|
old_memory = space.shared_memory_md
|
||||||
|
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
|
||||||
|
current_memory=old_memory or "(empty)",
|
||||||
|
author=author_display_name or "Unknown team member",
|
||||||
|
user_message=user_message,
|
||||||
|
)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
|
||||||
|
)
|
||||||
|
text = (
|
||||||
|
response.content
|
||||||
|
if isinstance(response.content, str)
|
||||||
|
else str(response.content)
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
if text == "NO_UPDATE" or not text:
|
||||||
|
logger.debug(
|
||||||
|
"Team memory extraction: no update needed (space %s)",
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
save_result = await _save_memory(
|
||||||
|
updated_memory=text,
|
||||||
|
old_memory=old_memory,
|
||||||
|
llm=llm,
|
||||||
|
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||||
|
commit_fn=session.commit,
|
||||||
|
rollback_fn=session.rollback,
|
||||||
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Background team memory extraction for space %s: %s",
|
||||||
|
search_space_id,
|
||||||
|
save_result.get("status"),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Background team memory extraction failed")
|
||||||
|
|
@ -9,9 +9,13 @@ from app.agents.new_chat.middleware.filesystem import (
|
||||||
from app.agents.new_chat.middleware.knowledge_search import (
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
KnowledgeBaseSearchMiddleware,
|
KnowledgeBaseSearchMiddleware,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.memory_injection import (
|
||||||
|
MemoryInjectionMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DedupHITLToolCallsMiddleware",
|
"DedupHITLToolCallsMiddleware",
|
||||||
"KnowledgeBaseSearchMiddleware",
|
"KnowledgeBaseSearchMiddleware",
|
||||||
|
"MemoryInjectionMiddleware",
|
||||||
"SurfSenseFilesystemMiddleware",
|
"SurfSenseFilesystemMiddleware",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,138 @@
|
||||||
|
"""Memory injection middleware for the SurfSense agent.
|
||||||
|
|
||||||
|
Injects memory markdown into the system prompt on every turn:
|
||||||
|
- Private threads: only personal memory (<user_memory>)
|
||||||
|
- Shared threads: only team memory (<team_memory>)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
|
||||||
|
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Injects memory markdown into the conversation on every turn."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: str | UUID | None,
|
||||||
|
search_space_id: int,
|
||||||
|
thread_visibility: ChatVisibility | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.user_id = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
self.search_space_id = search_space_id
|
||||||
|
self.visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
|
||||||
|
async def abefore_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
messages = state.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
last_message = messages[-1]
|
||||||
|
if not isinstance(last_message, HumanMessage):
|
||||||
|
return None
|
||||||
|
|
||||||
|
memory_blocks: list[str] = []
|
||||||
|
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||||
|
team_memory = await self._load_team_memory(session)
|
||||||
|
if team_memory:
|
||||||
|
chars = len(team_memory)
|
||||||
|
memory_blocks.append(
|
||||||
|
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||||
|
f"{team_memory}\n"
|
||||||
|
f"</team_memory>"
|
||||||
|
)
|
||||||
|
if chars > MEMORY_SOFT_LIMIT:
|
||||||
|
memory_blocks.append(
|
||||||
|
f"<memory_warning>Team memory is at "
|
||||||
|
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||||
|
f"the hard limit. On your next update_memory call, consolidate "
|
||||||
|
f"by merging duplicates, removing outdated entries, and "
|
||||||
|
f"shortening descriptions before adding anything new."
|
||||||
|
f"</memory_warning>"
|
||||||
|
)
|
||||||
|
elif self.user_id is not None:
|
||||||
|
user_memory, display_name = await self._load_user_memory(session)
|
||||||
|
if display_name and display_name.strip():
|
||||||
|
first_name = display_name.strip().split()[0]
|
||||||
|
memory_blocks.append(f"<user_name>{first_name}</user_name>")
|
||||||
|
if user_memory:
|
||||||
|
chars = len(user_memory)
|
||||||
|
memory_blocks.append(
|
||||||
|
f'<user_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||||
|
f"{user_memory}\n"
|
||||||
|
f"</user_memory>"
|
||||||
|
)
|
||||||
|
if chars > MEMORY_SOFT_LIMIT:
|
||||||
|
memory_blocks.append(
|
||||||
|
f"<memory_warning>Your personal memory is at "
|
||||||
|
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||||
|
f"the hard limit. On your next update_memory call, consolidate "
|
||||||
|
f"by merging duplicates, removing outdated entries, and "
|
||||||
|
f"shortening descriptions before adding anything new."
|
||||||
|
f"</memory_warning>"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not memory_blocks:
|
||||||
|
return None
|
||||||
|
|
||||||
|
memory_text = "\n\n".join(memory_blocks)
|
||||||
|
memory_msg = SystemMessage(content=memory_text)
|
||||||
|
|
||||||
|
new_messages = list(messages)
|
||||||
|
insert_idx = 1 if len(new_messages) > 1 else 0
|
||||||
|
new_messages.insert(insert_idx, memory_msg)
|
||||||
|
|
||||||
|
return {"messages": new_messages}
|
||||||
|
|
||||||
|
async def _load_user_memory(
|
||||||
|
self, session: AsyncSession
|
||||||
|
) -> tuple[str | None, str | None]:
|
||||||
|
"""Return (memory_content, display_name)."""
|
||||||
|
try:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User.memory_md, User.display_name).where(User.id == self.user_id)
|
||||||
|
)
|
||||||
|
row = result.one_or_none()
|
||||||
|
if row is None:
|
||||||
|
return None, None
|
||||||
|
return row.memory_md or None, row.display_name
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load user memory")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def _load_team_memory(self, session: AsyncSession) -> str | None:
|
||||||
|
try:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSpace.shared_memory_md).where(
|
||||||
|
SearchSpace.id == self.search_space_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
return row if row else None
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load team memory")
|
||||||
|
return None
|
||||||
|
|
@ -40,6 +40,13 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||||
</knowledge_base_only_policy>
|
</knowledge_base_only_policy>
|
||||||
|
|
||||||
|
<memory_protocol>
|
||||||
|
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||||
|
reveal durable facts about the user (role, interests, preferences, projects,
|
||||||
|
background, or standing instructions)? If yes, you MUST call update_memory
|
||||||
|
alongside your normal response — do not defer this to a later turn.
|
||||||
|
</memory_protocol>
|
||||||
|
|
||||||
</system_instruction>
|
</system_instruction>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -71,6 +78,13 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||||
</knowledge_base_only_policy>
|
</knowledge_base_only_policy>
|
||||||
|
|
||||||
|
<memory_protocol>
|
||||||
|
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||||
|
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
||||||
|
or key facts)? If yes, you MUST call update_memory alongside your normal response —
|
||||||
|
do not defer this to a later turn.
|
||||||
|
</memory_protocol>
|
||||||
|
|
||||||
</system_instruction>
|
</system_instruction>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -248,115 +262,97 @@ _TOOL_INSTRUCTIONS["web_search"] = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Memory tool instructions have private and shared variants.
|
# Memory tool instructions have private and shared variants.
|
||||||
# We store them keyed as "save_memory" / "recall_memory" with sub-keys.
|
# We store them keyed as "update_memory" with sub-keys.
|
||||||
_MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
_MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
||||||
"save_memory": {
|
"update_memory": {
|
||||||
"private": """
|
"private": """
|
||||||
- save_memory: Save facts, preferences, or context for personalized responses.
|
- update_memory: Update your personal memory document about the user.
|
||||||
- Use this when the user explicitly or implicitly shares information worth remembering.
|
- Your current memory is already in <user_memory> in your context. The `chars` and
|
||||||
- Trigger scenarios:
|
`limit` attributes show your current usage and the maximum allowed size.
|
||||||
* User says "remember this", "keep this in mind", "note that", or similar
|
- This is your curated long-term memory — the distilled essence of what you know about
|
||||||
* User shares personal preferences (e.g., "I prefer Python over JavaScript")
|
the user, not raw conversation logs.
|
||||||
* User shares facts about themselves (e.g., "I'm a senior developer at Company X")
|
- Call update_memory when:
|
||||||
* User gives standing instructions (e.g., "always respond in bullet points")
|
* The user explicitly asks to remember or forget something
|
||||||
* User shares project context (e.g., "I'm working on migrating our codebase to TypeScript")
|
* The user shares durable facts or preferences that will matter in future conversations
|
||||||
|
- The user's first name is provided in <user_name>. Use it in memory entries
|
||||||
|
instead of "the user" (e.g. "{name} works at..." not "The user works at...").
|
||||||
|
Do not store the name itself as a separate memory entry.
|
||||||
|
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
||||||
|
session logistics, or things that only matter for the current task.
|
||||||
- Args:
|
- Args:
|
||||||
- content: The fact/preference to remember. Phrase it clearly:
|
- updated_memory: The FULL updated markdown document (not a diff).
|
||||||
* "User prefers dark mode for all interfaces"
|
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
||||||
* "User is a senior Python developer"
|
Treat every update as a curation pass — consolidate, don't just append.
|
||||||
* "User wants responses in bullet point format"
|
- Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text
|
||||||
* "User is working on project called ProjectX"
|
Markers:
|
||||||
- category: Type of memory:
|
[fact] — durable facts (role, background, projects, tools, expertise)
|
||||||
* "preference": User preferences (coding style, tools, formats)
|
[pref] — preferences (response style, languages, formats, tools)
|
||||||
* "fact": Facts about the user (role, expertise, background)
|
[instr] — standing instructions (always/never do, response rules)
|
||||||
* "instruction": Standing instructions (response format, communication style)
|
- Keep it concise and well under the character limit shown in <user_memory>.
|
||||||
* "context": Current context (ongoing projects, goals, challenges)
|
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
||||||
- Returns: Confirmation of saved memory
|
natural. Do NOT include the user's name in headings. Organize by context — e.g.
|
||||||
- IMPORTANT: Only save information that would be genuinely useful for future conversations.
|
who they are, what they're focused on, how they prefer things. Create, split, or
|
||||||
Don't save trivial or temporary information.
|
merge headings freely as the memory grows.
|
||||||
|
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
||||||
|
details and context rather than just a few words.
|
||||||
|
- During consolidation, prioritize keeping: [instr] > [pref] > [fact].
|
||||||
""",
|
""",
|
||||||
"shared": """
|
"shared": """
|
||||||
- save_memory: Save a fact, preference, or context to the team's shared memory for future reference.
|
- update_memory: Update the team's shared memory document for this search space.
|
||||||
- Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat.
|
- Your current team memory is already in <team_memory> in your context. The `chars`
|
||||||
- Use when the team agrees on something to remember (e.g., decisions, conventions).
|
and `limit` attributes show current usage and the maximum allowed size.
|
||||||
- Someone shares a preference or fact that should be visible to the whole team.
|
- This is the team's curated long-term memory — decisions, conventions, key facts.
|
||||||
- The saved information will be available in future shared conversations in this space.
|
- NEVER store personal memory in team memory (e.g. personal bio, individual
|
||||||
|
preferences, or user-only standing instructions).
|
||||||
|
- Call update_memory when:
|
||||||
|
* A team member explicitly asks to remember or forget something
|
||||||
|
* The conversation surfaces durable team decisions, conventions, or facts
|
||||||
|
that will matter in future conversations
|
||||||
|
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
||||||
|
session logistics, or things that only matter for the current task.
|
||||||
- Args:
|
- Args:
|
||||||
- content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays"
|
- updated_memory: The FULL updated markdown document (not a diff).
|
||||||
- category: Type of memory. One of:
|
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
||||||
* "preference": Team or workspace preferences
|
Treat every update as a curation pass — consolidate, don't just append.
|
||||||
* "fact": Facts the team agreed on (e.g., processes, locations)
|
- Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text
|
||||||
* "instruction": Standing instructions for the team
|
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory.
|
||||||
* "context": Current context (e.g., ongoing projects, goals)
|
- Keep it concise and well under the character limit shown in <team_memory>.
|
||||||
- Returns: Confirmation of saved memory; returned context may include who added it (added_by).
|
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
||||||
- IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space.
|
natural. Organize by context — e.g. what the team decided, current architecture,
|
||||||
""",
|
active processes. Create, split, or merge headings freely as the memory grows.
|
||||||
},
|
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
||||||
"recall_memory": {
|
details and context rather than just a few words.
|
||||||
"private": """
|
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
|
||||||
- recall_memory: Retrieve relevant memories about the user for personalized responses.
|
|
||||||
- Use this to access stored information about the user.
|
|
||||||
- Trigger scenarios:
|
|
||||||
* You need user context to give a better, more personalized answer
|
|
||||||
* User references something they mentioned before
|
|
||||||
* User asks "what do you know about me?" or similar
|
|
||||||
* Personalization would significantly improve response quality
|
|
||||||
* Before making recommendations that should consider user preferences
|
|
||||||
- Args:
|
|
||||||
- query: Optional search query to find specific memories (e.g., "programming preferences")
|
|
||||||
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
|
||||||
- top_k: Number of memories to retrieve (default: 5)
|
|
||||||
- Returns: Relevant memories formatted as context
|
|
||||||
- IMPORTANT: Use the recalled memories naturally in your response without explicitly
|
|
||||||
stating "Based on your memory..." - integrate the context seamlessly.
|
|
||||||
""",
|
|
||||||
"shared": """
|
|
||||||
- recall_memory: Recall relevant team memories for this space to provide contextual responses.
|
|
||||||
- Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?").
|
|
||||||
- Use when someone asks about something the team agreed to remember.
|
|
||||||
- Use when team preferences or conventions would improve the response.
|
|
||||||
- Args:
|
|
||||||
- query: Optional search query to find specific memories. If not provided, returns the most recent memories.
|
|
||||||
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
|
||||||
- top_k: Number of memories to retrieve (default: 5, max: 20)
|
|
||||||
- Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...".
|
|
||||||
""",
|
""",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = {
|
_MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = {
|
||||||
"save_memory": {
|
"update_memory": {
|
||||||
"private": """
|
"private": """
|
||||||
- User: "Remember that I prefer TypeScript over JavaScript"
|
- <user_name>Alex</user_name>, <user_memory> is empty. User: "I'm a space enthusiast, explain astrophage to me"
|
||||||
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
- The user casually shared a durable fact. Use their first name in the entry, short neutral heading:
|
||||||
- User: "I'm a data scientist working on ML pipelines"
|
update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n")
|
||||||
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
- User: "Remember that I prefer concise answers over detailed explanations"
|
||||||
- User: "Always give me code examples in Python"
|
- Durable preference. Merge with existing memory, add a new heading:
|
||||||
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n\\n## Response style\\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\\n")
|
||||||
|
- User: "I actually moved to Tokyo last month"
|
||||||
|
- Updated fact, date prefix reflects when recorded:
|
||||||
|
update_memory(updated_memory="## Interests & background\\n...\\n\\n## Personal context\\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\\n...")
|
||||||
|
- User: "I'm a freelance photographer working on a nature documentary"
|
||||||
|
- Durable background info under a fitting heading:
|
||||||
|
update_memory(updated_memory="...\\n\\n## Current focus\\n- (2025-03-15) [fact] Alex is a freelance photographer\\n- (2025-03-15) [fact] Alex is working on a nature documentary\\n")
|
||||||
|
- User: "Always respond in bullet points"
|
||||||
|
- Standing instruction:
|
||||||
|
update_memory(updated_memory="...\\n\\n## Response style\\n- (2025-03-15) [instr] Always respond to Alex in bullet points\\n")
|
||||||
""",
|
""",
|
||||||
"shared": """
|
"shared": """
|
||||||
- User: "Remember that API keys are stored in Vault"
|
- User: "Let's remember that we decided to do weekly standup meetings on Mondays"
|
||||||
- Call: `save_memory(content="API keys are stored in Vault", category="fact")`
|
- Durable team decision:
|
||||||
- User: "Let's remember that the team prefers weekly demos on Fridays"
|
update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\\n...")
|
||||||
- Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")`
|
- User: "Our office is in downtown Seattle, 5th floor"
|
||||||
""",
|
- Durable team fact:
|
||||||
},
|
update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\\n...")
|
||||||
"recall_memory": {
|
|
||||||
"private": """
|
|
||||||
- User: "What programming language should I use for this project?"
|
|
||||||
- First recall: `recall_memory(query="programming language preferences")`
|
|
||||||
- Then provide a personalized recommendation based on their preferences
|
|
||||||
- User: "What do you know about me?"
|
|
||||||
- Call: `recall_memory(top_k=10)`
|
|
||||||
- Then summarize the stored memories
|
|
||||||
""",
|
|
||||||
"shared": """
|
|
||||||
- User: "What did we decide about the release date?"
|
|
||||||
- First recall: `recall_memory(query="release date decision")`
|
|
||||||
- Then answer based on the team memories
|
|
||||||
- User: "Where do we document onboarding?"
|
|
||||||
- Call: `recall_memory(query="onboarding documentation")`
|
|
||||||
- Then answer using the recalled team context
|
|
||||||
""",
|
""",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -456,8 +452,7 @@ _ALL_TOOL_NAMES_ORDERED = [
|
||||||
"generate_report",
|
"generate_report",
|
||||||
"generate_image",
|
"generate_image",
|
||||||
"scrape_webpage",
|
"scrape_webpage",
|
||||||
"save_memory",
|
"update_memory",
|
||||||
"recall_memory",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,7 @@ Available tools:
|
||||||
- generate_video_presentation: Generate video presentations with slides and narration
|
- generate_video_presentation: Generate video presentations with slides and narration
|
||||||
- generate_image: Generate images from text descriptions using AI models
|
- generate_image: Generate images from text descriptions using AI models
|
||||||
- scrape_webpage: Extract content from webpages
|
- scrape_webpage: Extract content from webpages
|
||||||
- save_memory: Store facts/preferences about the user
|
- update_memory: Update the user's / team's memory document
|
||||||
- recall_memory: Retrieve relevant user memories
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Registry exports
|
# Registry exports
|
||||||
|
|
@ -33,7 +32,7 @@ from .registry import (
|
||||||
)
|
)
|
||||||
from .scrape_webpage import create_scrape_webpage_tool
|
from .scrape_webpage import create_scrape_webpage_tool
|
||||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||||
from .video_presentation import create_generate_video_presentation_tool
|
from .video_presentation import create_generate_video_presentation_tool
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -47,10 +46,10 @@ __all__ = [
|
||||||
"create_generate_image_tool",
|
"create_generate_image_tool",
|
||||||
"create_generate_podcast_tool",
|
"create_generate_podcast_tool",
|
||||||
"create_generate_video_presentation_tool",
|
"create_generate_video_presentation_tool",
|
||||||
"create_recall_memory_tool",
|
|
||||||
"create_save_memory_tool",
|
|
||||||
"create_scrape_webpage_tool",
|
"create_scrape_webpage_tool",
|
||||||
"create_search_surfsense_docs_tool",
|
"create_search_surfsense_docs_tool",
|
||||||
|
"create_update_memory_tool",
|
||||||
|
"create_update_team_memory_tool",
|
||||||
"format_documents_for_context",
|
"format_documents_for_context",
|
||||||
"get_all_tool_names",
|
"get_all_tool_names",
|
||||||
"get_default_enabled_tools",
|
"get_default_enabled_tools",
|
||||||
|
|
|
||||||
|
|
@ -94,11 +94,7 @@ from .podcast import create_generate_podcast_tool
|
||||||
from .report import create_generate_report_tool
|
from .report import create_generate_report_tool
|
||||||
from .scrape_webpage import create_scrape_webpage_tool
|
from .scrape_webpage import create_scrape_webpage_tool
|
||||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||||
from .shared_memory import (
|
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||||
create_recall_shared_memory_tool,
|
|
||||||
create_save_shared_memory_tool,
|
|
||||||
)
|
|
||||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
|
||||||
from .video_presentation import create_generate_video_presentation_tool
|
from .video_presentation import create_generate_video_presentation_tool
|
||||||
from .web_search import create_web_search_tool
|
from .web_search import create_web_search_tool
|
||||||
|
|
||||||
|
|
@ -214,42 +210,31 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
requires=["db_session"],
|
requires=["db_session"],
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# USER MEMORY TOOLS - private or team store by thread_visibility
|
# MEMORY TOOL - single update_memory, private or team by thread_visibility
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="save_memory",
|
name="update_memory",
|
||||||
description="Save facts, preferences, or context for personalized or team responses",
|
description="Save important long-term facts, preferences, and instructions to the (personal or team) memory",
|
||||||
factory=lambda deps: (
|
factory=lambda deps: (
|
||||||
create_save_shared_memory_tool(
|
create_update_team_memory_tool(
|
||||||
search_space_id=deps["search_space_id"],
|
search_space_id=deps["search_space_id"],
|
||||||
created_by_id=deps["user_id"],
|
|
||||||
db_session=deps["db_session"],
|
db_session=deps["db_session"],
|
||||||
|
llm=deps.get("llm"),
|
||||||
)
|
)
|
||||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||||
else create_save_memory_tool(
|
else create_update_memory_tool(
|
||||||
user_id=deps["user_id"],
|
user_id=deps["user_id"],
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
db_session=deps["db_session"],
|
db_session=deps["db_session"],
|
||||||
|
llm=deps.get("llm"),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
requires=[
|
||||||
),
|
"user_id",
|
||||||
ToolDefinition(
|
"search_space_id",
|
||||||
name="recall_memory",
|
"db_session",
|
||||||
description="Recall relevant memories (personal or team) for context",
|
"thread_visibility",
|
||||||
factory=lambda deps: (
|
"llm",
|
||||||
create_recall_shared_memory_tool(
|
],
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
db_session=deps["db_session"],
|
|
||||||
)
|
|
||||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
|
||||||
else create_recall_memory_tool(
|
|
||||||
user_id=deps["user_id"],
|
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
db_session=deps["db_session"],
|
|
||||||
)
|
|
||||||
),
|
|
||||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# LINEAR TOOLS - create, update, delete issues
|
# LINEAR TOOLS - create, update, delete issues
|
||||||
|
|
|
||||||
|
|
@ -1,281 +0,0 @@
|
||||||
"""Shared (team) memory backend for search-space-scoped AI context."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.db import MemoryCategory, SharedMemory, User
|
|
||||||
from app.utils.document_converters import embed_text
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
DEFAULT_RECALL_TOP_K = 5
|
|
||||||
MAX_MEMORIES_PER_SEARCH_SPACE = 250
|
|
||||||
|
|
||||||
|
|
||||||
async def get_shared_memory_count(
|
|
||||||
db_session: AsyncSession,
|
|
||||||
search_space_id: int,
|
|
||||||
) -> int:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SharedMemory).where(SharedMemory.search_space_id == search_space_id)
|
|
||||||
)
|
|
||||||
return len(result.scalars().all())
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_oldest_shared_memory(
|
|
||||||
db_session: AsyncSession,
|
|
||||||
search_space_id: int,
|
|
||||||
) -> None:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SharedMemory)
|
|
||||||
.where(SharedMemory.search_space_id == search_space_id)
|
|
||||||
.order_by(SharedMemory.updated_at.asc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
oldest = result.scalars().first()
|
|
||||||
if oldest:
|
|
||||||
await db_session.delete(oldest)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def _to_uuid(value: str | UUID) -> UUID:
|
|
||||||
if isinstance(value, UUID):
|
|
||||||
return value
|
|
||||||
return UUID(value)
|
|
||||||
|
|
||||||
|
|
||||||
async def save_shared_memory(
|
|
||||||
db_session: AsyncSession,
|
|
||||||
search_space_id: int,
|
|
||||||
created_by_id: str | UUID,
|
|
||||||
content: str,
|
|
||||||
category: str = "fact",
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
category = category.lower() if category else "fact"
|
|
||||||
valid = ["preference", "fact", "instruction", "context"]
|
|
||||||
if category not in valid:
|
|
||||||
category = "fact"
|
|
||||||
try:
|
|
||||||
count = await get_shared_memory_count(db_session, search_space_id)
|
|
||||||
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
|
|
||||||
await delete_oldest_shared_memory(db_session, search_space_id)
|
|
||||||
embedding = await asyncio.to_thread(embed_text, content)
|
|
||||||
row = SharedMemory(
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
created_by_id=_to_uuid(created_by_id),
|
|
||||||
memory_text=content,
|
|
||||||
category=MemoryCategory(category),
|
|
||||||
embedding=embedding,
|
|
||||||
)
|
|
||||||
db_session.add(row)
|
|
||||||
await db_session.commit()
|
|
||||||
await db_session.refresh(row)
|
|
||||||
return {
|
|
||||||
"status": "saved",
|
|
||||||
"memory_id": row.id,
|
|
||||||
"memory_text": content,
|
|
||||||
"category": category,
|
|
||||||
"message": f"I'll remember: {content}",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Failed to save shared memory: %s", e)
|
|
||||||
await db_session.rollback()
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e),
|
|
||||||
"message": "Failed to save memory. Please try again.",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def recall_shared_memory(
|
|
||||||
db_session: AsyncSession,
|
|
||||||
search_space_id: int,
|
|
||||||
query: str | None = None,
|
|
||||||
category: str | None = None,
|
|
||||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
top_k = min(max(top_k, 1), 20)
|
|
||||||
try:
|
|
||||||
valid_categories = ["preference", "fact", "instruction", "context"]
|
|
||||||
stmt = select(SharedMemory).where(
|
|
||||||
SharedMemory.search_space_id == search_space_id
|
|
||||||
)
|
|
||||||
if category and category in valid_categories:
|
|
||||||
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
|
|
||||||
if query:
|
|
||||||
query_embedding = await asyncio.to_thread(embed_text, query)
|
|
||||||
stmt = stmt.order_by(
|
|
||||||
SharedMemory.embedding.op("<=>")(query_embedding)
|
|
||||||
).limit(top_k)
|
|
||||||
else:
|
|
||||||
stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k)
|
|
||||||
result = await db_session.execute(stmt)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
memory_list = [
|
|
||||||
{
|
|
||||||
"id": m.id,
|
|
||||||
"memory_text": m.memory_text,
|
|
||||||
"category": m.category.value if m.category else "unknown",
|
|
||||||
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
|
||||||
"created_by_id": str(m.created_by_id) if m.created_by_id else None,
|
|
||||||
}
|
|
||||||
for m in rows
|
|
||||||
]
|
|
||||||
created_by_ids = list(
|
|
||||||
{m["created_by_id"] for m in memory_list if m["created_by_id"]}
|
|
||||||
)
|
|
||||||
created_by_map: dict[str, str] = {}
|
|
||||||
if created_by_ids:
|
|
||||||
uuids = [UUID(uid) for uid in created_by_ids]
|
|
||||||
users_result = await db_session.execute(
|
|
||||||
select(User).where(User.id.in_(uuids))
|
|
||||||
)
|
|
||||||
for u in users_result.scalars().all():
|
|
||||||
created_by_map[str(u.id)] = u.display_name or "A team member"
|
|
||||||
formatted_context = format_shared_memories_for_context(
|
|
||||||
memory_list, created_by_map
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"count": len(memory_list),
|
|
||||||
"memories": memory_list,
|
|
||||||
"formatted_context": formatted_context,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Failed to recall shared memory: %s", e)
|
|
||||||
await db_session.rollback()
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e),
|
|
||||||
"memories": [],
|
|
||||||
"formatted_context": "Failed to recall memories.",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def format_shared_memories_for_context(
|
|
||||||
memories: list[dict[str, Any]],
|
|
||||||
created_by_map: dict[str, str] | None = None,
|
|
||||||
) -> str:
|
|
||||||
if not memories:
|
|
||||||
return "No relevant team memories found."
|
|
||||||
created_by_map = created_by_map or {}
|
|
||||||
parts = ["<team_memories>"]
|
|
||||||
for memory in memories:
|
|
||||||
category = memory.get("category", "unknown")
|
|
||||||
text = memory.get("memory_text", "")
|
|
||||||
updated = memory.get("updated_at", "")
|
|
||||||
created_by_id = memory.get("created_by_id")
|
|
||||||
added_by = (
|
|
||||||
created_by_map.get(str(created_by_id), "A team member")
|
|
||||||
if created_by_id is not None
|
|
||||||
else "A team member"
|
|
||||||
)
|
|
||||||
parts.append(
|
|
||||||
f" <memory category='{category}' updated='{updated}' added_by='{added_by}'>{text}</memory>"
|
|
||||||
)
|
|
||||||
parts.append("</team_memories>")
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
def create_save_shared_memory_tool(
|
|
||||||
search_space_id: int,
|
|
||||||
created_by_id: str | UUID,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the save_memory tool for shared (team) chats.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_space_id: The search space ID
|
|
||||||
created_by_id: The user ID of the person adding the memory
|
|
||||||
db_session: Database session for executing queries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A configured tool function for saving team memories
|
|
||||||
"""
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def save_memory(
|
|
||||||
content: str,
|
|
||||||
category: str = "fact",
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Save a fact, preference, or context to the team's shared memory for future reference.
|
|
||||||
|
|
||||||
Use this tool when:
|
|
||||||
- User or a team member says "remember this", "keep this in mind", or similar in this shared chat
|
|
||||||
- The team agrees on something to remember (e.g., decisions, conventions, where things live)
|
|
||||||
- Someone shares a preference or fact that should be visible to the whole team
|
|
||||||
|
|
||||||
The saved information will be available in future shared conversations in this space.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: The fact/preference/context to remember.
|
|
||||||
Phrase it clearly, e.g., "API keys are stored in Vault",
|
|
||||||
"The team prefers weekly demos on Fridays"
|
|
||||||
category: Type of memory. One of:
|
|
||||||
- "preference": Team or workspace preferences
|
|
||||||
- "fact": Facts the team agreed on (e.g., processes, locations)
|
|
||||||
- "instruction": Standing instructions for the team
|
|
||||||
- "context": Current context (e.g., ongoing projects, goals)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary with the save status and memory details
|
|
||||||
"""
|
|
||||||
return await save_shared_memory(
|
|
||||||
db_session, search_space_id, created_by_id, content, category
|
|
||||||
)
|
|
||||||
|
|
||||||
return save_memory
|
|
||||||
|
|
||||||
|
|
||||||
def create_recall_shared_memory_tool(
|
|
||||||
search_space_id: int,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the recall_memory tool for shared (team) chats.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_space_id: The search space ID
|
|
||||||
db_session: Database session for executing queries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A configured tool function for recalling team memories
|
|
||||||
"""
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def recall_memory(
|
|
||||||
query: str | None = None,
|
|
||||||
category: str | None = None,
|
|
||||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Recall relevant team memories for this space to provide contextual responses.
|
|
||||||
|
|
||||||
Use this tool when:
|
|
||||||
- You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?")
|
|
||||||
- Someone asks about something the team agreed to remember
|
|
||||||
- Team preferences or conventions would improve the response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Optional search query to find specific memories.
|
|
||||||
If not provided, returns the most recent memories.
|
|
||||||
category: Optional category filter. One of:
|
|
||||||
"preference", "fact", "instruction", "context"
|
|
||||||
top_k: Number of memories to retrieve (default: 5, max: 20)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing relevant memories and formatted context
|
|
||||||
"""
|
|
||||||
return await recall_shared_memory(
|
|
||||||
db_session, search_space_id, query, category, top_k
|
|
||||||
)
|
|
||||||
|
|
||||||
return recall_memory
|
|
||||||
389
surfsense_backend/app/agents/new_chat/tools/update_memory.py
Normal file
389
surfsense_backend/app/agents/new_chat/tools/update_memory.py
Normal file
|
|
@ -0,0 +1,389 @@
|
||||||
|
"""Markdown-document memory tool for the SurfSense agent.
|
||||||
|
|
||||||
|
Replaces the old row-per-fact save_memory / recall_memory tools with a single
|
||||||
|
update_memory tool that overwrites a freeform markdown TEXT column. The LLM
|
||||||
|
always sees the current memory in <user_memory> / <team_memory> tags injected
|
||||||
|
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
|
||||||
|
|
||||||
|
Overflow handling:
|
||||||
|
- Soft limit (18K chars): a warning is returned telling the agent to
|
||||||
|
consolidate on the next update.
|
||||||
|
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
|
||||||
|
If it still exceeds the limit after rewriting, the save is rejected.
|
||||||
|
- Diff validation: warns when entire ``##`` sections are dropped or when the
|
||||||
|
document shrinks by more than 60%.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, Literal
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import SearchSpace, User
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MEMORY_SOFT_LIMIT = 18_000
|
||||||
|
MEMORY_HARD_LIMIT = 25_000
|
||||||
|
|
||||||
|
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||||
|
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
||||||
|
|
||||||
|
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
|
||||||
|
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
|
||||||
|
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Diff validation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_headings(memory: str) -> set[str]:
|
||||||
|
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
|
||||||
|
return set(_SECTION_HEADING_RE.findall(memory))
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_heading(heading: str) -> str:
|
||||||
|
"""Normalize heading text for robust scope checks."""
|
||||||
|
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_memory_scope(
|
||||||
|
content: str, scope: Literal["user", "team"]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Reject personal-only markers ([pref], [instr]) in team memory."""
|
||||||
|
if scope != "team":
|
||||||
|
return None
|
||||||
|
|
||||||
|
markers = set(_MARKER_RE.findall(content))
|
||||||
|
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
|
||||||
|
if leaked:
|
||||||
|
tags = ", ".join(f"[{m}]" for m in leaked)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": (
|
||||||
|
f"Team memory cannot include personal markers: {tags}. "
|
||||||
|
"Use [fact] only in team memory."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_bullet_format(content: str) -> list[str]:
|
||||||
|
"""Return warnings for bullet lines that don't match the required format.
|
||||||
|
|
||||||
|
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
|
||||||
|
"""
|
||||||
|
warnings: list[str] = []
|
||||||
|
for line in content.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if not stripped.startswith("- "):
|
||||||
|
continue
|
||||||
|
if not _BULLET_FORMAT_RE.match(stripped):
|
||||||
|
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
||||||
|
warnings.append(f"Malformed bullet: {short}")
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||||
|
"""Return a list of warning strings about suspicious changes."""
|
||||||
|
if not old_memory:
|
||||||
|
return []
|
||||||
|
|
||||||
|
warnings: list[str] = []
|
||||||
|
old_headings = _extract_headings(old_memory)
|
||||||
|
new_headings = _extract_headings(new_memory)
|
||||||
|
dropped = old_headings - new_headings
|
||||||
|
if dropped:
|
||||||
|
names = ", ".join(sorted(dropped))
|
||||||
|
warnings.append(
|
||||||
|
f"Sections removed: {names}. "
|
||||||
|
"If unintentional, the user can restore from the settings page."
|
||||||
|
)
|
||||||
|
|
||||||
|
old_len = len(old_memory)
|
||||||
|
new_len = len(new_memory)
|
||||||
|
if old_len > 0 and new_len < old_len * 0.4:
|
||||||
|
warnings.append(
|
||||||
|
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
|
||||||
|
"Possible data loss."
|
||||||
|
)
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Size validation & soft warning
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_memory_size(content: str) -> dict[str, Any] | None:
|
||||||
|
"""Return an error/warning dict if *content* is too large, else None."""
|
||||||
|
length = len(content)
|
||||||
|
if length > MEMORY_HARD_LIMIT:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": (
|
||||||
|
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
||||||
|
f"({length:,} chars). Consolidate by merging related items, "
|
||||||
|
"removing outdated entries, and shortening descriptions. "
|
||||||
|
"Then call update_memory again."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _soft_warning(content: str) -> str | None:
|
||||||
|
"""Return a warning string if content exceeds the soft limit."""
|
||||||
|
length = len(content)
|
||||||
|
if length > MEMORY_SOFT_LIMIT:
|
||||||
|
return (
|
||||||
|
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
||||||
|
"Consolidate by merging related items and removing less important "
|
||||||
|
"entries on your next update."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Forced rewrite when memory exceeds the hard limit
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_FORCED_REWRITE_PROMPT = """\
|
||||||
|
You are a memory curator. The following memory document exceeds the character \
|
||||||
|
limit and must be shortened.
|
||||||
|
|
||||||
|
RULES:
|
||||||
|
1. Rewrite the document to be under {target} characters.
|
||||||
|
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
|
||||||
|
or rename headings to consolidate, but keep names personal and descriptive.
|
||||||
|
3. Priority for keeping content: [instr] > [pref] > [fact].
|
||||||
|
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
||||||
|
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||||
|
6. Preserve the user's first name in entries — do not replace it with "the user".
|
||||||
|
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||||
|
|
||||||
|
<memory_document>
|
||||||
|
{content}
|
||||||
|
</memory_document>"""
|
||||||
|
|
||||||
|
|
||||||
|
async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
||||||
|
"""Use a focused LLM call to compress *content* under the hard limit.
|
||||||
|
|
||||||
|
Returns the rewritten string, or ``None`` if the call fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
prompt = _FORCED_REWRITE_PROMPT.format(
|
||||||
|
target=MEMORY_HARD_LIMIT, content=content
|
||||||
|
)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal"]},
|
||||||
|
)
|
||||||
|
text = (
|
||||||
|
response.content
|
||||||
|
if isinstance(response.content, str)
|
||||||
|
else str(response.content)
|
||||||
|
)
|
||||||
|
return text.strip()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Forced rewrite LLM call failed")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared save-and-respond logic
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_memory(
|
||||||
|
*,
|
||||||
|
updated_memory: str,
|
||||||
|
old_memory: str | None,
|
||||||
|
llm: Any | None,
|
||||||
|
apply_fn,
|
||||||
|
commit_fn,
|
||||||
|
rollback_fn,
|
||||||
|
label: str,
|
||||||
|
scope: Literal["user", "team"],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||||
|
return a response dict.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
updated_memory : str
|
||||||
|
The new document the agent submitted.
|
||||||
|
old_memory : str | None
|
||||||
|
The previously persisted document (for diff checks).
|
||||||
|
llm : Any | None
|
||||||
|
LLM instance for forced rewrite (may be ``None``).
|
||||||
|
apply_fn : callable(str) -> None
|
||||||
|
Callback that sets the new memory on the ORM object.
|
||||||
|
commit_fn : coroutine
|
||||||
|
``session.commit``.
|
||||||
|
rollback_fn : coroutine
|
||||||
|
``session.rollback``.
|
||||||
|
label : str
|
||||||
|
Human label for log messages (e.g. "user memory", "team memory").
|
||||||
|
"""
|
||||||
|
content = updated_memory
|
||||||
|
|
||||||
|
# --- forced rewrite if over the hard limit ---
|
||||||
|
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
|
||||||
|
rewritten = await _forced_rewrite(content, llm)
|
||||||
|
if rewritten is not None and len(rewritten) < len(content):
|
||||||
|
content = rewritten
|
||||||
|
|
||||||
|
# --- hard-limit gate (reject if still too large after rewrite) ---
|
||||||
|
size_err = _validate_memory_size(content)
|
||||||
|
if size_err:
|
||||||
|
return size_err
|
||||||
|
|
||||||
|
scope_err = _validate_memory_scope(content, scope)
|
||||||
|
if scope_err:
|
||||||
|
return scope_err
|
||||||
|
|
||||||
|
# --- persist ---
|
||||||
|
try:
|
||||||
|
apply_fn(content)
|
||||||
|
await commit_fn()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update %s: %s", label, e)
|
||||||
|
await rollback_fn()
|
||||||
|
return {"status": "error", "message": f"Failed to update {label}: {e}"}
|
||||||
|
|
||||||
|
# --- build response ---
|
||||||
|
resp: dict[str, Any] = {
|
||||||
|
"status": "saved",
|
||||||
|
"message": f"{label.capitalize()} updated.",
|
||||||
|
}
|
||||||
|
|
||||||
|
if content is not updated_memory:
|
||||||
|
resp["notice"] = "Memory was automatically rewritten to fit within limits."
|
||||||
|
|
||||||
|
diff_warnings = _validate_diff(old_memory, content)
|
||||||
|
if diff_warnings:
|
||||||
|
resp["diff_warnings"] = diff_warnings
|
||||||
|
|
||||||
|
format_warnings = _validate_bullet_format(content)
|
||||||
|
if format_warnings:
|
||||||
|
resp["format_warnings"] = format_warnings
|
||||||
|
|
||||||
|
warning = _soft_warning(content)
|
||||||
|
if warning:
|
||||||
|
resp["warning"] = warning
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tool factories
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def create_update_memory_tool(
|
||||||
|
user_id: str | UUID,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
llm: Any | None = None,
|
||||||
|
):
|
||||||
|
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||||
|
"""Update the user's personal memory document.
|
||||||
|
|
||||||
|
Your current memory is shown in <user_memory> in the system prompt.
|
||||||
|
When the user shares important long-term information (preferences,
|
||||||
|
facts, instructions, context), rewrite the memory document to include
|
||||||
|
the new information. Merge new facts with existing ones, update
|
||||||
|
contradictions, remove outdated entries, and keep it concise.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
updated_memory: The FULL updated markdown document (not a diff).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await db_session.execute(select(User).where(User.id == uid))
|
||||||
|
user = result.scalars().first()
|
||||||
|
if not user:
|
||||||
|
return {"status": "error", "message": "User not found."}
|
||||||
|
|
||||||
|
old_memory = user.memory_md
|
||||||
|
|
||||||
|
return await _save_memory(
|
||||||
|
updated_memory=updated_memory,
|
||||||
|
old_memory=old_memory,
|
||||||
|
llm=llm,
|
||||||
|
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||||
|
commit_fn=db_session.commit,
|
||||||
|
rollback_fn=db_session.rollback,
|
||||||
|
label="memory",
|
||||||
|
scope="user",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update user memory: %s", e)
|
||||||
|
await db_session.rollback()
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Failed to update memory: {e}",
|
||||||
|
}
|
||||||
|
|
||||||
|
return update_memory
|
||||||
|
|
||||||
|
|
||||||
|
def create_update_team_memory_tool(
|
||||||
|
search_space_id: int,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
llm: Any | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||||
|
"""Update the team's shared memory document for this search space.
|
||||||
|
|
||||||
|
Your current team memory is shown in <team_memory> in the system
|
||||||
|
prompt. When the team shares important long-term information
|
||||||
|
(decisions, conventions, key facts, priorities), rewrite the memory
|
||||||
|
document to include the new information. Merge new facts with
|
||||||
|
existing ones, update contradictions, remove outdated entries, and
|
||||||
|
keep it concise.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
updated_memory: The FULL updated markdown document (not a diff).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||||
|
)
|
||||||
|
space = result.scalars().first()
|
||||||
|
if not space:
|
||||||
|
return {"status": "error", "message": "Search space not found."}
|
||||||
|
|
||||||
|
old_memory = space.shared_memory_md
|
||||||
|
|
||||||
|
return await _save_memory(
|
||||||
|
updated_memory=updated_memory,
|
||||||
|
old_memory=old_memory,
|
||||||
|
llm=llm,
|
||||||
|
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||||
|
commit_fn=db_session.commit,
|
||||||
|
rollback_fn=db_session.rollback,
|
||||||
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update team memory: %s", e)
|
||||||
|
await db_session.rollback()
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Failed to update team memory: {e}",
|
||||||
|
}
|
||||||
|
|
||||||
|
return update_memory
|
||||||
|
|
@ -1,351 +0,0 @@
|
||||||
"""
|
|
||||||
User memory tools for the SurfSense agent.
|
|
||||||
|
|
||||||
This module provides tools for storing and retrieving user memories,
|
|
||||||
enabling personalized AI responses similar to Claude's memory feature.
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- save_memory: Store facts, preferences, and context about the user
|
|
||||||
- recall_memory: Retrieve relevant memories using semantic search
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.db import MemoryCategory, UserMemory
|
|
||||||
from app.utils.document_converters import embed_text
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Constants
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
# Default number of memories to retrieve
|
|
||||||
DEFAULT_RECALL_TOP_K = 5
|
|
||||||
|
|
||||||
# Maximum number of memories per user (to prevent unbounded growth)
|
|
||||||
MAX_MEMORIES_PER_USER = 100
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Helper Functions
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def _to_uuid(user_id: str) -> UUID:
|
|
||||||
"""Convert a string user_id to a UUID object."""
|
|
||||||
if isinstance(user_id, UUID):
|
|
||||||
return user_id
|
|
||||||
return UUID(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_memory_count(
|
|
||||||
db_session: AsyncSession,
|
|
||||||
user_id: str,
|
|
||||||
search_space_id: int | None = None,
|
|
||||||
) -> int:
|
|
||||||
"""Get the count of memories for a user."""
|
|
||||||
uuid_user_id = _to_uuid(user_id)
|
|
||||||
query = select(UserMemory).where(UserMemory.user_id == uuid_user_id)
|
|
||||||
if search_space_id is not None:
|
|
||||||
query = query.where(
|
|
||||||
(UserMemory.search_space_id == search_space_id)
|
|
||||||
| (UserMemory.search_space_id.is_(None))
|
|
||||||
)
|
|
||||||
result = await db_session.execute(query)
|
|
||||||
return len(result.scalars().all())
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_oldest_memory(
|
|
||||||
db_session: AsyncSession,
|
|
||||||
user_id: str,
|
|
||||||
search_space_id: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Delete the oldest memory for a user to make room for new ones."""
|
|
||||||
uuid_user_id = _to_uuid(user_id)
|
|
||||||
query = (
|
|
||||||
select(UserMemory)
|
|
||||||
.where(UserMemory.user_id == uuid_user_id)
|
|
||||||
.order_by(UserMemory.updated_at.asc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
if search_space_id is not None:
|
|
||||||
query = query.where(
|
|
||||||
(UserMemory.search_space_id == search_space_id)
|
|
||||||
| (UserMemory.search_space_id.is_(None))
|
|
||||||
)
|
|
||||||
result = await db_session.execute(query)
|
|
||||||
oldest_memory = result.scalars().first()
|
|
||||||
if oldest_memory:
|
|
||||||
await db_session.delete(oldest_memory)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def format_memories_for_context(memories: list[dict[str, Any]]) -> str:
|
|
||||||
"""Format retrieved memories into a readable context string for the LLM."""
|
|
||||||
if not memories:
|
|
||||||
return "No relevant memories found for this user."
|
|
||||||
|
|
||||||
parts = ["<user_memories>"]
|
|
||||||
for memory in memories:
|
|
||||||
category = memory.get("category", "unknown")
|
|
||||||
text = memory.get("memory_text", "")
|
|
||||||
updated = memory.get("updated_at", "")
|
|
||||||
parts.append(
|
|
||||||
f" <memory category='{category}' updated='{updated}'>{text}</memory>"
|
|
||||||
)
|
|
||||||
parts.append("</user_memories>")
|
|
||||||
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Tool Factory Functions
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def create_save_memory_tool(
|
|
||||||
user_id: str,
|
|
||||||
search_space_id: int,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the save_memory tool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user's UUID
|
|
||||||
search_space_id: The search space ID (for space-specific memories)
|
|
||||||
db_session: Database session for executing queries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A configured tool function for saving user memories
|
|
||||||
"""
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def save_memory(
|
|
||||||
content: str,
|
|
||||||
category: str = "fact",
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Save a fact, preference, or context about the user for future reference.
|
|
||||||
|
|
||||||
Use this tool when:
|
|
||||||
- User explicitly says "remember this", "keep this in mind", or similar
|
|
||||||
- User shares personal preferences (e.g., "I prefer Python over JavaScript")
|
|
||||||
- User shares important facts about themselves (name, role, interests, projects)
|
|
||||||
- User gives standing instructions (e.g., "always respond in bullet points")
|
|
||||||
- User shares relevant context (e.g., "I'm working on project X")
|
|
||||||
|
|
||||||
The saved information will be available in future conversations to provide
|
|
||||||
more personalized and contextual responses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: The fact/preference/context to remember.
|
|
||||||
Phrase it clearly, e.g., "User prefers dark mode",
|
|
||||||
"User is a senior Python developer", "User is working on an AI project"
|
|
||||||
category: Type of memory. One of:
|
|
||||||
- "preference": User preferences (e.g., coding style, tools, formats)
|
|
||||||
- "fact": Facts about the user (e.g., name, role, expertise)
|
|
||||||
- "instruction": Standing instructions (e.g., response format preferences)
|
|
||||||
- "context": Current context (e.g., ongoing projects, goals)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary with the save status and memory details
|
|
||||||
"""
|
|
||||||
# Normalize and validate category (LLMs may send uppercase)
|
|
||||||
category = category.lower() if category else "fact"
|
|
||||||
valid_categories = ["preference", "fact", "instruction", "context"]
|
|
||||||
if category not in valid_categories:
|
|
||||||
category = "fact"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Convert user_id to UUID
|
|
||||||
uuid_user_id = _to_uuid(user_id)
|
|
||||||
|
|
||||||
# Check if we've hit the memory limit
|
|
||||||
memory_count = await get_user_memory_count(
|
|
||||||
db_session, user_id, search_space_id
|
|
||||||
)
|
|
||||||
if memory_count >= MAX_MEMORIES_PER_USER:
|
|
||||||
# Delete oldest memory to make room
|
|
||||||
await delete_oldest_memory(db_session, user_id, search_space_id)
|
|
||||||
|
|
||||||
embedding = await asyncio.to_thread(embed_text, content)
|
|
||||||
|
|
||||||
# Create new memory using ORM
|
|
||||||
# The pgvector Vector column type handles embedding conversion automatically
|
|
||||||
new_memory = UserMemory(
|
|
||||||
user_id=uuid_user_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
memory_text=content,
|
|
||||||
category=MemoryCategory(category), # Convert string to enum
|
|
||||||
embedding=embedding, # Pass embedding directly (list or numpy array)
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(new_memory)
|
|
||||||
await db_session.commit()
|
|
||||||
await db_session.refresh(new_memory)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "saved",
|
|
||||||
"memory_id": new_memory.id,
|
|
||||||
"memory_text": content,
|
|
||||||
"category": category,
|
|
||||||
"message": f"I'll remember: {content}",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to save memory for user {user_id}: {e}")
|
|
||||||
# Rollback the session to clear any failed transaction state
|
|
||||||
await db_session.rollback()
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e),
|
|
||||||
"message": "Failed to save memory. Please try again.",
|
|
||||||
}
|
|
||||||
|
|
||||||
return save_memory
|
|
||||||
|
|
||||||
|
|
||||||
def create_recall_memory_tool(
|
|
||||||
user_id: str,
|
|
||||||
search_space_id: int,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the recall_memory tool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user's UUID
|
|
||||||
search_space_id: The search space ID
|
|
||||||
db_session: Database session for executing queries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A configured tool function for recalling user memories
|
|
||||||
"""
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def recall_memory(
|
|
||||||
query: str | None = None,
|
|
||||||
category: str | None = None,
|
|
||||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Recall relevant memories about the user to provide personalized responses.
|
|
||||||
|
|
||||||
Use this tool when:
|
|
||||||
- You need user context to give a better, more personalized answer
|
|
||||||
- User asks about their preferences or past information they shared
|
|
||||||
- User references something they told you before
|
|
||||||
- Personalization would significantly improve the response quality
|
|
||||||
- User asks "what do you know about me?" or similar
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Optional search query to find specific memories.
|
|
||||||
If not provided, returns the most recent memories.
|
|
||||||
Example: "programming preferences", "current projects"
|
|
||||||
category: Optional category filter. One of:
|
|
||||||
"preference", "fact", "instruction", "context"
|
|
||||||
If not provided, searches all categories.
|
|
||||||
top_k: Number of memories to retrieve (default: 5, max: 20)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing relevant memories and formatted context
|
|
||||||
"""
|
|
||||||
top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Convert user_id to UUID
|
|
||||||
uuid_user_id = _to_uuid(user_id)
|
|
||||||
|
|
||||||
if query:
|
|
||||||
query_embedding = await asyncio.to_thread(embed_text, query)
|
|
||||||
|
|
||||||
# Build query with vector similarity
|
|
||||||
stmt = (
|
|
||||||
select(UserMemory)
|
|
||||||
.where(UserMemory.user_id == uuid_user_id)
|
|
||||||
.where(
|
|
||||||
(UserMemory.search_space_id == search_space_id)
|
|
||||||
| (UserMemory.search_space_id.is_(None))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add category filter if specified
|
|
||||||
if category and category in [
|
|
||||||
"preference",
|
|
||||||
"fact",
|
|
||||||
"instruction",
|
|
||||||
"context",
|
|
||||||
]:
|
|
||||||
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
|
|
||||||
|
|
||||||
# Order by vector similarity
|
|
||||||
stmt = stmt.order_by(
|
|
||||||
UserMemory.embedding.op("<=>")(query_embedding)
|
|
||||||
).limit(top_k)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# No query - return most recent memories
|
|
||||||
stmt = (
|
|
||||||
select(UserMemory)
|
|
||||||
.where(UserMemory.user_id == uuid_user_id)
|
|
||||||
.where(
|
|
||||||
(UserMemory.search_space_id == search_space_id)
|
|
||||||
| (UserMemory.search_space_id.is_(None))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add category filter if specified
|
|
||||||
if category and category in [
|
|
||||||
"preference",
|
|
||||||
"fact",
|
|
||||||
"instruction",
|
|
||||||
"context",
|
|
||||||
]:
|
|
||||||
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
|
|
||||||
|
|
||||||
stmt = stmt.order_by(UserMemory.updated_at.desc()).limit(top_k)
|
|
||||||
|
|
||||||
result = await db_session.execute(stmt)
|
|
||||||
memories = result.scalars().all()
|
|
||||||
|
|
||||||
# Format memories for response
|
|
||||||
memory_list = [
|
|
||||||
{
|
|
||||||
"id": m.id,
|
|
||||||
"memory_text": m.memory_text,
|
|
||||||
"category": m.category.value if m.category else "unknown",
|
|
||||||
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
|
||||||
}
|
|
||||||
for m in memories
|
|
||||||
]
|
|
||||||
|
|
||||||
formatted_context = format_memories_for_context(memory_list)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"count": len(memory_list),
|
|
||||||
"memories": memory_list,
|
|
||||||
"formatted_context": formatted_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to recall memories for user {user_id}: {e}")
|
|
||||||
await db_session.rollback()
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e),
|
|
||||||
"memories": [],
|
|
||||||
"formatted_context": "Failed to recall memories.",
|
|
||||||
}
|
|
||||||
|
|
||||||
return recall_memory
|
|
||||||
|
|
@ -861,99 +861,6 @@ class ChatSessionState(BaseModel):
|
||||||
ai_responding_to_user = relationship("User")
|
ai_responding_to_user = relationship("User")
|
||||||
|
|
||||||
|
|
||||||
class MemoryCategory(StrEnum):
|
|
||||||
"""Categories for user memories."""
|
|
||||||
|
|
||||||
# Using lowercase keys to match PostgreSQL enum values
|
|
||||||
preference = "preference" # User preferences (e.g., "prefers dark mode")
|
|
||||||
fact = "fact" # Facts about the user (e.g., "is a Python developer")
|
|
||||||
instruction = (
|
|
||||||
"instruction" # Standing instructions (e.g., "always respond in bullet points")
|
|
||||||
)
|
|
||||||
context = "context" # Contextual information (e.g., "working on project X")
|
|
||||||
|
|
||||||
|
|
||||||
class UserMemory(BaseModel, TimestampMixin):
|
|
||||||
"""
|
|
||||||
Private memory: facts, preferences, context per user per search space.
|
|
||||||
Used only for private chats (not shared/team chats).
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "user_memories"
|
|
||||||
|
|
||||||
user_id = Column(
|
|
||||||
UUID(as_uuid=True),
|
|
||||||
ForeignKey("user.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
index=True,
|
|
||||||
)
|
|
||||||
# Optional association with a search space (if memory is space-specific)
|
|
||||||
search_space_id = Column(
|
|
||||||
Integer,
|
|
||||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
|
||||||
nullable=True,
|
|
||||||
index=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# The actual memory content
|
|
||||||
memory_text = Column(Text, nullable=False)
|
|
||||||
# Category for organization and filtering
|
|
||||||
category = Column(
|
|
||||||
SQLAlchemyEnum(MemoryCategory),
|
|
||||||
nullable=False,
|
|
||||||
default=MemoryCategory.fact,
|
|
||||||
)
|
|
||||||
# Vector embedding for semantic search
|
|
||||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
|
||||||
|
|
||||||
# Track when memory was last updated
|
|
||||||
updated_at = Column(
|
|
||||||
TIMESTAMP(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
default=lambda: datetime.now(UTC),
|
|
||||||
onupdate=lambda: datetime.now(UTC),
|
|
||||||
index=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
user = relationship("User", back_populates="memories")
|
|
||||||
search_space = relationship("SearchSpace", back_populates="user_memories")
|
|
||||||
|
|
||||||
|
|
||||||
class SharedMemory(BaseModel, TimestampMixin):
|
|
||||||
__tablename__ = "shared_memories"
|
|
||||||
|
|
||||||
search_space_id = Column(
|
|
||||||
Integer,
|
|
||||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
index=True,
|
|
||||||
)
|
|
||||||
created_by_id = Column(
|
|
||||||
UUID(as_uuid=True),
|
|
||||||
ForeignKey("user.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
index=True,
|
|
||||||
)
|
|
||||||
memory_text = Column(Text, nullable=False)
|
|
||||||
category = Column(
|
|
||||||
SQLAlchemyEnum(MemoryCategory),
|
|
||||||
nullable=False,
|
|
||||||
default=MemoryCategory.fact,
|
|
||||||
)
|
|
||||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
|
||||||
updated_at = Column(
|
|
||||||
TIMESTAMP(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
default=lambda: datetime.now(UTC),
|
|
||||||
onupdate=lambda: datetime.now(UTC),
|
|
||||||
index=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
search_space = relationship("SearchSpace", back_populates="shared_memories")
|
|
||||||
created_by = relationship("User")
|
|
||||||
|
|
||||||
|
|
||||||
class Folder(BaseModel, TimestampMixin):
|
class Folder(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "folders"
|
__tablename__ = "folders"
|
||||||
|
|
||||||
|
|
@ -1394,6 +1301,8 @@ class SearchSpace(BaseModel, TimestampMixin):
|
||||||
Text, nullable=True, default=""
|
Text, nullable=True, default=""
|
||||||
) # User's custom instructions
|
) # User's custom instructions
|
||||||
|
|
||||||
|
shared_memory_md = Column(Text, nullable=True, server_default="")
|
||||||
|
|
||||||
# Search space-level LLM preferences (shared by all members)
|
# Search space-level LLM preferences (shared by all members)
|
||||||
# Note: ID values:
|
# Note: ID values:
|
||||||
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
||||||
|
|
@ -1516,20 +1425,6 @@ class SearchSpace(BaseModel, TimestampMixin):
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
# User memories associated with this search space
|
|
||||||
user_memories = relationship(
|
|
||||||
"UserMemory",
|
|
||||||
back_populates="search_space",
|
|
||||||
order_by="UserMemory.updated_at.desc()",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
)
|
|
||||||
shared_memories = relationship(
|
|
||||||
"SharedMemory",
|
|
||||||
back_populates="search_space",
|
|
||||||
order_by="SharedMemory.updated_at.desc()",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "search_source_connectors"
|
__tablename__ = "search_source_connectors"
|
||||||
|
|
@ -2030,14 +1925,6 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# User memories for personalized AI responses
|
|
||||||
memories = relationship(
|
|
||||||
"UserMemory",
|
|
||||||
back_populates="user",
|
|
||||||
order_by="UserMemory.updated_at.desc()",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Incentive tasks completed by this user
|
# Incentive tasks completed by this user
|
||||||
incentive_tasks = relationship(
|
incentive_tasks = relationship(
|
||||||
"UserIncentiveTask",
|
"UserIncentiveTask",
|
||||||
|
|
@ -2065,6 +1952,8 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
|
|
||||||
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
memory_md = Column(Text, nullable=True, server_default="")
|
||||||
|
|
||||||
# Refresh tokens for this user
|
# Refresh tokens for this user
|
||||||
refresh_tokens = relationship(
|
refresh_tokens = relationship(
|
||||||
"RefreshToken",
|
"RefreshToken",
|
||||||
|
|
@ -2150,14 +2039,6 @@ else:
|
||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# User memories for personalized AI responses
|
|
||||||
memories = relationship(
|
|
||||||
"UserMemory",
|
|
||||||
back_populates="user",
|
|
||||||
order_by="UserMemory.updated_at.desc()",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Incentive tasks completed by this user
|
# Incentive tasks completed by this user
|
||||||
incentive_tasks = relationship(
|
incentive_tasks = relationship(
|
||||||
"UserIncentiveTask",
|
"UserIncentiveTask",
|
||||||
|
|
@ -2185,6 +2066,8 @@ else:
|
||||||
|
|
||||||
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
memory_md = Column(Text, nullable=True, server_default="")
|
||||||
|
|
||||||
# Refresh tokens for this user
|
# Refresh tokens for this user
|
||||||
refresh_tokens = relationship(
|
refresh_tokens = relationship(
|
||||||
"RefreshToken",
|
"RefreshToken",
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
|
||||||
from .linear_add_connector_route import router as linear_add_connector_router
|
from .linear_add_connector_route import router as linear_add_connector_router
|
||||||
from .logs_routes import router as logs_router
|
from .logs_routes import router as logs_router
|
||||||
from .luma_add_connector_route import router as luma_add_connector_router
|
from .luma_add_connector_route import router as luma_add_connector_router
|
||||||
|
from .memory_routes import router as memory_router
|
||||||
from .model_list_routes import router as model_list_router
|
from .model_list_routes import router as model_list_router
|
||||||
from .new_chat_routes import router as new_chat_router
|
from .new_chat_routes import router as new_chat_router
|
||||||
from .new_llm_config_routes import router as new_llm_config_router
|
from .new_llm_config_routes import router as new_llm_config_router
|
||||||
|
|
@ -98,4 +99,5 @@ router.include_router(incentive_tasks_router) # Incentive tasks for earning fre
|
||||||
router.include_router(stripe_router) # Stripe checkout for additional page packs
|
router.include_router(stripe_router) # Stripe checkout for additional page packs
|
||||||
router.include_router(youtube_router) # YouTube playlist resolution
|
router.include_router(youtube_router) # YouTube playlist resolution
|
||||||
router.include_router(prompts_router)
|
router.include_router(prompts_router)
|
||||||
|
router.include_router(memory_router) # User personal memory (memory.md style)
|
||||||
router.include_router(autocomplete_router) # Lightweight autocomplete with KB context
|
router.include_router(autocomplete_router) # Lightweight autocomplete with KB context
|
||||||
|
|
|
||||||
153
surfsense_backend/app/routes/memory_routes.py
Normal file
153
surfsense_backend/app/routes/memory_routes.py
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
"""Routes for user memory management (personal memory.md)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import (
|
||||||
|
create_chat_litellm_from_agent_config,
|
||||||
|
load_agent_llm_config_for_search_space,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
||||||
|
from app.db import User, get_async_session
|
||||||
|
from app.users import current_active_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRead(BaseModel):
|
||||||
|
memory_md: str
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryUpdate(BaseModel):
|
||||||
|
memory_md: str
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEditRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
search_space_id: int
|
||||||
|
|
||||||
|
|
||||||
|
_MEMORY_EDIT_PROMPT = """\
|
||||||
|
You are a memory editor. The user wants to modify their memory document. \
|
||||||
|
Apply the user's instruction to the existing memory document and output the \
|
||||||
|
FULL updated document.
|
||||||
|
|
||||||
|
RULES:
|
||||||
|
1. If the instruction asks to add something, add it with format: \
|
||||||
|
- (YYYY-MM-DD) [fact|pref|instr] text, under an existing or new ## heading. \
|
||||||
|
Heading names should be personal and descriptive, not generic categories.
|
||||||
|
2. If the instruction asks to remove something, remove the matching entry.
|
||||||
|
3. If the instruction asks to change something, update the matching entry.
|
||||||
|
4. Preserve existing ## headings and all other entries.
|
||||||
|
5. Every bullet must include a marker: [fact], [pref], or [instr].
|
||||||
|
6. Use the user's first name (from <user_name>) in entries instead of "the user".
|
||||||
|
7. Output ONLY the updated markdown — no explanations, no wrapping.
|
||||||
|
|
||||||
|
<user_name>{user_name}</user_name>
|
||||||
|
|
||||||
|
<current_memory>
|
||||||
|
{current_memory}
|
||||||
|
</current_memory>
|
||||||
|
|
||||||
|
<user_instruction>
|
||||||
|
{instruction}
|
||||||
|
</user_instruction>"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/users/me/memory", response_model=MemoryRead)
|
||||||
|
async def get_user_memory(
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
|
await session.refresh(user, ["memory_md"])
|
||||||
|
return MemoryRead(memory_md=user.memory_md or "")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/users/me/memory", response_model=MemoryRead)
|
||||||
|
async def update_user_memory(
|
||||||
|
body: MemoryUpdate,
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
|
if len(body.memory_md) > MEMORY_HARD_LIMIT:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit ({len(body.memory_md):,} chars).",
|
||||||
|
)
|
||||||
|
user.memory_md = body.memory_md
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user, ["memory_md"])
|
||||||
|
return MemoryRead(memory_md=user.memory_md or "")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/users/me/memory/edit", response_model=MemoryRead)
|
||||||
|
async def edit_user_memory(
|
||||||
|
body: MemoryEditRequest,
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
|
"""Apply a natural language edit to the user's personal memory via LLM."""
|
||||||
|
agent_config = await load_agent_llm_config_for_search_space(
|
||||||
|
session, body.search_space_id
|
||||||
|
)
|
||||||
|
if not agent_config:
|
||||||
|
raise HTTPException(status_code=500, detail="No LLM configuration available.")
|
||||||
|
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||||
|
if not llm:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
|
||||||
|
|
||||||
|
await session.refresh(user, ["memory_md", "display_name"])
|
||||||
|
current_memory = user.memory_md or ""
|
||||||
|
first_name = (
|
||||||
|
user.display_name.strip().split()[0]
|
||||||
|
if user.display_name and user.display_name.strip()
|
||||||
|
else "The user"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = _MEMORY_EDIT_PROMPT.format(
|
||||||
|
current_memory=current_memory or "(empty)",
|
||||||
|
instruction=body.query,
|
||||||
|
user_name=first_name,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal", "memory-edit"]},
|
||||||
|
)
|
||||||
|
updated = (
|
||||||
|
response.content
|
||||||
|
if isinstance(response.content, str)
|
||||||
|
else str(response.content)
|
||||||
|
).strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Memory edit LLM call failed: %s", e)
|
||||||
|
raise HTTPException(status_code=500, detail="Memory edit failed.") from e
|
||||||
|
|
||||||
|
if not updated:
|
||||||
|
raise HTTPException(status_code=400, detail="LLM returned empty result.")
|
||||||
|
|
||||||
|
result = await _save_memory(
|
||||||
|
updated_memory=updated,
|
||||||
|
old_memory=current_memory,
|
||||||
|
llm=llm,
|
||||||
|
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||||
|
commit_fn=session.commit,
|
||||||
|
rollback_fn=session.rollback,
|
||||||
|
label="memory",
|
||||||
|
scope="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("status") == "error":
|
||||||
|
raise HTTPException(status_code=400, detail=result["message"])
|
||||||
|
|
||||||
|
await session.refresh(user, ["memory_md"])
|
||||||
|
return MemoryRead(memory_md=user.memory_md or "")
|
||||||
|
|
@ -1,10 +1,17 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import (
|
||||||
|
create_chat_litellm_from_agent_config,
|
||||||
|
load_agent_llm_config_for_search_space,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ImageGenerationConfig,
|
ImageGenerationConfig,
|
||||||
|
|
@ -34,6 +41,34 @@ logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class _TeamMemoryEditRequest(PydanticBaseModel):
|
||||||
|
query: str
|
||||||
|
|
||||||
|
|
||||||
|
_TEAM_MEMORY_EDIT_PROMPT = """\
|
||||||
|
You are a memory editor for a team workspace. The user wants to modify the \
|
||||||
|
team's shared memory document. Apply the user's instruction to the existing \
|
||||||
|
memory document and output the FULL updated document.
|
||||||
|
|
||||||
|
RULES:
|
||||||
|
1. If the instruction asks to add something, add it with format: \
|
||||||
|
- (YYYY-MM-DD) [fact] text, under an existing or new ## heading. \
|
||||||
|
Heading names should be descriptive, not generic categories.
|
||||||
|
2. If the instruction asks to remove something, remove the matching entry.
|
||||||
|
3. If the instruction asks to change something, update the matching entry.
|
||||||
|
4. Preserve existing ## headings and all other entries.
|
||||||
|
5. NEVER use [pref] or [instr] markers. Team memory uses [fact] only.
|
||||||
|
6. Output ONLY the updated markdown — no explanations, no wrapping.
|
||||||
|
|
||||||
|
<current_memory>
|
||||||
|
{current_memory}
|
||||||
|
</current_memory>
|
||||||
|
|
||||||
|
<user_instruction>
|
||||||
|
{instruction}
|
||||||
|
</user_instruction>"""
|
||||||
|
|
||||||
|
|
||||||
async def create_default_roles_and_membership(
|
async def create_default_roles_and_membership(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -255,6 +290,16 @@ async def update_search_space(
|
||||||
raise HTTPException(status_code=404, detail="Search space not found")
|
raise HTTPException(status_code=404, detail="Search space not found")
|
||||||
|
|
||||||
update_data = search_space_update.model_dump(exclude_unset=True)
|
update_data = search_space_update.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"shared_memory_md" in update_data
|
||||||
|
and len(update_data["shared_memory_md"] or "") > MEMORY_HARD_LIMIT
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Team memory exceeds {MEMORY_HARD_LIMIT:,} character limit.",
|
||||||
|
)
|
||||||
|
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(db_search_space, key, value)
|
setattr(db_search_space, key, value)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
@ -269,6 +314,76 @@ async def update_search_space(
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/searchspaces/{search_space_id}/memory/edit",
|
||||||
|
response_model=SearchSpaceRead,
|
||||||
|
)
|
||||||
|
async def edit_team_memory(
|
||||||
|
search_space_id: int,
|
||||||
|
body: _TeamMemoryEditRequest,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Apply a natural language edit to the team memory via LLM."""
|
||||||
|
await check_search_space_access(session, user, search_space_id)
|
||||||
|
|
||||||
|
agent_config = await load_agent_llm_config_for_search_space(
|
||||||
|
session, search_space_id
|
||||||
|
)
|
||||||
|
if not agent_config:
|
||||||
|
raise HTTPException(status_code=500, detail="No LLM configuration available.")
|
||||||
|
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||||
|
if not llm:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||||
|
)
|
||||||
|
db_search_space = result.scalars().first()
|
||||||
|
if not db_search_space:
|
||||||
|
raise HTTPException(status_code=404, detail="Search space not found")
|
||||||
|
|
||||||
|
current_memory = db_search_space.shared_memory_md or ""
|
||||||
|
|
||||||
|
prompt = _TEAM_MEMORY_EDIT_PROMPT.format(
|
||||||
|
current_memory=current_memory or "(empty)",
|
||||||
|
instruction=body.query,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal", "memory-edit"]},
|
||||||
|
)
|
||||||
|
updated = (
|
||||||
|
response.content
|
||||||
|
if isinstance(response.content, str)
|
||||||
|
else str(response.content)
|
||||||
|
).strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Team memory edit LLM call failed: %s", e)
|
||||||
|
raise HTTPException(status_code=500, detail="Team memory edit failed.") from e
|
||||||
|
|
||||||
|
if not updated:
|
||||||
|
raise HTTPException(status_code=400, detail="LLM returned empty result.")
|
||||||
|
|
||||||
|
save_result = await _save_memory(
|
||||||
|
updated_memory=updated,
|
||||||
|
old_memory=current_memory,
|
||||||
|
llm=llm,
|
||||||
|
apply_fn=lambda content: setattr(db_search_space, "shared_memory_md", content),
|
||||||
|
commit_fn=session.commit,
|
||||||
|
rollback_fn=session.rollback,
|
||||||
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
|
)
|
||||||
|
|
||||||
|
if save_result.get("status") == "error":
|
||||||
|
raise HTTPException(status_code=400, detail=save_result["message"])
|
||||||
|
|
||||||
|
await session.refresh(db_search_space)
|
||||||
|
return db_search_space
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
||||||
async def delete_search_space(
|
async def delete_search_space(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ class SearchSpaceUpdate(BaseModel):
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
citations_enabled: bool | None = None
|
citations_enabled: bool | None = None
|
||||||
qna_custom_instructions: str | None = None
|
qna_custom_instructions: str | None = None
|
||||||
|
shared_memory_md: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||||
|
|
@ -29,6 +30,7 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||||
user_id: uuid.UUID
|
user_id: uuid.UUID
|
||||||
citations_enabled: bool
|
citations_enabled: bool
|
||||||
qna_custom_instructions: str | None = None
|
qna_custom_instructions: str | None = None
|
||||||
|
shared_memory_md: str | None = None
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,10 @@ from app.agents.new_chat.llm_config import (
|
||||||
load_agent_config,
|
load_agent_config,
|
||||||
load_llm_config_from_yaml,
|
load_llm_config_from_yaml,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.memory_extraction import (
|
||||||
|
extract_and_save_memory,
|
||||||
|
extract_and_save_team_memory,
|
||||||
|
)
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ChatVisibility,
|
ChatVisibility,
|
||||||
NewChatMessage,
|
NewChatMessage,
|
||||||
|
|
@ -59,8 +63,6 @@ from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_hea
|
||||||
|
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
_background_tasks: set[asyncio.Task] = set()
|
|
||||||
|
|
||||||
|
|
||||||
def format_mentioned_surfsense_docs_as_context(
|
def format_mentioned_surfsense_docs_as_context(
|
||||||
documents: list[SurfsenseDocsDocument],
|
documents: list[SurfsenseDocsDocument],
|
||||||
|
|
@ -141,6 +143,7 @@ class StreamResult:
|
||||||
is_interrupted: bool = False
|
is_interrupted: bool = False
|
||||||
interrupt_value: dict[str, Any] | None = None
|
interrupt_value: dict[str, Any] | None = None
|
||||||
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
|
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
|
||||||
|
agent_called_update_memory: bool = False
|
||||||
|
|
||||||
|
|
||||||
async def _stream_agent_events(
|
async def _stream_agent_events(
|
||||||
|
|
@ -183,6 +186,7 @@ async def _stream_agent_events(
|
||||||
last_active_step_items: list[str] = initial_step_items or []
|
last_active_step_items: list[str] = initial_step_items or []
|
||||||
just_finished_tool: bool = False
|
just_finished_tool: bool = False
|
||||||
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
|
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
|
||||||
|
called_update_memory: bool = False
|
||||||
|
|
||||||
def next_thinking_step_id() -> str:
|
def next_thinking_step_id() -> str:
|
||||||
nonlocal thinking_step_counter
|
nonlocal thinking_step_counter
|
||||||
|
|
@ -490,6 +494,9 @@ async def _stream_agent_events(
|
||||||
tool_name = event.get("name", "unknown_tool")
|
tool_name = event.get("name", "unknown_tool")
|
||||||
raw_output = event.get("data", {}).get("output", "")
|
raw_output = event.get("data", {}).get("output", "")
|
||||||
|
|
||||||
|
if tool_name == "update_memory":
|
||||||
|
called_update_memory = True
|
||||||
|
|
||||||
if hasattr(raw_output, "content"):
|
if hasattr(raw_output, "content"):
|
||||||
content = raw_output.content
|
content = raw_output.content
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
|
|
@ -1111,6 +1118,7 @@ async def _stream_agent_events(
|
||||||
yield completion_event
|
yield completion_event
|
||||||
|
|
||||||
result.accumulated_text = accumulated_text
|
result.accumulated_text = accumulated_text
|
||||||
|
result.agent_called_update_memory = called_update_memory
|
||||||
|
|
||||||
state = await agent.aget_state(config)
|
state = await agent.aget_state(config)
|
||||||
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
||||||
|
|
@ -1540,6 +1548,27 @@ async def stream_new_chat(
|
||||||
chat_id, generated_title
|
chat_id, generated_title
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Fire background memory extraction if the agent didn't handle it.
|
||||||
|
# Shared threads write to team memory; private threads write to user memory.
|
||||||
|
if not stream_result.agent_called_update_memory:
|
||||||
|
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||||
|
asyncio.create_task(
|
||||||
|
extract_and_save_team_memory(
|
||||||
|
user_message=user_query,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
llm=llm,
|
||||||
|
author_display_name=current_user_display_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif user_id:
|
||||||
|
asyncio.create_task(
|
||||||
|
extract_and_save_memory(
|
||||||
|
user_message=user_query,
|
||||||
|
user_id=user_id,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Finish the step and message
|
# Finish the step and message
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,198 @@
|
||||||
|
"""Unit tests for memory scope validation and bullet format validation."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.update_memory import (
|
||||||
|
_save_memory,
|
||||||
|
_validate_bullet_format,
|
||||||
|
_validate_memory_scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _Recorder:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.applied_content: str | None = None
|
||||||
|
self.commit_calls = 0
|
||||||
|
self.rollback_calls = 0
|
||||||
|
|
||||||
|
def apply(self, content: str) -> None:
|
||||||
|
self.applied_content = content
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
self.commit_calls += 1
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
self.rollback_calls += 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _validate_memory_scope — marker-based
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_rejects_pref_marker_in_team_scope() -> None:
|
||||||
|
content = "- (2026-04-10) [pref] Prefers dark mode\n"
|
||||||
|
result = _validate_memory_scope(content, "team")
|
||||||
|
assert result is not None
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "[pref]" in result["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_rejects_instr_marker_in_team_scope() -> None:
|
||||||
|
content = "- (2026-04-10) [instr] Always respond in Spanish\n"
|
||||||
|
result = _validate_memory_scope(content, "team")
|
||||||
|
assert result is not None
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "[instr]" in result["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_rejects_both_personal_markers_in_team() -> None:
|
||||||
|
content = (
|
||||||
|
"- (2026-04-10) [pref] Prefers dark mode\n"
|
||||||
|
"- (2026-04-10) [instr] Always respond in Spanish\n"
|
||||||
|
)
|
||||||
|
result = _validate_memory_scope(content, "team")
|
||||||
|
assert result is not None
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "[instr]" in result["message"]
|
||||||
|
assert "[pref]" in result["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_allows_fact_in_team_scope() -> None:
|
||||||
|
content = "- (2026-04-10) [fact] Office is in downtown Seattle\n"
|
||||||
|
result = _validate_memory_scope(content, "team")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_allows_all_markers_in_user_scope() -> None:
|
||||||
|
content = (
|
||||||
|
"- (2026-04-10) [fact] Python developer\n"
|
||||||
|
"- (2026-04-10) [pref] Prefers concise answers\n"
|
||||||
|
"- (2026-04-10) [instr] Always use bullet points\n"
|
||||||
|
)
|
||||||
|
result = _validate_memory_scope(content, "user")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_allows_any_heading_in_team() -> None:
|
||||||
|
content = "## Architecture\n- (2026-04-10) [fact] Uses PostgreSQL for persistence\n"
|
||||||
|
result = _validate_memory_scope(content, "team")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_memory_scope_allows_any_heading_in_user() -> None:
|
||||||
|
content = "## My Projects\n- (2026-04-10) [fact] Working on SurfSense\n"
|
||||||
|
result = _validate_memory_scope(content, "user")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _validate_bullet_format
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_bullet_format_passes_valid_bullets() -> None:
|
||||||
|
content = (
|
||||||
|
"## Work\n"
|
||||||
|
"- (2026-04-10) [fact] Senior Python developer\n"
|
||||||
|
"- (2026-04-10) [pref] Prefers dark mode\n"
|
||||||
|
"- (2026-04-10) [instr] Always respond in bullet points\n"
|
||||||
|
)
|
||||||
|
warnings = _validate_bullet_format(content)
|
||||||
|
assert warnings == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_bullet_format_warns_on_missing_marker() -> None:
|
||||||
|
content = "- (2026-04-10) Senior Python developer\n"
|
||||||
|
warnings = _validate_bullet_format(content)
|
||||||
|
assert len(warnings) == 1
|
||||||
|
assert "Malformed bullet" in warnings[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_bullet_format_warns_on_missing_date() -> None:
|
||||||
|
content = "- [fact] Senior Python developer\n"
|
||||||
|
warnings = _validate_bullet_format(content)
|
||||||
|
assert len(warnings) == 1
|
||||||
|
assert "Malformed bullet" in warnings[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_bullet_format_warns_on_unknown_marker() -> None:
|
||||||
|
content = "- (2026-04-10) [context] Working on project X\n"
|
||||||
|
warnings = _validate_bullet_format(content)
|
||||||
|
assert len(warnings) == 1
|
||||||
|
assert "Malformed bullet" in warnings[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_bullet_format_ignores_non_bullet_lines() -> None:
|
||||||
|
content = "## Some Heading\nSome paragraph text\n"
|
||||||
|
warnings = _validate_bullet_format(content)
|
||||||
|
assert warnings == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_bullet_format_warns_on_old_format_without_marker() -> None:
|
||||||
|
content = "## About the user\n- (2026-04-10) Likes cats\n"
|
||||||
|
warnings = _validate_bullet_format(content)
|
||||||
|
assert len(warnings) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _save_memory — end-to-end with marker scope check
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_blocks_pref_in_team_before_commit() -> None:
|
||||||
|
recorder = _Recorder()
|
||||||
|
result = await _save_memory(
|
||||||
|
updated_memory="- (2026-04-10) [pref] Prefers dark mode\n",
|
||||||
|
old_memory=None,
|
||||||
|
llm=None,
|
||||||
|
apply_fn=recorder.apply,
|
||||||
|
commit_fn=recorder.commit,
|
||||||
|
rollback_fn=recorder.rollback,
|
||||||
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
|
)
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert recorder.commit_calls == 0
|
||||||
|
assert recorder.applied_content is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_allows_fact_in_team_and_commits() -> None:
|
||||||
|
recorder = _Recorder()
|
||||||
|
content = "- (2026-04-10) [fact] Weekly standup on Mondays\n"
|
||||||
|
result = await _save_memory(
|
||||||
|
updated_memory=content,
|
||||||
|
old_memory=None,
|
||||||
|
llm=None,
|
||||||
|
apply_fn=recorder.apply,
|
||||||
|
commit_fn=recorder.commit,
|
||||||
|
rollback_fn=recorder.rollback,
|
||||||
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
|
)
|
||||||
|
assert result["status"] == "saved"
|
||||||
|
assert recorder.commit_calls == 1
|
||||||
|
assert recorder.applied_content == content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_memory_includes_format_warnings() -> None:
|
||||||
|
recorder = _Recorder()
|
||||||
|
content = "- (2026-04-10) Missing marker text\n"
|
||||||
|
result = await _save_memory(
|
||||||
|
updated_memory=content,
|
||||||
|
old_memory=None,
|
||||||
|
llm=None,
|
||||||
|
apply_fn=recorder.apply,
|
||||||
|
commit_fn=recorder.commit,
|
||||||
|
rollback_fn=recorder.rollback,
|
||||||
|
label="memory",
|
||||||
|
scope="user",
|
||||||
|
)
|
||||||
|
assert result["status"] == "saved"
|
||||||
|
assert "format_warnings" in result
|
||||||
|
assert len(result["format_warnings"]) == 1
|
||||||
|
|
@ -0,0 +1,291 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useAtomValue } from "jotai";
|
||||||
|
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react";
|
||||||
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import { z } from "zod";
|
||||||
|
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||||
|
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||||
|
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
DropdownMenu,
|
||||||
|
DropdownMenuContent,
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuTrigger,
|
||||||
|
} from "@/components/ui/dropdown-menu";
|
||||||
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
|
|
||||||
|
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||||
|
|
||||||
|
const MEMORY_HARD_LIMIT = 25_000;
|
||||||
|
|
||||||
|
const MemoryReadSchema = z.object({
|
||||||
|
memory_md: z.string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export function MemoryContent() {
|
||||||
|
const activeSearchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||||
|
const [memory, setMemory] = useState("");
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [saving, setSaving] = useState(false);
|
||||||
|
const [editQuery, setEditQuery] = useState("");
|
||||||
|
const [editing, setEditing] = useState(false);
|
||||||
|
const [showInput, setShowInput] = useState(false);
|
||||||
|
const textareaRef = useRef<HTMLInputElement>(null);
|
||||||
|
const inputContainerRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
const fetchMemory = useCallback(async () => {
|
||||||
|
try {
|
||||||
|
setLoading(true);
|
||||||
|
const data = await baseApiService.get("/api/v1/users/me/memory", MemoryReadSchema);
|
||||||
|
setMemory(data.memory_md);
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to load memory");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchMemory();
|
||||||
|
}, [fetchMemory]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!showInput) return;
|
||||||
|
|
||||||
|
const handlePointerDownOutside = (event: MouseEvent | TouchEvent) => {
|
||||||
|
const target = event.target;
|
||||||
|
if (!(target instanceof Node)) return;
|
||||||
|
if (inputContainerRef.current?.contains(target)) return;
|
||||||
|
|
||||||
|
setShowInput(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
document.addEventListener("mousedown", handlePointerDownOutside);
|
||||||
|
document.addEventListener("touchstart", handlePointerDownOutside, { passive: true });
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
document.removeEventListener("mousedown", handlePointerDownOutside);
|
||||||
|
document.removeEventListener("touchstart", handlePointerDownOutside);
|
||||||
|
};
|
||||||
|
}, [showInput]);
|
||||||
|
|
||||||
|
const handleClear = async () => {
|
||||||
|
try {
|
||||||
|
setSaving(true);
|
||||||
|
const data = await baseApiService.put("/api/v1/users/me/memory", MemoryReadSchema, {
|
||||||
|
body: { memory_md: "" },
|
||||||
|
});
|
||||||
|
setMemory(data.memory_md);
|
||||||
|
toast.success("Memory cleared");
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to clear memory");
|
||||||
|
} finally {
|
||||||
|
setSaving(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleEdit = async () => {
|
||||||
|
const query = editQuery.trim();
|
||||||
|
if (!query) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
setEditing(true);
|
||||||
|
const data = await baseApiService.post("/api/v1/users/me/memory/edit", MemoryReadSchema, {
|
||||||
|
body: { query, search_space_id: Number(activeSearchSpaceId) },
|
||||||
|
});
|
||||||
|
setMemory(data.memory_md);
|
||||||
|
setEditQuery("");
|
||||||
|
setShowInput(false);
|
||||||
|
toast.success("Memory updated");
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to edit memory");
|
||||||
|
} finally {
|
||||||
|
setEditing(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const openInput = () => {
|
||||||
|
setShowInput(true);
|
||||||
|
requestAnimationFrame(() => textareaRef.current?.focus());
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDownload = () => {
|
||||||
|
if (!memory) return;
|
||||||
|
try {
|
||||||
|
const blob = new Blob([memory], { type: "text/markdown;charset=utf-8" });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement("a");
|
||||||
|
a.href = url;
|
||||||
|
a.download = "personal-memory.md";
|
||||||
|
document.body.appendChild(a);
|
||||||
|
a.click();
|
||||||
|
document.body.removeChild(a);
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to download memory");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCopyMarkdown = async () => {
|
||||||
|
if (!memory) return;
|
||||||
|
try {
|
||||||
|
await navigator.clipboard.writeText(memory);
|
||||||
|
toast.success("Copied to clipboard");
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to copy memory");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||||
|
if (e.key === "Enter" && !e.shiftKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
handleEdit();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const displayMemory = memory.replace(/\(\d{4}-\d{2}-\d{2}\)\s*\[(fact|pref|instr)\]\s*/g, "");
|
||||||
|
const charCount = memory.length;
|
||||||
|
|
||||||
|
const getCounterColor = () => {
|
||||||
|
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
|
||||||
|
if (charCount > 15_000) return "text-orange-500";
|
||||||
|
if (charCount > 10_000) return "text-yellow-500";
|
||||||
|
return "text-muted-foreground";
|
||||||
|
};
|
||||||
|
|
||||||
|
if (loading) {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-center py-12">
|
||||||
|
<Spinner size="md" className="text-muted-foreground" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!memory) {
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col items-center justify-center py-16 text-center">
|
||||||
|
<h3 className="text-base font-medium text-foreground">What does SurfSense remember?</h3>
|
||||||
|
<p className="mt-2 max-w-sm text-sm text-muted-foreground">
|
||||||
|
Nothing yet. SurfSense picks up on your preferences and context as you chat.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<Alert className="bg-muted/50 py-3 md:py-4">
|
||||||
|
<Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
|
||||||
|
<AlertDescription className="text-xs md:text-sm">
|
||||||
|
<p>
|
||||||
|
SurfSense uses this personal memory to personalize your responses across all
|
||||||
|
conversations.
|
||||||
|
</p>
|
||||||
|
</AlertDescription>
|
||||||
|
</Alert>
|
||||||
|
|
||||||
|
<div className="relative h-[380px] rounded-lg border bg-background">
|
||||||
|
<div className="h-full overflow-y-auto scrollbar-thin">
|
||||||
|
<PlateEditor
|
||||||
|
markdown={displayMemory}
|
||||||
|
readOnly
|
||||||
|
preset="readonly"
|
||||||
|
variant="default"
|
||||||
|
editorVariant="none"
|
||||||
|
className="px-5 py-4 text-sm min-h-full"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{showInput ? (
|
||||||
|
<div className="absolute bottom-3 inset-x-3 z-10">
|
||||||
|
<div
|
||||||
|
ref={inputContainerRef}
|
||||||
|
className="relative flex h-[54px] items-center gap-2 rounded-[9999px] border bg-muted/60 backdrop-blur-sm pl-4 pr-1 shadow-sm"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
ref={textareaRef}
|
||||||
|
type="text"
|
||||||
|
value={editQuery}
|
||||||
|
onChange={(e) => setEditQuery(e.target.value)}
|
||||||
|
onKeyDown={handleKeyDown}
|
||||||
|
placeholder="Tell SurfSense what to remember or forget"
|
||||||
|
disabled={editing}
|
||||||
|
className="flex-1 bg-transparent text-sm outline-none placeholder:text-muted-foreground/70"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
size="icon"
|
||||||
|
variant="ghost"
|
||||||
|
onClick={handleEdit}
|
||||||
|
disabled={editing || !editQuery.trim()}
|
||||||
|
className={`h-11 w-11 shrink-0 rounded-full ${
|
||||||
|
editing ? "" : "bg-muted-foreground/15 hover:bg-muted-foreground/20"
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{editing ? (
|
||||||
|
<Spinner size="sm" />
|
||||||
|
) : (
|
||||||
|
<ArrowUp className="!h-5 !w-5 text-foreground" strokeWidth={2.25} />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
size="icon"
|
||||||
|
variant="secondary"
|
||||||
|
onClick={openInput}
|
||||||
|
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
|
||||||
|
>
|
||||||
|
<Pen className="!h-5 !w-5" />
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex items-center justify-between gap-2">
|
||||||
|
<span className={`text-xs shrink-0 ${getCounterColor()}`}>
|
||||||
|
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()}
|
||||||
|
<span className="hidden sm:inline"> characters</span>
|
||||||
|
<span className="sm:hidden"> chars</span>
|
||||||
|
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
|
||||||
|
{charCount > MEMORY_HARD_LIMIT && " - Exceeds limit"}
|
||||||
|
</span>
|
||||||
|
<div className="flex items-center gap-1.5 sm:gap-2">
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="destructive"
|
||||||
|
size="sm"
|
||||||
|
className="text-xs sm:text-sm"
|
||||||
|
onClick={handleClear}
|
||||||
|
disabled={saving || editing || !memory}
|
||||||
|
>
|
||||||
|
<span className="hidden sm:inline">Reset Memory</span>
|
||||||
|
<span className="sm:hidden">Reset</span>
|
||||||
|
</Button>
|
||||||
|
<DropdownMenu>
|
||||||
|
<DropdownMenuTrigger asChild>
|
||||||
|
<Button type="button" variant="secondary" size="sm" disabled={!memory}>
|
||||||
|
Export
|
||||||
|
<ChevronDown className="h-3 w-3 opacity-60" />
|
||||||
|
</Button>
|
||||||
|
</DropdownMenuTrigger>
|
||||||
|
<DropdownMenuContent align="end">
|
||||||
|
<DropdownMenuItem onClick={handleCopyMarkdown}>
|
||||||
|
<ClipboardCopy className="h-4 w-4 mr-2" />
|
||||||
|
Copy as Markdown
|
||||||
|
</DropdownMenuItem>
|
||||||
|
<DropdownMenuItem onClick={handleDownload}>
|
||||||
|
<Download className="h-4 w-4 mr-2" />
|
||||||
|
Download as Markdown
|
||||||
|
</DropdownMenuItem>
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useQuery } from "@tanstack/react-query";
|
import { useQuery } from "@tanstack/react-query";
|
||||||
import { Receipt } from "lucide-react";
|
import { ReceiptText } from "lucide-react";
|
||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
import {
|
import {
|
||||||
|
|
@ -65,7 +65,7 @@ export function PurchaseHistoryContent() {
|
||||||
if (purchases.length === 0) {
|
if (purchases.length === 0) {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col items-center justify-center gap-2 py-16 text-center">
|
<div className="flex flex-col items-center justify-center gap-2 py-16 text-center">
|
||||||
<Receipt className="h-8 w-8 text-muted-foreground" />
|
<ReceiptText className="h-8 w-8 text-muted-foreground" />
|
||||||
<p className="text-sm font-medium">No purchases yet</p>
|
<p className="text-sm font-medium">No purchases yet</p>
|
||||||
<p className="text-xs text-muted-foreground">
|
<p className="text-xs text-muted-foreground">
|
||||||
Your page-pack purchases will appear here after checkout.
|
Your page-pack purchases will appear here after checkout.
|
||||||
|
|
|
||||||
|
|
@ -76,12 +76,8 @@ const GenerateImageToolUI = dynamic(
|
||||||
import("@/components/tool-ui/generate-image").then((m) => ({ default: m.GenerateImageToolUI })),
|
import("@/components/tool-ui/generate-image").then((m) => ({ default: m.GenerateImageToolUI })),
|
||||||
{ ssr: false }
|
{ ssr: false }
|
||||||
);
|
);
|
||||||
const SaveMemoryToolUI = dynamic(
|
const UpdateMemoryToolUI = dynamic(
|
||||||
() => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.SaveMemoryToolUI })),
|
() => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.UpdateMemoryToolUI })),
|
||||||
{ ssr: false }
|
|
||||||
);
|
|
||||||
const RecallMemoryToolUI = dynamic(
|
|
||||||
() => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.RecallMemoryToolUI })),
|
|
||||||
{ ssr: false }
|
{ ssr: false }
|
||||||
);
|
);
|
||||||
const SandboxExecuteToolUI = dynamic(
|
const SandboxExecuteToolUI = dynamic(
|
||||||
|
|
@ -386,8 +382,7 @@ const AssistantMessageInner: FC = () => {
|
||||||
generate_video_presentation: GenerateVideoPresentationToolUI,
|
generate_video_presentation: GenerateVideoPresentationToolUI,
|
||||||
display_image: GenerateImageToolUI,
|
display_image: GenerateImageToolUI,
|
||||||
generate_image: GenerateImageToolUI,
|
generate_image: GenerateImageToolUI,
|
||||||
save_memory: SaveMemoryToolUI,
|
update_memory: UpdateMemoryToolUI,
|
||||||
recall_memory: RecallMemoryToolUI,
|
|
||||||
execute: SandboxExecuteToolUI,
|
execute: SandboxExecuteToolUI,
|
||||||
create_notion_page: CreateNotionPageToolUI,
|
create_notion_page: CreateNotionPageToolUI,
|
||||||
update_notion_page: UpdateNotionPageToolUI,
|
update_notion_page: UpdateNotionPageToolUI,
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ import {
|
||||||
DropdownMenuTrigger,
|
DropdownMenuTrigger,
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||||
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
import { Switch } from "@/components/ui/switch";
|
import { Switch } from "@/components/ui/switch";
|
||||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||||
|
|
@ -92,7 +93,7 @@ import { useMediaQuery } from "@/hooks/use-media-query";
|
||||||
import { useElectronAPI } from "@/hooks/use-platform";
|
import { useElectronAPI } from "@/hooks/use-platform";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
const COMPOSER_PLACEHOLDER = "Ask anything · Type / for prompts · Type @ to mention docs";
|
const COMPOSER_PLACEHOLDER = "Ask anything, type / for prompts, type @ to mention docs";
|
||||||
|
|
||||||
export const Thread: FC = () => {
|
export const Thread: FC = () => {
|
||||||
return <ThreadContent />;
|
return <ThreadContent />;
|
||||||
|
|
@ -804,7 +805,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
||||||
const isDesktop = useMediaQuery("(min-width: 640px)");
|
const isDesktop = useMediaQuery("(min-width: 640px)");
|
||||||
const { openDialog: openUploadDialog } = useDocumentUploadDialog();
|
const { openDialog: openUploadDialog } = useDocumentUploadDialog();
|
||||||
const [toolsScrollPos, setToolsScrollPos] = useState<"top" | "middle" | "bottom">("top");
|
const [toolsScrollPos, setToolsScrollPos] = useState<"top" | "middle" | "bottom">("top");
|
||||||
const toolsRafRef = useRef<number>();
|
const toolsRafRef = useRef<number | undefined>(undefined);
|
||||||
const handleToolsScroll = useCallback((e: React.UIEvent<HTMLDivElement>) => {
|
const handleToolsScroll = useCallback((e: React.UIEvent<HTMLDivElement>) => {
|
||||||
const el = e.currentTarget;
|
const el = e.currentTarget;
|
||||||
if (toolsRafRef.current) return;
|
if (toolsRafRef.current) return;
|
||||||
|
|
@ -1021,8 +1022,23 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{!filteredTools?.length && (
|
{!filteredTools?.length && (
|
||||||
<div className="px-4 py-6 text-center text-sm text-muted-foreground">
|
<div className="px-4 pt-3 pb-2">
|
||||||
Loading tools...
|
<Skeleton className="h-3 w-16 mb-2" />
|
||||||
|
{["t1", "t2", "t3", "t4"].map((k) => (
|
||||||
|
<div key={k} className="flex items-center gap-3 py-2">
|
||||||
|
<Skeleton className="size-4 rounded shrink-0" />
|
||||||
|
<Skeleton className="h-3.5 flex-1" />
|
||||||
|
<Skeleton className="h-5 w-9 rounded-full shrink-0" />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
<Skeleton className="h-3 w-24 mt-3 mb-2" />
|
||||||
|
{["c1", "c2", "c3"].map((k) => (
|
||||||
|
<div key={k} className="flex items-center gap-3 py-2">
|
||||||
|
<Skeleton className="size-4 rounded shrink-0" />
|
||||||
|
<Skeleton className="h-3.5 flex-1" />
|
||||||
|
<Skeleton className="h-5 w-9 rounded-full shrink-0" />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -1058,12 +1074,12 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
||||||
side="bottom"
|
side="bottom"
|
||||||
align="start"
|
align="start"
|
||||||
sideOffset={12}
|
sideOffset={12}
|
||||||
className="w-[calc(100vw-2rem)] max-w-56 sm:max-w-72 sm:w-72 p-0 select-none"
|
className="w-[calc(100vw-2rem)] max-w-48 sm:max-w-56 sm:w-56 p-0 select-none"
|
||||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||||
>
|
>
|
||||||
<div className="sr-only">Manage Tools</div>
|
<div className="sr-only">Manage Tools</div>
|
||||||
<div
|
<div
|
||||||
className="max-h-48 sm:max-h-64 overflow-y-auto overscroll-none py-0.5 sm:py-1"
|
className="max-h-44 sm:max-h-56 overflow-y-auto overscroll-none py-0.5"
|
||||||
onScroll={handleToolsScroll}
|
onScroll={handleToolsScroll}
|
||||||
style={{
|
style={{
|
||||||
maskImage: `linear-gradient(to bottom, ${toolsScrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${toolsScrollPos === "bottom" ? "black" : "transparent"})`,
|
maskImage: `linear-gradient(to bottom, ${toolsScrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${toolsScrollPos === "bottom" ? "black" : "transparent"})`,
|
||||||
|
|
@ -1074,22 +1090,22 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
||||||
.filter((g) => !g.connectorIcon)
|
.filter((g) => !g.connectorIcon)
|
||||||
.map((group) => (
|
.map((group) => (
|
||||||
<div key={group.label}>
|
<div key={group.label}>
|
||||||
<div className="px-2.5 sm:px-3 pt-2 pb-0.5 text-[10px] sm:text-xs text-muted-foreground/80 font-normal select-none">
|
<div className="px-2 sm:px-2.5 pt-1.5 pb-0.5 text-[9px] sm:text-[10px] text-muted-foreground/80 font-normal select-none">
|
||||||
{group.label}
|
{group.label}
|
||||||
</div>
|
</div>
|
||||||
{group.tools.map((tool) => {
|
{group.tools.map((tool) => {
|
||||||
const isDisabled = disabledToolsSet.has(tool.name);
|
const isDisabled = disabledToolsSet.has(tool.name);
|
||||||
const ToolIcon = getToolIcon(tool.name);
|
const ToolIcon = getToolIcon(tool.name);
|
||||||
const row = (
|
const row = (
|
||||||
<div className="flex w-full items-center gap-2 sm:gap-3 px-2.5 sm:px-3 py-1 sm:py-1.5 hover:bg-muted-foreground/10 transition-colors">
|
<div className="flex w-full items-center gap-1.5 sm:gap-2 px-2 sm:px-2.5 py-0.5 sm:py-1 hover:bg-muted-foreground/10 transition-colors">
|
||||||
<ToolIcon className="size-3.5 sm:size-4 shrink-0 text-muted-foreground" />
|
<ToolIcon className="size-3 sm:size-3.5 shrink-0 text-muted-foreground" />
|
||||||
<span className="flex-1 min-w-0 text-xs sm:text-sm font-medium truncate">
|
<span className="flex-1 min-w-0 text-[11px] sm:text-xs font-medium truncate">
|
||||||
{formatToolName(tool.name)}
|
{formatToolName(tool.name)}
|
||||||
</span>
|
</span>
|
||||||
<Switch
|
<Switch
|
||||||
checked={!isDisabled}
|
checked={!isDisabled}
|
||||||
onCheckedChange={() => toggleTool(tool.name)}
|
onCheckedChange={() => toggleTool(tool.name)}
|
||||||
className="shrink-0 scale-[0.6] sm:scale-75"
|
className="shrink-0 scale-50 sm:scale-[0.6]"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
@ -1106,7 +1122,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
||||||
))}
|
))}
|
||||||
{groupedTools.some((g) => g.connectorIcon) && (
|
{groupedTools.some((g) => g.connectorIcon) && (
|
||||||
<div>
|
<div>
|
||||||
<div className="px-2.5 sm:px-3 pt-2 pb-0.5 text-[10px] sm:text-xs text-muted-foreground/80 font-normal select-none">
|
<div className="px-2 sm:px-2.5 pt-1.5 pb-0.5 text-[9px] sm:text-[10px] text-muted-foreground/80 font-normal select-none">
|
||||||
Connector Actions
|
Connector Actions
|
||||||
</div>
|
</div>
|
||||||
{groupedTools
|
{groupedTools
|
||||||
|
|
@ -1118,26 +1134,26 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
||||||
const allDisabled = toolNames.every((n) => disabledToolsSet.has(n));
|
const allDisabled = toolNames.every((n) => disabledToolsSet.has(n));
|
||||||
const groupDef = TOOL_GROUPS.find((g) => g.label === group.label);
|
const groupDef = TOOL_GROUPS.find((g) => g.label === group.label);
|
||||||
const row = (
|
const row = (
|
||||||
<div className="flex w-full items-center gap-2 sm:gap-3 px-2.5 sm:px-3 py-1 sm:py-1.5 hover:bg-muted-foreground/10 transition-colors">
|
<div className="flex w-full items-center gap-1.5 sm:gap-2 px-2 sm:px-2.5 py-0.5 sm:py-1 hover:bg-muted-foreground/10 transition-colors">
|
||||||
{iconInfo ? (
|
{iconInfo ? (
|
||||||
<Image
|
<Image
|
||||||
src={iconInfo.src}
|
src={iconInfo.src}
|
||||||
alt={iconInfo.alt}
|
alt={iconInfo.alt}
|
||||||
width={16}
|
width={14}
|
||||||
height={16}
|
height={14}
|
||||||
className="size-3.5 sm:size-4 shrink-0 select-none pointer-events-none"
|
className="size-3 sm:size-3.5 shrink-0 select-none pointer-events-none"
|
||||||
draggable={false}
|
draggable={false}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<Wrench className="size-3.5 sm:size-4 shrink-0 text-muted-foreground" />
|
<Wrench className="size-3 sm:size-3.5 shrink-0 text-muted-foreground" />
|
||||||
)}
|
)}
|
||||||
<span className="flex-1 min-w-0 text-xs sm:text-sm font-medium truncate">
|
<span className="flex-1 min-w-0 text-[11px] sm:text-xs font-medium truncate">
|
||||||
{group.label}
|
{group.label}
|
||||||
</span>
|
</span>
|
||||||
<Switch
|
<Switch
|
||||||
checked={!allDisabled}
|
checked={!allDisabled}
|
||||||
onCheckedChange={() => toggleToolGroup(toolNames)}
|
onCheckedChange={() => toggleToolGroup(toolNames)}
|
||||||
className="shrink-0 scale-[0.6] sm:scale-75"
|
className="shrink-0 scale-50 sm:scale-[0.6]"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
@ -1158,8 +1174,23 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{!filteredTools?.length && (
|
{!filteredTools?.length && (
|
||||||
<div className="px-3 py-4 text-center text-xs text-muted-foreground">
|
<div className="px-2 sm:px-2.5 pt-1.5 pb-1">
|
||||||
Loading tools...
|
<Skeleton className="h-2 w-12 mb-1.5" />
|
||||||
|
{["dt1", "dt2", "dt3", "dt4"].map((k) => (
|
||||||
|
<div key={k} className="flex items-center gap-1.5 sm:gap-2 py-0.5 sm:py-1">
|
||||||
|
<Skeleton className="size-3 sm:size-3.5 rounded shrink-0" />
|
||||||
|
<Skeleton className="h-2.5 sm:h-3 flex-1" />
|
||||||
|
<Skeleton className="h-3.5 sm:h-4 w-7 sm:w-8 rounded-full shrink-0" />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
<Skeleton className="h-2 w-20 mt-2 mb-1.5" />
|
||||||
|
{["dc1", "dc2", "dc3"].map((k) => (
|
||||||
|
<div key={k} className="flex items-center gap-1.5 sm:gap-2 py-0.5 sm:py-1">
|
||||||
|
<Skeleton className="size-3 sm:size-3.5 rounded shrink-0" />
|
||||||
|
<Skeleton className="h-2.5 sm:h-3 flex-1" />
|
||||||
|
<Skeleton className="h-3.5 sm:h-4 w-7 sm:w-8 rounded-full shrink-0" />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -1297,7 +1328,7 @@ const TOOL_GROUPS: ToolGroup[] = [
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
label: "Memory",
|
label: "Memory",
|
||||||
tools: ["save_memory", "recall_memory"],
|
tools: ["update_memory"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
label: "Gmail",
|
label: "Gmail",
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import { MarkdownPlugin, remarkMdx } from "@platejs/markdown";
|
import { MarkdownPlugin, remarkMdx } from "@platejs/markdown";
|
||||||
import { slateToHtml } from "@slate-serializers/html";
|
import { slateToHtml } from "@slate-serializers/html";
|
||||||
import type { AnyPluginConfig, Descendant, Value } from "platejs";
|
import type { AnyPluginConfig, Descendant, Value } from "platejs";
|
||||||
import { createPlatePlugin, Key, Plate, usePlateEditor } from "platejs/react";
|
import { createPlatePlugin, Key, Plate, useEditorReadOnly, usePlateEditor } from "platejs/react";
|
||||||
import { useEffect, useMemo, useRef } from "react";
|
import { useEffect, useMemo, useRef } from "react";
|
||||||
import remarkGfm from "remark-gfm";
|
import remarkGfm from "remark-gfm";
|
||||||
import remarkMath from "remark-math";
|
import remarkMath from "remark-math";
|
||||||
|
|
@ -60,6 +60,24 @@ export interface PlateEditorProps {
|
||||||
extraPlugins?: AnyPluginConfig[];
|
extraPlugins?: AnyPluginConfig[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function PlateEditorContent({
|
||||||
|
editorVariant,
|
||||||
|
placeholder,
|
||||||
|
}: {
|
||||||
|
editorVariant: PlateEditorProps["editorVariant"];
|
||||||
|
placeholder?: string;
|
||||||
|
}) {
|
||||||
|
const isReadOnly = useEditorReadOnly();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Editor
|
||||||
|
variant={editorVariant}
|
||||||
|
placeholder={isReadOnly ? undefined : placeholder}
|
||||||
|
className="min-h-full"
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export function PlateEditor({
|
export function PlateEditor({
|
||||||
markdown,
|
markdown,
|
||||||
html,
|
html,
|
||||||
|
|
@ -188,7 +206,7 @@ export function PlateEditor({
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<EditorContainer variant={variant} className={className}>
|
<EditorContainer variant={variant} className={className}>
|
||||||
<Editor variant={editorVariant} placeholder={placeholder} />
|
<PlateEditorContent editorVariant={editorVariant} placeholder={placeholder} />
|
||||||
</EditorContainer>
|
</EditorContainer>
|
||||||
</Plate>
|
</Plate>
|
||||||
</EditorSaveContext.Provider>
|
</EditorSaveContext.Provider>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import type { AnyPluginConfig } from "platejs";
|
import type { AnyPluginConfig } from "platejs";
|
||||||
|
import { TrailingBlockPlugin } from "platejs";
|
||||||
|
|
||||||
import { AutoformatKit } from "@/components/editor/plugins/autoformat-kit";
|
import { AutoformatKit } from "@/components/editor/plugins/autoformat-kit";
|
||||||
import { BasicNodesKit } from "@/components/editor/plugins/basic-nodes-kit";
|
import { BasicNodesKit } from "@/components/editor/plugins/basic-nodes-kit";
|
||||||
|
|
@ -36,6 +37,7 @@ export const fullPreset: AnyPluginConfig[] = [
|
||||||
...FloatingToolbarKit,
|
...FloatingToolbarKit,
|
||||||
...AutoformatKit,
|
...AutoformatKit,
|
||||||
...DndKit,
|
...DndKit,
|
||||||
|
TrailingBlockPlugin,
|
||||||
];
|
];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -48,8 +50,8 @@ export const minimalPreset: AnyPluginConfig[] = [
|
||||||
...ListKit,
|
...ListKit,
|
||||||
...CodeBlockKit,
|
...CodeBlockKit,
|
||||||
...LinkKit,
|
...LinkKit,
|
||||||
...FloatingToolbarKit,
|
|
||||||
...AutoformatKit,
|
...AutoformatKit,
|
||||||
|
TrailingBlockPlugin,
|
||||||
];
|
];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
import { ChevronDown, Download, Monitor } from "lucide-react";
|
import { ChevronDown, Download, Monitor } from "lucide-react";
|
||||||
import { AnimatePresence, motion } from "motion/react";
|
import { AnimatePresence, motion } from "motion/react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from "react";
|
import React, { memo, useCallback, useEffect, useRef, useState } from "react";
|
||||||
import Balancer from "react-wrap-balancer";
|
import Balancer from "react-wrap-balancer";
|
||||||
import {
|
import {
|
||||||
DropdownMenu,
|
DropdownMenu,
|
||||||
|
|
@ -12,6 +12,11 @@ import {
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
import { ExpandedMediaOverlay, useExpandedMedia } from "@/components/ui/expanded-gif-overlay";
|
import { ExpandedMediaOverlay, useExpandedMedia } from "@/components/ui/expanded-gif-overlay";
|
||||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||||
|
import {
|
||||||
|
GITHUB_RELEASES_URL,
|
||||||
|
getAssetLabel,
|
||||||
|
usePrimaryDownload,
|
||||||
|
} from "@/lib/desktop-download-utils";
|
||||||
import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config";
|
import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config";
|
||||||
import { trackLoginAttempt } from "@/lib/posthog/events";
|
import { trackLoginAttempt } from "@/lib/posthog/events";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
@ -200,107 +205,8 @@ function GetStartedButton() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
type OSInfo = {
|
|
||||||
os: "macOS" | "Windows" | "Linux";
|
|
||||||
arch: "arm64" | "x64";
|
|
||||||
};
|
|
||||||
|
|
||||||
function useUserOS(): OSInfo {
|
|
||||||
const [info, setInfo] = useState<OSInfo>({ os: "macOS", arch: "arm64" });
|
|
||||||
useEffect(() => {
|
|
||||||
const ua = navigator.userAgent;
|
|
||||||
let os: OSInfo["os"] = "macOS";
|
|
||||||
let arch: OSInfo["arch"] = "x64";
|
|
||||||
|
|
||||||
if (/Windows/i.test(ua)) {
|
|
||||||
os = "Windows";
|
|
||||||
arch = "x64";
|
|
||||||
} else if (/Linux/i.test(ua)) {
|
|
||||||
os = "Linux";
|
|
||||||
arch = "x64";
|
|
||||||
} else {
|
|
||||||
os = "macOS";
|
|
||||||
arch = /Mac/.test(ua) && !/Intel/.test(ua) ? "arm64" : "arm64";
|
|
||||||
}
|
|
||||||
|
|
||||||
const uaData = (navigator as Navigator & { userAgentData?: { architecture?: string } })
|
|
||||||
.userAgentData;
|
|
||||||
if (uaData?.architecture === "arm") arch = "arm64";
|
|
||||||
else if (uaData?.architecture === "x86") arch = "x64";
|
|
||||||
|
|
||||||
setInfo({ os, arch });
|
|
||||||
}, []);
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ReleaseAsset {
|
|
||||||
name: string;
|
|
||||||
url: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
function useLatestRelease() {
|
|
||||||
const [assets, setAssets] = useState<ReleaseAsset[]>([]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const controller = new AbortController();
|
|
||||||
fetch("https://api.github.com/repos/MODSetter/SurfSense/releases/latest", {
|
|
||||||
signal: controller.signal,
|
|
||||||
})
|
|
||||||
.then((r) => r.json())
|
|
||||||
.then((data) => {
|
|
||||||
if (data?.assets) {
|
|
||||||
setAssets(
|
|
||||||
data.assets
|
|
||||||
.filter((a: { name: string }) => /\.(exe|dmg|AppImage|deb)$/.test(a.name))
|
|
||||||
.map((a: { name: string; browser_download_url: string }) => ({
|
|
||||||
name: a.name,
|
|
||||||
url: a.browser_download_url,
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.catch(() => {});
|
|
||||||
return () => controller.abort();
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
return assets;
|
|
||||||
}
|
|
||||||
|
|
||||||
const ASSET_LABELS: Record<string, string> = {
|
|
||||||
".exe": "Windows (exe)",
|
|
||||||
"-arm64.dmg": "macOS Apple Silicon (dmg)",
|
|
||||||
"-x64.dmg": "macOS Intel (dmg)",
|
|
||||||
"-arm64.zip": "macOS Apple Silicon (zip)",
|
|
||||||
"-x64.zip": "macOS Intel (zip)",
|
|
||||||
".AppImage": "Linux (AppImage)",
|
|
||||||
".deb": "Linux (deb)",
|
|
||||||
};
|
|
||||||
|
|
||||||
function getAssetLabel(name: string): string {
|
|
||||||
for (const [suffix, label] of Object.entries(ASSET_LABELS)) {
|
|
||||||
if (name.endsWith(suffix)) return label;
|
|
||||||
}
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
function DownloadButton() {
|
function DownloadButton() {
|
||||||
const { os, arch } = useUserOS();
|
const { os, primary, alternatives } = usePrimaryDownload();
|
||||||
const assets = useLatestRelease();
|
|
||||||
|
|
||||||
const { primary, alternatives } = useMemo(() => {
|
|
||||||
if (assets.length === 0) return { primary: null, alternatives: [] };
|
|
||||||
|
|
||||||
const matchers: Record<string, (n: string) => boolean> = {
|
|
||||||
Windows: (n) => n.endsWith(".exe"),
|
|
||||||
macOS: (n) => n.endsWith(`-${arch}.dmg`),
|
|
||||||
Linux: (n) => n.endsWith(".AppImage"),
|
|
||||||
};
|
|
||||||
|
|
||||||
const match = matchers[os];
|
|
||||||
const primary = assets.find((a) => match(a.name)) ?? null;
|
|
||||||
const alternatives = assets.filter((a) => a !== primary);
|
|
||||||
return { primary, alternatives };
|
|
||||||
}, [assets, os, arch]);
|
|
||||||
|
|
||||||
const fallbackUrl = GITHUB_RELEASES_URL;
|
const fallbackUrl = GITHUB_RELEASES_URL;
|
||||||
|
|
||||||
|
|
@ -504,5 +410,3 @@ const TabVideo = memo(function TabVideo({ src }: { src: string }) {
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
const GITHUB_RELEASES_URL = "https://github.com/MODSetter/SurfSense/releases/latest";
|
|
||||||
|
|
|
||||||
|
|
@ -91,12 +91,12 @@ export function SidebarSlideOutPanel({
|
||||||
|
|
||||||
{/* Panel extending from sidebar's right edge, flush with the wrapper border */}
|
{/* Panel extending from sidebar's right edge, flush with the wrapper border */}
|
||||||
<motion.div
|
<motion.div
|
||||||
style={{ width, left: "100%", top: -1, bottom: -1 }}
|
initial={{ width: 0 }}
|
||||||
initial={{ x: -width }}
|
animate={{ width }}
|
||||||
animate={{ x: 0 }}
|
exit={{ width: 0 }}
|
||||||
exit={{ x: -width }}
|
|
||||||
transition={{ type: "tween", duration: 0.2, ease: [0.4, 0, 0.2, 1] }}
|
transition={{ type: "tween", duration: 0.2, ease: [0.4, 0, 0.2, 1] }}
|
||||||
className="absolute z-20 overflow-hidden"
|
className="absolute z-20 overflow-hidden"
|
||||||
|
style={{ left: "100%", top: -1, bottom: -1 }}
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
style={{ width }}
|
style={{ width }}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import {
|
import {
|
||||||
Check,
|
Check,
|
||||||
ChevronUp,
|
ChevronUp,
|
||||||
|
Download,
|
||||||
ExternalLink,
|
ExternalLink,
|
||||||
Info,
|
Info,
|
||||||
Languages,
|
Languages,
|
||||||
|
|
@ -29,6 +30,8 @@ import {
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
import { useLocaleContext } from "@/contexts/LocaleContext";
|
import { useLocaleContext } from "@/contexts/LocaleContext";
|
||||||
|
import { usePlatform } from "@/hooks/use-platform";
|
||||||
|
import { GITHUB_RELEASES_URL, usePrimaryDownload } from "@/lib/desktop-download-utils";
|
||||||
import { APP_VERSION } from "@/lib/env-config";
|
import { APP_VERSION } from "@/lib/env-config";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import type { User } from "../../types/layout.types";
|
import type { User } from "../../types/layout.types";
|
||||||
|
|
@ -149,10 +152,13 @@ export function SidebarUserProfile({
|
||||||
}: SidebarUserProfileProps) {
|
}: SidebarUserProfileProps) {
|
||||||
const t = useTranslations("sidebar");
|
const t = useTranslations("sidebar");
|
||||||
const { locale, setLocale } = useLocaleContext();
|
const { locale, setLocale } = useLocaleContext();
|
||||||
|
const { isDesktop } = usePlatform();
|
||||||
|
const { os, primary } = usePrimaryDownload();
|
||||||
const [isLoggingOut, setIsLoggingOut] = useState(false);
|
const [isLoggingOut, setIsLoggingOut] = useState(false);
|
||||||
const bgColor = stringToColor(user.email);
|
const bgColor = stringToColor(user.email);
|
||||||
const initials = getInitials(user.email);
|
const initials = getInitials(user.email);
|
||||||
const displayName = user.name || user.email.split("@")[0];
|
const displayName = user.name || user.email.split("@")[0];
|
||||||
|
const downloadUrl = primary?.url ?? GITHUB_RELEASES_URL;
|
||||||
|
|
||||||
const handleLanguageChange = (newLocale: "en" | "es" | "pt" | "hi" | "zh") => {
|
const handleLanguageChange = (newLocale: "en" | "es" | "pt" | "hi" | "zh") => {
|
||||||
setLocale(newLocale);
|
setLocale(newLocale);
|
||||||
|
|
@ -294,6 +300,15 @@ export function SidebarUserProfile({
|
||||||
</DropdownMenuPortal>
|
</DropdownMenuPortal>
|
||||||
</DropdownMenuSub>
|
</DropdownMenuSub>
|
||||||
|
|
||||||
|
{!isDesktop && (
|
||||||
|
<DropdownMenuItem asChild className="font-medium">
|
||||||
|
<a href={downloadUrl} target="_blank" rel="noopener noreferrer">
|
||||||
|
<Download className="h-4 w-4" strokeWidth={2.5} />
|
||||||
|
{t("download_for_os", { os })}
|
||||||
|
</a>
|
||||||
|
</DropdownMenuItem>
|
||||||
|
)}
|
||||||
|
|
||||||
<DropdownMenuSeparator className="dark:bg-neutral-700" />
|
<DropdownMenuSeparator className="dark:bg-neutral-700" />
|
||||||
|
|
||||||
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
||||||
|
|
@ -439,6 +454,15 @@ export function SidebarUserProfile({
|
||||||
</DropdownMenuPortal>
|
</DropdownMenuPortal>
|
||||||
</DropdownMenuSub>
|
</DropdownMenuSub>
|
||||||
|
|
||||||
|
{!isDesktop && (
|
||||||
|
<DropdownMenuItem asChild className="font-medium">
|
||||||
|
<a href={downloadUrl} target="_blank" rel="noopener noreferrer">
|
||||||
|
<Download className="h-4 w-4" strokeWidth={2.5} />
|
||||||
|
{t("download_for_os", { os })}
|
||||||
|
</a>
|
||||||
|
</DropdownMenuItem>
|
||||||
|
)}
|
||||||
|
|
||||||
<DropdownMenuSeparator className="dark:bg-neutral-700" />
|
<DropdownMenuSeparator className="dark:bg-neutral-700" />
|
||||||
|
|
||||||
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,39 @@ import {
|
||||||
DropdownMenuItem,
|
DropdownMenuItem,
|
||||||
DropdownMenuTrigger,
|
DropdownMenuTrigger,
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
import { Skeleton } from "@/components/ui/skeleton";
|
|
||||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||||
|
|
||||||
|
function ReportPanelSkeleton() {
|
||||||
|
return (
|
||||||
|
<div className="space-y-6 p-6">
|
||||||
|
<div className="h-6 w-3/4 rounded-md bg-muted/60 animate-pulse" />
|
||||||
|
<div className="space-y-2.5">
|
||||||
|
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse" />
|
||||||
|
<div className="h-3 w-[95%] rounded-md bg-muted/60 animate-pulse [animation-delay:100ms]" />
|
||||||
|
<div className="h-3 w-[88%] rounded-md bg-muted/60 animate-pulse [animation-delay:200ms]" />
|
||||||
|
<div className="h-3 w-[60%] rounded-md bg-muted/60 animate-pulse [animation-delay:300ms]" />
|
||||||
|
</div>
|
||||||
|
<div className="h-5 w-2/5 rounded-md bg-muted/60 animate-pulse [animation-delay:400ms]" />
|
||||||
|
<div className="space-y-2.5">
|
||||||
|
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:500ms]" />
|
||||||
|
<div className="h-3 w-[92%] rounded-md bg-muted/60 animate-pulse [animation-delay:600ms]" />
|
||||||
|
<div className="h-3 w-[97%] rounded-md bg-muted/60 animate-pulse [animation-delay:700ms]" />
|
||||||
|
</div>
|
||||||
|
<div className="h-5 w-1/3 rounded-md bg-muted/60 animate-pulse [animation-delay:800ms]" />
|
||||||
|
<div className="space-y-2.5">
|
||||||
|
<div className="h-3 w-[90%] rounded-md bg-muted/60 animate-pulse [animation-delay:900ms]" />
|
||||||
|
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:1000ms]" />
|
||||||
|
<div className="h-3 w-[75%] rounded-md bg-muted/60 animate-pulse [animation-delay:1100ms]" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const PlateEditor = dynamic(
|
const PlateEditor = dynamic(
|
||||||
() => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })),
|
() => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })),
|
||||||
{ ssr: false, loading: () => <Skeleton className="h-64 w-full" /> }
|
{ ssr: false, loading: () => <ReportPanelSkeleton /> }
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -59,46 +84,6 @@ const ReportContentResponseSchema = z.object({
|
||||||
type ReportContentResponse = z.infer<typeof ReportContentResponseSchema>;
|
type ReportContentResponse = z.infer<typeof ReportContentResponseSchema>;
|
||||||
type VersionInfo = z.infer<typeof VersionInfoSchema>;
|
type VersionInfo = z.infer<typeof VersionInfoSchema>;
|
||||||
|
|
||||||
/**
|
|
||||||
* Shimmer loading skeleton for report panel
|
|
||||||
*/
|
|
||||||
function ReportPanelSkeleton() {
|
|
||||||
return (
|
|
||||||
<div className="space-y-6 p-6">
|
|
||||||
{/* Title skeleton */}
|
|
||||||
<div className="h-6 w-3/4 rounded-md bg-muted/60 animate-pulse" />
|
|
||||||
|
|
||||||
{/* Paragraph 1 */}
|
|
||||||
<div className="space-y-2.5">
|
|
||||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse" />
|
|
||||||
<div className="h-3 w-[95%] rounded-md bg-muted/60 animate-pulse [animation-delay:100ms]" />
|
|
||||||
<div className="h-3 w-[88%] rounded-md bg-muted/60 animate-pulse [animation-delay:200ms]" />
|
|
||||||
<div className="h-3 w-[60%] rounded-md bg-muted/60 animate-pulse [animation-delay:300ms]" />
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Heading */}
|
|
||||||
<div className="h-5 w-2/5 rounded-md bg-muted/60 animate-pulse [animation-delay:400ms]" />
|
|
||||||
|
|
||||||
{/* Paragraph 2 */}
|
|
||||||
<div className="space-y-2.5">
|
|
||||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:500ms]" />
|
|
||||||
<div className="h-3 w-[92%] rounded-md bg-muted/60 animate-pulse [animation-delay:600ms]" />
|
|
||||||
<div className="h-3 w-[97%] rounded-md bg-muted/60 animate-pulse [animation-delay:700ms]" />
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Heading */}
|
|
||||||
<div className="h-5 w-1/3 rounded-md bg-muted/60 animate-pulse [animation-delay:800ms]" />
|
|
||||||
|
|
||||||
{/* Paragraph 3 */}
|
|
||||||
<div className="space-y-2.5">
|
|
||||||
<div className="h-3 w-[90%] rounded-md bg-muted/60 animate-pulse [animation-delay:900ms]" />
|
|
||||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:1000ms]" />
|
|
||||||
<div className="h-3 w-[75%] rounded-md bg-muted/60 animate-pulse [animation-delay:1100ms]" />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inner content component used by desktop panel, mobile drawer, and the layout right panel
|
* Inner content component used by desktop panel, mobile drawer, and the layout right panel
|
||||||
*/
|
*/
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
||||||
? "model"
|
? "model"
|
||||||
: "models"}
|
: "models"}
|
||||||
</span>{" "}
|
</span>{" "}
|
||||||
available from your administrator. Use the model selector to view and select them.
|
available from your administrator.
|
||||||
</p>
|
</p>
|
||||||
</AlertDescription>
|
</AlertDescription>
|
||||||
</Alert>
|
</Alert>
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import {
|
||||||
FileText,
|
FileText,
|
||||||
ImageIcon,
|
ImageIcon,
|
||||||
RefreshCw,
|
RefreshCw,
|
||||||
Shuffle,
|
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
|
|
@ -44,7 +43,6 @@ import {
|
||||||
} from "@/components/ui/select";
|
} from "@/components/ui/select";
|
||||||
import { Skeleton } from "@/components/ui/skeleton";
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
import { getProviderIcon } from "@/lib/provider-icons";
|
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
const ROLE_DESCRIPTIONS = {
|
const ROLE_DESCRIPTIONS = {
|
||||||
|
|
@ -79,8 +77,8 @@ const ROLE_DESCRIPTIONS = {
|
||||||
icon: Eye,
|
icon: Eye,
|
||||||
title: "Vision LLM",
|
title: "Vision LLM",
|
||||||
description: "Vision-capable model for screenshot analysis and context extraction",
|
description: "Vision-capable model for screenshot analysis and context extraction",
|
||||||
color: "text-amber-600 dark:text-amber-400",
|
color: "text-muted-foreground",
|
||||||
bgColor: "bg-amber-500/10",
|
bgColor: "bg-muted",
|
||||||
prefKey: "vision_llm_config_id" as const,
|
prefKey: "vision_llm_config_id" as const,
|
||||||
configType: "vision" as const,
|
configType: "vision" as const,
|
||||||
},
|
},
|
||||||
|
|
@ -205,11 +203,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
const isAssignmentComplete =
|
|
||||||
allLLMConfigs.some((c) => c.id === assignments.agent_llm_id) &&
|
|
||||||
allLLMConfigs.some((c) => c.id === assignments.document_summary_llm_id) &&
|
|
||||||
allImageConfigs.some((c) => c.id === assignments.image_generation_config_id);
|
|
||||||
|
|
||||||
const isLoading =
|
const isLoading =
|
||||||
configsLoading ||
|
configsLoading ||
|
||||||
preferencesLoading ||
|
preferencesLoading ||
|
||||||
|
|
@ -231,7 +224,7 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
return (
|
return (
|
||||||
<div className="space-y-5 md:space-y-6">
|
<div className="space-y-5 md:space-y-6">
|
||||||
{/* Header actions */}
|
{/* Header actions */}
|
||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-start">
|
||||||
<Button
|
<Button
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
size="sm"
|
size="sm"
|
||||||
|
|
@ -239,15 +232,9 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
disabled={isLoading}
|
disabled={isLoading}
|
||||||
className="gap-2"
|
className="gap-2"
|
||||||
>
|
>
|
||||||
<RefreshCw className="h-3.5 w-3.5" />
|
<RefreshCw className={cn("h-3.5 w-3.5", isLoading && "animate-spin")} />
|
||||||
Refresh
|
Refresh
|
||||||
</Button>
|
</Button>
|
||||||
{isAssignmentComplete && !isLoading && !hasError && (
|
|
||||||
<Badge variant="outline" className="text-xs gap-1.5 text-muted-foreground">
|
|
||||||
<CircleCheck className="h-3 w-3" />
|
|
||||||
All roles assigned
|
|
||||||
</Badge>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Error Alert */}
|
{/* Error Alert */}
|
||||||
|
|
@ -343,8 +330,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
|
|
||||||
const assignedConfig = roleAllConfigs.find((config) => config.id === currentAssignment);
|
const assignedConfig = roleAllConfigs.find((config) => config.id === currentAssignment);
|
||||||
const isAssigned = !!assignedConfig;
|
const isAssigned = !!assignedConfig;
|
||||||
const isAutoMode =
|
|
||||||
assignedConfig && "is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div key={key}>
|
<div key={key}>
|
||||||
|
|
@ -389,7 +374,7 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
<SelectTrigger className="w-full h-9 md:h-10 text-xs md:text-sm">
|
<SelectTrigger className="w-full h-9 md:h-10 text-xs md:text-sm">
|
||||||
<SelectValue placeholder="Select a configuration" />
|
<SelectValue placeholder="Select a configuration" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent className="max-w-[calc(100vw-2rem)]">
|
<SelectContent className="max-w-[calc(100vw-2rem)] select-none">
|
||||||
<SelectItem
|
<SelectItem
|
||||||
value="unassigned"
|
value="unassigned"
|
||||||
className="text-xs md:text-sm py-1.5 md:py-2"
|
className="text-xs md:text-sm py-1.5 md:py-2"
|
||||||
|
|
@ -412,21 +397,9 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
className="text-xs md:text-sm py-1.5 md:py-2"
|
className="text-xs md:text-sm py-1.5 md:py-2"
|
||||||
>
|
>
|
||||||
<div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0">
|
<div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0">
|
||||||
{isAuto ? (
|
|
||||||
<Shuffle className="size-3 md:size-3.5 shrink-0 text-muted-foreground" />
|
|
||||||
) : (
|
|
||||||
getProviderIcon(config.provider, {
|
|
||||||
className: "size-3 md:size-3.5 shrink-0",
|
|
||||||
})
|
|
||||||
)}
|
|
||||||
<span className="truncate text-xs md:text-sm">
|
<span className="truncate text-xs md:text-sm">
|
||||||
{config.name}
|
{config.name}
|
||||||
</span>
|
</span>
|
||||||
{!isAuto && (
|
|
||||||
<span className="text-muted-foreground text-[10px] md:text-[11px] truncate">
|
|
||||||
({config.model_name})
|
|
||||||
</span>
|
|
||||||
)}
|
|
||||||
{isAuto && (
|
{isAuto && (
|
||||||
<Badge
|
<Badge
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
|
|
@ -455,15 +428,9 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
className="text-xs md:text-sm py-1.5 md:py-2"
|
className="text-xs md:text-sm py-1.5 md:py-2"
|
||||||
>
|
>
|
||||||
<div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0">
|
<div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0">
|
||||||
{getProviderIcon(config.provider, {
|
|
||||||
className: "size-3 md:size-3.5 shrink-0",
|
|
||||||
})}
|
|
||||||
<span className="truncate text-xs md:text-sm">
|
<span className="truncate text-xs md:text-sm">
|
||||||
{config.name}
|
{config.name}
|
||||||
</span>
|
</span>
|
||||||
<span className="text-muted-foreground text-[10px] md:text-[11px] truncate">
|
|
||||||
({config.model_name})
|
|
||||||
</span>
|
|
||||||
</div>
|
</div>
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
))}
|
))}
|
||||||
|
|
@ -472,63 +439,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Assigned Config Summary */}
|
|
||||||
{assignedConfig && (
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"rounded-lg p-3 border",
|
|
||||||
isAutoMode
|
|
||||||
? "bg-violet-50 dark:bg-violet-900/10 border-violet-200/50 dark:border-violet-800/30"
|
|
||||||
: "bg-muted/40 border-border/50"
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{isAutoMode ? (
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<Shuffle
|
|
||||||
className={cn(
|
|
||||||
"w-3.5 h-3.5 shrink-0 text-violet-600 dark:text-violet-400"
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
<div className="min-w-0">
|
|
||||||
<p className="text-xs font-medium text-violet-700 dark:text-violet-300">
|
|
||||||
Auto Mode
|
|
||||||
</p>
|
|
||||||
<p className="text-[10px] text-violet-600/70 dark:text-violet-400/70 mt-0.5">
|
|
||||||
Routes across all available providers
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<div className="flex items-start gap-2">
|
|
||||||
<IconComponent className="w-3.5 h-3.5 shrink-0 mt-0.5 text-muted-foreground" />
|
|
||||||
<div className="min-w-0 flex-1">
|
|
||||||
<div className="flex items-center gap-1.5 flex-wrap">
|
|
||||||
<span className="text-xs font-medium">{assignedConfig.name}</span>
|
|
||||||
{"is_global" in assignedConfig && assignedConfig.is_global && (
|
|
||||||
<Badge variant="secondary" className="text-[9px] px-1.5 py-0">
|
|
||||||
🌐 Global
|
|
||||||
</Badge>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
<div className="flex items-center gap-1.5 mt-1">
|
|
||||||
{getProviderIcon(assignedConfig.provider, {
|
|
||||||
className: "size-3 shrink-0",
|
|
||||||
})}
|
|
||||||
<code className="text-[10px] text-muted-foreground font-mono truncate">
|
|
||||||
{assignedConfig.model_name}
|
|
||||||
</code>
|
|
||||||
</div>
|
|
||||||
{assignedConfig.api_base && (
|
|
||||||
<p className="text-[10px] text-muted-foreground/60 mt-1 truncate">
|
|
||||||
{assignedConfig.api_base}
|
|
||||||
</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -196,7 +196,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
||||||
<span className="font-medium">
|
<span className="font-medium">
|
||||||
{globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"}
|
{globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"}
|
||||||
</span>{" "}
|
</span>{" "}
|
||||||
available from your administrator. Use the model selector to view and select them.
|
available from your administrator.
|
||||||
</p>
|
</p>
|
||||||
</AlertDescription>
|
</AlertDescription>
|
||||||
</Alert>
|
</Alert>
|
||||||
|
|
|
||||||
|
|
@ -113,11 +113,11 @@ export function MorePagesContent() {
|
||||||
{isLoading ? (
|
{isLoading ? (
|
||||||
<Card>
|
<Card>
|
||||||
<CardContent className="flex items-center gap-3 p-3">
|
<CardContent className="flex items-center gap-3 p-3">
|
||||||
<Skeleton className="h-8 w-8 rounded-full bg-muted" />
|
<Skeleton className="h-8 w-8 rounded-full" />
|
||||||
<div className="flex-1 space-y-2">
|
<div className="flex-1 space-y-2">
|
||||||
<Skeleton className="h-4 w-3/4 bg-muted" />
|
<Skeleton className="h-4 w-3/4" />
|
||||||
</div>
|
</div>
|
||||||
<Skeleton className="h-8 w-16 bg-muted" />
|
<Skeleton className="h-8 w-16" />
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
) : (
|
) : (
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,17 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useAtom } from "jotai";
|
import { useAtom } from "jotai";
|
||||||
import { Bot, Brain, Eye, FileText, Globe, ImageIcon, MessageSquare, Shield } from "lucide-react";
|
import {
|
||||||
|
BookText,
|
||||||
|
Bot,
|
||||||
|
Brain,
|
||||||
|
CircleUser,
|
||||||
|
Earth,
|
||||||
|
Eye,
|
||||||
|
ImageIcon,
|
||||||
|
ListChecks,
|
||||||
|
UserKey,
|
||||||
|
} from "lucide-react";
|
||||||
import dynamic from "next/dynamic";
|
import dynamic from "next/dynamic";
|
||||||
import { useTranslations } from "next-intl";
|
import { useTranslations } from "next-intl";
|
||||||
import type React from "react";
|
import type React from "react";
|
||||||
|
|
@ -59,6 +69,13 @@ const PublicChatSnapshotsManager = dynamic(
|
||||||
})),
|
})),
|
||||||
{ ssr: false }
|
{ ssr: false }
|
||||||
);
|
);
|
||||||
|
const TeamMemoryManager = dynamic(
|
||||||
|
() =>
|
||||||
|
import("@/components/settings/team-memory-manager").then((m) => ({
|
||||||
|
default: m.TeamMemoryManager,
|
||||||
|
})),
|
||||||
|
{ ssr: false }
|
||||||
|
);
|
||||||
|
|
||||||
interface SearchSpaceSettingsDialogProps {
|
interface SearchSpaceSettingsDialogProps {
|
||||||
searchSpaceId: number;
|
searchSpaceId: number;
|
||||||
|
|
@ -69,9 +86,9 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
||||||
const [state, setState] = useAtom(searchSpaceSettingsDialogAtom);
|
const [state, setState] = useAtom(searchSpaceSettingsDialogAtom);
|
||||||
|
|
||||||
const navItems = [
|
const navItems = [
|
||||||
{ value: "general", label: t("nav_general"), icon: <FileText className="h-4 w-4" /> },
|
{ value: "general", label: t("nav_general"), icon: <CircleUser className="h-4 w-4" /> },
|
||||||
|
{ value: "roles", label: t("nav_role_assignments"), icon: <ListChecks className="h-4 w-4" /> },
|
||||||
{ value: "models", label: t("nav_agent_configs"), icon: <Bot className="h-4 w-4" /> },
|
{ value: "models", label: t("nav_agent_configs"), icon: <Bot className="h-4 w-4" /> },
|
||||||
{ value: "roles", label: t("nav_role_assignments"), icon: <Brain className="h-4 w-4" /> },
|
|
||||||
{
|
{
|
||||||
value: "image-models",
|
value: "image-models",
|
||||||
label: t("nav_image_models"),
|
label: t("nav_image_models"),
|
||||||
|
|
@ -82,13 +99,18 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
||||||
label: t("nav_vision_models"),
|
label: t("nav_vision_models"),
|
||||||
icon: <Eye className="h-4 w-4" />,
|
icon: <Eye className="h-4 w-4" />,
|
||||||
},
|
},
|
||||||
{ value: "team-roles", label: t("nav_team_roles"), icon: <Shield className="h-4 w-4" /> },
|
{ value: "team-roles", label: t("nav_team_roles"), icon: <UserKey className="h-4 w-4" /> },
|
||||||
{
|
{
|
||||||
value: "prompts",
|
value: "prompts",
|
||||||
label: t("nav_system_instructions"),
|
label: t("nav_system_instructions"),
|
||||||
icon: <MessageSquare className="h-4 w-4" />,
|
icon: <BookText className="h-4 w-4" />,
|
||||||
},
|
},
|
||||||
{ value: "public-links", label: t("nav_public_links"), icon: <Globe className="h-4 w-4" /> },
|
{
|
||||||
|
value: "team-memory",
|
||||||
|
label: "Team Memory",
|
||||||
|
icon: <Brain className="h-4 w-4" />,
|
||||||
|
},
|
||||||
|
{ value: "public-links", label: t("nav_public_links"), icon: <Earth className="h-4 w-4" /> },
|
||||||
];
|
];
|
||||||
|
|
||||||
const content: Record<string, React.ReactNode> = {
|
const content: Record<string, React.ReactNode> = {
|
||||||
|
|
@ -99,6 +121,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
||||||
"vision-models": <VisionModelManager searchSpaceId={searchSpaceId} />,
|
"vision-models": <VisionModelManager searchSpaceId={searchSpaceId} />,
|
||||||
"team-roles": <RolesManager searchSpaceId={searchSpaceId} />,
|
"team-roles": <RolesManager searchSpaceId={searchSpaceId} />,
|
||||||
prompts: <PromptConfigManager searchSpaceId={searchSpaceId} />,
|
prompts: <PromptConfigManager searchSpaceId={searchSpaceId} />,
|
||||||
|
"team-memory": <TeamMemoryManager searchSpaceId={searchSpaceId} />,
|
||||||
"public-links": <PublicChatSnapshotsManager searchSpaceId={searchSpaceId} />,
|
"public-links": <PublicChatSnapshotsManager searchSpaceId={searchSpaceId} />,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
297
surfsense_web/components/settings/team-memory-manager.tsx
Normal file
297
surfsense_web/components/settings/team-memory-manager.tsx
Normal file
|
|
@ -0,0 +1,297 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useAtomValue } from "jotai";
|
||||||
|
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react";
|
||||||
|
import { useEffect, useRef, useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import { z } from "zod";
|
||||||
|
import { updateSearchSpaceMutationAtom } from "@/atoms/search-spaces/search-space-mutation.atoms";
|
||||||
|
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||||
|
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
DropdownMenu,
|
||||||
|
DropdownMenuContent,
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuTrigger,
|
||||||
|
} from "@/components/ui/dropdown-menu";
|
||||||
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
|
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||||
|
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||||
|
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||||
|
|
||||||
|
const MEMORY_HARD_LIMIT = 25_000;
|
||||||
|
|
||||||
|
const SearchSpaceSchema = z
|
||||||
|
.object({
|
||||||
|
shared_memory_md: z.string().optional().default(""),
|
||||||
|
})
|
||||||
|
.passthrough();
|
||||||
|
|
||||||
|
interface TeamMemoryManagerProps {
|
||||||
|
searchSpaceId: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const { data: searchSpace, isLoading: loading } = useQuery({
|
||||||
|
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
|
||||||
|
queryFn: () => searchSpacesApiService.getSearchSpace({ id: searchSpaceId }),
|
||||||
|
enabled: !!searchSpaceId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { mutateAsync: updateSearchSpace } = useAtomValue(updateSearchSpaceMutationAtom);
|
||||||
|
|
||||||
|
const [saving, setSaving] = useState(false);
|
||||||
|
const [editQuery, setEditQuery] = useState("");
|
||||||
|
const [editing, setEditing] = useState(false);
|
||||||
|
const [showInput, setShowInput] = useState(false);
|
||||||
|
const textareaRef = useRef<HTMLInputElement>(null);
|
||||||
|
const inputContainerRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
const memory = searchSpace?.shared_memory_md || "";
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!showInput) return;
|
||||||
|
|
||||||
|
const handlePointerDownOutside = (event: MouseEvent | TouchEvent) => {
|
||||||
|
const target = event.target;
|
||||||
|
if (!(target instanceof Node)) return;
|
||||||
|
if (inputContainerRef.current?.contains(target)) return;
|
||||||
|
|
||||||
|
setShowInput(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
document.addEventListener("mousedown", handlePointerDownOutside);
|
||||||
|
document.addEventListener("touchstart", handlePointerDownOutside, { passive: true });
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
document.removeEventListener("mousedown", handlePointerDownOutside);
|
||||||
|
document.removeEventListener("touchstart", handlePointerDownOutside);
|
||||||
|
};
|
||||||
|
}, [showInput]);
|
||||||
|
|
||||||
|
const handleClear = async () => {
|
||||||
|
try {
|
||||||
|
setSaving(true);
|
||||||
|
await updateSearchSpace({
|
||||||
|
id: searchSpaceId,
|
||||||
|
data: { shared_memory_md: "" },
|
||||||
|
});
|
||||||
|
toast.success("Team memory cleared");
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to clear team memory");
|
||||||
|
} finally {
|
||||||
|
setSaving(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleEdit = async () => {
|
||||||
|
const query = editQuery.trim();
|
||||||
|
if (!query) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
setEditing(true);
|
||||||
|
await baseApiService.post(
|
||||||
|
`/api/v1/searchspaces/${searchSpaceId}/memory/edit`,
|
||||||
|
SearchSpaceSchema,
|
||||||
|
{ body: { query } }
|
||||||
|
);
|
||||||
|
setEditQuery("");
|
||||||
|
setShowInput(false);
|
||||||
|
await queryClient.invalidateQueries({
|
||||||
|
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
|
||||||
|
});
|
||||||
|
toast.success("Team memory updated");
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to edit team memory");
|
||||||
|
} finally {
|
||||||
|
setEditing(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const openInput = () => {
|
||||||
|
setShowInput(true);
|
||||||
|
requestAnimationFrame(() => textareaRef.current?.focus());
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDownload = () => {
|
||||||
|
if (!memory) return;
|
||||||
|
try {
|
||||||
|
const blob = new Blob([memory], { type: "text/markdown;charset=utf-8" });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement("a");
|
||||||
|
a.href = url;
|
||||||
|
a.download = "team-memory.md";
|
||||||
|
document.body.appendChild(a);
|
||||||
|
a.click();
|
||||||
|
document.body.removeChild(a);
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to download team memory");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCopyMarkdown = async () => {
|
||||||
|
if (!memory) return;
|
||||||
|
try {
|
||||||
|
await navigator.clipboard.writeText(memory);
|
||||||
|
toast.success("Copied to clipboard");
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to copy team memory");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||||
|
if (e.key === "Enter" && !e.shiftKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
handleEdit();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const displayMemory = memory.replace(/\(\d{4}-\d{2}-\d{2}\)\s*\[(fact|pref|instr)\]\s*/g, "");
|
||||||
|
const charCount = memory.length;
|
||||||
|
|
||||||
|
const getCounterColor = () => {
|
||||||
|
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
|
||||||
|
if (charCount > 15_000) return "text-orange-500";
|
||||||
|
if (charCount > 10_000) return "text-yellow-500";
|
||||||
|
return "text-muted-foreground";
|
||||||
|
};
|
||||||
|
|
||||||
|
if (loading) {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-center py-12">
|
||||||
|
<Spinner size="md" className="text-muted-foreground" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!memory) {
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col items-center justify-center py-16 text-center">
|
||||||
|
<h3 className="text-base font-medium text-foreground">
|
||||||
|
What does SurfSense remember about your team?
|
||||||
|
</h3>
|
||||||
|
<p className="mt-2 max-w-sm text-sm text-muted-foreground">
|
||||||
|
Nothing yet. SurfSense picks up on team decisions and conventions as your team chats.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<Alert className="bg-muted/50 py-3 md:py-4">
|
||||||
|
<Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
|
||||||
|
<AlertDescription className="text-xs md:text-sm">
|
||||||
|
<p>
|
||||||
|
SurfSense uses this shared memory to provide team-wide context across all conversations
|
||||||
|
in this search space.
|
||||||
|
</p>
|
||||||
|
</AlertDescription>
|
||||||
|
</Alert>
|
||||||
|
|
||||||
|
<div className="relative h-[380px] rounded-lg border bg-background">
|
||||||
|
<div className="h-full overflow-y-auto scrollbar-thin">
|
||||||
|
<PlateEditor
|
||||||
|
markdown={displayMemory}
|
||||||
|
readOnly
|
||||||
|
preset="readonly"
|
||||||
|
variant="default"
|
||||||
|
editorVariant="none"
|
||||||
|
className="px-5 py-4 text-sm min-h-full"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{showInput ? (
|
||||||
|
<div className="absolute bottom-3 inset-x-3 z-10">
|
||||||
|
<div
|
||||||
|
ref={inputContainerRef}
|
||||||
|
className="relative flex h-[54px] items-center gap-2 rounded-[9999px] border bg-muted/60 backdrop-blur-sm pl-4 pr-1 shadow-sm"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
ref={textareaRef}
|
||||||
|
type="text"
|
||||||
|
value={editQuery}
|
||||||
|
onChange={(e) => setEditQuery(e.target.value)}
|
||||||
|
onKeyDown={handleKeyDown}
|
||||||
|
placeholder="Tell SurfSense what to remember or forget about your team"
|
||||||
|
disabled={editing}
|
||||||
|
className="flex-1 bg-transparent text-sm outline-none placeholder:text-muted-foreground/70"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
size="icon"
|
||||||
|
variant="ghost"
|
||||||
|
onClick={handleEdit}
|
||||||
|
disabled={editing || !editQuery.trim()}
|
||||||
|
className={`h-11 w-11 shrink-0 rounded-full ${
|
||||||
|
editing ? "" : "bg-muted-foreground/15 hover:bg-muted-foreground/20"
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{editing ? (
|
||||||
|
<Spinner size="sm" />
|
||||||
|
) : (
|
||||||
|
<ArrowUp className="!h-5 !w-5 text-foreground" strokeWidth={2.25} />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
size="icon"
|
||||||
|
variant="secondary"
|
||||||
|
onClick={openInput}
|
||||||
|
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
|
||||||
|
>
|
||||||
|
<Pen className="!h-5 !w-5" />
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex items-center justify-between gap-2">
|
||||||
|
<span className={`text-xs shrink-0 ${getCounterColor()}`}>
|
||||||
|
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()}
|
||||||
|
<span className="hidden sm:inline"> characters</span>
|
||||||
|
<span className="sm:hidden"> chars</span>
|
||||||
|
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
|
||||||
|
{charCount > MEMORY_HARD_LIMIT && " - Exceeds limit"}
|
||||||
|
</span>
|
||||||
|
<div className="flex items-center gap-1.5 sm:gap-2">
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="destructive"
|
||||||
|
size="sm"
|
||||||
|
className="text-xs sm:text-sm"
|
||||||
|
onClick={handleClear}
|
||||||
|
disabled={saving || editing || !memory}
|
||||||
|
>
|
||||||
|
<span className="hidden sm:inline">Reset Memory</span>
|
||||||
|
<span className="sm:hidden">Reset</span>
|
||||||
|
</Button>
|
||||||
|
<DropdownMenu>
|
||||||
|
<DropdownMenuTrigger asChild>
|
||||||
|
<Button type="button" variant="secondary" size="sm" disabled={!memory}>
|
||||||
|
Export
|
||||||
|
<ChevronDown className="h-3 w-3 opacity-60" />
|
||||||
|
</Button>
|
||||||
|
</DropdownMenuTrigger>
|
||||||
|
<DropdownMenuContent align="end">
|
||||||
|
<DropdownMenuItem onClick={handleCopyMarkdown}>
|
||||||
|
<ClipboardCopy className="h-4 w-4 mr-2" />
|
||||||
|
Copy as Markdown
|
||||||
|
</DropdownMenuItem>
|
||||||
|
<DropdownMenuItem onClick={handleDownload}>
|
||||||
|
<Download className="h-4 w-4 mr-2" />
|
||||||
|
Download as Markdown
|
||||||
|
</DropdownMenuItem>
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useAtom } from "jotai";
|
import { useAtom } from "jotai";
|
||||||
import { Globe, KeyRound, Monitor, Receipt, Sparkles, User } from "lucide-react";
|
import { Brain, CircleUser, Globe, KeyRound, Monitor, ReceiptText, Sparkles } from "lucide-react";
|
||||||
import dynamic from "next/dynamic";
|
import dynamic from "next/dynamic";
|
||||||
import { useTranslations } from "next-intl";
|
import { useTranslations } from "next-intl";
|
||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
|
|
@ -51,6 +51,13 @@ const DesktopContent = dynamic(
|
||||||
),
|
),
|
||||||
{ ssr: false }
|
{ ssr: false }
|
||||||
);
|
);
|
||||||
|
const MemoryContent = dynamic(
|
||||||
|
() =>
|
||||||
|
import("@/app/dashboard/[search_space_id]/user-settings/components/MemoryContent").then(
|
||||||
|
(m) => ({ default: m.MemoryContent })
|
||||||
|
),
|
||||||
|
{ ssr: false }
|
||||||
|
);
|
||||||
|
|
||||||
export function UserSettingsDialog() {
|
export function UserSettingsDialog() {
|
||||||
const t = useTranslations("userSettings");
|
const t = useTranslations("userSettings");
|
||||||
|
|
@ -59,7 +66,7 @@ export function UserSettingsDialog() {
|
||||||
|
|
||||||
const navItems = useMemo(
|
const navItems = useMemo(
|
||||||
() => [
|
() => [
|
||||||
{ value: "profile", label: t("profile_nav_label"), icon: <User className="h-4 w-4" /> },
|
{ value: "profile", label: t("profile_nav_label"), icon: <CircleUser className="h-4 w-4" /> },
|
||||||
{
|
{
|
||||||
value: "api-key",
|
value: "api-key",
|
||||||
label: t("api_key_nav_label"),
|
label: t("api_key_nav_label"),
|
||||||
|
|
@ -75,10 +82,15 @@ export function UserSettingsDialog() {
|
||||||
label: "Community Prompts",
|
label: "Community Prompts",
|
||||||
icon: <Globe className="h-4 w-4" />,
|
icon: <Globe className="h-4 w-4" />,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
value: "memory",
|
||||||
|
label: "Memory",
|
||||||
|
icon: <Brain className="h-4 w-4" />,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
value: "purchases",
|
value: "purchases",
|
||||||
label: "Purchase History",
|
label: "Purchase History",
|
||||||
icon: <Receipt className="h-4 w-4" />,
|
icon: <ReceiptText className="h-4 w-4" />,
|
||||||
},
|
},
|
||||||
...(isDesktop
|
...(isDesktop
|
||||||
? [{ value: "desktop", label: "Desktop", icon: <Monitor className="h-4 w-4" /> }]
|
? [{ value: "desktop", label: "Desktop", icon: <Monitor className="h-4 w-4" /> }]
|
||||||
|
|
@ -101,6 +113,7 @@ export function UserSettingsDialog() {
|
||||||
{state.initialTab === "api-key" && <ApiKeyContent />}
|
{state.initialTab === "api-key" && <ApiKeyContent />}
|
||||||
{state.initialTab === "prompts" && <PromptsContent />}
|
{state.initialTab === "prompts" && <PromptsContent />}
|
||||||
{state.initialTab === "community-prompts" && <CommunityPromptsContent />}
|
{state.initialTab === "community-prompts" && <CommunityPromptsContent />}
|
||||||
|
{state.initialTab === "memory" && <MemoryContent />}
|
||||||
{state.initialTab === "purchases" && <PurchaseHistoryContent />}
|
{state.initialTab === "purchases" && <PurchaseHistoryContent />}
|
||||||
{state.initialTab === "desktop" && <DesktopContent />}
|
{state.initialTab === "desktop" && <DesktopContent />}
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
||||||
? "model"
|
? "model"
|
||||||
: "models"}
|
: "models"}
|
||||||
</span>{" "}
|
</span>{" "}
|
||||||
available from your administrator. Use the model selector to view and select them.
|
available from your administrator.
|
||||||
</p>
|
</p>
|
||||||
</AlertDescription>
|
</AlertDescription>
|
||||||
</Alert>
|
</Alert>
|
||||||
|
|
|
||||||
|
|
@ -51,17 +51,11 @@ export {
|
||||||
SandboxExecuteToolUI,
|
SandboxExecuteToolUI,
|
||||||
} from "./sandbox-execute";
|
} from "./sandbox-execute";
|
||||||
export {
|
export {
|
||||||
type MemoryItem,
|
type UpdateMemoryArgs,
|
||||||
type RecallMemoryArgs,
|
UpdateMemoryArgsSchema,
|
||||||
RecallMemoryArgsSchema,
|
type UpdateMemoryResult,
|
||||||
type RecallMemoryResult,
|
UpdateMemoryResultSchema,
|
||||||
RecallMemoryResultSchema,
|
UpdateMemoryToolUI,
|
||||||
RecallMemoryToolUI,
|
|
||||||
type SaveMemoryArgs,
|
|
||||||
SaveMemoryArgsSchema,
|
|
||||||
type SaveMemoryResult,
|
|
||||||
SaveMemoryResultSchema,
|
|
||||||
SaveMemoryToolUI,
|
|
||||||
} from "./user-memory";
|
} from "./user-memory";
|
||||||
export { GenerateVideoPresentationToolUI } from "./video-presentation";
|
export { GenerateVideoPresentationToolUI } from "./video-presentation";
|
||||||
export { type WriteTodosData, WriteTodosSchema, WriteTodosToolUI } from "./write-todos";
|
export { type WriteTodosData, WriteTodosSchema, WriteTodosToolUI } from "./write-todos";
|
||||||
|
|
|
||||||
|
|
@ -1,100 +1,38 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
|
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
|
||||||
import { BrainIcon, CheckIcon, Loader2Icon, SearchIcon, XIcon } from "lucide-react";
|
import { AlertTriangleIcon, BrainIcon, CheckIcon, Loader2Icon, XIcon } from "lucide-react";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Zod Schemas for save_memory tool
|
// Zod Schemas for update_memory tool
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
const SaveMemoryArgsSchema = z.object({
|
const UpdateMemoryArgsSchema = z.object({
|
||||||
content: z.string(),
|
updated_memory: z.string(),
|
||||||
category: z.string().default("fact"),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const SaveMemoryResultSchema = z.object({
|
const UpdateMemoryResultSchema = z.object({
|
||||||
status: z.enum(["saved", "error"]),
|
status: z.enum(["saved", "error"]),
|
||||||
memory_id: z.number().nullish(),
|
|
||||||
memory_text: z.string().nullish(),
|
|
||||||
category: z.string().nullish(),
|
|
||||||
message: z.string().nullish(),
|
message: z.string().nullish(),
|
||||||
error: z.string().nullish(),
|
warning: z.string().nullish(),
|
||||||
});
|
});
|
||||||
|
|
||||||
type SaveMemoryArgs = z.infer<typeof SaveMemoryArgsSchema>;
|
type UpdateMemoryArgs = z.infer<typeof UpdateMemoryArgsSchema>;
|
||||||
type SaveMemoryResult = z.infer<typeof SaveMemoryResultSchema>;
|
type UpdateMemoryResult = z.infer<typeof UpdateMemoryResultSchema>;
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Zod Schemas for recall_memory tool
|
// Update Memory Tool UI
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
const RecallMemoryArgsSchema = z.object({
|
export const UpdateMemoryToolUI = ({
|
||||||
query: z.string().nullish(),
|
|
||||||
category: z.string().nullish(),
|
|
||||||
top_k: z.number().default(5),
|
|
||||||
});
|
|
||||||
|
|
||||||
const MemoryItemSchema = z.object({
|
|
||||||
id: z.number(),
|
|
||||||
memory_text: z.string(),
|
|
||||||
category: z.string(),
|
|
||||||
updated_at: z.string().nullish(),
|
|
||||||
});
|
|
||||||
|
|
||||||
const RecallMemoryResultSchema = z.object({
|
|
||||||
status: z.enum(["success", "error"]),
|
|
||||||
count: z.number().nullish(),
|
|
||||||
memories: z.array(MemoryItemSchema).nullish(),
|
|
||||||
formatted_context: z.string().nullish(),
|
|
||||||
error: z.string().nullish(),
|
|
||||||
});
|
|
||||||
|
|
||||||
type RecallMemoryArgs = z.infer<typeof RecallMemoryArgsSchema>;
|
|
||||||
type RecallMemoryResult = z.infer<typeof RecallMemoryResultSchema>;
|
|
||||||
type MemoryItem = z.infer<typeof MemoryItemSchema>;
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// Category badge colors
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
const categoryColors: Record<string, string> = {
|
|
||||||
preference: "bg-blue-500/10 text-blue-600 dark:text-blue-400",
|
|
||||||
fact: "bg-green-500/10 text-green-600 dark:text-green-400",
|
|
||||||
instruction: "bg-purple-500/10 text-purple-600 dark:text-purple-400",
|
|
||||||
context: "bg-orange-500/10 text-orange-600 dark:text-orange-400",
|
|
||||||
};
|
|
||||||
|
|
||||||
function CategoryBadge({ category }: { category: string }) {
|
|
||||||
const colorClass = categoryColors[category] || "bg-gray-500/10 text-gray-600 dark:text-gray-400";
|
|
||||||
return (
|
|
||||||
<span
|
|
||||||
className={`inline-flex items-center rounded-full px-2 py-0.5 text-xs font-medium ${colorClass}`}
|
|
||||||
>
|
|
||||||
{category}
|
|
||||||
</span>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// Save Memory Tool UI
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
export const SaveMemoryToolUI = ({
|
|
||||||
args,
|
|
||||||
result,
|
result,
|
||||||
status,
|
status,
|
||||||
}: ToolCallMessagePartProps<SaveMemoryArgs, SaveMemoryResult>) => {
|
}: ToolCallMessagePartProps<UpdateMemoryArgs, UpdateMemoryResult>) => {
|
||||||
const isRunning = status.type === "running" || status.type === "requires-action";
|
const isRunning = status.type === "running" || status.type === "requires-action";
|
||||||
const isComplete = status.type === "complete";
|
const isComplete = status.type === "complete";
|
||||||
const isError = result?.status === "error";
|
const isError = result?.status === "error";
|
||||||
|
|
||||||
// Parse args safely
|
|
||||||
const parsedArgs = SaveMemoryArgsSchema.safeParse(args);
|
|
||||||
const content = parsedArgs.success ? parsedArgs.data.content : "";
|
|
||||||
const category = parsedArgs.success ? parsedArgs.data.category : "fact";
|
|
||||||
|
|
||||||
// Loading state
|
|
||||||
if (isRunning) {
|
if (isRunning) {
|
||||||
return (
|
return (
|
||||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||||
|
|
@ -102,13 +40,12 @@ export const SaveMemoryToolUI = ({
|
||||||
<Loader2Icon className="size-4 animate-spin text-primary" />
|
<Loader2Icon className="size-4 animate-spin text-primary" />
|
||||||
</div>
|
</div>
|
||||||
<div className="flex-1">
|
<div className="flex-1">
|
||||||
<span className="text-sm text-muted-foreground">Saving to memory...</span>
|
<span className="text-sm text-muted-foreground">Updating memory...</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error state
|
|
||||||
if (isError) {
|
if (isError) {
|
||||||
return (
|
return (
|
||||||
<div className="my-3 flex items-center gap-3 rounded-lg border border-destructive/20 bg-destructive/5 px-4 py-3">
|
<div className="my-3 flex items-center gap-3 rounded-lg border border-destructive/20 bg-destructive/5 px-4 py-3">
|
||||||
|
|
@ -116,14 +53,13 @@ export const SaveMemoryToolUI = ({
|
||||||
<XIcon className="size-4 text-destructive" />
|
<XIcon className="size-4 text-destructive" />
|
||||||
</div>
|
</div>
|
||||||
<div className="flex-1">
|
<div className="flex-1">
|
||||||
<span className="text-sm text-destructive">Failed to save memory</span>
|
<span className="text-sm text-destructive">Failed to update memory</span>
|
||||||
{result?.error && <p className="mt-1 text-xs text-destructive/70">{result.error}</p>}
|
{result?.message && <p className="mt-1 text-xs text-destructive/70">{result.message}</p>}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Success state
|
|
||||||
if (isComplete && result?.status === "saved") {
|
if (isComplete && result?.status === "saved") {
|
||||||
return (
|
return (
|
||||||
<div className="my-3 flex items-center gap-3 rounded-lg border border-primary/20 bg-primary/5 px-4 py-3">
|
<div className="my-3 flex items-center gap-3 rounded-lg border border-primary/20 bg-primary/5 px-4 py-3">
|
||||||
|
|
@ -133,138 +69,19 @@ export const SaveMemoryToolUI = ({
|
||||||
<div className="flex-1 min-w-0">
|
<div className="flex-1 min-w-0">
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<CheckIcon className="size-3 text-green-500 shrink-0" />
|
<CheckIcon className="size-3 text-green-500 shrink-0" />
|
||||||
<span className="text-sm font-medium text-foreground">Memory saved</span>
|
<span className="text-sm font-medium text-foreground">Memory updated</span>
|
||||||
<CategoryBadge category={category} />
|
|
||||||
</div>
|
</div>
|
||||||
<p className="mt-1 truncate text-sm text-muted-foreground">{content}</p>
|
{result.warning && (
|
||||||
</div>
|
<div className="mt-1.5 flex items-start gap-1.5">
|
||||||
</div>
|
<AlertTriangleIcon className="size-3 text-yellow-500 shrink-0 mt-0.5" />
|
||||||
);
|
<p className="text-xs text-yellow-600 dark:text-yellow-400">{result.warning}</p>
|
||||||
}
|
|
||||||
|
|
||||||
// Default/incomplete state - show what's being saved
|
|
||||||
if (content) {
|
|
||||||
return (
|
|
||||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
|
||||||
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
|
||||||
<BrainIcon className="size-4 text-muted-foreground" />
|
|
||||||
</div>
|
|
||||||
<div className="flex-1 min-w-0">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<span className="text-sm text-muted-foreground">Saving memory</span>
|
|
||||||
<CategoryBadge category={category} />
|
|
||||||
</div>
|
|
||||||
<p className="mt-1 truncate text-sm text-muted-foreground">{content}</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// Recall Memory Tool UI
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
export const RecallMemoryToolUI = ({
|
|
||||||
args,
|
|
||||||
result,
|
|
||||||
status,
|
|
||||||
}: ToolCallMessagePartProps<RecallMemoryArgs, RecallMemoryResult>) => {
|
|
||||||
const isRunning = status.type === "running" || status.type === "requires-action";
|
|
||||||
const isComplete = status.type === "complete";
|
|
||||||
const isError = result?.status === "error";
|
|
||||||
|
|
||||||
// Parse args safely
|
|
||||||
const parsedArgs = RecallMemoryArgsSchema.safeParse(args);
|
|
||||||
const query = parsedArgs.success ? parsedArgs.data.query : null;
|
|
||||||
|
|
||||||
// Loading state
|
|
||||||
if (isRunning) {
|
|
||||||
return (
|
|
||||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
|
||||||
<div className="flex size-8 items-center justify-center rounded-full bg-primary/10">
|
|
||||||
<Loader2Icon className="size-4 animate-spin text-primary" />
|
|
||||||
</div>
|
|
||||||
<div className="flex-1">
|
|
||||||
<span className="text-sm text-muted-foreground">
|
|
||||||
{query ? `Searching memories for "${query}"...` : "Recalling memories..."}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error state
|
|
||||||
if (isError) {
|
|
||||||
return (
|
|
||||||
<div className="my-3 flex items-center gap-3 rounded-lg border border-destructive/20 bg-destructive/5 px-4 py-3">
|
|
||||||
<div className="flex size-8 items-center justify-center rounded-full bg-destructive/10">
|
|
||||||
<XIcon className="size-4 text-destructive" />
|
|
||||||
</div>
|
|
||||||
<div className="flex-1">
|
|
||||||
<span className="text-sm text-destructive">Failed to recall memories</span>
|
|
||||||
{result?.error && <p className="mt-1 text-xs text-destructive/70">{result.error}</p>}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Success state with memories
|
|
||||||
if (isComplete && result?.status === "success") {
|
|
||||||
const memories = result.memories || [];
|
|
||||||
const count = result.count || 0;
|
|
||||||
|
|
||||||
if (count === 0) {
|
|
||||||
return (
|
|
||||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
|
||||||
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
|
||||||
<SearchIcon className="size-4 text-muted-foreground" />
|
|
||||||
</div>
|
|
||||||
<span className="text-sm text-muted-foreground">No memories found</span>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="my-3 rounded-lg border bg-card/60 px-4 py-3">
|
|
||||||
<div className="flex items-center gap-2 mb-2">
|
|
||||||
<BrainIcon className="size-4 text-primary" />
|
|
||||||
<span className="text-sm font-medium text-foreground">
|
|
||||||
Recalled {count} {count === 1 ? "memory" : "memories"}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<div className="space-y-2">
|
|
||||||
{memories.slice(0, 5).map((memory: MemoryItem) => (
|
|
||||||
<div
|
|
||||||
key={memory.id}
|
|
||||||
className="flex items-start gap-2 rounded-md bg-muted/50 px-3 py-2"
|
|
||||||
>
|
|
||||||
<CategoryBadge category={memory.category} />
|
|
||||||
<span className="text-sm text-muted-foreground flex-1">{memory.memory_text}</span>
|
|
||||||
</div>
|
</div>
|
||||||
))}
|
|
||||||
{memories.length > 5 && (
|
|
||||||
<p className="text-xs text-muted-foreground">...and {memories.length - 5} more</p>
|
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default/incomplete state
|
|
||||||
if (query) {
|
|
||||||
return (
|
|
||||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
|
||||||
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
|
||||||
<SearchIcon className="size-4 text-muted-foreground" />
|
|
||||||
</div>
|
|
||||||
<span className="text-sm text-muted-foreground">Searching memories for "{query}"</span>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -273,13 +90,8 @@ export const RecallMemoryToolUI = ({
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
export {
|
export {
|
||||||
SaveMemoryArgsSchema,
|
UpdateMemoryArgsSchema,
|
||||||
SaveMemoryResultSchema,
|
UpdateMemoryResultSchema,
|
||||||
RecallMemoryArgsSchema,
|
type UpdateMemoryArgs,
|
||||||
RecallMemoryResultSchema,
|
type UpdateMemoryResult,
|
||||||
type SaveMemoryArgs,
|
|
||||||
type SaveMemoryResult,
|
|
||||||
type RecallMemoryArgs,
|
|
||||||
type RecallMemoryResult,
|
|
||||||
type MemoryItem,
|
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ function AlertDialogContent({
|
||||||
<AlertDialogPrimitive.Content
|
<AlertDialogPrimitive.Content
|
||||||
data-slot="alert-dialog-content"
|
data-slot="alert-dialog-content"
|
||||||
className={cn(
|
className={cn(
|
||||||
"bg-background dark:bg-neutral-900 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-xl p-6 shadow-2xl duration-200 sm:max-w-lg",
|
"bg-background dark:bg-neutral-900 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-xl p-6 shadow-2xl duration-200 sm:max-w-lg select-none",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ const DialogContent = React.forwardRef<
|
||||||
<DialogPrimitive.Content
|
<DialogPrimitive.Content
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn(
|
className={cn(
|
||||||
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 bg-background dark:bg-neutral-900 p-6 shadow-2xl duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 rounded-xl focus:outline-none focus:ring-0 focus-visible:outline-none focus-visible:ring-0",
|
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 bg-background dark:bg-neutral-900 p-6 shadow-2xl duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 rounded-xl focus:outline-none focus:ring-0 focus-visible:outline-none focus-visible:ring-0 select-none",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
|
|
|
||||||
|
|
@ -54,8 +54,8 @@ const editorVariants = cva(
|
||||||
cn(
|
cn(
|
||||||
"group/editor",
|
"group/editor",
|
||||||
"relative w-full cursor-text select-text overflow-x-hidden whitespace-pre-wrap break-words",
|
"relative w-full cursor-text select-text overflow-x-hidden whitespace-pre-wrap break-words",
|
||||||
"rounded-md ring-offset-background focus-visible:outline-none",
|
"rounded-none ring-offset-background focus-visible:outline-none",
|
||||||
"**:data-slate-placeholder:!top-1/2 **:data-slate-placeholder:-translate-y-1/2 placeholder:text-muted-foreground/80 **:data-slate-placeholder:text-muted-foreground/80 **:data-slate-placeholder:opacity-100!",
|
"placeholder:text-muted-foreground/80 **:data-slate-placeholder:text-muted-foreground/80 **:data-slate-placeholder:py-1",
|
||||||
"[&_strong]:font-bold"
|
"[&_strong]:font-bold"
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -65,8 +65,9 @@ export function FloatingToolbar({
|
||||||
{...rootProps}
|
{...rootProps}
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn(
|
className={cn(
|
||||||
"scrollbar-hide absolute z-50 overflow-x-auto whitespace-nowrap rounded-md border bg-popover p-1 opacity-100 shadow-md print:hidden",
|
"scrollbar-hide absolute z-50 overflow-x-auto whitespace-nowrap rounded-md border dark:border-neutral-700 bg-muted p-1 opacity-100 shadow-md print:hidden",
|
||||||
"max-w-[80vw]",
|
"max-w-[80vw]",
|
||||||
|
"[&_button:hover]:bg-neutral-200 dark:[&_button:hover]:bg-neutral-700",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import type { PlateElementProps } from "platejs/react";
|
||||||
import { PlateElement } from "platejs/react";
|
import { PlateElement } from "platejs/react";
|
||||||
import * as React from "react";
|
import * as React from "react";
|
||||||
|
|
||||||
const headingVariants = cva("relative mb-1", {
|
const headingVariants = cva("relative mb-1 first:mt-0", {
|
||||||
variants: {
|
variants: {
|
||||||
variant: {
|
variant: {
|
||||||
h1: "mt-[1.6em] pb-1 font-bold font-heading text-4xl",
|
h1: "mt-[1.6em] pb-1 font-bold font-heading text-4xl",
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,7 @@ const TOOL_ICONS: Record<string, LucideIcon> = {
|
||||||
scrape_webpage: ScanLine,
|
scrape_webpage: ScanLine,
|
||||||
web_search: Globe,
|
web_search: Globe,
|
||||||
search_surfsense_docs: BookOpen,
|
search_surfsense_docs: BookOpen,
|
||||||
save_memory: Brain,
|
update_memory: Brain,
|
||||||
recall_memory: Brain,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export function getToolIcon(name: string): LucideIcon {
|
export function getToolIcon(name: string): LucideIcon {
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ export const searchSpace = z.object({
|
||||||
user_id: z.string(),
|
user_id: z.string(),
|
||||||
citations_enabled: z.boolean(),
|
citations_enabled: z.boolean(),
|
||||||
qna_custom_instructions: z.string().nullable(),
|
qna_custom_instructions: z.string().nullable(),
|
||||||
|
shared_memory_md: z.string().nullable().optional(),
|
||||||
member_count: z.number(),
|
member_count: z.number(),
|
||||||
is_owner: z.boolean(),
|
is_owner: z.boolean(),
|
||||||
});
|
});
|
||||||
|
|
@ -54,6 +55,7 @@ export const updateSearchSpaceRequest = z.object({
|
||||||
description: true,
|
description: true,
|
||||||
citations_enabled: true,
|
citations_enabled: true,
|
||||||
qna_custom_instructions: true,
|
qna_custom_instructions: true,
|
||||||
|
shared_memory_md: true,
|
||||||
})
|
})
|
||||||
.partial(),
|
.partial(),
|
||||||
});
|
});
|
||||||
|
|
|
||||||
108
surfsense_web/lib/desktop-download-utils.ts
Normal file
108
surfsense_web/lib/desktop-download-utils.ts
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
import { useEffect, useMemo, useState } from "react";
|
||||||
|
|
||||||
|
export type OSInfo = {
|
||||||
|
os: "macOS" | "Windows" | "Linux";
|
||||||
|
arch: "arm64" | "x64";
|
||||||
|
};
|
||||||
|
|
||||||
|
export function useUserOS(): OSInfo {
|
||||||
|
const [info, setInfo] = useState<OSInfo>({ os: "macOS", arch: "arm64" });
|
||||||
|
useEffect(() => {
|
||||||
|
const ua = navigator.userAgent;
|
||||||
|
let os: OSInfo["os"] = "macOS";
|
||||||
|
let arch: OSInfo["arch"] = "x64";
|
||||||
|
|
||||||
|
if (/Windows/i.test(ua)) {
|
||||||
|
os = "Windows";
|
||||||
|
arch = "x64";
|
||||||
|
} else if (/Linux/i.test(ua)) {
|
||||||
|
os = "Linux";
|
||||||
|
arch = "x64";
|
||||||
|
} else {
|
||||||
|
os = "macOS";
|
||||||
|
arch = /Mac/.test(ua) && !/Intel/.test(ua) ? "arm64" : "arm64";
|
||||||
|
}
|
||||||
|
|
||||||
|
const uaData = (navigator as Navigator & { userAgentData?: { architecture?: string } })
|
||||||
|
.userAgentData;
|
||||||
|
if (uaData?.architecture === "arm") arch = "arm64";
|
||||||
|
else if (uaData?.architecture === "x86") arch = "x64";
|
||||||
|
|
||||||
|
setInfo({ os, arch });
|
||||||
|
}, []);
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ReleaseAsset {
|
||||||
|
name: string;
|
||||||
|
url: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useLatestRelease() {
|
||||||
|
const [assets, setAssets] = useState<ReleaseAsset[]>([]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const controller = new AbortController();
|
||||||
|
fetch("https://api.github.com/repos/MODSetter/SurfSense/releases/latest", {
|
||||||
|
signal: controller.signal,
|
||||||
|
})
|
||||||
|
.then((r) => r.json())
|
||||||
|
.then((data) => {
|
||||||
|
if (data?.assets) {
|
||||||
|
setAssets(
|
||||||
|
data.assets
|
||||||
|
.filter((a: { name: string }) => /\.(exe|dmg|AppImage|deb)$/.test(a.name))
|
||||||
|
.map((a: { name: string; browser_download_url: string }) => ({
|
||||||
|
name: a.name,
|
||||||
|
url: a.browser_download_url,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch(() => {});
|
||||||
|
return () => controller.abort();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return assets;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ASSET_LABELS: Record<string, string> = {
|
||||||
|
".exe": "Windows (exe)",
|
||||||
|
"-arm64.dmg": "macOS Apple Silicon (dmg)",
|
||||||
|
"-x64.dmg": "macOS Intel (dmg)",
|
||||||
|
"-arm64.zip": "macOS Apple Silicon (zip)",
|
||||||
|
"-x64.zip": "macOS Intel (zip)",
|
||||||
|
".AppImage": "Linux (AppImage)",
|
||||||
|
".deb": "Linux (deb)",
|
||||||
|
};
|
||||||
|
|
||||||
|
export function getAssetLabel(name: string): string {
|
||||||
|
for (const [suffix, label] of Object.entries(ASSET_LABELS)) {
|
||||||
|
if (name.endsWith(suffix)) return label;
|
||||||
|
}
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const GITHUB_RELEASES_URL = "https://github.com/MODSetter/SurfSense/releases/latest";
|
||||||
|
|
||||||
|
export function usePrimaryDownload() {
|
||||||
|
const { os, arch } = useUserOS();
|
||||||
|
const assets = useLatestRelease();
|
||||||
|
|
||||||
|
const { primary, alternatives } = useMemo(() => {
|
||||||
|
if (assets.length === 0) return { primary: null, alternatives: [] };
|
||||||
|
|
||||||
|
const matchers: Record<string, (n: string) => boolean> = {
|
||||||
|
Windows: (n) => n.endsWith(".exe"),
|
||||||
|
macOS: (n) => n.endsWith(`-${arch}.dmg`),
|
||||||
|
Linux: (n) => n.endsWith(".AppImage"),
|
||||||
|
};
|
||||||
|
|
||||||
|
const match = matchers[os];
|
||||||
|
const primary = assets.find((a) => match(a.name)) ?? null;
|
||||||
|
const alternatives = assets.filter((a) => a !== primary);
|
||||||
|
return { primary, alternatives };
|
||||||
|
}, [assets, os, arch]);
|
||||||
|
|
||||||
|
return { os, arch, assets, primary, alternatives };
|
||||||
|
}
|
||||||
|
|
@ -693,6 +693,7 @@
|
||||||
"learn_more": "Learn more",
|
"learn_more": "Learn more",
|
||||||
"documentation": "Documentation",
|
"documentation": "Documentation",
|
||||||
"github": "GitHub",
|
"github": "GitHub",
|
||||||
|
"download_for_os": "Download for {os}",
|
||||||
"inbox": "Inbox",
|
"inbox": "Inbox",
|
||||||
"search_inbox": "Search inbox",
|
"search_inbox": "Search inbox",
|
||||||
"mark_all_read": "Mark all as read",
|
"mark_all_read": "Mark all as read",
|
||||||
|
|
|
||||||
|
|
@ -693,6 +693,7 @@
|
||||||
"learn_more": "Más información",
|
"learn_more": "Más información",
|
||||||
"documentation": "Documentación",
|
"documentation": "Documentación",
|
||||||
"github": "GitHub",
|
"github": "GitHub",
|
||||||
|
"download_for_os": "Descargar para {os}",
|
||||||
"inbox": "Bandeja de entrada",
|
"inbox": "Bandeja de entrada",
|
||||||
"search_inbox": "Buscar en bandeja de entrada",
|
"search_inbox": "Buscar en bandeja de entrada",
|
||||||
"mark_all_read": "Marcar todo como leído",
|
"mark_all_read": "Marcar todo como leído",
|
||||||
|
|
|
||||||
|
|
@ -693,6 +693,7 @@
|
||||||
"learn_more": "और जानें",
|
"learn_more": "और जानें",
|
||||||
"documentation": "दस्तावेज़ीकरण",
|
"documentation": "दस्तावेज़ीकरण",
|
||||||
"github": "GitHub",
|
"github": "GitHub",
|
||||||
|
"download_for_os": "{os} के लिए डाउनलोड करें",
|
||||||
"inbox": "इनबॉक्स",
|
"inbox": "इनबॉक्स",
|
||||||
"search_inbox": "इनबॉक्स में खोजें",
|
"search_inbox": "इनबॉक्स में खोजें",
|
||||||
"mark_all_read": "सभी पढ़ा हुआ चिह्नित करें",
|
"mark_all_read": "सभी पढ़ा हुआ चिह्नित करें",
|
||||||
|
|
|
||||||
|
|
@ -693,6 +693,7 @@
|
||||||
"learn_more": "Saiba mais",
|
"learn_more": "Saiba mais",
|
||||||
"documentation": "Documentação",
|
"documentation": "Documentação",
|
||||||
"github": "GitHub",
|
"github": "GitHub",
|
||||||
|
"download_for_os": "Baixar para {os}",
|
||||||
"inbox": "Caixa de entrada",
|
"inbox": "Caixa de entrada",
|
||||||
"search_inbox": "Pesquisar caixa de entrada",
|
"search_inbox": "Pesquisar caixa de entrada",
|
||||||
"mark_all_read": "Marcar tudo como lido",
|
"mark_all_read": "Marcar tudo como lido",
|
||||||
|
|
|
||||||
|
|
@ -677,6 +677,7 @@
|
||||||
"learn_more": "了解更多",
|
"learn_more": "了解更多",
|
||||||
"documentation": "文档",
|
"documentation": "文档",
|
||||||
"github": "GitHub",
|
"github": "GitHub",
|
||||||
|
"download_for_os": "下载 {os} 版本",
|
||||||
"inbox": "收件箱",
|
"inbox": "收件箱",
|
||||||
"search_inbox": "搜索收件箱",
|
"search_inbox": "搜索收件箱",
|
||||||
"mark_all_read": "全部标记为已读",
|
"mark_all_read": "全部标记为已读",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue