mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-15 18:25:18 +02:00
Merge pull request #1200 from AnishSarkar22/refactor/persistent-memory
refactor: persistent memory
This commit is contained in:
commit
b96dc49c8a
52 changed files with 2622 additions and 1415 deletions
|
|
@ -0,0 +1,38 @@
|
|||
"""Add memory_md columns to user and searchspaces tables
|
||||
|
||||
Revision ID: 121
|
||||
Revises: 120
|
||||
|
||||
Changes:
|
||||
1. Add memory_md TEXT column to user table (personal memory)
|
||||
2. Add shared_memory_md TEXT column to searchspaces table (team memory)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "121"
|
||||
down_revision: str | None = "120"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("memory_md", sa.Text(), nullable=True, server_default=""),
|
||||
)
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("shared_memory_md", sa.Text(), nullable=True, server_default=""),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "shared_memory_md")
|
||||
op.drop_column("user", "memory_md")
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
"""Migrate row-per-fact memories to markdown, then drop legacy tables
|
||||
|
||||
Revision ID: 122
|
||||
Revises: 121
|
||||
|
||||
Converts user_memories rows into per-user markdown documents stored in
|
||||
user.memory_md, and shared_memories rows into per-search-space markdown
|
||||
stored in searchspaces.shared_memory_md. Then drops the old tables and
|
||||
the memorycategory enum.
|
||||
|
||||
The markdown format matches the new memory system:
|
||||
## Heading
|
||||
- (YYYY-MM-DD) [fact|pref|instr] memory text
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
from alembic import op
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
revision: str = "122"
|
||||
down_revision: str | None = "121"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
EMBEDDING_DIM = config.embedding_model_instance.dimension
|
||||
|
||||
_CATEGORY_TO_MARKER = {
|
||||
"fact": "fact",
|
||||
"context": "fact",
|
||||
"preference": "pref",
|
||||
"instruction": "instr",
|
||||
}
|
||||
|
||||
_CATEGORY_HEADING = {
|
||||
"fact": "Facts",
|
||||
"preference": "Preferences",
|
||||
"instruction": "Instructions",
|
||||
"context": "Context",
|
||||
}
|
||||
|
||||
_HEADING_ORDER = ["fact", "preference", "instruction", "context"]
|
||||
|
||||
|
||||
def _build_markdown(rows: list[tuple]) -> str:
|
||||
"""Build a markdown document from (memory_text, category, created_at) rows."""
|
||||
by_category: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for memory_text, category, created_at in rows:
|
||||
cat = str(category)
|
||||
marker = _CATEGORY_TO_MARKER.get(cat, "fact")
|
||||
date_str = created_at.strftime("%Y-%m-%d")
|
||||
clean_text = str(memory_text).replace("\n", " ").strip()
|
||||
bullet = f"- ({date_str}) [{marker}] {clean_text}"
|
||||
by_category[cat].append(bullet)
|
||||
|
||||
sections: list[str] = []
|
||||
for cat in _HEADING_ORDER:
|
||||
if cat in by_category:
|
||||
heading = _CATEGORY_HEADING[cat]
|
||||
sections.append(f"## {heading}")
|
||||
sections.extend(by_category[cat])
|
||||
sections.append("")
|
||||
|
||||
return "\n".join(sections).strip() + "\n"
|
||||
|
||||
|
||||
def _migrate_user_memories(conn: sa.engine.Connection) -> None:
|
||||
"""Convert user_memories rows → user.memory_md grouped by user_id."""
|
||||
rows = conn.execute(
|
||||
sa.text(
|
||||
"SELECT user_id, memory_text, category::text, created_at "
|
||||
"FROM user_memories ORDER BY created_at"
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
logger.info("user_memories is empty, skipping data migration.")
|
||||
return
|
||||
|
||||
by_user: dict[UUID, list[tuple]] = defaultdict(list)
|
||||
for user_id, memory_text, category, created_at in rows:
|
||||
by_user[user_id].append((memory_text, category, created_at))
|
||||
|
||||
migrated = 0
|
||||
for uid, user_rows in by_user.items():
|
||||
existing = conn.execute(
|
||||
sa.text('SELECT memory_md FROM "user" WHERE id = :uid'),
|
||||
{"uid": uid},
|
||||
).scalar()
|
||||
|
||||
if existing and existing.strip():
|
||||
logger.info("User %s already has memory_md, skipping.", uid)
|
||||
continue
|
||||
|
||||
markdown = _build_markdown(user_rows)
|
||||
conn.execute(
|
||||
sa.text('UPDATE "user" SET memory_md = :md WHERE id = :uid'),
|
||||
{"md": markdown, "uid": uid},
|
||||
)
|
||||
migrated += 1
|
||||
|
||||
logger.info("Migrated user_memories for %d user(s).", migrated)
|
||||
|
||||
|
||||
def _migrate_shared_memories(conn: sa.engine.Connection) -> None:
|
||||
"""Convert shared_memories rows → searchspaces.shared_memory_md."""
|
||||
rows = conn.execute(
|
||||
sa.text(
|
||||
"SELECT search_space_id, memory_text, category::text, created_at "
|
||||
"FROM shared_memories ORDER BY created_at"
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
logger.info("shared_memories is empty, skipping data migration.")
|
||||
return
|
||||
|
||||
by_space: dict[int, list[tuple]] = defaultdict(list)
|
||||
for search_space_id, memory_text, category, created_at in rows:
|
||||
by_space[search_space_id].append((memory_text, category, created_at))
|
||||
|
||||
migrated = 0
|
||||
for space_id, space_rows in by_space.items():
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT shared_memory_md FROM searchspaces WHERE id = :sid"),
|
||||
{"sid": space_id},
|
||||
).scalar()
|
||||
|
||||
if existing and existing.strip():
|
||||
logger.info(
|
||||
"Search space %s already has shared_memory_md, skipping.", space_id
|
||||
)
|
||||
continue
|
||||
|
||||
markdown = _build_markdown(space_rows)
|
||||
conn.execute(
|
||||
sa.text("UPDATE searchspaces SET shared_memory_md = :md WHERE id = :sid"),
|
||||
{"md": markdown, "sid": space_id},
|
||||
)
|
||||
migrated += 1
|
||||
|
||||
logger.info("Migrated shared_memories for %d search space(s).", migrated)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa_inspect(conn)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
if "user_memories" in tables:
|
||||
_migrate_user_memories(conn)
|
||||
|
||||
if "shared_memories" in tables:
|
||||
_migrate_shared_memories(conn)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS shared_memories CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS user_memories CASCADE;")
|
||||
op.execute("DROP TYPE IF EXISTS memorycategory;")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'memorycategory') THEN
|
||||
CREATE TYPE memorycategory AS ENUM (
|
||||
'preference',
|
||||
'fact',
|
||||
'instruction',
|
||||
'context'
|
||||
);
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS user_memories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
search_space_id INTEGER REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
memory_text TEXT NOT NULL,
|
||||
category memorycategory NOT NULL DEFAULT 'fact',
|
||||
embedding vector({EMBEDDING_DIM}),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_user_id ON user_memories(user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_search_space_id ON user_memories(search_space_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_updated_at ON user_memories(updated_at);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_category ON user_memories(category);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_user_search_space ON user_memories(user_id, search_space_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS user_memories_vector_index ON user_memories USING hnsw (embedding public.vector_cosine_ops);"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS shared_memories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
created_by_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
memory_text TEXT NOT NULL,
|
||||
category memorycategory NOT NULL DEFAULT 'fact',
|
||||
embedding vector({EMBEDDING_DIM})
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_shared_memories_search_space_id ON shared_memories(search_space_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_shared_memories_updated_at ON shared_memories(updated_at);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_shared_memories_created_by_id ON shared_memories(created_by_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS shared_memories_vector_index ON shared_memories USING hnsw (embedding public.vector_cosine_ops);"
|
||||
)
|
||||
|
|
@ -38,6 +38,7 @@ from app.agents.new_chat.llm_config import AgentConfig
|
|||
from app.agents.new_chat.middleware import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
MemoryInjectionMiddleware,
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.system_prompt import (
|
||||
|
|
@ -168,8 +169,7 @@ async def create_surfsense_deep_agent(
|
|||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- scrape_webpage: Extract content from webpages
|
||||
- save_memory: Store facts/preferences about the user
|
||||
- recall_memory: Retrieve relevant user memories
|
||||
- update_memory: Update the user's personal or team memory document
|
||||
|
||||
The agent also includes TodoListMiddleware by default (via create_deep_agent) which provides:
|
||||
- write_todos: Create and update planning/todo lists for complex tasks
|
||||
|
|
@ -281,6 +281,7 @@ async def create_surfsense_deep_agent(
|
|||
"available_connectors": available_connectors,
|
||||
"available_document_types": available_document_types,
|
||||
"max_input_tokens": _max_input_tokens,
|
||||
"llm": llm,
|
||||
}
|
||||
|
||||
# Disable Notion action tools if no Notion connector is configured
|
||||
|
|
@ -425,9 +426,16 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
|
||||
# -- Build the middleware stack (mirrors create_deep_agent internals) ------
|
||||
_memory_middleware = MemoryInjectionMiddleware(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
||||
# General-purpose subagent middleware
|
||||
gp_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
SurfSenseFilesystemMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
|
|
@ -447,6 +455,7 @@ async def create_surfsense_deep_agent(
|
|||
# Main agent middleware
|
||||
deepagent_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
KnowledgeBaseSearchMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
|
|
|
|||
239
surfsense_backend/app/agents/new_chat/memory_extraction.py
Normal file
239
surfsense_backend/app/agents/new_chat/memory_extraction.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Background memory extraction for the SurfSense agent.
|
||||
|
||||
After each agent response, if the agent did not call ``update_memory`` during
|
||||
the turn, this module can run a lightweight LLM call to decide whether the
|
||||
latest message contains long-term information worth persisting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.new_chat.tools.update_memory import _save_memory
|
||||
from app.db import SearchSpace, User, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MEMORY_EXTRACT_PROMPT = """\
|
||||
You are a memory extraction assistant. Analyze the user's message and decide \
|
||||
if it contains any long-term information worth persisting to memory.
|
||||
|
||||
Worth remembering: preferences, background/identity, goals, projects, \
|
||||
instructions, tools/languages they use, decisions, expertise, workplace — \
|
||||
durable facts that will matter in future conversations.
|
||||
|
||||
NOT worth remembering: greetings, one-off factual questions, session \
|
||||
logistics, ephemeral requests, follow-up clarifications with no new personal \
|
||||
info, things that only matter for the current task.
|
||||
|
||||
If the message contains memorizable information, output the FULL updated \
|
||||
memory document with the new facts merged into the existing content. Follow \
|
||||
these rules:
|
||||
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
||||
freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's
|
||||
name in headings.
|
||||
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||
[fact] = durable facts, [pref] = preferences, [instr] = standing instructions.
|
||||
- Use the user's first name (from <user_name>) in entry text, not "the user".
|
||||
- If a new fact contradicts an existing entry, update the existing entry.
|
||||
- Do not duplicate information that is already present.
|
||||
|
||||
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||
|
||||
<user_name>{user_name}</user_name>
|
||||
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
<user_message>
|
||||
{user_message}
|
||||
</user_message>"""
|
||||
|
||||
_TEAM_MEMORY_EXTRACT_PROMPT = """\
|
||||
You are a team-memory extraction assistant. Analyze the latest message and \
|
||||
decide if it contains durable TEAM-level information worth persisting.
|
||||
|
||||
Decision policy:
|
||||
- Prioritize recall for durable team context, while avoiding personal-only facts.
|
||||
- Do NOT require explicit consensus language. A direct team-level statement can
|
||||
be stored if it is stable and broadly useful for future team chats.
|
||||
- If evidence is weak or clearly tentative, output NO_UPDATE.
|
||||
|
||||
Worth remembering (team-level only):
|
||||
- Decisions and defaults that guide future team work
|
||||
- Team conventions/standards (naming, review policy, coding norms)
|
||||
- Stable org/project facts (locations, ownership, constraints)
|
||||
- Long-lived architecture/process facts
|
||||
- Ongoing priorities that are likely relevant beyond this turn
|
||||
|
||||
NOT worth remembering:
|
||||
- Personal preferences or biography of one person
|
||||
- Questions, brainstorming, tentative ideas, or speculation
|
||||
- One-off requests, status updates, TODOs, logistics for this session
|
||||
- Information scoped only to a single ephemeral task
|
||||
|
||||
If the message contains memorizable team information, output the FULL updated \
|
||||
team memory document with new facts merged into existing content. Follow rules:
|
||||
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
||||
freely. Keep heading names short (2-3 words) and natural.
|
||||
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text
|
||||
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr].
|
||||
- If a new fact contradicts an existing entry, update the existing entry.
|
||||
- Do not duplicate existing information.
|
||||
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
|
||||
|
||||
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||
|
||||
<current_team_memory>
|
||||
{current_memory}
|
||||
</current_team_memory>
|
||||
|
||||
<latest_message_author>
|
||||
{author}
|
||||
</latest_message_author>
|
||||
|
||||
<latest_message>
|
||||
{user_message}
|
||||
</latest_message>"""
|
||||
|
||||
|
||||
async def extract_and_save_memory(
|
||||
*,
|
||||
user_message: str,
|
||||
user_id: str | None,
|
||||
llm: Any,
|
||||
) -> None:
|
||||
"""Background task: extract memorizable info and persist it.
|
||||
|
||||
Designed to be fire-and-forget — catches all exceptions internally.
|
||||
"""
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
try:
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return
|
||||
|
||||
old_memory = user.memory_md
|
||||
first_name = (
|
||||
user.display_name.strip().split()[0]
|
||||
if user.display_name and user.display_name.strip()
|
||||
else "The user"
|
||||
)
|
||||
prompt = _MEMORY_EXTRACT_PROMPT.format(
|
||||
current_memory=old_memory or "(empty)",
|
||||
user_message=user_message,
|
||||
user_name=first_name,
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "memory-extraction"]},
|
||||
)
|
||||
text = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
).strip()
|
||||
|
||||
if text == "NO_UPDATE" or not text:
|
||||
logger.debug("Memory extraction: no update needed (user %s)", uid)
|
||||
return
|
||||
|
||||
save_result = await _save_memory(
|
||||
updated_memory=text,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
logger.info(
|
||||
"Background memory extraction for user %s: %s",
|
||||
uid,
|
||||
save_result.get("status"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Background user memory extraction failed")
|
||||
|
||||
|
||||
async def extract_and_save_team_memory(
|
||||
*,
|
||||
user_message: str,
|
||||
search_space_id: int | None,
|
||||
llm: Any,
|
||||
author_display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Background task: extract team-level memory and persist it.
|
||||
|
||||
Runs only for shared threads. Designed to be fire-and-forget and catches
|
||||
exceptions internally.
|
||||
"""
|
||||
if not search_space_id:
|
||||
return
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
|
||||
current_memory=old_memory or "(empty)",
|
||||
author=author_display_name or "Unknown team member",
|
||||
user_message=user_message,
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
|
||||
)
|
||||
text = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
).strip()
|
||||
|
||||
if text == "NO_UPDATE" or not text:
|
||||
logger.debug(
|
||||
"Team memory extraction: no update needed (space %s)",
|
||||
search_space_id,
|
||||
)
|
||||
return
|
||||
|
||||
save_result = await _save_memory(
|
||||
updated_memory=text,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
logger.info(
|
||||
"Background team memory extraction for space %s: %s",
|
||||
search_space_id,
|
||||
save_result.get("status"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Background team memory extraction failed")
|
||||
|
|
@ -9,9 +9,13 @@ from app.agents.new_chat.middleware.filesystem import (
|
|||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.memory_injection import (
|
||||
MemoryInjectionMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DedupHITLToolCallsMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"MemoryInjectionMiddleware",
|
||||
"SurfSenseFilesystemMiddleware",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,138 @@
|
|||
"""Memory injection middleware for the SurfSense agent.
|
||||
|
||||
Injects memory markdown into the system prompt on every turn:
|
||||
- Private threads: only personal memory (<user_memory>)
|
||||
- Shared threads: only team memory (<team_memory>)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
|
||||
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Injects memory markdown into the conversation on every turn."""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
user_id: str | UUID | None,
|
||||
search_space_id: int,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> None:
|
||||
self.user_id = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
self.search_space_id = search_space_id
|
||||
self.visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
if not isinstance(last_message, HumanMessage):
|
||||
return None
|
||||
|
||||
memory_blocks: list[str] = []
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||
team_memory = await self._load_team_memory(session)
|
||||
if team_memory:
|
||||
chars = len(team_memory)
|
||||
memory_blocks.append(
|
||||
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{team_memory}\n"
|
||||
f"</team_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Team memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
elif self.user_id is not None:
|
||||
user_memory, display_name = await self._load_user_memory(session)
|
||||
if display_name and display_name.strip():
|
||||
first_name = display_name.strip().split()[0]
|
||||
memory_blocks.append(f"<user_name>{first_name}</user_name>")
|
||||
if user_memory:
|
||||
chars = len(user_memory)
|
||||
memory_blocks.append(
|
||||
f'<user_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{user_memory}\n"
|
||||
f"</user_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Your personal memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
|
||||
if not memory_blocks:
|
||||
return None
|
||||
|
||||
memory_text = "\n\n".join(memory_blocks)
|
||||
memory_msg = SystemMessage(content=memory_text)
|
||||
|
||||
new_messages = list(messages)
|
||||
insert_idx = 1 if len(new_messages) > 1 else 0
|
||||
new_messages.insert(insert_idx, memory_msg)
|
||||
|
||||
return {"messages": new_messages}
|
||||
|
||||
async def _load_user_memory(
|
||||
self, session: AsyncSession
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Return (memory_content, display_name)."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(User.memory_md, User.display_name).where(User.id == self.user_id)
|
||||
)
|
||||
row = result.one_or_none()
|
||||
if row is None:
|
||||
return None, None
|
||||
return row.memory_md or None, row.display_name
|
||||
except Exception:
|
||||
logger.exception("Failed to load user memory")
|
||||
return None, None
|
||||
|
||||
async def _load_team_memory(self, session: AsyncSession) -> str | None:
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace.shared_memory_md).where(
|
||||
SearchSpace.id == self.search_space_id
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return row if row else None
|
||||
except Exception:
|
||||
logger.exception("Failed to load team memory")
|
||||
return None
|
||||
|
|
@ -40,6 +40,13 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
|||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
</knowledge_base_only_policy>
|
||||
|
||||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the user (role, interests, preferences, projects,
|
||||
background, or standing instructions)? If yes, you MUST call update_memory
|
||||
alongside your normal response — do not defer this to a later turn.
|
||||
</memory_protocol>
|
||||
|
||||
</system_instruction>
|
||||
"""
|
||||
|
||||
|
|
@ -71,6 +78,13 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
|||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
</knowledge_base_only_policy>
|
||||
|
||||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
||||
or key facts)? If yes, you MUST call update_memory alongside your normal response —
|
||||
do not defer this to a later turn.
|
||||
</memory_protocol>
|
||||
|
||||
</system_instruction>
|
||||
"""
|
||||
|
||||
|
|
@ -248,115 +262,97 @@ _TOOL_INSTRUCTIONS["web_search"] = """
|
|||
"""
|
||||
|
||||
# Memory tool instructions have private and shared variants.
|
||||
# We store them keyed as "save_memory" / "recall_memory" with sub-keys.
|
||||
# We store them keyed as "update_memory" with sub-keys.
|
||||
_MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
||||
"save_memory": {
|
||||
"update_memory": {
|
||||
"private": """
|
||||
- save_memory: Save facts, preferences, or context for personalized responses.
|
||||
- Use this when the user explicitly or implicitly shares information worth remembering.
|
||||
- Trigger scenarios:
|
||||
* User says "remember this", "keep this in mind", "note that", or similar
|
||||
* User shares personal preferences (e.g., "I prefer Python over JavaScript")
|
||||
* User shares facts about themselves (e.g., "I'm a senior developer at Company X")
|
||||
* User gives standing instructions (e.g., "always respond in bullet points")
|
||||
* User shares project context (e.g., "I'm working on migrating our codebase to TypeScript")
|
||||
- update_memory: Update your personal memory document about the user.
|
||||
- Your current memory is already in <user_memory> in your context. The `chars` and
|
||||
`limit` attributes show your current usage and the maximum allowed size.
|
||||
- This is your curated long-term memory — the distilled essence of what you know about
|
||||
the user, not raw conversation logs.
|
||||
- Call update_memory when:
|
||||
* The user explicitly asks to remember or forget something
|
||||
* The user shares durable facts or preferences that will matter in future conversations
|
||||
- The user's first name is provided in <user_name>. Use it in memory entries
|
||||
instead of "the user" (e.g. "{name} works at..." not "The user works at...").
|
||||
Do not store the name itself as a separate memory entry.
|
||||
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
||||
session logistics, or things that only matter for the current task.
|
||||
- Args:
|
||||
- content: The fact/preference to remember. Phrase it clearly:
|
||||
* "User prefers dark mode for all interfaces"
|
||||
* "User is a senior Python developer"
|
||||
* "User wants responses in bullet point format"
|
||||
* "User is working on project called ProjectX"
|
||||
- category: Type of memory:
|
||||
* "preference": User preferences (coding style, tools, formats)
|
||||
* "fact": Facts about the user (role, expertise, background)
|
||||
* "instruction": Standing instructions (response format, communication style)
|
||||
* "context": Current context (ongoing projects, goals, challenges)
|
||||
- Returns: Confirmation of saved memory
|
||||
- IMPORTANT: Only save information that would be genuinely useful for future conversations.
|
||||
Don't save trivial or temporary information.
|
||||
- updated_memory: The FULL updated markdown document (not a diff).
|
||||
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
||||
Treat every update as a curation pass — consolidate, don't just append.
|
||||
- Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text
|
||||
Markers:
|
||||
[fact] — durable facts (role, background, projects, tools, expertise)
|
||||
[pref] — preferences (response style, languages, formats, tools)
|
||||
[instr] — standing instructions (always/never do, response rules)
|
||||
- Keep it concise and well under the character limit shown in <user_memory>.
|
||||
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
||||
natural. Do NOT include the user's name in headings. Organize by context — e.g.
|
||||
who they are, what they're focused on, how they prefer things. Create, split, or
|
||||
merge headings freely as the memory grows.
|
||||
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- During consolidation, prioritize keeping: [instr] > [pref] > [fact].
|
||||
""",
|
||||
"shared": """
|
||||
- save_memory: Save a fact, preference, or context to the team's shared memory for future reference.
|
||||
- Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat.
|
||||
- Use when the team agrees on something to remember (e.g., decisions, conventions).
|
||||
- Someone shares a preference or fact that should be visible to the whole team.
|
||||
- The saved information will be available in future shared conversations in this space.
|
||||
- update_memory: Update the team's shared memory document for this search space.
|
||||
- Your current team memory is already in <team_memory> in your context. The `chars`
|
||||
and `limit` attributes show current usage and the maximum allowed size.
|
||||
- This is the team's curated long-term memory — decisions, conventions, key facts.
|
||||
- NEVER store personal memory in team memory (e.g. personal bio, individual
|
||||
preferences, or user-only standing instructions).
|
||||
- Call update_memory when:
|
||||
* A team member explicitly asks to remember or forget something
|
||||
* The conversation surfaces durable team decisions, conventions, or facts
|
||||
that will matter in future conversations
|
||||
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
||||
session logistics, or things that only matter for the current task.
|
||||
- Args:
|
||||
- content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays"
|
||||
- category: Type of memory. One of:
|
||||
* "preference": Team or workspace preferences
|
||||
* "fact": Facts the team agreed on (e.g., processes, locations)
|
||||
* "instruction": Standing instructions for the team
|
||||
* "context": Current context (e.g., ongoing projects, goals)
|
||||
- Returns: Confirmation of saved memory; returned context may include who added it (added_by).
|
||||
- IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space.
|
||||
""",
|
||||
},
|
||||
"recall_memory": {
|
||||
"private": """
|
||||
- recall_memory: Retrieve relevant memories about the user for personalized responses.
|
||||
- Use this to access stored information about the user.
|
||||
- Trigger scenarios:
|
||||
* You need user context to give a better, more personalized answer
|
||||
* User references something they mentioned before
|
||||
* User asks "what do you know about me?" or similar
|
||||
* Personalization would significantly improve response quality
|
||||
* Before making recommendations that should consider user preferences
|
||||
- Args:
|
||||
- query: Optional search query to find specific memories (e.g., "programming preferences")
|
||||
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
||||
- top_k: Number of memories to retrieve (default: 5)
|
||||
- Returns: Relevant memories formatted as context
|
||||
- IMPORTANT: Use the recalled memories naturally in your response without explicitly
|
||||
stating "Based on your memory..." - integrate the context seamlessly.
|
||||
""",
|
||||
"shared": """
|
||||
- recall_memory: Recall relevant team memories for this space to provide contextual responses.
|
||||
- Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?").
|
||||
- Use when someone asks about something the team agreed to remember.
|
||||
- Use when team preferences or conventions would improve the response.
|
||||
- Args:
|
||||
- query: Optional search query to find specific memories. If not provided, returns the most recent memories.
|
||||
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
||||
- top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
- Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...".
|
||||
- updated_memory: The FULL updated markdown document (not a diff).
|
||||
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
||||
Treat every update as a curation pass — consolidate, don't just append.
|
||||
- Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text
|
||||
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory.
|
||||
- Keep it concise and well under the character limit shown in <team_memory>.
|
||||
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
||||
natural. Organize by context — e.g. what the team decided, current architecture,
|
||||
active processes. Create, split, or merge headings freely as the memory grows.
|
||||
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
|
||||
""",
|
||||
},
|
||||
}
|
||||
|
||||
_MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = {
|
||||
"save_memory": {
|
||||
"update_memory": {
|
||||
"private": """
|
||||
- User: "Remember that I prefer TypeScript over JavaScript"
|
||||
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
||||
- User: "I'm a data scientist working on ML pipelines"
|
||||
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
||||
- User: "Always give me code examples in Python"
|
||||
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
||||
- <user_name>Alex</user_name>, <user_memory> is empty. User: "I'm a space enthusiast, explain astrophage to me"
|
||||
- The user casually shared a durable fact. Use their first name in the entry, short neutral heading:
|
||||
update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n")
|
||||
- User: "Remember that I prefer concise answers over detailed explanations"
|
||||
- Durable preference. Merge with existing memory, add a new heading:
|
||||
update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n\\n## Response style\\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\\n")
|
||||
- User: "I actually moved to Tokyo last month"
|
||||
- Updated fact, date prefix reflects when recorded:
|
||||
update_memory(updated_memory="## Interests & background\\n...\\n\\n## Personal context\\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\\n...")
|
||||
- User: "I'm a freelance photographer working on a nature documentary"
|
||||
- Durable background info under a fitting heading:
|
||||
update_memory(updated_memory="...\\n\\n## Current focus\\n- (2025-03-15) [fact] Alex is a freelance photographer\\n- (2025-03-15) [fact] Alex is working on a nature documentary\\n")
|
||||
- User: "Always respond in bullet points"
|
||||
- Standing instruction:
|
||||
update_memory(updated_memory="...\\n\\n## Response style\\n- (2025-03-15) [instr] Always respond to Alex in bullet points\\n")
|
||||
""",
|
||||
"shared": """
|
||||
- User: "Remember that API keys are stored in Vault"
|
||||
- Call: `save_memory(content="API keys are stored in Vault", category="fact")`
|
||||
- User: "Let's remember that the team prefers weekly demos on Fridays"
|
||||
- Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")`
|
||||
""",
|
||||
},
|
||||
"recall_memory": {
|
||||
"private": """
|
||||
- User: "What programming language should I use for this project?"
|
||||
- First recall: `recall_memory(query="programming language preferences")`
|
||||
- Then provide a personalized recommendation based on their preferences
|
||||
- User: "What do you know about me?"
|
||||
- Call: `recall_memory(top_k=10)`
|
||||
- Then summarize the stored memories
|
||||
""",
|
||||
"shared": """
|
||||
- User: "What did we decide about the release date?"
|
||||
- First recall: `recall_memory(query="release date decision")`
|
||||
- Then answer based on the team memories
|
||||
- User: "Where do we document onboarding?"
|
||||
- Call: `recall_memory(query="onboarding documentation")`
|
||||
- Then answer using the recalled team context
|
||||
- User: "Let's remember that we decided to do weekly standup meetings on Mondays"
|
||||
- Durable team decision:
|
||||
update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\\n...")
|
||||
- User: "Our office is in downtown Seattle, 5th floor"
|
||||
- Durable team fact:
|
||||
update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\\n...")
|
||||
""",
|
||||
},
|
||||
}
|
||||
|
|
@ -456,8 +452,7 @@ _ALL_TOOL_NAMES_ORDERED = [
|
|||
"generate_report",
|
||||
"generate_image",
|
||||
"scrape_webpage",
|
||||
"save_memory",
|
||||
"recall_memory",
|
||||
"update_memory",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ Available tools:
|
|||
- generate_video_presentation: Generate video presentations with slides and narration
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- scrape_webpage: Extract content from webpages
|
||||
- save_memory: Store facts/preferences about the user
|
||||
- recall_memory: Retrieve relevant user memories
|
||||
- update_memory: Update the user's / team's memory document
|
||||
"""
|
||||
|
||||
# Registry exports
|
||||
|
|
@ -33,7 +32,7 @@ from .registry import (
|
|||
)
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -47,10 +46,10 @@ __all__ = [
|
|||
"create_generate_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_generate_video_presentation_tool",
|
||||
"create_recall_memory_tool",
|
||||
"create_save_memory_tool",
|
||||
"create_scrape_webpage_tool",
|
||||
"create_search_surfsense_docs_tool",
|
||||
"create_update_memory_tool",
|
||||
"create_update_team_memory_tool",
|
||||
"format_documents_for_context",
|
||||
"get_all_tool_names",
|
||||
"get_default_enabled_tools",
|
||||
|
|
|
|||
|
|
@ -94,11 +94,7 @@ from .podcast import create_generate_podcast_tool
|
|||
from .report import create_generate_report_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .shared_memory import (
|
||||
create_recall_shared_memory_tool,
|
||||
create_save_shared_memory_tool,
|
||||
)
|
||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
from .web_search import create_web_search_tool
|
||||
|
||||
|
|
@ -214,42 +210,31 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["db_session"],
|
||||
),
|
||||
# =========================================================================
|
||||
# USER MEMORY TOOLS - private or team store by thread_visibility
|
||||
# MEMORY TOOL - single update_memory, private or team by thread_visibility
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="save_memory",
|
||||
description="Save facts, preferences, or context for personalized or team responses",
|
||||
name="update_memory",
|
||||
description="Save important long-term facts, preferences, and instructions to the (personal or team) memory",
|
||||
factory=lambda deps: (
|
||||
create_save_shared_memory_tool(
|
||||
create_update_team_memory_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
created_by_id=deps["user_id"],
|
||||
db_session=deps["db_session"],
|
||||
llm=deps.get("llm"),
|
||||
)
|
||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||
else create_save_memory_tool(
|
||||
else create_update_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
llm=deps.get("llm"),
|
||||
)
|
||||
),
|
||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="recall_memory",
|
||||
description="Recall relevant memories (personal or team) for context",
|
||||
factory=lambda deps: (
|
||||
create_recall_shared_memory_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||
else create_recall_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
),
|
||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
requires=[
|
||||
"user_id",
|
||||
"search_space_id",
|
||||
"db_session",
|
||||
"thread_visibility",
|
||||
"llm",
|
||||
],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS - create, update, delete issues
|
||||
|
|
|
|||
|
|
@ -1,281 +0,0 @@
|
|||
"""Shared (team) memory backend for search-space-scoped AI context."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import MemoryCategory, SharedMemory, User
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RECALL_TOP_K = 5
|
||||
MAX_MEMORIES_PER_SEARCH_SPACE = 250
|
||||
|
||||
|
||||
async def get_shared_memory_count(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> int:
|
||||
result = await db_session.execute(
|
||||
select(SharedMemory).where(SharedMemory.search_space_id == search_space_id)
|
||||
)
|
||||
return len(result.scalars().all())
|
||||
|
||||
|
||||
async def delete_oldest_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> None:
|
||||
result = await db_session.execute(
|
||||
select(SharedMemory)
|
||||
.where(SharedMemory.search_space_id == search_space_id)
|
||||
.order_by(SharedMemory.updated_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
oldest = result.scalars().first()
|
||||
if oldest:
|
||||
await db_session.delete(oldest)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
def _to_uuid(value: str | UUID) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
return UUID(value)
|
||||
|
||||
|
||||
async def save_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
created_by_id: str | UUID,
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
category = category.lower() if category else "fact"
|
||||
valid = ["preference", "fact", "instruction", "context"]
|
||||
if category not in valid:
|
||||
category = "fact"
|
||||
try:
|
||||
count = await get_shared_memory_count(db_session, search_space_id)
|
||||
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
|
||||
await delete_oldest_shared_memory(db_session, search_space_id)
|
||||
embedding = await asyncio.to_thread(embed_text, content)
|
||||
row = SharedMemory(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=_to_uuid(created_by_id),
|
||||
memory_text=content,
|
||||
category=MemoryCategory(category),
|
||||
embedding=embedding,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(row)
|
||||
return {
|
||||
"status": "saved",
|
||||
"memory_id": row.id,
|
||||
"memory_text": content,
|
||||
"category": category,
|
||||
"message": f"I'll remember: {content}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save shared memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": "Failed to save memory. Please try again.",
|
||||
}
|
||||
|
||||
|
||||
async def recall_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
top_k = min(max(top_k, 1), 20)
|
||||
try:
|
||||
valid_categories = ["preference", "fact", "instruction", "context"]
|
||||
stmt = select(SharedMemory).where(
|
||||
SharedMemory.search_space_id == search_space_id
|
||||
)
|
||||
if category and category in valid_categories:
|
||||
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
|
||||
if query:
|
||||
query_embedding = await asyncio.to_thread(embed_text, query)
|
||||
stmt = stmt.order_by(
|
||||
SharedMemory.embedding.op("<=>")(query_embedding)
|
||||
).limit(top_k)
|
||||
else:
|
||||
stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k)
|
||||
result = await db_session.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
memory_list = [
|
||||
{
|
||||
"id": m.id,
|
||||
"memory_text": m.memory_text,
|
||||
"category": m.category.value if m.category else "unknown",
|
||||
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
||||
"created_by_id": str(m.created_by_id) if m.created_by_id else None,
|
||||
}
|
||||
for m in rows
|
||||
]
|
||||
created_by_ids = list(
|
||||
{m["created_by_id"] for m in memory_list if m["created_by_id"]}
|
||||
)
|
||||
created_by_map: dict[str, str] = {}
|
||||
if created_by_ids:
|
||||
uuids = [UUID(uid) for uid in created_by_ids]
|
||||
users_result = await db_session.execute(
|
||||
select(User).where(User.id.in_(uuids))
|
||||
)
|
||||
for u in users_result.scalars().all():
|
||||
created_by_map[str(u.id)] = u.display_name or "A team member"
|
||||
formatted_context = format_shared_memories_for_context(
|
||||
memory_list, created_by_map
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(memory_list),
|
||||
"memories": memory_list,
|
||||
"formatted_context": formatted_context,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to recall shared memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"memories": [],
|
||||
"formatted_context": "Failed to recall memories.",
|
||||
}
|
||||
|
||||
|
||||
def format_shared_memories_for_context(
|
||||
memories: list[dict[str, Any]],
|
||||
created_by_map: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
if not memories:
|
||||
return "No relevant team memories found."
|
||||
created_by_map = created_by_map or {}
|
||||
parts = ["<team_memories>"]
|
||||
for memory in memories:
|
||||
category = memory.get("category", "unknown")
|
||||
text = memory.get("memory_text", "")
|
||||
updated = memory.get("updated_at", "")
|
||||
created_by_id = memory.get("created_by_id")
|
||||
added_by = (
|
||||
created_by_map.get(str(created_by_id), "A team member")
|
||||
if created_by_id is not None
|
||||
else "A team member"
|
||||
)
|
||||
parts.append(
|
||||
f" <memory category='{category}' updated='{updated}' added_by='{added_by}'>{text}</memory>"
|
||||
)
|
||||
parts.append("</team_memories>")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def create_save_shared_memory_tool(
|
||||
search_space_id: int,
|
||||
created_by_id: str | UUID,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the save_memory tool for shared (team) chats.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
created_by_id: The user ID of the person adding the memory
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for saving team memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def save_memory(
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Save a fact, preference, or context to the team's shared memory for future reference.
|
||||
|
||||
Use this tool when:
|
||||
- User or a team member says "remember this", "keep this in mind", or similar in this shared chat
|
||||
- The team agrees on something to remember (e.g., decisions, conventions, where things live)
|
||||
- Someone shares a preference or fact that should be visible to the whole team
|
||||
|
||||
The saved information will be available in future shared conversations in this space.
|
||||
|
||||
Args:
|
||||
content: The fact/preference/context to remember.
|
||||
Phrase it clearly, e.g., "API keys are stored in Vault",
|
||||
"The team prefers weekly demos on Fridays"
|
||||
category: Type of memory. One of:
|
||||
- "preference": Team or workspace preferences
|
||||
- "fact": Facts the team agreed on (e.g., processes, locations)
|
||||
- "instruction": Standing instructions for the team
|
||||
- "context": Current context (e.g., ongoing projects, goals)
|
||||
|
||||
Returns:
|
||||
A dictionary with the save status and memory details
|
||||
"""
|
||||
return await save_shared_memory(
|
||||
db_session, search_space_id, created_by_id, content, category
|
||||
)
|
||||
|
||||
return save_memory
|
||||
|
||||
|
||||
def create_recall_shared_memory_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the recall_memory tool for shared (team) chats.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for recalling team memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def recall_memory(
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Recall relevant team memories for this space to provide contextual responses.
|
||||
|
||||
Use this tool when:
|
||||
- You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?")
|
||||
- Someone asks about something the team agreed to remember
|
||||
- Team preferences or conventions would improve the response
|
||||
|
||||
Args:
|
||||
query: Optional search query to find specific memories.
|
||||
If not provided, returns the most recent memories.
|
||||
category: Optional category filter. One of:
|
||||
"preference", "fact", "instruction", "context"
|
||||
top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
|
||||
Returns:
|
||||
A dictionary containing relevant memories and formatted context
|
||||
"""
|
||||
return await recall_shared_memory(
|
||||
db_session, search_space_id, query, category, top_k
|
||||
)
|
||||
|
||||
return recall_memory
|
||||
389
surfsense_backend/app/agents/new_chat/tools/update_memory.py
Normal file
389
surfsense_backend/app/agents/new_chat/tools/update_memory.py
Normal file
|
|
@ -0,0 +1,389 @@
|
|||
"""Markdown-document memory tool for the SurfSense agent.
|
||||
|
||||
Replaces the old row-per-fact save_memory / recall_memory tools with a single
|
||||
update_memory tool that overwrites a freeform markdown TEXT column. The LLM
|
||||
always sees the current memory in <user_memory> / <team_memory> tags injected
|
||||
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
|
||||
|
||||
Overflow handling:
|
||||
- Soft limit (18K chars): a warning is returned telling the agent to
|
||||
consolidate on the next update.
|
||||
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
|
||||
If it still exceeds the limit after rewriting, the save is rejected.
|
||||
- Diff validation: warns when entire ``##`` sections are dropped or when the
|
||||
document shrinks by more than 60%.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSpace, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_SOFT_LIMIT = 18_000
|
||||
MEMORY_HARD_LIMIT = 25_000
|
||||
|
||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
||||
|
||||
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
|
||||
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
|
||||
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diff validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_headings(memory: str) -> set[str]:
|
||||
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
|
||||
return set(_SECTION_HEADING_RE.findall(memory))
|
||||
|
||||
|
||||
def _normalize_heading(heading: str) -> str:
|
||||
"""Normalize heading text for robust scope checks."""
|
||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
||||
|
||||
|
||||
def _validate_memory_scope(
|
||||
content: str, scope: Literal["user", "team"]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Reject personal-only markers ([pref], [instr]) in team memory."""
|
||||
if scope != "team":
|
||||
return None
|
||||
|
||||
markers = set(_MARKER_RE.findall(content))
|
||||
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
|
||||
if leaked:
|
||||
tags = ", ".join(f"[{m}]" for m in leaked)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Team memory cannot include personal markers: {tags}. "
|
||||
"Use [fact] only in team memory."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _validate_bullet_format(content: str) -> list[str]:
|
||||
"""Return warnings for bullet lines that don't match the required format.
|
||||
|
||||
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
|
||||
"""
|
||||
warnings: list[str] = []
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped.startswith("- "):
|
||||
continue
|
||||
if not _BULLET_FORMAT_RE.match(stripped):
|
||||
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
||||
warnings.append(f"Malformed bullet: {short}")
|
||||
return warnings
|
||||
|
||||
|
||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||
"""Return a list of warning strings about suspicious changes."""
|
||||
if not old_memory:
|
||||
return []
|
||||
|
||||
warnings: list[str] = []
|
||||
old_headings = _extract_headings(old_memory)
|
||||
new_headings = _extract_headings(new_memory)
|
||||
dropped = old_headings - new_headings
|
||||
if dropped:
|
||||
names = ", ".join(sorted(dropped))
|
||||
warnings.append(
|
||||
f"Sections removed: {names}. "
|
||||
"If unintentional, the user can restore from the settings page."
|
||||
)
|
||||
|
||||
old_len = len(old_memory)
|
||||
new_len = len(new_memory)
|
||||
if old_len > 0 and new_len < old_len * 0.4:
|
||||
warnings.append(
|
||||
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
|
||||
"Possible data loss."
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Size validation & soft warning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_memory_size(content: str) -> dict[str, Any] | None:
|
||||
"""Return an error/warning dict if *content* is too large, else None."""
|
||||
length = len(content)
|
||||
if length > MEMORY_HARD_LIMIT:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
||||
f"({length:,} chars). Consolidate by merging related items, "
|
||||
"removing outdated entries, and shortening descriptions. "
|
||||
"Then call update_memory again."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _soft_warning(content: str) -> str | None:
|
||||
"""Return a warning string if content exceeds the soft limit."""
|
||||
length = len(content)
|
||||
if length > MEMORY_SOFT_LIMIT:
|
||||
return (
|
||||
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
||||
"Consolidate by merging related items and removing less important "
|
||||
"entries on your next update."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forced rewrite when memory exceeds the hard limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FORCED_REWRITE_PROMPT = """\
|
||||
You are a memory curator. The following memory document exceeds the character \
|
||||
limit and must be shortened.
|
||||
|
||||
RULES:
|
||||
1. Rewrite the document to be under {target} characters.
|
||||
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
|
||||
or rename headings to consolidate, but keep names personal and descriptive.
|
||||
3. Priority for keeping content: [instr] > [pref] > [fact].
|
||||
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
||||
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||
6. Preserve the user's first name in entries — do not replace it with "the user".
|
||||
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||
|
||||
<memory_document>
|
||||
{content}
|
||||
</memory_document>"""
|
||||
|
||||
|
||||
async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
||||
"""Use a focused LLM call to compress *content* under the hard limit.
|
||||
|
||||
Returns the rewritten string, or ``None`` if the call fails.
|
||||
"""
|
||||
try:
|
||||
prompt = _FORCED_REWRITE_PROMPT.format(
|
||||
target=MEMORY_HARD_LIMIT, content=content
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
text = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
)
|
||||
return text.strip()
|
||||
except Exception:
|
||||
logger.exception("Forced rewrite LLM call failed")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared save-and-respond logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _save_memory(
|
||||
*,
|
||||
updated_memory: str,
|
||||
old_memory: str | None,
|
||||
llm: Any | None,
|
||||
apply_fn,
|
||||
commit_fn,
|
||||
rollback_fn,
|
||||
label: str,
|
||||
scope: Literal["user", "team"],
|
||||
) -> dict[str, Any]:
|
||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||
return a response dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
updated_memory : str
|
||||
The new document the agent submitted.
|
||||
old_memory : str | None
|
||||
The previously persisted document (for diff checks).
|
||||
llm : Any | None
|
||||
LLM instance for forced rewrite (may be ``None``).
|
||||
apply_fn : callable(str) -> None
|
||||
Callback that sets the new memory on the ORM object.
|
||||
commit_fn : coroutine
|
||||
``session.commit``.
|
||||
rollback_fn : coroutine
|
||||
``session.rollback``.
|
||||
label : str
|
||||
Human label for log messages (e.g. "user memory", "team memory").
|
||||
"""
|
||||
content = updated_memory
|
||||
|
||||
# --- forced rewrite if over the hard limit ---
|
||||
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
|
||||
rewritten = await _forced_rewrite(content, llm)
|
||||
if rewritten is not None and len(rewritten) < len(content):
|
||||
content = rewritten
|
||||
|
||||
# --- hard-limit gate (reject if still too large after rewrite) ---
|
||||
size_err = _validate_memory_size(content)
|
||||
if size_err:
|
||||
return size_err
|
||||
|
||||
scope_err = _validate_memory_scope(content, scope)
|
||||
if scope_err:
|
||||
return scope_err
|
||||
|
||||
# --- persist ---
|
||||
try:
|
||||
apply_fn(content)
|
||||
await commit_fn()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update %s: %s", label, e)
|
||||
await rollback_fn()
|
||||
return {"status": "error", "message": f"Failed to update {label}: {e}"}
|
||||
|
||||
# --- build response ---
|
||||
resp: dict[str, Any] = {
|
||||
"status": "saved",
|
||||
"message": f"{label.capitalize()} updated.",
|
||||
}
|
||||
|
||||
if content is not updated_memory:
|
||||
resp["notice"] = "Memory was automatically rewritten to fit within limits."
|
||||
|
||||
diff_warnings = _validate_diff(old_memory, content)
|
||||
if diff_warnings:
|
||||
resp["diff_warnings"] = diff_warnings
|
||||
|
||||
format_warnings = _validate_bullet_format(content)
|
||||
if format_warnings:
|
||||
resp["format_warnings"] = format_warnings
|
||||
|
||||
warning = _soft_warning(content)
|
||||
if warning:
|
||||
resp["warning"] = warning
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_update_memory_tool(
|
||||
user_id: str | UUID,
|
||||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
@tool
|
||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the user's personal memory document.
|
||||
|
||||
Your current memory is shown in <user_memory> in the system prompt.
|
||||
When the user shares important long-term information (preferences,
|
||||
facts, instructions, context), rewrite the memory document to include
|
||||
the new information. Merge new facts with existing ones, update
|
||||
contradictions, remove outdated entries, and keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return {"status": "error", "message": "User not found."}
|
||||
|
||||
old_memory = user.memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update user memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update memory: {e}",
|
||||
}
|
||||
|
||||
return update_memory
|
||||
|
||||
|
||||
def create_update_team_memory_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the team's shared memory document for this search space.
|
||||
|
||||
Your current team memory is shown in <team_memory> in the system
|
||||
prompt. When the team shares important long-term information
|
||||
(decisions, conventions, key facts, priorities), rewrite the memory
|
||||
document to include the new information. Merge new facts with
|
||||
existing ones, update contradictions, remove outdated entries, and
|
||||
keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return {"status": "error", "message": "Search space not found."}
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update team memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update team memory: {e}",
|
||||
}
|
||||
|
||||
return update_memory
|
||||
|
|
@ -1,351 +0,0 @@
|
|||
"""
|
||||
User memory tools for the SurfSense agent.
|
||||
|
||||
This module provides tools for storing and retrieving user memories,
|
||||
enabling personalized AI responses similar to Claude's memory feature.
|
||||
|
||||
Features:
|
||||
- save_memory: Store facts, preferences, and context about the user
|
||||
- recall_memory: Retrieve relevant memories using semantic search
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import MemoryCategory, UserMemory
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
# Default number of memories to retrieve
|
||||
DEFAULT_RECALL_TOP_K = 5
|
||||
|
||||
# Maximum number of memories per user (to prevent unbounded growth)
|
||||
MAX_MEMORIES_PER_USER = 100
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _to_uuid(user_id: str) -> UUID:
|
||||
"""Convert a string user_id to a UUID object."""
|
||||
if isinstance(user_id, UUID):
|
||||
return user_id
|
||||
return UUID(user_id)
|
||||
|
||||
|
||||
async def get_user_memory_count(
|
||||
db_session: AsyncSession,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> int:
|
||||
"""Get the count of memories for a user."""
|
||||
uuid_user_id = _to_uuid(user_id)
|
||||
query = select(UserMemory).where(UserMemory.user_id == uuid_user_id)
|
||||
if search_space_id is not None:
|
||||
query = query.where(
|
||||
(UserMemory.search_space_id == search_space_id)
|
||||
| (UserMemory.search_space_id.is_(None))
|
||||
)
|
||||
result = await db_session.execute(query)
|
||||
return len(result.scalars().all())
|
||||
|
||||
|
||||
async def delete_oldest_memory(
|
||||
db_session: AsyncSession,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> None:
|
||||
"""Delete the oldest memory for a user to make room for new ones."""
|
||||
uuid_user_id = _to_uuid(user_id)
|
||||
query = (
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == uuid_user_id)
|
||||
.order_by(UserMemory.updated_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
if search_space_id is not None:
|
||||
query = query.where(
|
||||
(UserMemory.search_space_id == search_space_id)
|
||||
| (UserMemory.search_space_id.is_(None))
|
||||
)
|
||||
result = await db_session.execute(query)
|
||||
oldest_memory = result.scalars().first()
|
||||
if oldest_memory:
|
||||
await db_session.delete(oldest_memory)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
def format_memories_for_context(memories: list[dict[str, Any]]) -> str:
|
||||
"""Format retrieved memories into a readable context string for the LLM."""
|
||||
if not memories:
|
||||
return "No relevant memories found for this user."
|
||||
|
||||
parts = ["<user_memories>"]
|
||||
for memory in memories:
|
||||
category = memory.get("category", "unknown")
|
||||
text = memory.get("memory_text", "")
|
||||
updated = memory.get("updated_at", "")
|
||||
parts.append(
|
||||
f" <memory category='{category}' updated='{updated}'>{text}</memory>"
|
||||
)
|
||||
parts.append("</user_memories>")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tool Factory Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_save_memory_tool(
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the save_memory tool.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID
|
||||
search_space_id: The search space ID (for space-specific memories)
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for saving user memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def save_memory(
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Save a fact, preference, or context about the user for future reference.
|
||||
|
||||
Use this tool when:
|
||||
- User explicitly says "remember this", "keep this in mind", or similar
|
||||
- User shares personal preferences (e.g., "I prefer Python over JavaScript")
|
||||
- User shares important facts about themselves (name, role, interests, projects)
|
||||
- User gives standing instructions (e.g., "always respond in bullet points")
|
||||
- User shares relevant context (e.g., "I'm working on project X")
|
||||
|
||||
The saved information will be available in future conversations to provide
|
||||
more personalized and contextual responses.
|
||||
|
||||
Args:
|
||||
content: The fact/preference/context to remember.
|
||||
Phrase it clearly, e.g., "User prefers dark mode",
|
||||
"User is a senior Python developer", "User is working on an AI project"
|
||||
category: Type of memory. One of:
|
||||
- "preference": User preferences (e.g., coding style, tools, formats)
|
||||
- "fact": Facts about the user (e.g., name, role, expertise)
|
||||
- "instruction": Standing instructions (e.g., response format preferences)
|
||||
- "context": Current context (e.g., ongoing projects, goals)
|
||||
|
||||
Returns:
|
||||
A dictionary with the save status and memory details
|
||||
"""
|
||||
# Normalize and validate category (LLMs may send uppercase)
|
||||
category = category.lower() if category else "fact"
|
||||
valid_categories = ["preference", "fact", "instruction", "context"]
|
||||
if category not in valid_categories:
|
||||
category = "fact"
|
||||
|
||||
try:
|
||||
# Convert user_id to UUID
|
||||
uuid_user_id = _to_uuid(user_id)
|
||||
|
||||
# Check if we've hit the memory limit
|
||||
memory_count = await get_user_memory_count(
|
||||
db_session, user_id, search_space_id
|
||||
)
|
||||
if memory_count >= MAX_MEMORIES_PER_USER:
|
||||
# Delete oldest memory to make room
|
||||
await delete_oldest_memory(db_session, user_id, search_space_id)
|
||||
|
||||
embedding = await asyncio.to_thread(embed_text, content)
|
||||
|
||||
# Create new memory using ORM
|
||||
# The pgvector Vector column type handles embedding conversion automatically
|
||||
new_memory = UserMemory(
|
||||
user_id=uuid_user_id,
|
||||
search_space_id=search_space_id,
|
||||
memory_text=content,
|
||||
category=MemoryCategory(category), # Convert string to enum
|
||||
embedding=embedding, # Pass embedding directly (list or numpy array)
|
||||
)
|
||||
|
||||
db_session.add(new_memory)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(new_memory)
|
||||
|
||||
return {
|
||||
"status": "saved",
|
||||
"memory_id": new_memory.id,
|
||||
"memory_text": content,
|
||||
"category": category,
|
||||
"message": f"I'll remember: {content}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to save memory for user {user_id}: {e}")
|
||||
# Rollback the session to clear any failed transaction state
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": "Failed to save memory. Please try again.",
|
||||
}
|
||||
|
||||
return save_memory
|
||||
|
||||
|
||||
def create_recall_memory_tool(
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the recall_memory tool.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID
|
||||
search_space_id: The search space ID
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for recalling user memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def recall_memory(
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Recall relevant memories about the user to provide personalized responses.
|
||||
|
||||
Use this tool when:
|
||||
- You need user context to give a better, more personalized answer
|
||||
- User asks about their preferences or past information they shared
|
||||
- User references something they told you before
|
||||
- Personalization would significantly improve the response quality
|
||||
- User asks "what do you know about me?" or similar
|
||||
|
||||
Args:
|
||||
query: Optional search query to find specific memories.
|
||||
If not provided, returns the most recent memories.
|
||||
Example: "programming preferences", "current projects"
|
||||
category: Optional category filter. One of:
|
||||
"preference", "fact", "instruction", "context"
|
||||
If not provided, searches all categories.
|
||||
top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
|
||||
Returns:
|
||||
A dictionary containing relevant memories and formatted context
|
||||
"""
|
||||
top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20
|
||||
|
||||
try:
|
||||
# Convert user_id to UUID
|
||||
uuid_user_id = _to_uuid(user_id)
|
||||
|
||||
if query:
|
||||
query_embedding = await asyncio.to_thread(embed_text, query)
|
||||
|
||||
# Build query with vector similarity
|
||||
stmt = (
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == uuid_user_id)
|
||||
.where(
|
||||
(UserMemory.search_space_id == search_space_id)
|
||||
| (UserMemory.search_space_id.is_(None))
|
||||
)
|
||||
)
|
||||
|
||||
# Add category filter if specified
|
||||
if category and category in [
|
||||
"preference",
|
||||
"fact",
|
||||
"instruction",
|
||||
"context",
|
||||
]:
|
||||
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
|
||||
|
||||
# Order by vector similarity
|
||||
stmt = stmt.order_by(
|
||||
UserMemory.embedding.op("<=>")(query_embedding)
|
||||
).limit(top_k)
|
||||
|
||||
else:
|
||||
# No query - return most recent memories
|
||||
stmt = (
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == uuid_user_id)
|
||||
.where(
|
||||
(UserMemory.search_space_id == search_space_id)
|
||||
| (UserMemory.search_space_id.is_(None))
|
||||
)
|
||||
)
|
||||
|
||||
# Add category filter if specified
|
||||
if category and category in [
|
||||
"preference",
|
||||
"fact",
|
||||
"instruction",
|
||||
"context",
|
||||
]:
|
||||
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
|
||||
|
||||
stmt = stmt.order_by(UserMemory.updated_at.desc()).limit(top_k)
|
||||
|
||||
result = await db_session.execute(stmt)
|
||||
memories = result.scalars().all()
|
||||
|
||||
# Format memories for response
|
||||
memory_list = [
|
||||
{
|
||||
"id": m.id,
|
||||
"memory_text": m.memory_text,
|
||||
"category": m.category.value if m.category else "unknown",
|
||||
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
||||
}
|
||||
for m in memories
|
||||
]
|
||||
|
||||
formatted_context = format_memories_for_context(memory_list)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(memory_list),
|
||||
"memories": memory_list,
|
||||
"formatted_context": formatted_context,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to recall memories for user {user_id}: {e}")
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"memories": [],
|
||||
"formatted_context": "Failed to recall memories.",
|
||||
}
|
||||
|
||||
return recall_memory
|
||||
|
|
@ -861,99 +861,6 @@ class ChatSessionState(BaseModel):
|
|||
ai_responding_to_user = relationship("User")
|
||||
|
||||
|
||||
class MemoryCategory(StrEnum):
|
||||
"""Categories for user memories."""
|
||||
|
||||
# Using lowercase keys to match PostgreSQL enum values
|
||||
preference = "preference" # User preferences (e.g., "prefers dark mode")
|
||||
fact = "fact" # Facts about the user (e.g., "is a Python developer")
|
||||
instruction = (
|
||||
"instruction" # Standing instructions (e.g., "always respond in bullet points")
|
||||
)
|
||||
context = "context" # Contextual information (e.g., "working on project X")
|
||||
|
||||
|
||||
class UserMemory(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Private memory: facts, preferences, context per user per search space.
|
||||
Used only for private chats (not shared/team chats).
|
||||
"""
|
||||
|
||||
__tablename__ = "user_memories"
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
# Optional association with a search space (if memory is space-specific)
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# The actual memory content
|
||||
memory_text = Column(Text, nullable=False)
|
||||
# Category for organization and filtering
|
||||
category = Column(
|
||||
SQLAlchemyEnum(MemoryCategory),
|
||||
nullable=False,
|
||||
default=MemoryCategory.fact,
|
||||
)
|
||||
# Vector embedding for semantic search
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
|
||||
# Track when memory was last updated
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="memories")
|
||||
search_space = relationship("SearchSpace", back_populates="user_memories")
|
||||
|
||||
|
||||
class SharedMemory(BaseModel, TimestampMixin):
|
||||
__tablename__ = "shared_memories"
|
||||
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
created_by_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
memory_text = Column(Text, nullable=False)
|
||||
category = Column(
|
||||
SQLAlchemyEnum(MemoryCategory),
|
||||
nullable=False,
|
||||
default=MemoryCategory.fact,
|
||||
)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
search_space = relationship("SearchSpace", back_populates="shared_memories")
|
||||
created_by = relationship("User")
|
||||
|
||||
|
||||
class Folder(BaseModel, TimestampMixin):
|
||||
__tablename__ = "folders"
|
||||
|
||||
|
|
@ -1394,6 +1301,8 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
Text, nullable=True, default=""
|
||||
) # User's custom instructions
|
||||
|
||||
shared_memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Search space-level LLM preferences (shared by all members)
|
||||
# Note: ID values:
|
||||
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
||||
|
|
@ -1516,20 +1425,6 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# User memories associated with this search space
|
||||
user_memories = relationship(
|
||||
"UserMemory",
|
||||
back_populates="search_space",
|
||||
order_by="UserMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
shared_memories = relationship(
|
||||
"SharedMemory",
|
||||
back_populates="search_space",
|
||||
order_by="SharedMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||
__tablename__ = "search_source_connectors"
|
||||
|
|
@ -2030,14 +1925,6 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
back_populates="user",
|
||||
order_by="UserMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Incentive tasks completed by this user
|
||||
incentive_tasks = relationship(
|
||||
"UserIncentiveTask",
|
||||
|
|
@ -2065,6 +1952,8 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
|
||||
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Refresh tokens for this user
|
||||
refresh_tokens = relationship(
|
||||
"RefreshToken",
|
||||
|
|
@ -2150,14 +2039,6 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
back_populates="user",
|
||||
order_by="UserMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Incentive tasks completed by this user
|
||||
incentive_tasks = relationship(
|
||||
"UserIncentiveTask",
|
||||
|
|
@ -2185,6 +2066,8 @@ else:
|
|||
|
||||
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Refresh tokens for this user
|
||||
refresh_tokens = relationship(
|
||||
"RefreshToken",
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
|
|||
from .linear_add_connector_route import router as linear_add_connector_router
|
||||
from .logs_routes import router as logs_router
|
||||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .memory_routes import router as memory_router
|
||||
from .model_list_routes import router as model_list_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
from .new_llm_config_routes import router as new_llm_config_router
|
||||
|
|
@ -98,4 +99,5 @@ router.include_router(incentive_tasks_router) # Incentive tasks for earning fre
|
|||
router.include_router(stripe_router) # Stripe checkout for additional page packs
|
||||
router.include_router(youtube_router) # YouTube playlist resolution
|
||||
router.include_router(prompts_router)
|
||||
router.include_router(memory_router) # User personal memory (memory.md style)
|
||||
router.include_router(autocomplete_router) # Lightweight autocomplete with KB context
|
||||
|
|
|
|||
153
surfsense_backend/app/routes/memory_routes.py
Normal file
153
surfsense_backend/app/routes/memory_routes.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""Routes for user memory management (personal memory.md)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.llm_config import (
|
||||
create_chat_litellm_from_agent_config,
|
||||
load_agent_llm_config_for_search_space,
|
||||
)
|
||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
||||
from app.db import User, get_async_session
|
||||
from app.users import current_active_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MemoryRead(BaseModel):
|
||||
memory_md: str
|
||||
|
||||
|
||||
class MemoryUpdate(BaseModel):
|
||||
memory_md: str
|
||||
|
||||
|
||||
class MemoryEditRequest(BaseModel):
|
||||
query: str
|
||||
search_space_id: int
|
||||
|
||||
|
||||
_MEMORY_EDIT_PROMPT = """\
|
||||
You are a memory editor. The user wants to modify their memory document. \
|
||||
Apply the user's instruction to the existing memory document and output the \
|
||||
FULL updated document.
|
||||
|
||||
RULES:
|
||||
1. If the instruction asks to add something, add it with format: \
|
||||
- (YYYY-MM-DD) [fact|pref|instr] text, under an existing or new ## heading. \
|
||||
Heading names should be personal and descriptive, not generic categories.
|
||||
2. If the instruction asks to remove something, remove the matching entry.
|
||||
3. If the instruction asks to change something, update the matching entry.
|
||||
4. Preserve existing ## headings and all other entries.
|
||||
5. Every bullet must include a marker: [fact], [pref], or [instr].
|
||||
6. Use the user's first name (from <user_name>) in entries instead of "the user".
|
||||
7. Output ONLY the updated markdown — no explanations, no wrapping.
|
||||
|
||||
<user_name>{user_name}</user_name>
|
||||
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
<user_instruction>
|
||||
{instruction}
|
||||
</user_instruction>"""
|
||||
|
||||
|
||||
@router.get("/users/me/memory", response_model=MemoryRead)
|
||||
async def get_user_memory(
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
await session.refresh(user, ["memory_md"])
|
||||
return MemoryRead(memory_md=user.memory_md or "")
|
||||
|
||||
|
||||
@router.put("/users/me/memory", response_model=MemoryRead)
|
||||
async def update_user_memory(
|
||||
body: MemoryUpdate,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if len(body.memory_md) > MEMORY_HARD_LIMIT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit ({len(body.memory_md):,} chars).",
|
||||
)
|
||||
user.memory_md = body.memory_md
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user, ["memory_md"])
|
||||
return MemoryRead(memory_md=user.memory_md or "")
|
||||
|
||||
|
||||
@router.post("/users/me/memory/edit", response_model=MemoryRead)
|
||||
async def edit_user_memory(
|
||||
body: MemoryEditRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Apply a natural language edit to the user's personal memory via LLM."""
|
||||
agent_config = await load_agent_llm_config_for_search_space(
|
||||
session, body.search_space_id
|
||||
)
|
||||
if not agent_config:
|
||||
raise HTTPException(status_code=500, detail="No LLM configuration available.")
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
if not llm:
|
||||
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
|
||||
|
||||
await session.refresh(user, ["memory_md", "display_name"])
|
||||
current_memory = user.memory_md or ""
|
||||
first_name = (
|
||||
user.display_name.strip().split()[0]
|
||||
if user.display_name and user.display_name.strip()
|
||||
else "The user"
|
||||
)
|
||||
|
||||
prompt = _MEMORY_EDIT_PROMPT.format(
|
||||
current_memory=current_memory or "(empty)",
|
||||
instruction=body.query,
|
||||
user_name=first_name,
|
||||
)
|
||||
try:
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "memory-edit"]},
|
||||
)
|
||||
updated = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
).strip()
|
||||
except Exception as e:
|
||||
logger.exception("Memory edit LLM call failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Memory edit failed.") from e
|
||||
|
||||
if not updated:
|
||||
raise HTTPException(status_code=400, detail="LLM returned empty result.")
|
||||
|
||||
result = await _save_memory(
|
||||
updated_memory=updated,
|
||||
old_memory=current_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
|
||||
if result.get("status") == "error":
|
||||
raise HTTPException(status_code=400, detail=result["message"])
|
||||
|
||||
await session.refresh(user, ["memory_md"])
|
||||
return MemoryRead(memory_md=user.memory_md or "")
|
||||
|
|
@ -1,10 +1,17 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.llm_config import (
|
||||
create_chat_litellm_from_agent_config,
|
||||
load_agent_llm_config_for_search_space,
|
||||
)
|
||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGenerationConfig,
|
||||
|
|
@ -34,6 +41,34 @@ logger = logging.getLogger(__name__)
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
class _TeamMemoryEditRequest(PydanticBaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
_TEAM_MEMORY_EDIT_PROMPT = """\
|
||||
You are a memory editor for a team workspace. The user wants to modify the \
|
||||
team's shared memory document. Apply the user's instruction to the existing \
|
||||
memory document and output the FULL updated document.
|
||||
|
||||
RULES:
|
||||
1. If the instruction asks to add something, add it with format: \
|
||||
- (YYYY-MM-DD) [fact] text, under an existing or new ## heading. \
|
||||
Heading names should be descriptive, not generic categories.
|
||||
2. If the instruction asks to remove something, remove the matching entry.
|
||||
3. If the instruction asks to change something, update the matching entry.
|
||||
4. Preserve existing ## headings and all other entries.
|
||||
5. NEVER use [pref] or [instr] markers. Team memory uses [fact] only.
|
||||
6. Output ONLY the updated markdown — no explanations, no wrapping.
|
||||
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
<user_instruction>
|
||||
{instruction}
|
||||
</user_instruction>"""
|
||||
|
||||
|
||||
async def create_default_roles_and_membership(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
|
|
@ -255,6 +290,16 @@ async def update_search_space(
|
|||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
update_data = search_space_update.model_dump(exclude_unset=True)
|
||||
|
||||
if (
|
||||
"shared_memory_md" in update_data
|
||||
and len(update_data["shared_memory_md"] or "") > MEMORY_HARD_LIMIT
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Team memory exceeds {MEMORY_HARD_LIMIT:,} character limit.",
|
||||
)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(db_search_space, key, value)
|
||||
await session.commit()
|
||||
|
|
@ -269,6 +314,76 @@ async def update_search_space(
|
|||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/searchspaces/{search_space_id}/memory/edit",
|
||||
response_model=SearchSpaceRead,
|
||||
)
|
||||
async def edit_team_memory(
|
||||
search_space_id: int,
|
||||
body: _TeamMemoryEditRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Apply a natural language edit to the team memory via LLM."""
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
|
||||
agent_config = await load_agent_llm_config_for_search_space(
|
||||
session, search_space_id
|
||||
)
|
||||
if not agent_config:
|
||||
raise HTTPException(status_code=500, detail="No LLM configuration available.")
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
if not llm:
|
||||
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
db_search_space = result.scalars().first()
|
||||
if not db_search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
current_memory = db_search_space.shared_memory_md or ""
|
||||
|
||||
prompt = _TEAM_MEMORY_EDIT_PROMPT.format(
|
||||
current_memory=current_memory or "(empty)",
|
||||
instruction=body.query,
|
||||
)
|
||||
try:
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "memory-edit"]},
|
||||
)
|
||||
updated = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
).strip()
|
||||
except Exception as e:
|
||||
logger.exception("Team memory edit LLM call failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Team memory edit failed.") from e
|
||||
|
||||
if not updated:
|
||||
raise HTTPException(status_code=400, detail="LLM returned empty result.")
|
||||
|
||||
save_result = await _save_memory(
|
||||
updated_memory=updated,
|
||||
old_memory=current_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(db_search_space, "shared_memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
|
||||
if save_result.get("status") == "error":
|
||||
raise HTTPException(status_code=400, detail=save_result["message"])
|
||||
|
||||
await session.refresh(db_search_space)
|
||||
return db_search_space
|
||||
|
||||
|
||||
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
||||
async def delete_search_space(
|
||||
search_space_id: int,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class SearchSpaceUpdate(BaseModel):
|
|||
description: str | None = None
|
||||
citations_enabled: bool | None = None
|
||||
qna_custom_instructions: str | None = None
|
||||
shared_memory_md: str | None = None
|
||||
|
||||
|
||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||
|
|
@ -29,6 +30,7 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
|||
user_id: uuid.UUID
|
||||
citations_enabled: bool
|
||||
qna_custom_instructions: str | None = None
|
||||
shared_memory_md: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,10 @@ from app.agents.new_chat.llm_config import (
|
|||
load_agent_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.agents.new_chat.memory_extraction import (
|
||||
extract_and_save_memory,
|
||||
extract_and_save_team_memory,
|
||||
)
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
|
|
@ -59,8 +63,6 @@ from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_hea
|
|||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def format_mentioned_surfsense_docs_as_context(
|
||||
documents: list[SurfsenseDocsDocument],
|
||||
|
|
@ -141,6 +143,7 @@ class StreamResult:
|
|||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
|
||||
agent_called_update_memory: bool = False
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -183,6 +186,7 @@ async def _stream_agent_events(
|
|||
last_active_step_items: list[str] = initial_step_items or []
|
||||
just_finished_tool: bool = False
|
||||
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
|
||||
called_update_memory: bool = False
|
||||
|
||||
def next_thinking_step_id() -> str:
|
||||
nonlocal thinking_step_counter
|
||||
|
|
@ -490,6 +494,9 @@ async def _stream_agent_events(
|
|||
tool_name = event.get("name", "unknown_tool")
|
||||
raw_output = event.get("data", {}).get("output", "")
|
||||
|
||||
if tool_name == "update_memory":
|
||||
called_update_memory = True
|
||||
|
||||
if hasattr(raw_output, "content"):
|
||||
content = raw_output.content
|
||||
if isinstance(content, str):
|
||||
|
|
@ -1111,6 +1118,7 @@ async def _stream_agent_events(
|
|||
yield completion_event
|
||||
|
||||
result.accumulated_text = accumulated_text
|
||||
result.agent_called_update_memory = called_update_memory
|
||||
|
||||
state = await agent.aget_state(config)
|
||||
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
||||
|
|
@ -1540,6 +1548,27 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Fire background memory extraction if the agent didn't handle it.
|
||||
# Shared threads write to team memory; private threads write to user memory.
|
||||
if not stream_result.agent_called_update_memory:
|
||||
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||
asyncio.create_task(
|
||||
extract_and_save_team_memory(
|
||||
user_message=user_query,
|
||||
search_space_id=search_space_id,
|
||||
llm=llm,
|
||||
author_display_name=current_user_display_name,
|
||||
)
|
||||
)
|
||||
elif user_id:
|
||||
asyncio.create_task(
|
||||
extract_and_save_memory(
|
||||
user_message=user_query,
|
||||
user_id=user_id,
|
||||
llm=llm,
|
||||
)
|
||||
)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
"""Unit tests for memory scope validation and bullet format validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.tools.update_memory import (
|
||||
_save_memory,
|
||||
_validate_bullet_format,
|
||||
_validate_memory_scope,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _Recorder:
|
||||
def __init__(self) -> None:
|
||||
self.applied_content: str | None = None
|
||||
self.commit_calls = 0
|
||||
self.rollback_calls = 0
|
||||
|
||||
def apply(self, content: str) -> None:
|
||||
self.applied_content = content
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commit_calls += 1
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollback_calls += 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_memory_scope — marker-based
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_memory_scope_rejects_pref_marker_in_team_scope() -> None:
|
||||
content = "- (2026-04-10) [pref] Prefers dark mode\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is not None
|
||||
assert result["status"] == "error"
|
||||
assert "[pref]" in result["message"]
|
||||
|
||||
|
||||
def test_validate_memory_scope_rejects_instr_marker_in_team_scope() -> None:
|
||||
content = "- (2026-04-10) [instr] Always respond in Spanish\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is not None
|
||||
assert result["status"] == "error"
|
||||
assert "[instr]" in result["message"]
|
||||
|
||||
|
||||
def test_validate_memory_scope_rejects_both_personal_markers_in_team() -> None:
|
||||
content = (
|
||||
"- (2026-04-10) [pref] Prefers dark mode\n"
|
||||
"- (2026-04-10) [instr] Always respond in Spanish\n"
|
||||
)
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is not None
|
||||
assert result["status"] == "error"
|
||||
assert "[instr]" in result["message"]
|
||||
assert "[pref]" in result["message"]
|
||||
|
||||
|
||||
def test_validate_memory_scope_allows_fact_in_team_scope() -> None:
|
||||
content = "- (2026-04-10) [fact] Office is in downtown Seattle\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_memory_scope_allows_all_markers_in_user_scope() -> None:
|
||||
content = (
|
||||
"- (2026-04-10) [fact] Python developer\n"
|
||||
"- (2026-04-10) [pref] Prefers concise answers\n"
|
||||
"- (2026-04-10) [instr] Always use bullet points\n"
|
||||
)
|
||||
result = _validate_memory_scope(content, "user")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_memory_scope_allows_any_heading_in_team() -> None:
|
||||
content = "## Architecture\n- (2026-04-10) [fact] Uses PostgreSQL for persistence\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_memory_scope_allows_any_heading_in_user() -> None:
|
||||
content = "## My Projects\n- (2026-04-10) [fact] Working on SurfSense\n"
|
||||
result = _validate_memory_scope(content, "user")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_bullet_format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_bullet_format_passes_valid_bullets() -> None:
|
||||
content = (
|
||||
"## Work\n"
|
||||
"- (2026-04-10) [fact] Senior Python developer\n"
|
||||
"- (2026-04-10) [pref] Prefers dark mode\n"
|
||||
"- (2026-04-10) [instr] Always respond in bullet points\n"
|
||||
)
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert warnings == []
|
||||
|
||||
|
||||
def test_validate_bullet_format_warns_on_missing_marker() -> None:
|
||||
content = "- (2026-04-10) Senior Python developer\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert len(warnings) == 1
|
||||
assert "Malformed bullet" in warnings[0]
|
||||
|
||||
|
||||
def test_validate_bullet_format_warns_on_missing_date() -> None:
|
||||
content = "- [fact] Senior Python developer\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert len(warnings) == 1
|
||||
assert "Malformed bullet" in warnings[0]
|
||||
|
||||
|
||||
def test_validate_bullet_format_warns_on_unknown_marker() -> None:
|
||||
content = "- (2026-04-10) [context] Working on project X\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert len(warnings) == 1
|
||||
assert "Malformed bullet" in warnings[0]
|
||||
|
||||
|
||||
def test_validate_bullet_format_ignores_non_bullet_lines() -> None:
|
||||
content = "## Some Heading\nSome paragraph text\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert warnings == []
|
||||
|
||||
|
||||
def test_validate_bullet_format_warns_on_old_format_without_marker() -> None:
|
||||
content = "## About the user\n- (2026-04-10) Likes cats\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert len(warnings) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _save_memory — end-to-end with marker scope check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_blocks_pref_in_team_before_commit() -> None:
|
||||
recorder = _Recorder()
|
||||
result = await _save_memory(
|
||||
updated_memory="- (2026-04-10) [pref] Prefers dark mode\n",
|
||||
old_memory=None,
|
||||
llm=None,
|
||||
apply_fn=recorder.apply,
|
||||
commit_fn=recorder.commit,
|
||||
rollback_fn=recorder.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert recorder.commit_calls == 0
|
||||
assert recorder.applied_content is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_allows_fact_in_team_and_commits() -> None:
|
||||
recorder = _Recorder()
|
||||
content = "- (2026-04-10) [fact] Weekly standup on Mondays\n"
|
||||
result = await _save_memory(
|
||||
updated_memory=content,
|
||||
old_memory=None,
|
||||
llm=None,
|
||||
apply_fn=recorder.apply,
|
||||
commit_fn=recorder.commit,
|
||||
rollback_fn=recorder.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
assert result["status"] == "saved"
|
||||
assert recorder.commit_calls == 1
|
||||
assert recorder.applied_content == content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_includes_format_warnings() -> None:
|
||||
recorder = _Recorder()
|
||||
content = "- (2026-04-10) Missing marker text\n"
|
||||
result = await _save_memory(
|
||||
updated_memory=content,
|
||||
old_memory=None,
|
||||
llm=None,
|
||||
apply_fn=recorder.apply,
|
||||
commit_fn=recorder.commit,
|
||||
rollback_fn=recorder.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
assert result["status"] == "saved"
|
||||
assert "format_warnings" in result
|
||||
assert len(result["format_warnings"]) == 1
|
||||
|
|
@ -0,0 +1,291 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { z } from "zod";
|
||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
|
||||
const MEMORY_HARD_LIMIT = 25_000;
|
||||
|
||||
const MemoryReadSchema = z.object({
|
||||
memory_md: z.string(),
|
||||
});
|
||||
|
||||
export function MemoryContent() {
|
||||
const activeSearchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||
const [memory, setMemory] = useState("");
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [editQuery, setEditQuery] = useState("");
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [showInput, setShowInput] = useState(false);
|
||||
const textareaRef = useRef<HTMLInputElement>(null);
|
||||
const inputContainerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const fetchMemory = useCallback(async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const data = await baseApiService.get("/api/v1/users/me/memory", MemoryReadSchema);
|
||||
setMemory(data.memory_md);
|
||||
} catch {
|
||||
toast.error("Failed to load memory");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
fetchMemory();
|
||||
}, [fetchMemory]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!showInput) return;
|
||||
|
||||
const handlePointerDownOutside = (event: MouseEvent | TouchEvent) => {
|
||||
const target = event.target;
|
||||
if (!(target instanceof Node)) return;
|
||||
if (inputContainerRef.current?.contains(target)) return;
|
||||
|
||||
setShowInput(false);
|
||||
};
|
||||
|
||||
document.addEventListener("mousedown", handlePointerDownOutside);
|
||||
document.addEventListener("touchstart", handlePointerDownOutside, { passive: true });
|
||||
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handlePointerDownOutside);
|
||||
document.removeEventListener("touchstart", handlePointerDownOutside);
|
||||
};
|
||||
}, [showInput]);
|
||||
|
||||
const handleClear = async () => {
|
||||
try {
|
||||
setSaving(true);
|
||||
const data = await baseApiService.put("/api/v1/users/me/memory", MemoryReadSchema, {
|
||||
body: { memory_md: "" },
|
||||
});
|
||||
setMemory(data.memory_md);
|
||||
toast.success("Memory cleared");
|
||||
} catch {
|
||||
toast.error("Failed to clear memory");
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleEdit = async () => {
|
||||
const query = editQuery.trim();
|
||||
if (!query) return;
|
||||
|
||||
try {
|
||||
setEditing(true);
|
||||
const data = await baseApiService.post("/api/v1/users/me/memory/edit", MemoryReadSchema, {
|
||||
body: { query, search_space_id: Number(activeSearchSpaceId) },
|
||||
});
|
||||
setMemory(data.memory_md);
|
||||
setEditQuery("");
|
||||
setShowInput(false);
|
||||
toast.success("Memory updated");
|
||||
} catch {
|
||||
toast.error("Failed to edit memory");
|
||||
} finally {
|
||||
setEditing(false);
|
||||
}
|
||||
};
|
||||
|
||||
const openInput = () => {
|
||||
setShowInput(true);
|
||||
requestAnimationFrame(() => textareaRef.current?.focus());
|
||||
};
|
||||
|
||||
const handleDownload = () => {
|
||||
if (!memory) return;
|
||||
try {
|
||||
const blob = new Blob([memory], { type: "text/markdown;charset=utf-8" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = "personal-memory.md";
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
} catch {
|
||||
toast.error("Failed to download memory");
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyMarkdown = async () => {
|
||||
if (!memory) return;
|
||||
try {
|
||||
await navigator.clipboard.writeText(memory);
|
||||
toast.success("Copied to clipboard");
|
||||
} catch {
|
||||
toast.error("Failed to copy memory");
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleEdit();
|
||||
}
|
||||
};
|
||||
|
||||
const displayMemory = memory.replace(/\(\d{4}-\d{2}-\d{2}\)\s*\[(fact|pref|instr)\]\s*/g, "");
|
||||
const charCount = memory.length;
|
||||
|
||||
const getCounterColor = () => {
|
||||
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
|
||||
if (charCount > 15_000) return "text-orange-500";
|
||||
if (charCount > 10_000) return "text-yellow-500";
|
||||
return "text-muted-foreground";
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spinner size="md" className="text-muted-foreground" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!memory) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-16 text-center">
|
||||
<h3 className="text-base font-medium text-foreground">What does SurfSense remember?</h3>
|
||||
<p className="mt-2 max-w-sm text-sm text-muted-foreground">
|
||||
Nothing yet. SurfSense picks up on your preferences and context as you chat.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<Alert className="bg-muted/50 py-3 md:py-4">
|
||||
<Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
|
||||
<AlertDescription className="text-xs md:text-sm">
|
||||
<p>
|
||||
SurfSense uses this personal memory to personalize your responses across all
|
||||
conversations.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className="relative h-[380px] rounded-lg border bg-background">
|
||||
<div className="h-full overflow-y-auto scrollbar-thin">
|
||||
<PlateEditor
|
||||
markdown={displayMemory}
|
||||
readOnly
|
||||
preset="readonly"
|
||||
variant="default"
|
||||
editorVariant="none"
|
||||
className="px-5 py-4 text-sm min-h-full"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{showInput ? (
|
||||
<div className="absolute bottom-3 inset-x-3 z-10">
|
||||
<div
|
||||
ref={inputContainerRef}
|
||||
className="relative flex h-[54px] items-center gap-2 rounded-[9999px] border bg-muted/60 backdrop-blur-sm pl-4 pr-1 shadow-sm"
|
||||
>
|
||||
<input
|
||||
ref={textareaRef}
|
||||
type="text"
|
||||
value={editQuery}
|
||||
onChange={(e) => setEditQuery(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Tell SurfSense what to remember or forget"
|
||||
disabled={editing}
|
||||
className="flex-1 bg-transparent text-sm outline-none placeholder:text-muted-foreground/70"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
onClick={handleEdit}
|
||||
disabled={editing || !editQuery.trim()}
|
||||
className={`h-11 w-11 shrink-0 rounded-full ${
|
||||
editing ? "" : "bg-muted-foreground/15 hover:bg-muted-foreground/20"
|
||||
}`}
|
||||
>
|
||||
{editing ? (
|
||||
<Spinner size="sm" />
|
||||
) : (
|
||||
<ArrowUp className="!h-5 !w-5 text-foreground" strokeWidth={2.25} />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
size="icon"
|
||||
variant="secondary"
|
||||
onClick={openInput}
|
||||
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
|
||||
>
|
||||
<Pen className="!h-5 !w-5" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className={`text-xs shrink-0 ${getCounterColor()}`}>
|
||||
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()}
|
||||
<span className="hidden sm:inline"> characters</span>
|
||||
<span className="sm:hidden"> chars</span>
|
||||
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
|
||||
{charCount > MEMORY_HARD_LIMIT && " - Exceeds limit"}
|
||||
</span>
|
||||
<div className="flex items-center gap-1.5 sm:gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
className="text-xs sm:text-sm"
|
||||
onClick={handleClear}
|
||||
disabled={saving || editing || !memory}
|
||||
>
|
||||
<span className="hidden sm:inline">Reset Memory</span>
|
||||
<span className="sm:hidden">Reset</span>
|
||||
</Button>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button type="button" variant="secondary" size="sm" disabled={!memory}>
|
||||
Export
|
||||
<ChevronDown className="h-3 w-3 opacity-60" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem onClick={handleCopyMarkdown}>
|
||||
<ClipboardCopy className="h-4 w-4 mr-2" />
|
||||
Copy as Markdown
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handleDownload}>
|
||||
<Download className="h-4 w-4 mr-2" />
|
||||
Download as Markdown
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { Receipt } from "lucide-react";
|
||||
import { ReceiptText } from "lucide-react";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import {
|
||||
|
|
@ -65,7 +65,7 @@ export function PurchaseHistoryContent() {
|
|||
if (purchases.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center gap-2 py-16 text-center">
|
||||
<Receipt className="h-8 w-8 text-muted-foreground" />
|
||||
<ReceiptText className="h-8 w-8 text-muted-foreground" />
|
||||
<p className="text-sm font-medium">No purchases yet</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Your page-pack purchases will appear here after checkout.
|
||||
|
|
|
|||
|
|
@ -76,12 +76,8 @@ const GenerateImageToolUI = dynamic(
|
|||
import("@/components/tool-ui/generate-image").then((m) => ({ default: m.GenerateImageToolUI })),
|
||||
{ ssr: false }
|
||||
);
|
||||
const SaveMemoryToolUI = dynamic(
|
||||
() => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.SaveMemoryToolUI })),
|
||||
{ ssr: false }
|
||||
);
|
||||
const RecallMemoryToolUI = dynamic(
|
||||
() => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.RecallMemoryToolUI })),
|
||||
const UpdateMemoryToolUI = dynamic(
|
||||
() => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.UpdateMemoryToolUI })),
|
||||
{ ssr: false }
|
||||
);
|
||||
const SandboxExecuteToolUI = dynamic(
|
||||
|
|
@ -386,8 +382,7 @@ const AssistantMessageInner: FC = () => {
|
|||
generate_video_presentation: GenerateVideoPresentationToolUI,
|
||||
display_image: GenerateImageToolUI,
|
||||
generate_image: GenerateImageToolUI,
|
||||
save_memory: SaveMemoryToolUI,
|
||||
recall_memory: RecallMemoryToolUI,
|
||||
update_memory: UpdateMemoryToolUI,
|
||||
execute: SandboxExecuteToolUI,
|
||||
create_notion_page: CreateNotionPageToolUI,
|
||||
update_notion_page: UpdateNotionPageToolUI,
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ import {
|
|||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
|
|
@ -92,7 +93,7 @@ import { useMediaQuery } from "@/hooks/use-media-query";
|
|||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const COMPOSER_PLACEHOLDER = "Ask anything · Type / for prompts · Type @ to mention docs";
|
||||
const COMPOSER_PLACEHOLDER = "Ask anything, type / for prompts, type @ to mention docs";
|
||||
|
||||
export const Thread: FC = () => {
|
||||
return <ThreadContent />;
|
||||
|
|
@ -804,7 +805,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
const isDesktop = useMediaQuery("(min-width: 640px)");
|
||||
const { openDialog: openUploadDialog } = useDocumentUploadDialog();
|
||||
const [toolsScrollPos, setToolsScrollPos] = useState<"top" | "middle" | "bottom">("top");
|
||||
const toolsRafRef = useRef<number>();
|
||||
const toolsRafRef = useRef<number | undefined>(undefined);
|
||||
const handleToolsScroll = useCallback((e: React.UIEvent<HTMLDivElement>) => {
|
||||
const el = e.currentTarget;
|
||||
if (toolsRafRef.current) return;
|
||||
|
|
@ -1021,8 +1022,23 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
</div>
|
||||
)}
|
||||
{!filteredTools?.length && (
|
||||
<div className="px-4 py-6 text-center text-sm text-muted-foreground">
|
||||
Loading tools...
|
||||
<div className="px-4 pt-3 pb-2">
|
||||
<Skeleton className="h-3 w-16 mb-2" />
|
||||
{["t1", "t2", "t3", "t4"].map((k) => (
|
||||
<div key={k} className="flex items-center gap-3 py-2">
|
||||
<Skeleton className="size-4 rounded shrink-0" />
|
||||
<Skeleton className="h-3.5 flex-1" />
|
||||
<Skeleton className="h-5 w-9 rounded-full shrink-0" />
|
||||
</div>
|
||||
))}
|
||||
<Skeleton className="h-3 w-24 mt-3 mb-2" />
|
||||
{["c1", "c2", "c3"].map((k) => (
|
||||
<div key={k} className="flex items-center gap-3 py-2">
|
||||
<Skeleton className="size-4 rounded shrink-0" />
|
||||
<Skeleton className="h-3.5 flex-1" />
|
||||
<Skeleton className="h-5 w-9 rounded-full shrink-0" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
|
@ -1058,12 +1074,12 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
side="bottom"
|
||||
align="start"
|
||||
sideOffset={12}
|
||||
className="w-[calc(100vw-2rem)] max-w-56 sm:max-w-72 sm:w-72 p-0 select-none"
|
||||
className="w-[calc(100vw-2rem)] max-w-48 sm:max-w-56 sm:w-56 p-0 select-none"
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<div className="sr-only">Manage Tools</div>
|
||||
<div
|
||||
className="max-h-48 sm:max-h-64 overflow-y-auto overscroll-none py-0.5 sm:py-1"
|
||||
className="max-h-44 sm:max-h-56 overflow-y-auto overscroll-none py-0.5"
|
||||
onScroll={handleToolsScroll}
|
||||
style={{
|
||||
maskImage: `linear-gradient(to bottom, ${toolsScrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${toolsScrollPos === "bottom" ? "black" : "transparent"})`,
|
||||
|
|
@ -1074,22 +1090,22 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
.filter((g) => !g.connectorIcon)
|
||||
.map((group) => (
|
||||
<div key={group.label}>
|
||||
<div className="px-2.5 sm:px-3 pt-2 pb-0.5 text-[10px] sm:text-xs text-muted-foreground/80 font-normal select-none">
|
||||
<div className="px-2 sm:px-2.5 pt-1.5 pb-0.5 text-[9px] sm:text-[10px] text-muted-foreground/80 font-normal select-none">
|
||||
{group.label}
|
||||
</div>
|
||||
{group.tools.map((tool) => {
|
||||
const isDisabled = disabledToolsSet.has(tool.name);
|
||||
const ToolIcon = getToolIcon(tool.name);
|
||||
const row = (
|
||||
<div className="flex w-full items-center gap-2 sm:gap-3 px-2.5 sm:px-3 py-1 sm:py-1.5 hover:bg-muted-foreground/10 transition-colors">
|
||||
<ToolIcon className="size-3.5 sm:size-4 shrink-0 text-muted-foreground" />
|
||||
<span className="flex-1 min-w-0 text-xs sm:text-sm font-medium truncate">
|
||||
<div className="flex w-full items-center gap-1.5 sm:gap-2 px-2 sm:px-2.5 py-0.5 sm:py-1 hover:bg-muted-foreground/10 transition-colors">
|
||||
<ToolIcon className="size-3 sm:size-3.5 shrink-0 text-muted-foreground" />
|
||||
<span className="flex-1 min-w-0 text-[11px] sm:text-xs font-medium truncate">
|
||||
{formatToolName(tool.name)}
|
||||
</span>
|
||||
<Switch
|
||||
checked={!isDisabled}
|
||||
onCheckedChange={() => toggleTool(tool.name)}
|
||||
className="shrink-0 scale-[0.6] sm:scale-75"
|
||||
className="shrink-0 scale-50 sm:scale-[0.6]"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -1106,7 +1122,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
))}
|
||||
{groupedTools.some((g) => g.connectorIcon) && (
|
||||
<div>
|
||||
<div className="px-2.5 sm:px-3 pt-2 pb-0.5 text-[10px] sm:text-xs text-muted-foreground/80 font-normal select-none">
|
||||
<div className="px-2 sm:px-2.5 pt-1.5 pb-0.5 text-[9px] sm:text-[10px] text-muted-foreground/80 font-normal select-none">
|
||||
Connector Actions
|
||||
</div>
|
||||
{groupedTools
|
||||
|
|
@ -1118,26 +1134,26 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
const allDisabled = toolNames.every((n) => disabledToolsSet.has(n));
|
||||
const groupDef = TOOL_GROUPS.find((g) => g.label === group.label);
|
||||
const row = (
|
||||
<div className="flex w-full items-center gap-2 sm:gap-3 px-2.5 sm:px-3 py-1 sm:py-1.5 hover:bg-muted-foreground/10 transition-colors">
|
||||
<div className="flex w-full items-center gap-1.5 sm:gap-2 px-2 sm:px-2.5 py-0.5 sm:py-1 hover:bg-muted-foreground/10 transition-colors">
|
||||
{iconInfo ? (
|
||||
<Image
|
||||
src={iconInfo.src}
|
||||
alt={iconInfo.alt}
|
||||
width={16}
|
||||
height={16}
|
||||
className="size-3.5 sm:size-4 shrink-0 select-none pointer-events-none"
|
||||
width={14}
|
||||
height={14}
|
||||
className="size-3 sm:size-3.5 shrink-0 select-none pointer-events-none"
|
||||
draggable={false}
|
||||
/>
|
||||
) : (
|
||||
<Wrench className="size-3.5 sm:size-4 shrink-0 text-muted-foreground" />
|
||||
<Wrench className="size-3 sm:size-3.5 shrink-0 text-muted-foreground" />
|
||||
)}
|
||||
<span className="flex-1 min-w-0 text-xs sm:text-sm font-medium truncate">
|
||||
<span className="flex-1 min-w-0 text-[11px] sm:text-xs font-medium truncate">
|
||||
{group.label}
|
||||
</span>
|
||||
<Switch
|
||||
checked={!allDisabled}
|
||||
onCheckedChange={() => toggleToolGroup(toolNames)}
|
||||
className="shrink-0 scale-[0.6] sm:scale-75"
|
||||
className="shrink-0 scale-50 sm:scale-[0.6]"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -1158,8 +1174,23 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
</div>
|
||||
)}
|
||||
{!filteredTools?.length && (
|
||||
<div className="px-3 py-4 text-center text-xs text-muted-foreground">
|
||||
Loading tools...
|
||||
<div className="px-2 sm:px-2.5 pt-1.5 pb-1">
|
||||
<Skeleton className="h-2 w-12 mb-1.5" />
|
||||
{["dt1", "dt2", "dt3", "dt4"].map((k) => (
|
||||
<div key={k} className="flex items-center gap-1.5 sm:gap-2 py-0.5 sm:py-1">
|
||||
<Skeleton className="size-3 sm:size-3.5 rounded shrink-0" />
|
||||
<Skeleton className="h-2.5 sm:h-3 flex-1" />
|
||||
<Skeleton className="h-3.5 sm:h-4 w-7 sm:w-8 rounded-full shrink-0" />
|
||||
</div>
|
||||
))}
|
||||
<Skeleton className="h-2 w-20 mt-2 mb-1.5" />
|
||||
{["dc1", "dc2", "dc3"].map((k) => (
|
||||
<div key={k} className="flex items-center gap-1.5 sm:gap-2 py-0.5 sm:py-1">
|
||||
<Skeleton className="size-3 sm:size-3.5 rounded shrink-0" />
|
||||
<Skeleton className="h-2.5 sm:h-3 flex-1" />
|
||||
<Skeleton className="h-3.5 sm:h-4 w-7 sm:w-8 rounded-full shrink-0" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
|
@ -1297,7 +1328,7 @@ const TOOL_GROUPS: ToolGroup[] = [
|
|||
},
|
||||
{
|
||||
label: "Memory",
|
||||
tools: ["save_memory", "recall_memory"],
|
||||
tools: ["update_memory"],
|
||||
},
|
||||
{
|
||||
label: "Gmail",
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import { MarkdownPlugin, remarkMdx } from "@platejs/markdown";
|
||||
import { slateToHtml } from "@slate-serializers/html";
|
||||
import type { AnyPluginConfig, Descendant, Value } from "platejs";
|
||||
import { createPlatePlugin, Key, Plate, usePlateEditor } from "platejs/react";
|
||||
import { createPlatePlugin, Key, Plate, useEditorReadOnly, usePlateEditor } from "platejs/react";
|
||||
import { useEffect, useMemo, useRef } from "react";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
|
|
@ -60,6 +60,24 @@ export interface PlateEditorProps {
|
|||
extraPlugins?: AnyPluginConfig[];
|
||||
}
|
||||
|
||||
function PlateEditorContent({
|
||||
editorVariant,
|
||||
placeholder,
|
||||
}: {
|
||||
editorVariant: PlateEditorProps["editorVariant"];
|
||||
placeholder?: string;
|
||||
}) {
|
||||
const isReadOnly = useEditorReadOnly();
|
||||
|
||||
return (
|
||||
<Editor
|
||||
variant={editorVariant}
|
||||
placeholder={isReadOnly ? undefined : placeholder}
|
||||
className="min-h-full"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export function PlateEditor({
|
||||
markdown,
|
||||
html,
|
||||
|
|
@ -188,7 +206,7 @@ export function PlateEditor({
|
|||
}}
|
||||
>
|
||||
<EditorContainer variant={variant} className={className}>
|
||||
<Editor variant={editorVariant} placeholder={placeholder} />
|
||||
<PlateEditorContent editorVariant={editorVariant} placeholder={placeholder} />
|
||||
</EditorContainer>
|
||||
</Plate>
|
||||
</EditorSaveContext.Provider>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import type { AnyPluginConfig } from "platejs";
|
||||
import { TrailingBlockPlugin } from "platejs";
|
||||
|
||||
import { AutoformatKit } from "@/components/editor/plugins/autoformat-kit";
|
||||
import { BasicNodesKit } from "@/components/editor/plugins/basic-nodes-kit";
|
||||
|
|
@ -36,6 +37,7 @@ export const fullPreset: AnyPluginConfig[] = [
|
|||
...FloatingToolbarKit,
|
||||
...AutoformatKit,
|
||||
...DndKit,
|
||||
TrailingBlockPlugin,
|
||||
];
|
||||
|
||||
/**
|
||||
|
|
@ -48,8 +50,8 @@ export const minimalPreset: AnyPluginConfig[] = [
|
|||
...ListKit,
|
||||
...CodeBlockKit,
|
||||
...LinkKit,
|
||||
...FloatingToolbarKit,
|
||||
...AutoformatKit,
|
||||
TrailingBlockPlugin,
|
||||
];
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import { ChevronDown, Download, Monitor } from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import Link from "next/link";
|
||||
import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import React, { memo, useCallback, useEffect, useRef, useState } from "react";
|
||||
import Balancer from "react-wrap-balancer";
|
||||
import {
|
||||
DropdownMenu,
|
||||
|
|
@ -12,6 +12,11 @@ import {
|
|||
} from "@/components/ui/dropdown-menu";
|
||||
import { ExpandedMediaOverlay, useExpandedMedia } from "@/components/ui/expanded-gif-overlay";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import {
|
||||
GITHUB_RELEASES_URL,
|
||||
getAssetLabel,
|
||||
usePrimaryDownload,
|
||||
} from "@/lib/desktop-download-utils";
|
||||
import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config";
|
||||
import { trackLoginAttempt } from "@/lib/posthog/events";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
|
@ -200,107 +205,8 @@ function GetStartedButton() {
|
|||
);
|
||||
}
|
||||
|
||||
type OSInfo = {
|
||||
os: "macOS" | "Windows" | "Linux";
|
||||
arch: "arm64" | "x64";
|
||||
};
|
||||
|
||||
function useUserOS(): OSInfo {
|
||||
const [info, setInfo] = useState<OSInfo>({ os: "macOS", arch: "arm64" });
|
||||
useEffect(() => {
|
||||
const ua = navigator.userAgent;
|
||||
let os: OSInfo["os"] = "macOS";
|
||||
let arch: OSInfo["arch"] = "x64";
|
||||
|
||||
if (/Windows/i.test(ua)) {
|
||||
os = "Windows";
|
||||
arch = "x64";
|
||||
} else if (/Linux/i.test(ua)) {
|
||||
os = "Linux";
|
||||
arch = "x64";
|
||||
} else {
|
||||
os = "macOS";
|
||||
arch = /Mac/.test(ua) && !/Intel/.test(ua) ? "arm64" : "arm64";
|
||||
}
|
||||
|
||||
const uaData = (navigator as Navigator & { userAgentData?: { architecture?: string } })
|
||||
.userAgentData;
|
||||
if (uaData?.architecture === "arm") arch = "arm64";
|
||||
else if (uaData?.architecture === "x86") arch = "x64";
|
||||
|
||||
setInfo({ os, arch });
|
||||
}, []);
|
||||
return info;
|
||||
}
|
||||
|
||||
interface ReleaseAsset {
|
||||
name: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
function useLatestRelease() {
|
||||
const [assets, setAssets] = useState<ReleaseAsset[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
fetch("https://api.github.com/repos/MODSetter/SurfSense/releases/latest", {
|
||||
signal: controller.signal,
|
||||
})
|
||||
.then((r) => r.json())
|
||||
.then((data) => {
|
||||
if (data?.assets) {
|
||||
setAssets(
|
||||
data.assets
|
||||
.filter((a: { name: string }) => /\.(exe|dmg|AppImage|deb)$/.test(a.name))
|
||||
.map((a: { name: string; browser_download_url: string }) => ({
|
||||
name: a.name,
|
||||
url: a.browser_download_url,
|
||||
}))
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch(() => {});
|
||||
return () => controller.abort();
|
||||
}, []);
|
||||
|
||||
return assets;
|
||||
}
|
||||
|
||||
const ASSET_LABELS: Record<string, string> = {
|
||||
".exe": "Windows (exe)",
|
||||
"-arm64.dmg": "macOS Apple Silicon (dmg)",
|
||||
"-x64.dmg": "macOS Intel (dmg)",
|
||||
"-arm64.zip": "macOS Apple Silicon (zip)",
|
||||
"-x64.zip": "macOS Intel (zip)",
|
||||
".AppImage": "Linux (AppImage)",
|
||||
".deb": "Linux (deb)",
|
||||
};
|
||||
|
||||
function getAssetLabel(name: string): string {
|
||||
for (const [suffix, label] of Object.entries(ASSET_LABELS)) {
|
||||
if (name.endsWith(suffix)) return label;
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
function DownloadButton() {
|
||||
const { os, arch } = useUserOS();
|
||||
const assets = useLatestRelease();
|
||||
|
||||
const { primary, alternatives } = useMemo(() => {
|
||||
if (assets.length === 0) return { primary: null, alternatives: [] };
|
||||
|
||||
const matchers: Record<string, (n: string) => boolean> = {
|
||||
Windows: (n) => n.endsWith(".exe"),
|
||||
macOS: (n) => n.endsWith(`-${arch}.dmg`),
|
||||
Linux: (n) => n.endsWith(".AppImage"),
|
||||
};
|
||||
|
||||
const match = matchers[os];
|
||||
const primary = assets.find((a) => match(a.name)) ?? null;
|
||||
const alternatives = assets.filter((a) => a !== primary);
|
||||
return { primary, alternatives };
|
||||
}, [assets, os, arch]);
|
||||
const { os, primary, alternatives } = usePrimaryDownload();
|
||||
|
||||
const fallbackUrl = GITHUB_RELEASES_URL;
|
||||
|
||||
|
|
@ -504,5 +410,3 @@ const TabVideo = memo(function TabVideo({ src }: { src: string }) {
|
|||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
const GITHUB_RELEASES_URL = "https://github.com/MODSetter/SurfSense/releases/latest";
|
||||
|
|
|
|||
|
|
@ -91,12 +91,12 @@ export function SidebarSlideOutPanel({
|
|||
|
||||
{/* Panel extending from sidebar's right edge, flush with the wrapper border */}
|
||||
<motion.div
|
||||
style={{ width, left: "100%", top: -1, bottom: -1 }}
|
||||
initial={{ x: -width }}
|
||||
animate={{ x: 0 }}
|
||||
exit={{ x: -width }}
|
||||
initial={{ width: 0 }}
|
||||
animate={{ width }}
|
||||
exit={{ width: 0 }}
|
||||
transition={{ type: "tween", duration: 0.2, ease: [0.4, 0, 0.2, 1] }}
|
||||
className="absolute z-20 overflow-hidden"
|
||||
style={{ left: "100%", top: -1, bottom: -1 }}
|
||||
>
|
||||
<div
|
||||
style={{ width }}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import {
|
||||
Check,
|
||||
ChevronUp,
|
||||
Download,
|
||||
ExternalLink,
|
||||
Info,
|
||||
Languages,
|
||||
|
|
@ -29,6 +30,8 @@ import {
|
|||
} from "@/components/ui/dropdown-menu";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { useLocaleContext } from "@/contexts/LocaleContext";
|
||||
import { usePlatform } from "@/hooks/use-platform";
|
||||
import { GITHUB_RELEASES_URL, usePrimaryDownload } from "@/lib/desktop-download-utils";
|
||||
import { APP_VERSION } from "@/lib/env-config";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { User } from "../../types/layout.types";
|
||||
|
|
@ -149,10 +152,13 @@ export function SidebarUserProfile({
|
|||
}: SidebarUserProfileProps) {
|
||||
const t = useTranslations("sidebar");
|
||||
const { locale, setLocale } = useLocaleContext();
|
||||
const { isDesktop } = usePlatform();
|
||||
const { os, primary } = usePrimaryDownload();
|
||||
const [isLoggingOut, setIsLoggingOut] = useState(false);
|
||||
const bgColor = stringToColor(user.email);
|
||||
const initials = getInitials(user.email);
|
||||
const displayName = user.name || user.email.split("@")[0];
|
||||
const downloadUrl = primary?.url ?? GITHUB_RELEASES_URL;
|
||||
|
||||
const handleLanguageChange = (newLocale: "en" | "es" | "pt" | "hi" | "zh") => {
|
||||
setLocale(newLocale);
|
||||
|
|
@ -294,6 +300,15 @@ export function SidebarUserProfile({
|
|||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
|
||||
{!isDesktop && (
|
||||
<DropdownMenuItem asChild className="font-medium">
|
||||
<a href={downloadUrl} target="_blank" rel="noopener noreferrer">
|
||||
<Download className="h-4 w-4" strokeWidth={2.5} />
|
||||
{t("download_for_os", { os })}
|
||||
</a>
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
|
||||
<DropdownMenuSeparator className="dark:bg-neutral-700" />
|
||||
|
||||
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
||||
|
|
@ -439,6 +454,15 @@ export function SidebarUserProfile({
|
|||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
|
||||
{!isDesktop && (
|
||||
<DropdownMenuItem asChild className="font-medium">
|
||||
<a href={downloadUrl} target="_blank" rel="noopener noreferrer">
|
||||
<Download className="h-4 w-4" strokeWidth={2.5} />
|
||||
{t("download_for_os", { os })}
|
||||
</a>
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
|
||||
<DropdownMenuSeparator className="dark:bg-neutral-700" />
|
||||
|
||||
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
||||
|
|
|
|||
|
|
@ -18,14 +18,39 @@ import {
|
|||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
|
||||
function ReportPanelSkeleton() {
|
||||
return (
|
||||
<div className="space-y-6 p-6">
|
||||
<div className="h-6 w-3/4 rounded-md bg-muted/60 animate-pulse" />
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse" />
|
||||
<div className="h-3 w-[95%] rounded-md bg-muted/60 animate-pulse [animation-delay:100ms]" />
|
||||
<div className="h-3 w-[88%] rounded-md bg-muted/60 animate-pulse [animation-delay:200ms]" />
|
||||
<div className="h-3 w-[60%] rounded-md bg-muted/60 animate-pulse [animation-delay:300ms]" />
|
||||
</div>
|
||||
<div className="h-5 w-2/5 rounded-md bg-muted/60 animate-pulse [animation-delay:400ms]" />
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:500ms]" />
|
||||
<div className="h-3 w-[92%] rounded-md bg-muted/60 animate-pulse [animation-delay:600ms]" />
|
||||
<div className="h-3 w-[97%] rounded-md bg-muted/60 animate-pulse [animation-delay:700ms]" />
|
||||
</div>
|
||||
<div className="h-5 w-1/3 rounded-md bg-muted/60 animate-pulse [animation-delay:800ms]" />
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-[90%] rounded-md bg-muted/60 animate-pulse [animation-delay:900ms]" />
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:1000ms]" />
|
||||
<div className="h-3 w-[75%] rounded-md bg-muted/60 animate-pulse [animation-delay:1100ms]" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const PlateEditor = dynamic(
|
||||
() => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })),
|
||||
{ ssr: false, loading: () => <Skeleton className="h-64 w-full" /> }
|
||||
{ ssr: false, loading: () => <ReportPanelSkeleton /> }
|
||||
);
|
||||
|
||||
/**
|
||||
|
|
@ -59,46 +84,6 @@ const ReportContentResponseSchema = z.object({
|
|||
type ReportContentResponse = z.infer<typeof ReportContentResponseSchema>;
|
||||
type VersionInfo = z.infer<typeof VersionInfoSchema>;
|
||||
|
||||
/**
|
||||
* Shimmer loading skeleton for report panel
|
||||
*/
|
||||
function ReportPanelSkeleton() {
|
||||
return (
|
||||
<div className="space-y-6 p-6">
|
||||
{/* Title skeleton */}
|
||||
<div className="h-6 w-3/4 rounded-md bg-muted/60 animate-pulse" />
|
||||
|
||||
{/* Paragraph 1 */}
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse" />
|
||||
<div className="h-3 w-[95%] rounded-md bg-muted/60 animate-pulse [animation-delay:100ms]" />
|
||||
<div className="h-3 w-[88%] rounded-md bg-muted/60 animate-pulse [animation-delay:200ms]" />
|
||||
<div className="h-3 w-[60%] rounded-md bg-muted/60 animate-pulse [animation-delay:300ms]" />
|
||||
</div>
|
||||
|
||||
{/* Heading */}
|
||||
<div className="h-5 w-2/5 rounded-md bg-muted/60 animate-pulse [animation-delay:400ms]" />
|
||||
|
||||
{/* Paragraph 2 */}
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:500ms]" />
|
||||
<div className="h-3 w-[92%] rounded-md bg-muted/60 animate-pulse [animation-delay:600ms]" />
|
||||
<div className="h-3 w-[97%] rounded-md bg-muted/60 animate-pulse [animation-delay:700ms]" />
|
||||
</div>
|
||||
|
||||
{/* Heading */}
|
||||
<div className="h-5 w-1/3 rounded-md bg-muted/60 animate-pulse [animation-delay:800ms]" />
|
||||
|
||||
{/* Paragraph 3 */}
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-[90%] rounded-md bg-muted/60 animate-pulse [animation-delay:900ms]" />
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:1000ms]" />
|
||||
<div className="h-3 w-[75%] rounded-md bg-muted/60 animate-pulse [animation-delay:1100ms]" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Inner content component used by desktop panel, mobile drawer, and the layout right panel
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
? "model"
|
||||
: "models"}
|
||||
</span>{" "}
|
||||
available from your administrator. Use the model selector to view and select them.
|
||||
available from your administrator.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import {
|
|||
FileText,
|
||||
ImageIcon,
|
||||
RefreshCw,
|
||||
Shuffle,
|
||||
} from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
|
@ -44,7 +43,6 @@ import {
|
|||
} from "@/components/ui/select";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { getProviderIcon } from "@/lib/provider-icons";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const ROLE_DESCRIPTIONS = {
|
||||
|
|
@ -79,8 +77,8 @@ const ROLE_DESCRIPTIONS = {
|
|||
icon: Eye,
|
||||
title: "Vision LLM",
|
||||
description: "Vision-capable model for screenshot analysis and context extraction",
|
||||
color: "text-amber-600 dark:text-amber-400",
|
||||
bgColor: "bg-amber-500/10",
|
||||
color: "text-muted-foreground",
|
||||
bgColor: "bg-muted",
|
||||
prefKey: "vision_llm_config_id" as const,
|
||||
configType: "vision" as const,
|
||||
},
|
||||
|
|
@ -205,11 +203,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
),
|
||||
];
|
||||
|
||||
const isAssignmentComplete =
|
||||
allLLMConfigs.some((c) => c.id === assignments.agent_llm_id) &&
|
||||
allLLMConfigs.some((c) => c.id === assignments.document_summary_llm_id) &&
|
||||
allImageConfigs.some((c) => c.id === assignments.image_generation_config_id);
|
||||
|
||||
const isLoading =
|
||||
configsLoading ||
|
||||
preferencesLoading ||
|
||||
|
|
@ -231,7 +224,7 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
return (
|
||||
<div className="space-y-5 md:space-y-6">
|
||||
{/* Header actions */}
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center justify-start">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
|
|
@ -239,15 +232,9 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
disabled={isLoading}
|
||||
className="gap-2"
|
||||
>
|
||||
<RefreshCw className="h-3.5 w-3.5" />
|
||||
<RefreshCw className={cn("h-3.5 w-3.5", isLoading && "animate-spin")} />
|
||||
Refresh
|
||||
</Button>
|
||||
{isAssignmentComplete && !isLoading && !hasError && (
|
||||
<Badge variant="outline" className="text-xs gap-1.5 text-muted-foreground">
|
||||
<CircleCheck className="h-3 w-3" />
|
||||
All roles assigned
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Error Alert */}
|
||||
|
|
@ -343,8 +330,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
|
||||
const assignedConfig = roleAllConfigs.find((config) => config.id === currentAssignment);
|
||||
const isAssigned = !!assignedConfig;
|
||||
const isAutoMode =
|
||||
assignedConfig && "is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode;
|
||||
|
||||
return (
|
||||
<div key={key}>
|
||||
|
|
@ -389,7 +374,7 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
<SelectTrigger className="w-full h-9 md:h-10 text-xs md:text-sm">
|
||||
<SelectValue placeholder="Select a configuration" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="max-w-[calc(100vw-2rem)]">
|
||||
<SelectContent className="max-w-[calc(100vw-2rem)] select-none">
|
||||
<SelectItem
|
||||
value="unassigned"
|
||||
className="text-xs md:text-sm py-1.5 md:py-2"
|
||||
|
|
@ -412,21 +397,9 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
className="text-xs md:text-sm py-1.5 md:py-2"
|
||||
>
|
||||
<div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0">
|
||||
{isAuto ? (
|
||||
<Shuffle className="size-3 md:size-3.5 shrink-0 text-muted-foreground" />
|
||||
) : (
|
||||
getProviderIcon(config.provider, {
|
||||
className: "size-3 md:size-3.5 shrink-0",
|
||||
})
|
||||
)}
|
||||
<span className="truncate text-xs md:text-sm">
|
||||
{config.name}
|
||||
</span>
|
||||
{!isAuto && (
|
||||
<span className="text-muted-foreground text-[10px] md:text-[11px] truncate">
|
||||
({config.model_name})
|
||||
</span>
|
||||
)}
|
||||
{isAuto && (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
|
|
@ -455,15 +428,9 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
className="text-xs md:text-sm py-1.5 md:py-2"
|
||||
>
|
||||
<div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0">
|
||||
{getProviderIcon(config.provider, {
|
||||
className: "size-3 md:size-3.5 shrink-0",
|
||||
})}
|
||||
<span className="truncate text-xs md:text-sm">
|
||||
{config.name}
|
||||
</span>
|
||||
<span className="text-muted-foreground text-[10px] md:text-[11px] truncate">
|
||||
({config.model_name})
|
||||
</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
|
|
@ -472,63 +439,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{/* Assigned Config Summary */}
|
||||
{assignedConfig && (
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-lg p-3 border",
|
||||
isAutoMode
|
||||
? "bg-violet-50 dark:bg-violet-900/10 border-violet-200/50 dark:border-violet-800/30"
|
||||
: "bg-muted/40 border-border/50"
|
||||
)}
|
||||
>
|
||||
{isAutoMode ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Shuffle
|
||||
className={cn(
|
||||
"w-3.5 h-3.5 shrink-0 text-violet-600 dark:text-violet-400"
|
||||
)}
|
||||
/>
|
||||
<div className="min-w-0">
|
||||
<p className="text-xs font-medium text-violet-700 dark:text-violet-300">
|
||||
Auto Mode
|
||||
</p>
|
||||
<p className="text-[10px] text-violet-600/70 dark:text-violet-400/70 mt-0.5">
|
||||
Routes across all available providers
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-start gap-2">
|
||||
<IconComponent className="w-3.5 h-3.5 shrink-0 mt-0.5 text-muted-foreground" />
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-1.5 flex-wrap">
|
||||
<span className="text-xs font-medium">{assignedConfig.name}</span>
|
||||
{"is_global" in assignedConfig && assignedConfig.is_global && (
|
||||
<Badge variant="secondary" className="text-[9px] px-1.5 py-0">
|
||||
🌐 Global
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-1.5 mt-1">
|
||||
{getProviderIcon(assignedConfig.provider, {
|
||||
className: "size-3 shrink-0",
|
||||
})}
|
||||
<code className="text-[10px] text-muted-foreground font-mono truncate">
|
||||
{assignedConfig.model_name}
|
||||
</code>
|
||||
</div>
|
||||
{assignedConfig.api_base && (
|
||||
<p className="text-[10px] text-muted-foreground/60 mt-1 truncate">
|
||||
{assignedConfig.api_base}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
<span className="font-medium">
|
||||
{globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"}
|
||||
</span>{" "}
|
||||
available from your administrator. Use the model selector to view and select them.
|
||||
available from your administrator.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
|
|
|||
|
|
@ -113,11 +113,11 @@ export function MorePagesContent() {
|
|||
{isLoading ? (
|
||||
<Card>
|
||||
<CardContent className="flex items-center gap-3 p-3">
|
||||
<Skeleton className="h-8 w-8 rounded-full bg-muted" />
|
||||
<Skeleton className="h-8 w-8 rounded-full" />
|
||||
<div className="flex-1 space-y-2">
|
||||
<Skeleton className="h-4 w-3/4 bg-muted" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
</div>
|
||||
<Skeleton className="h-8 w-16 bg-muted" />
|
||||
<Skeleton className="h-8 w-16" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
) : (
|
||||
|
|
|
|||
|
|
@ -1,7 +1,17 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom } from "jotai";
|
||||
import { Bot, Brain, Eye, FileText, Globe, ImageIcon, MessageSquare, Shield } from "lucide-react";
|
||||
import {
|
||||
BookText,
|
||||
Bot,
|
||||
Brain,
|
||||
CircleUser,
|
||||
Earth,
|
||||
Eye,
|
||||
ImageIcon,
|
||||
ListChecks,
|
||||
UserKey,
|
||||
} from "lucide-react";
|
||||
import dynamic from "next/dynamic";
|
||||
import { useTranslations } from "next-intl";
|
||||
import type React from "react";
|
||||
|
|
@ -59,6 +69,13 @@ const PublicChatSnapshotsManager = dynamic(
|
|||
})),
|
||||
{ ssr: false }
|
||||
);
|
||||
const TeamMemoryManager = dynamic(
|
||||
() =>
|
||||
import("@/components/settings/team-memory-manager").then((m) => ({
|
||||
default: m.TeamMemoryManager,
|
||||
})),
|
||||
{ ssr: false }
|
||||
);
|
||||
|
||||
interface SearchSpaceSettingsDialogProps {
|
||||
searchSpaceId: number;
|
||||
|
|
@ -69,9 +86,9 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
|||
const [state, setState] = useAtom(searchSpaceSettingsDialogAtom);
|
||||
|
||||
const navItems = [
|
||||
{ value: "general", label: t("nav_general"), icon: <FileText className="h-4 w-4" /> },
|
||||
{ value: "general", label: t("nav_general"), icon: <CircleUser className="h-4 w-4" /> },
|
||||
{ value: "roles", label: t("nav_role_assignments"), icon: <ListChecks className="h-4 w-4" /> },
|
||||
{ value: "models", label: t("nav_agent_configs"), icon: <Bot className="h-4 w-4" /> },
|
||||
{ value: "roles", label: t("nav_role_assignments"), icon: <Brain className="h-4 w-4" /> },
|
||||
{
|
||||
value: "image-models",
|
||||
label: t("nav_image_models"),
|
||||
|
|
@ -82,13 +99,18 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
|||
label: t("nav_vision_models"),
|
||||
icon: <Eye className="h-4 w-4" />,
|
||||
},
|
||||
{ value: "team-roles", label: t("nav_team_roles"), icon: <Shield className="h-4 w-4" /> },
|
||||
{ value: "team-roles", label: t("nav_team_roles"), icon: <UserKey className="h-4 w-4" /> },
|
||||
{
|
||||
value: "prompts",
|
||||
label: t("nav_system_instructions"),
|
||||
icon: <MessageSquare className="h-4 w-4" />,
|
||||
icon: <BookText className="h-4 w-4" />,
|
||||
},
|
||||
{ value: "public-links", label: t("nav_public_links"), icon: <Globe className="h-4 w-4" /> },
|
||||
{
|
||||
value: "team-memory",
|
||||
label: "Team Memory",
|
||||
icon: <Brain className="h-4 w-4" />,
|
||||
},
|
||||
{ value: "public-links", label: t("nav_public_links"), icon: <Earth className="h-4 w-4" /> },
|
||||
];
|
||||
|
||||
const content: Record<string, React.ReactNode> = {
|
||||
|
|
@ -99,6 +121,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
|||
"vision-models": <VisionModelManager searchSpaceId={searchSpaceId} />,
|
||||
"team-roles": <RolesManager searchSpaceId={searchSpaceId} />,
|
||||
prompts: <PromptConfigManager searchSpaceId={searchSpaceId} />,
|
||||
"team-memory": <TeamMemoryManager searchSpaceId={searchSpaceId} />,
|
||||
"public-links": <PublicChatSnapshotsManager searchSpaceId={searchSpaceId} />,
|
||||
};
|
||||
|
||||
|
|
|
|||
297
surfsense_web/components/settings/team-memory-manager.tsx
Normal file
297
surfsense_web/components/settings/team-memory-manager.tsx
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { z } from "zod";
|
||||
import { updateSearchSpaceMutationAtom } from "@/atoms/search-spaces/search-space-mutation.atoms";
|
||||
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
|
||||
const MEMORY_HARD_LIMIT = 25_000;
|
||||
|
||||
const SearchSpaceSchema = z
|
||||
.object({
|
||||
shared_memory_md: z.string().optional().default(""),
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
interface TeamMemoryManagerProps {
|
||||
searchSpaceId: number;
|
||||
}
|
||||
|
||||
export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
|
||||
const queryClient = useQueryClient();
|
||||
const { data: searchSpace, isLoading: loading } = useQuery({
|
||||
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
|
||||
queryFn: () => searchSpacesApiService.getSearchSpace({ id: searchSpaceId }),
|
||||
enabled: !!searchSpaceId,
|
||||
});
|
||||
|
||||
const { mutateAsync: updateSearchSpace } = useAtomValue(updateSearchSpaceMutationAtom);
|
||||
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [editQuery, setEditQuery] = useState("");
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [showInput, setShowInput] = useState(false);
|
||||
const textareaRef = useRef<HTMLInputElement>(null);
|
||||
const inputContainerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const memory = searchSpace?.shared_memory_md || "";
|
||||
|
||||
useEffect(() => {
|
||||
if (!showInput) return;
|
||||
|
||||
const handlePointerDownOutside = (event: MouseEvent | TouchEvent) => {
|
||||
const target = event.target;
|
||||
if (!(target instanceof Node)) return;
|
||||
if (inputContainerRef.current?.contains(target)) return;
|
||||
|
||||
setShowInput(false);
|
||||
};
|
||||
|
||||
document.addEventListener("mousedown", handlePointerDownOutside);
|
||||
document.addEventListener("touchstart", handlePointerDownOutside, { passive: true });
|
||||
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handlePointerDownOutside);
|
||||
document.removeEventListener("touchstart", handlePointerDownOutside);
|
||||
};
|
||||
}, [showInput]);
|
||||
|
||||
const handleClear = async () => {
|
||||
try {
|
||||
setSaving(true);
|
||||
await updateSearchSpace({
|
||||
id: searchSpaceId,
|
||||
data: { shared_memory_md: "" },
|
||||
});
|
||||
toast.success("Team memory cleared");
|
||||
} catch {
|
||||
toast.error("Failed to clear team memory");
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleEdit = async () => {
|
||||
const query = editQuery.trim();
|
||||
if (!query) return;
|
||||
|
||||
try {
|
||||
setEditing(true);
|
||||
await baseApiService.post(
|
||||
`/api/v1/searchspaces/${searchSpaceId}/memory/edit`,
|
||||
SearchSpaceSchema,
|
||||
{ body: { query } }
|
||||
);
|
||||
setEditQuery("");
|
||||
setShowInput(false);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
|
||||
});
|
||||
toast.success("Team memory updated");
|
||||
} catch {
|
||||
toast.error("Failed to edit team memory");
|
||||
} finally {
|
||||
setEditing(false);
|
||||
}
|
||||
};
|
||||
|
||||
const openInput = () => {
|
||||
setShowInput(true);
|
||||
requestAnimationFrame(() => textareaRef.current?.focus());
|
||||
};
|
||||
|
||||
const handleDownload = () => {
|
||||
if (!memory) return;
|
||||
try {
|
||||
const blob = new Blob([memory], { type: "text/markdown;charset=utf-8" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = "team-memory.md";
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
} catch {
|
||||
toast.error("Failed to download team memory");
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyMarkdown = async () => {
|
||||
if (!memory) return;
|
||||
try {
|
||||
await navigator.clipboard.writeText(memory);
|
||||
toast.success("Copied to clipboard");
|
||||
} catch {
|
||||
toast.error("Failed to copy team memory");
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleEdit();
|
||||
}
|
||||
};
|
||||
|
||||
const displayMemory = memory.replace(/\(\d{4}-\d{2}-\d{2}\)\s*\[(fact|pref|instr)\]\s*/g, "");
|
||||
const charCount = memory.length;
|
||||
|
||||
const getCounterColor = () => {
|
||||
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
|
||||
if (charCount > 15_000) return "text-orange-500";
|
||||
if (charCount > 10_000) return "text-yellow-500";
|
||||
return "text-muted-foreground";
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spinner size="md" className="text-muted-foreground" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!memory) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-16 text-center">
|
||||
<h3 className="text-base font-medium text-foreground">
|
||||
What does SurfSense remember about your team?
|
||||
</h3>
|
||||
<p className="mt-2 max-w-sm text-sm text-muted-foreground">
|
||||
Nothing yet. SurfSense picks up on team decisions and conventions as your team chats.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<Alert className="bg-muted/50 py-3 md:py-4">
|
||||
<Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
|
||||
<AlertDescription className="text-xs md:text-sm">
|
||||
<p>
|
||||
SurfSense uses this shared memory to provide team-wide context across all conversations
|
||||
in this search space.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className="relative h-[380px] rounded-lg border bg-background">
|
||||
<div className="h-full overflow-y-auto scrollbar-thin">
|
||||
<PlateEditor
|
||||
markdown={displayMemory}
|
||||
readOnly
|
||||
preset="readonly"
|
||||
variant="default"
|
||||
editorVariant="none"
|
||||
className="px-5 py-4 text-sm min-h-full"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{showInput ? (
|
||||
<div className="absolute bottom-3 inset-x-3 z-10">
|
||||
<div
|
||||
ref={inputContainerRef}
|
||||
className="relative flex h-[54px] items-center gap-2 rounded-[9999px] border bg-muted/60 backdrop-blur-sm pl-4 pr-1 shadow-sm"
|
||||
>
|
||||
<input
|
||||
ref={textareaRef}
|
||||
type="text"
|
||||
value={editQuery}
|
||||
onChange={(e) => setEditQuery(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Tell SurfSense what to remember or forget about your team"
|
||||
disabled={editing}
|
||||
className="flex-1 bg-transparent text-sm outline-none placeholder:text-muted-foreground/70"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
onClick={handleEdit}
|
||||
disabled={editing || !editQuery.trim()}
|
||||
className={`h-11 w-11 shrink-0 rounded-full ${
|
||||
editing ? "" : "bg-muted-foreground/15 hover:bg-muted-foreground/20"
|
||||
}`}
|
||||
>
|
||||
{editing ? (
|
||||
<Spinner size="sm" />
|
||||
) : (
|
||||
<ArrowUp className="!h-5 !w-5 text-foreground" strokeWidth={2.25} />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
size="icon"
|
||||
variant="secondary"
|
||||
onClick={openInput}
|
||||
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
|
||||
>
|
||||
<Pen className="!h-5 !w-5" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className={`text-xs shrink-0 ${getCounterColor()}`}>
|
||||
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()}
|
||||
<span className="hidden sm:inline"> characters</span>
|
||||
<span className="sm:hidden"> chars</span>
|
||||
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
|
||||
{charCount > MEMORY_HARD_LIMIT && " - Exceeds limit"}
|
||||
</span>
|
||||
<div className="flex items-center gap-1.5 sm:gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
className="text-xs sm:text-sm"
|
||||
onClick={handleClear}
|
||||
disabled={saving || editing || !memory}
|
||||
>
|
||||
<span className="hidden sm:inline">Reset Memory</span>
|
||||
<span className="sm:hidden">Reset</span>
|
||||
</Button>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button type="button" variant="secondary" size="sm" disabled={!memory}>
|
||||
Export
|
||||
<ChevronDown className="h-3 w-3 opacity-60" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem onClick={handleCopyMarkdown}>
|
||||
<ClipboardCopy className="h-4 w-4 mr-2" />
|
||||
Copy as Markdown
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handleDownload}>
|
||||
<Download className="h-4 w-4 mr-2" />
|
||||
Download as Markdown
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom } from "jotai";
|
||||
import { Globe, KeyRound, Monitor, Receipt, Sparkles, User } from "lucide-react";
|
||||
import { Brain, CircleUser, Globe, KeyRound, Monitor, ReceiptText, Sparkles } from "lucide-react";
|
||||
import dynamic from "next/dynamic";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useMemo } from "react";
|
||||
|
|
@ -51,6 +51,13 @@ const DesktopContent = dynamic(
|
|||
),
|
||||
{ ssr: false }
|
||||
);
|
||||
const MemoryContent = dynamic(
|
||||
() =>
|
||||
import("@/app/dashboard/[search_space_id]/user-settings/components/MemoryContent").then(
|
||||
(m) => ({ default: m.MemoryContent })
|
||||
),
|
||||
{ ssr: false }
|
||||
);
|
||||
|
||||
export function UserSettingsDialog() {
|
||||
const t = useTranslations("userSettings");
|
||||
|
|
@ -59,7 +66,7 @@ export function UserSettingsDialog() {
|
|||
|
||||
const navItems = useMemo(
|
||||
() => [
|
||||
{ value: "profile", label: t("profile_nav_label"), icon: <User className="h-4 w-4" /> },
|
||||
{ value: "profile", label: t("profile_nav_label"), icon: <CircleUser className="h-4 w-4" /> },
|
||||
{
|
||||
value: "api-key",
|
||||
label: t("api_key_nav_label"),
|
||||
|
|
@ -75,10 +82,15 @@ export function UserSettingsDialog() {
|
|||
label: "Community Prompts",
|
||||
icon: <Globe className="h-4 w-4" />,
|
||||
},
|
||||
{
|
||||
value: "memory",
|
||||
label: "Memory",
|
||||
icon: <Brain className="h-4 w-4" />,
|
||||
},
|
||||
{
|
||||
value: "purchases",
|
||||
label: "Purchase History",
|
||||
icon: <Receipt className="h-4 w-4" />,
|
||||
icon: <ReceiptText className="h-4 w-4" />,
|
||||
},
|
||||
...(isDesktop
|
||||
? [{ value: "desktop", label: "Desktop", icon: <Monitor className="h-4 w-4" /> }]
|
||||
|
|
@ -101,6 +113,7 @@ export function UserSettingsDialog() {
|
|||
{state.initialTab === "api-key" && <ApiKeyContent />}
|
||||
{state.initialTab === "prompts" && <PromptsContent />}
|
||||
{state.initialTab === "community-prompts" && <CommunityPromptsContent />}
|
||||
{state.initialTab === "memory" && <MemoryContent />}
|
||||
{state.initialTab === "purchases" && <PurchaseHistoryContent />}
|
||||
{state.initialTab === "desktop" && <DesktopContent />}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
? "model"
|
||||
: "models"}
|
||||
</span>{" "}
|
||||
available from your administrator. Use the model selector to view and select them.
|
||||
available from your administrator.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
|
|
|||
|
|
@ -51,17 +51,11 @@ export {
|
|||
SandboxExecuteToolUI,
|
||||
} from "./sandbox-execute";
|
||||
export {
|
||||
type MemoryItem,
|
||||
type RecallMemoryArgs,
|
||||
RecallMemoryArgsSchema,
|
||||
type RecallMemoryResult,
|
||||
RecallMemoryResultSchema,
|
||||
RecallMemoryToolUI,
|
||||
type SaveMemoryArgs,
|
||||
SaveMemoryArgsSchema,
|
||||
type SaveMemoryResult,
|
||||
SaveMemoryResultSchema,
|
||||
SaveMemoryToolUI,
|
||||
type UpdateMemoryArgs,
|
||||
UpdateMemoryArgsSchema,
|
||||
type UpdateMemoryResult,
|
||||
UpdateMemoryResultSchema,
|
||||
UpdateMemoryToolUI,
|
||||
} from "./user-memory";
|
||||
export { GenerateVideoPresentationToolUI } from "./video-presentation";
|
||||
export { type WriteTodosData, WriteTodosSchema, WriteTodosToolUI } from "./write-todos";
|
||||
|
|
|
|||
|
|
@ -1,100 +1,38 @@
|
|||
"use client";
|
||||
|
||||
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
|
||||
import { BrainIcon, CheckIcon, Loader2Icon, SearchIcon, XIcon } from "lucide-react";
|
||||
import { AlertTriangleIcon, BrainIcon, CheckIcon, Loader2Icon, XIcon } from "lucide-react";
|
||||
import { z } from "zod";
|
||||
|
||||
// ============================================================================
|
||||
// Zod Schemas for save_memory tool
|
||||
// Zod Schemas for update_memory tool
|
||||
// ============================================================================
|
||||
|
||||
const SaveMemoryArgsSchema = z.object({
|
||||
content: z.string(),
|
||||
category: z.string().default("fact"),
|
||||
const UpdateMemoryArgsSchema = z.object({
|
||||
updated_memory: z.string(),
|
||||
});
|
||||
|
||||
const SaveMemoryResultSchema = z.object({
|
||||
const UpdateMemoryResultSchema = z.object({
|
||||
status: z.enum(["saved", "error"]),
|
||||
memory_id: z.number().nullish(),
|
||||
memory_text: z.string().nullish(),
|
||||
category: z.string().nullish(),
|
||||
message: z.string().nullish(),
|
||||
error: z.string().nullish(),
|
||||
warning: z.string().nullish(),
|
||||
});
|
||||
|
||||
type SaveMemoryArgs = z.infer<typeof SaveMemoryArgsSchema>;
|
||||
type SaveMemoryResult = z.infer<typeof SaveMemoryResultSchema>;
|
||||
type UpdateMemoryArgs = z.infer<typeof UpdateMemoryArgsSchema>;
|
||||
type UpdateMemoryResult = z.infer<typeof UpdateMemoryResultSchema>;
|
||||
|
||||
// ============================================================================
|
||||
// Zod Schemas for recall_memory tool
|
||||
// Update Memory Tool UI
|
||||
// ============================================================================
|
||||
|
||||
const RecallMemoryArgsSchema = z.object({
|
||||
query: z.string().nullish(),
|
||||
category: z.string().nullish(),
|
||||
top_k: z.number().default(5),
|
||||
});
|
||||
|
||||
const MemoryItemSchema = z.object({
|
||||
id: z.number(),
|
||||
memory_text: z.string(),
|
||||
category: z.string(),
|
||||
updated_at: z.string().nullish(),
|
||||
});
|
||||
|
||||
const RecallMemoryResultSchema = z.object({
|
||||
status: z.enum(["success", "error"]),
|
||||
count: z.number().nullish(),
|
||||
memories: z.array(MemoryItemSchema).nullish(),
|
||||
formatted_context: z.string().nullish(),
|
||||
error: z.string().nullish(),
|
||||
});
|
||||
|
||||
type RecallMemoryArgs = z.infer<typeof RecallMemoryArgsSchema>;
|
||||
type RecallMemoryResult = z.infer<typeof RecallMemoryResultSchema>;
|
||||
type MemoryItem = z.infer<typeof MemoryItemSchema>;
|
||||
|
||||
// ============================================================================
|
||||
// Category badge colors
|
||||
// ============================================================================
|
||||
|
||||
const categoryColors: Record<string, string> = {
|
||||
preference: "bg-blue-500/10 text-blue-600 dark:text-blue-400",
|
||||
fact: "bg-green-500/10 text-green-600 dark:text-green-400",
|
||||
instruction: "bg-purple-500/10 text-purple-600 dark:text-purple-400",
|
||||
context: "bg-orange-500/10 text-orange-600 dark:text-orange-400",
|
||||
};
|
||||
|
||||
function CategoryBadge({ category }: { category: string }) {
|
||||
const colorClass = categoryColors[category] || "bg-gray-500/10 text-gray-600 dark:text-gray-400";
|
||||
return (
|
||||
<span
|
||||
className={`inline-flex items-center rounded-full px-2 py-0.5 text-xs font-medium ${colorClass}`}
|
||||
>
|
||||
{category}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Save Memory Tool UI
|
||||
// ============================================================================
|
||||
|
||||
export const SaveMemoryToolUI = ({
|
||||
args,
|
||||
export const UpdateMemoryToolUI = ({
|
||||
result,
|
||||
status,
|
||||
}: ToolCallMessagePartProps<SaveMemoryArgs, SaveMemoryResult>) => {
|
||||
}: ToolCallMessagePartProps<UpdateMemoryArgs, UpdateMemoryResult>) => {
|
||||
const isRunning = status.type === "running" || status.type === "requires-action";
|
||||
const isComplete = status.type === "complete";
|
||||
const isError = result?.status === "error";
|
||||
|
||||
// Parse args safely
|
||||
const parsedArgs = SaveMemoryArgsSchema.safeParse(args);
|
||||
const content = parsedArgs.success ? parsedArgs.data.content : "";
|
||||
const category = parsedArgs.success ? parsedArgs.data.category : "fact";
|
||||
|
||||
// Loading state
|
||||
if (isRunning) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
|
|
@ -102,13 +40,12 @@ export const SaveMemoryToolUI = ({
|
|||
<Loader2Icon className="size-4 animate-spin text-primary" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-sm text-muted-foreground">Saving to memory...</span>
|
||||
<span className="text-sm text-muted-foreground">Updating memory...</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Error state
|
||||
if (isError) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border border-destructive/20 bg-destructive/5 px-4 py-3">
|
||||
|
|
@ -116,14 +53,13 @@ export const SaveMemoryToolUI = ({
|
|||
<XIcon className="size-4 text-destructive" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-sm text-destructive">Failed to save memory</span>
|
||||
{result?.error && <p className="mt-1 text-xs text-destructive/70">{result.error}</p>}
|
||||
<span className="text-sm text-destructive">Failed to update memory</span>
|
||||
{result?.message && <p className="mt-1 text-xs text-destructive/70">{result.message}</p>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Success state
|
||||
if (isComplete && result?.status === "saved") {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border border-primary/20 bg-primary/5 px-4 py-3">
|
||||
|
|
@ -133,138 +69,19 @@ export const SaveMemoryToolUI = ({
|
|||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<CheckIcon className="size-3 text-green-500 shrink-0" />
|
||||
<span className="text-sm font-medium text-foreground">Memory saved</span>
|
||||
<CategoryBadge category={category} />
|
||||
<span className="text-sm font-medium text-foreground">Memory updated</span>
|
||||
</div>
|
||||
<p className="mt-1 truncate text-sm text-muted-foreground">{content}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Default/incomplete state - show what's being saved
|
||||
if (content) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
||||
<BrainIcon className="size-4 text-muted-foreground" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-muted-foreground">Saving memory</span>
|
||||
<CategoryBadge category={category} />
|
||||
</div>
|
||||
<p className="mt-1 truncate text-sm text-muted-foreground">{content}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Recall Memory Tool UI
|
||||
// ============================================================================
|
||||
|
||||
export const RecallMemoryToolUI = ({
|
||||
args,
|
||||
result,
|
||||
status,
|
||||
}: ToolCallMessagePartProps<RecallMemoryArgs, RecallMemoryResult>) => {
|
||||
const isRunning = status.type === "running" || status.type === "requires-action";
|
||||
const isComplete = status.type === "complete";
|
||||
const isError = result?.status === "error";
|
||||
|
||||
// Parse args safely
|
||||
const parsedArgs = RecallMemoryArgsSchema.safeParse(args);
|
||||
const query = parsedArgs.success ? parsedArgs.data.query : null;
|
||||
|
||||
// Loading state
|
||||
if (isRunning) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-primary/10">
|
||||
<Loader2Icon className="size-4 animate-spin text-primary" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{query ? `Searching memories for "${query}"...` : "Recalling memories..."}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Error state
|
||||
if (isError) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border border-destructive/20 bg-destructive/5 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-destructive/10">
|
||||
<XIcon className="size-4 text-destructive" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-sm text-destructive">Failed to recall memories</span>
|
||||
{result?.error && <p className="mt-1 text-xs text-destructive/70">{result.error}</p>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Success state with memories
|
||||
if (isComplete && result?.status === "success") {
|
||||
const memories = result.memories || [];
|
||||
const count = result.count || 0;
|
||||
|
||||
if (count === 0) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
||||
<SearchIcon className="size-4 text-muted-foreground" />
|
||||
</div>
|
||||
<span className="text-sm text-muted-foreground">No memories found</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="my-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<BrainIcon className="size-4 text-primary" />
|
||||
<span className="text-sm font-medium text-foreground">
|
||||
Recalled {count} {count === 1 ? "memory" : "memories"}
|
||||
</span>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
{memories.slice(0, 5).map((memory: MemoryItem) => (
|
||||
<div
|
||||
key={memory.id}
|
||||
className="flex items-start gap-2 rounded-md bg-muted/50 px-3 py-2"
|
||||
>
|
||||
<CategoryBadge category={memory.category} />
|
||||
<span className="text-sm text-muted-foreground flex-1">{memory.memory_text}</span>
|
||||
{result.warning && (
|
||||
<div className="mt-1.5 flex items-start gap-1.5">
|
||||
<AlertTriangleIcon className="size-3 text-yellow-500 shrink-0 mt-0.5" />
|
||||
<p className="text-xs text-yellow-600 dark:text-yellow-400">{result.warning}</p>
|
||||
</div>
|
||||
))}
|
||||
{memories.length > 5 && (
|
||||
<p className="text-xs text-muted-foreground">...and {memories.length - 5} more</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Default/incomplete state
|
||||
if (query) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
||||
<SearchIcon className="size-4 text-muted-foreground" />
|
||||
</div>
|
||||
<span className="text-sm text-muted-foreground">Searching memories for "{query}"</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
|
|
@ -273,13 +90,8 @@ export const RecallMemoryToolUI = ({
|
|||
// ============================================================================
|
||||
|
||||
export {
|
||||
SaveMemoryArgsSchema,
|
||||
SaveMemoryResultSchema,
|
||||
RecallMemoryArgsSchema,
|
||||
RecallMemoryResultSchema,
|
||||
type SaveMemoryArgs,
|
||||
type SaveMemoryResult,
|
||||
type RecallMemoryArgs,
|
||||
type RecallMemoryResult,
|
||||
type MemoryItem,
|
||||
UpdateMemoryArgsSchema,
|
||||
UpdateMemoryResultSchema,
|
||||
type UpdateMemoryArgs,
|
||||
type UpdateMemoryResult,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ function AlertDialogContent({
|
|||
<AlertDialogPrimitive.Content
|
||||
data-slot="alert-dialog-content"
|
||||
className={cn(
|
||||
"bg-background dark:bg-neutral-900 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-xl p-6 shadow-2xl duration-200 sm:max-w-lg",
|
||||
"bg-background dark:bg-neutral-900 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-xl p-6 shadow-2xl duration-200 sm:max-w-lg select-none",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ const DialogContent = React.forwardRef<
|
|||
<DialogPrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 bg-background dark:bg-neutral-900 p-6 shadow-2xl duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 rounded-xl focus:outline-none focus:ring-0 focus-visible:outline-none focus-visible:ring-0",
|
||||
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 bg-background dark:bg-neutral-900 p-6 shadow-2xl duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 rounded-xl focus:outline-none focus:ring-0 focus-visible:outline-none focus-visible:ring-0 select-none",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
|
|
|||
|
|
@ -54,8 +54,8 @@ const editorVariants = cva(
|
|||
cn(
|
||||
"group/editor",
|
||||
"relative w-full cursor-text select-text overflow-x-hidden whitespace-pre-wrap break-words",
|
||||
"rounded-md ring-offset-background focus-visible:outline-none",
|
||||
"**:data-slate-placeholder:!top-1/2 **:data-slate-placeholder:-translate-y-1/2 placeholder:text-muted-foreground/80 **:data-slate-placeholder:text-muted-foreground/80 **:data-slate-placeholder:opacity-100!",
|
||||
"rounded-none ring-offset-background focus-visible:outline-none",
|
||||
"placeholder:text-muted-foreground/80 **:data-slate-placeholder:text-muted-foreground/80 **:data-slate-placeholder:py-1",
|
||||
"[&_strong]:font-bold"
|
||||
),
|
||||
{
|
||||
|
|
|
|||
|
|
@ -65,8 +65,9 @@ export function FloatingToolbar({
|
|||
{...rootProps}
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"scrollbar-hide absolute z-50 overflow-x-auto whitespace-nowrap rounded-md border bg-popover p-1 opacity-100 shadow-md print:hidden",
|
||||
"scrollbar-hide absolute z-50 overflow-x-auto whitespace-nowrap rounded-md border dark:border-neutral-700 bg-muted p-1 opacity-100 shadow-md print:hidden",
|
||||
"max-w-[80vw]",
|
||||
"[&_button:hover]:bg-neutral-200 dark:[&_button:hover]:bg-neutral-700",
|
||||
className
|
||||
)}
|
||||
>
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import type { PlateElementProps } from "platejs/react";
|
|||
import { PlateElement } from "platejs/react";
|
||||
import * as React from "react";
|
||||
|
||||
const headingVariants = cva("relative mb-1", {
|
||||
const headingVariants = cva("relative mb-1 first:mt-0", {
|
||||
variants: {
|
||||
variant: {
|
||||
h1: "mt-[1.6em] pb-1 font-bold font-heading text-4xl",
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ const TOOL_ICONS: Record<string, LucideIcon> = {
|
|||
scrape_webpage: ScanLine,
|
||||
web_search: Globe,
|
||||
search_surfsense_docs: BookOpen,
|
||||
save_memory: Brain,
|
||||
recall_memory: Brain,
|
||||
update_memory: Brain,
|
||||
};
|
||||
|
||||
export function getToolIcon(name: string): LucideIcon {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ export const searchSpace = z.object({
|
|||
user_id: z.string(),
|
||||
citations_enabled: z.boolean(),
|
||||
qna_custom_instructions: z.string().nullable(),
|
||||
shared_memory_md: z.string().nullable().optional(),
|
||||
member_count: z.number(),
|
||||
is_owner: z.boolean(),
|
||||
});
|
||||
|
|
@ -54,6 +55,7 @@ export const updateSearchSpaceRequest = z.object({
|
|||
description: true,
|
||||
citations_enabled: true,
|
||||
qna_custom_instructions: true,
|
||||
shared_memory_md: true,
|
||||
})
|
||||
.partial(),
|
||||
});
|
||||
|
|
|
|||
108
surfsense_web/lib/desktop-download-utils.ts
Normal file
108
surfsense_web/lib/desktop-download-utils.ts
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
export type OSInfo = {
|
||||
os: "macOS" | "Windows" | "Linux";
|
||||
arch: "arm64" | "x64";
|
||||
};
|
||||
|
||||
export function useUserOS(): OSInfo {
|
||||
const [info, setInfo] = useState<OSInfo>({ os: "macOS", arch: "arm64" });
|
||||
useEffect(() => {
|
||||
const ua = navigator.userAgent;
|
||||
let os: OSInfo["os"] = "macOS";
|
||||
let arch: OSInfo["arch"] = "x64";
|
||||
|
||||
if (/Windows/i.test(ua)) {
|
||||
os = "Windows";
|
||||
arch = "x64";
|
||||
} else if (/Linux/i.test(ua)) {
|
||||
os = "Linux";
|
||||
arch = "x64";
|
||||
} else {
|
||||
os = "macOS";
|
||||
arch = /Mac/.test(ua) && !/Intel/.test(ua) ? "arm64" : "arm64";
|
||||
}
|
||||
|
||||
const uaData = (navigator as Navigator & { userAgentData?: { architecture?: string } })
|
||||
.userAgentData;
|
||||
if (uaData?.architecture === "arm") arch = "arm64";
|
||||
else if (uaData?.architecture === "x86") arch = "x64";
|
||||
|
||||
setInfo({ os, arch });
|
||||
}, []);
|
||||
return info;
|
||||
}
|
||||
|
||||
export interface ReleaseAsset {
|
||||
name: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export function useLatestRelease() {
|
||||
const [assets, setAssets] = useState<ReleaseAsset[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
fetch("https://api.github.com/repos/MODSetter/SurfSense/releases/latest", {
|
||||
signal: controller.signal,
|
||||
})
|
||||
.then((r) => r.json())
|
||||
.then((data) => {
|
||||
if (data?.assets) {
|
||||
setAssets(
|
||||
data.assets
|
||||
.filter((a: { name: string }) => /\.(exe|dmg|AppImage|deb)$/.test(a.name))
|
||||
.map((a: { name: string; browser_download_url: string }) => ({
|
||||
name: a.name,
|
||||
url: a.browser_download_url,
|
||||
}))
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch(() => {});
|
||||
return () => controller.abort();
|
||||
}, []);
|
||||
|
||||
return assets;
|
||||
}
|
||||
|
||||
export const ASSET_LABELS: Record<string, string> = {
|
||||
".exe": "Windows (exe)",
|
||||
"-arm64.dmg": "macOS Apple Silicon (dmg)",
|
||||
"-x64.dmg": "macOS Intel (dmg)",
|
||||
"-arm64.zip": "macOS Apple Silicon (zip)",
|
||||
"-x64.zip": "macOS Intel (zip)",
|
||||
".AppImage": "Linux (AppImage)",
|
||||
".deb": "Linux (deb)",
|
||||
};
|
||||
|
||||
export function getAssetLabel(name: string): string {
|
||||
for (const [suffix, label] of Object.entries(ASSET_LABELS)) {
|
||||
if (name.endsWith(suffix)) return label;
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
export const GITHUB_RELEASES_URL = "https://github.com/MODSetter/SurfSense/releases/latest";
|
||||
|
||||
export function usePrimaryDownload() {
|
||||
const { os, arch } = useUserOS();
|
||||
const assets = useLatestRelease();
|
||||
|
||||
const { primary, alternatives } = useMemo(() => {
|
||||
if (assets.length === 0) return { primary: null, alternatives: [] };
|
||||
|
||||
const matchers: Record<string, (n: string) => boolean> = {
|
||||
Windows: (n) => n.endsWith(".exe"),
|
||||
macOS: (n) => n.endsWith(`-${arch}.dmg`),
|
||||
Linux: (n) => n.endsWith(".AppImage"),
|
||||
};
|
||||
|
||||
const match = matchers[os];
|
||||
const primary = assets.find((a) => match(a.name)) ?? null;
|
||||
const alternatives = assets.filter((a) => a !== primary);
|
||||
return { primary, alternatives };
|
||||
}, [assets, os, arch]);
|
||||
|
||||
return { os, arch, assets, primary, alternatives };
|
||||
}
|
||||
|
|
@ -693,6 +693,7 @@
|
|||
"learn_more": "Learn more",
|
||||
"documentation": "Documentation",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "Download for {os}",
|
||||
"inbox": "Inbox",
|
||||
"search_inbox": "Search inbox",
|
||||
"mark_all_read": "Mark all as read",
|
||||
|
|
|
|||
|
|
@ -693,6 +693,7 @@
|
|||
"learn_more": "Más información",
|
||||
"documentation": "Documentación",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "Descargar para {os}",
|
||||
"inbox": "Bandeja de entrada",
|
||||
"search_inbox": "Buscar en bandeja de entrada",
|
||||
"mark_all_read": "Marcar todo como leído",
|
||||
|
|
|
|||
|
|
@ -693,6 +693,7 @@
|
|||
"learn_more": "और जानें",
|
||||
"documentation": "दस्तावेज़ीकरण",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "{os} के लिए डाउनलोड करें",
|
||||
"inbox": "इनबॉक्स",
|
||||
"search_inbox": "इनबॉक्स में खोजें",
|
||||
"mark_all_read": "सभी पढ़ा हुआ चिह्नित करें",
|
||||
|
|
|
|||
|
|
@ -693,6 +693,7 @@
|
|||
"learn_more": "Saiba mais",
|
||||
"documentation": "Documentação",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "Baixar para {os}",
|
||||
"inbox": "Caixa de entrada",
|
||||
"search_inbox": "Pesquisar caixa de entrada",
|
||||
"mark_all_read": "Marcar tudo como lido",
|
||||
|
|
|
|||
|
|
@ -677,6 +677,7 @@
|
|||
"learn_more": "了解更多",
|
||||
"documentation": "文档",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "下载 {os} 版本",
|
||||
"inbox": "收件箱",
|
||||
"search_inbox": "搜索收件箱",
|
||||
"mark_all_read": "全部标记为已读",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue