mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Merge upstream/dev into feat/kb-export-and-folder-upload
This commit is contained in:
commit
c30cc08771
61 changed files with 2670 additions and 1474 deletions
|
|
@ -0,0 +1,38 @@
|
|||
"""Add memory_md columns to user and searchspaces tables
|
||||
|
||||
Revision ID: 121
|
||||
Revises: 120
|
||||
|
||||
Changes:
|
||||
1. Add memory_md TEXT column to user table (personal memory)
|
||||
2. Add shared_memory_md TEXT column to searchspaces table (team memory)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "121"
|
||||
down_revision: str | None = "120"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("memory_md", sa.Text(), nullable=True, server_default=""),
|
||||
)
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("shared_memory_md", sa.Text(), nullable=True, server_default=""),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "shared_memory_md")
|
||||
op.drop_column("user", "memory_md")
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
"""Migrate row-per-fact memories to markdown, then drop legacy tables
|
||||
|
||||
Revision ID: 122
|
||||
Revises: 121
|
||||
|
||||
Converts user_memories rows into per-user markdown documents stored in
|
||||
user.memory_md, and shared_memories rows into per-search-space markdown
|
||||
stored in searchspaces.shared_memory_md. Then drops the old tables and
|
||||
the memorycategory enum.
|
||||
|
||||
The markdown format matches the new memory system:
|
||||
## Heading
|
||||
- (YYYY-MM-DD) [fact|pref|instr] memory text
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
from alembic import op
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
revision: str = "122"
|
||||
down_revision: str | None = "121"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
EMBEDDING_DIM = config.embedding_model_instance.dimension
|
||||
|
||||
_CATEGORY_TO_MARKER = {
|
||||
"fact": "fact",
|
||||
"context": "fact",
|
||||
"preference": "pref",
|
||||
"instruction": "instr",
|
||||
}
|
||||
|
||||
_CATEGORY_HEADING = {
|
||||
"fact": "Facts",
|
||||
"preference": "Preferences",
|
||||
"instruction": "Instructions",
|
||||
"context": "Context",
|
||||
}
|
||||
|
||||
_HEADING_ORDER = ["fact", "preference", "instruction", "context"]
|
||||
|
||||
|
||||
def _build_markdown(rows: list[tuple]) -> str:
|
||||
"""Build a markdown document from (memory_text, category, created_at) rows."""
|
||||
by_category: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for memory_text, category, created_at in rows:
|
||||
cat = str(category)
|
||||
marker = _CATEGORY_TO_MARKER.get(cat, "fact")
|
||||
date_str = created_at.strftime("%Y-%m-%d")
|
||||
clean_text = str(memory_text).replace("\n", " ").strip()
|
||||
bullet = f"- ({date_str}) [{marker}] {clean_text}"
|
||||
by_category[cat].append(bullet)
|
||||
|
||||
sections: list[str] = []
|
||||
for cat in _HEADING_ORDER:
|
||||
if cat in by_category:
|
||||
heading = _CATEGORY_HEADING[cat]
|
||||
sections.append(f"## {heading}")
|
||||
sections.extend(by_category[cat])
|
||||
sections.append("")
|
||||
|
||||
return "\n".join(sections).strip() + "\n"
|
||||
|
||||
|
||||
def _migrate_user_memories(conn: sa.engine.Connection) -> None:
|
||||
"""Convert user_memories rows → user.memory_md grouped by user_id."""
|
||||
rows = conn.execute(
|
||||
sa.text(
|
||||
"SELECT user_id, memory_text, category::text, created_at "
|
||||
"FROM user_memories ORDER BY created_at"
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
logger.info("user_memories is empty, skipping data migration.")
|
||||
return
|
||||
|
||||
by_user: dict[UUID, list[tuple]] = defaultdict(list)
|
||||
for user_id, memory_text, category, created_at in rows:
|
||||
by_user[user_id].append((memory_text, category, created_at))
|
||||
|
||||
migrated = 0
|
||||
for uid, user_rows in by_user.items():
|
||||
existing = conn.execute(
|
||||
sa.text('SELECT memory_md FROM "user" WHERE id = :uid'),
|
||||
{"uid": uid},
|
||||
).scalar()
|
||||
|
||||
if existing and existing.strip():
|
||||
logger.info("User %s already has memory_md, skipping.", uid)
|
||||
continue
|
||||
|
||||
markdown = _build_markdown(user_rows)
|
||||
conn.execute(
|
||||
sa.text('UPDATE "user" SET memory_md = :md WHERE id = :uid'),
|
||||
{"md": markdown, "uid": uid},
|
||||
)
|
||||
migrated += 1
|
||||
|
||||
logger.info("Migrated user_memories for %d user(s).", migrated)
|
||||
|
||||
|
||||
def _migrate_shared_memories(conn: sa.engine.Connection) -> None:
|
||||
"""Convert shared_memories rows → searchspaces.shared_memory_md."""
|
||||
rows = conn.execute(
|
||||
sa.text(
|
||||
"SELECT search_space_id, memory_text, category::text, created_at "
|
||||
"FROM shared_memories ORDER BY created_at"
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
logger.info("shared_memories is empty, skipping data migration.")
|
||||
return
|
||||
|
||||
by_space: dict[int, list[tuple]] = defaultdict(list)
|
||||
for search_space_id, memory_text, category, created_at in rows:
|
||||
by_space[search_space_id].append((memory_text, category, created_at))
|
||||
|
||||
migrated = 0
|
||||
for space_id, space_rows in by_space.items():
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT shared_memory_md FROM searchspaces WHERE id = :sid"),
|
||||
{"sid": space_id},
|
||||
).scalar()
|
||||
|
||||
if existing and existing.strip():
|
||||
logger.info(
|
||||
"Search space %s already has shared_memory_md, skipping.", space_id
|
||||
)
|
||||
continue
|
||||
|
||||
markdown = _build_markdown(space_rows)
|
||||
conn.execute(
|
||||
sa.text("UPDATE searchspaces SET shared_memory_md = :md WHERE id = :sid"),
|
||||
{"md": markdown, "sid": space_id},
|
||||
)
|
||||
migrated += 1
|
||||
|
||||
logger.info("Migrated shared_memories for %d search space(s).", migrated)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa_inspect(conn)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
if "user_memories" in tables:
|
||||
_migrate_user_memories(conn)
|
||||
|
||||
if "shared_memories" in tables:
|
||||
_migrate_shared_memories(conn)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS shared_memories CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS user_memories CASCADE;")
|
||||
op.execute("DROP TYPE IF EXISTS memorycategory;")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'memorycategory') THEN
|
||||
CREATE TYPE memorycategory AS ENUM (
|
||||
'preference',
|
||||
'fact',
|
||||
'instruction',
|
||||
'context'
|
||||
);
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS user_memories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
search_space_id INTEGER REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
memory_text TEXT NOT NULL,
|
||||
category memorycategory NOT NULL DEFAULT 'fact',
|
||||
embedding vector({EMBEDDING_DIM}),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_user_id ON user_memories(user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_search_space_id ON user_memories(search_space_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_updated_at ON user_memories(updated_at);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_category ON user_memories(category);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_user_memories_user_search_space ON user_memories(user_id, search_space_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS user_memories_vector_index ON user_memories USING hnsw (embedding public.vector_cosine_ops);"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS shared_memories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
created_by_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
memory_text TEXT NOT NULL,
|
||||
category memorycategory NOT NULL DEFAULT 'fact',
|
||||
embedding vector({EMBEDDING_DIM})
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_shared_memories_search_space_id ON shared_memories(search_space_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_shared_memories_updated_at ON shared_memories(updated_at);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_shared_memories_created_by_id ON shared_memories(created_by_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS shared_memories_vector_index ON shared_memories USING hnsw (embedding public.vector_cosine_ops);"
|
||||
)
|
||||
|
|
@ -38,6 +38,7 @@ from app.agents.new_chat.llm_config import AgentConfig
|
|||
from app.agents.new_chat.middleware import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
MemoryInjectionMiddleware,
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.system_prompt import (
|
||||
|
|
@ -168,8 +169,7 @@ async def create_surfsense_deep_agent(
|
|||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- scrape_webpage: Extract content from webpages
|
||||
- save_memory: Store facts/preferences about the user
|
||||
- recall_memory: Retrieve relevant user memories
|
||||
- update_memory: Update the user's personal or team memory document
|
||||
|
||||
The agent also includes TodoListMiddleware by default (via create_deep_agent) which provides:
|
||||
- write_todos: Create and update planning/todo lists for complex tasks
|
||||
|
|
@ -281,6 +281,7 @@ async def create_surfsense_deep_agent(
|
|||
"available_connectors": available_connectors,
|
||||
"available_document_types": available_document_types,
|
||||
"max_input_tokens": _max_input_tokens,
|
||||
"llm": llm,
|
||||
}
|
||||
|
||||
# Disable Notion action tools if no Notion connector is configured
|
||||
|
|
@ -425,9 +426,16 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
|
||||
# -- Build the middleware stack (mirrors create_deep_agent internals) ------
|
||||
_memory_middleware = MemoryInjectionMiddleware(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
||||
# General-purpose subagent middleware
|
||||
gp_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
SurfSenseFilesystemMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
|
|
@ -447,6 +455,7 @@ async def create_surfsense_deep_agent(
|
|||
# Main agent middleware
|
||||
deepagent_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
KnowledgeBaseSearchMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
|
|
|
|||
239
surfsense_backend/app/agents/new_chat/memory_extraction.py
Normal file
239
surfsense_backend/app/agents/new_chat/memory_extraction.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Background memory extraction for the SurfSense agent.
|
||||
|
||||
After each agent response, if the agent did not call ``update_memory`` during
|
||||
the turn, this module can run a lightweight LLM call to decide whether the
|
||||
latest message contains long-term information worth persisting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.new_chat.tools.update_memory import _save_memory
|
||||
from app.db import SearchSpace, User, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MEMORY_EXTRACT_PROMPT = """\
|
||||
You are a memory extraction assistant. Analyze the user's message and decide \
|
||||
if it contains any long-term information worth persisting to memory.
|
||||
|
||||
Worth remembering: preferences, background/identity, goals, projects, \
|
||||
instructions, tools/languages they use, decisions, expertise, workplace — \
|
||||
durable facts that will matter in future conversations.
|
||||
|
||||
NOT worth remembering: greetings, one-off factual questions, session \
|
||||
logistics, ephemeral requests, follow-up clarifications with no new personal \
|
||||
info, things that only matter for the current task.
|
||||
|
||||
If the message contains memorizable information, output the FULL updated \
|
||||
memory document with the new facts merged into the existing content. Follow \
|
||||
these rules:
|
||||
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
||||
freely. Keep heading names short (2-3 words) and natural. Do NOT include the user's
|
||||
name in headings.
|
||||
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- Every bullet MUST use format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||
[fact] = durable facts, [pref] = preferences, [instr] = standing instructions.
|
||||
- Use the user's first name (from <user_name>) in entry text, not "the user".
|
||||
- If a new fact contradicts an existing entry, update the existing entry.
|
||||
- Do not duplicate information that is already present.
|
||||
|
||||
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||
|
||||
<user_name>{user_name}</user_name>
|
||||
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
<user_message>
|
||||
{user_message}
|
||||
</user_message>"""
|
||||
|
||||
_TEAM_MEMORY_EXTRACT_PROMPT = """\
|
||||
You are a team-memory extraction assistant. Analyze the latest message and \
|
||||
decide if it contains durable TEAM-level information worth persisting.
|
||||
|
||||
Decision policy:
|
||||
- Prioritize recall for durable team context, while avoiding personal-only facts.
|
||||
- Do NOT require explicit consensus language. A direct team-level statement can
|
||||
be stored if it is stable and broadly useful for future team chats.
|
||||
- If evidence is weak or clearly tentative, output NO_UPDATE.
|
||||
|
||||
Worth remembering (team-level only):
|
||||
- Decisions and defaults that guide future team work
|
||||
- Team conventions/standards (naming, review policy, coding norms)
|
||||
- Stable org/project facts (locations, ownership, constraints)
|
||||
- Long-lived architecture/process facts
|
||||
- Ongoing priorities that are likely relevant beyond this turn
|
||||
|
||||
NOT worth remembering:
|
||||
- Personal preferences or biography of one person
|
||||
- Questions, brainstorming, tentative ideas, or speculation
|
||||
- One-off requests, status updates, TODOs, logistics for this session
|
||||
- Information scoped only to a single ephemeral task
|
||||
|
||||
If the message contains memorizable team information, output the FULL updated \
|
||||
team memory document with new facts merged into existing content. Follow rules:
|
||||
- Every entry MUST be under a ## heading. Preserve existing headings; create new ones
|
||||
freely. Keep heading names short (2-3 words) and natural.
|
||||
- Keep entries as single bullet points. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- Every bullet MUST use format: - (YYYY-MM-DD) [fact] text
|
||||
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr].
|
||||
- If a new fact contradicts an existing entry, update the existing entry.
|
||||
- Do not duplicate existing information.
|
||||
- Preserve neutral team phrasing; avoid person-specific memory unless role-anchored.
|
||||
|
||||
If nothing is worth remembering, output exactly: NO_UPDATE
|
||||
|
||||
<current_team_memory>
|
||||
{current_memory}
|
||||
</current_team_memory>
|
||||
|
||||
<latest_message_author>
|
||||
{author}
|
||||
</latest_message_author>
|
||||
|
||||
<latest_message>
|
||||
{user_message}
|
||||
</latest_message>"""
|
||||
|
||||
|
||||
async def extract_and_save_memory(
|
||||
*,
|
||||
user_message: str,
|
||||
user_id: str | None,
|
||||
llm: Any,
|
||||
) -> None:
|
||||
"""Background task: extract memorizable info and persist it.
|
||||
|
||||
Designed to be fire-and-forget — catches all exceptions internally.
|
||||
"""
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
try:
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return
|
||||
|
||||
old_memory = user.memory_md
|
||||
first_name = (
|
||||
user.display_name.strip().split()[0]
|
||||
if user.display_name and user.display_name.strip()
|
||||
else "The user"
|
||||
)
|
||||
prompt = _MEMORY_EXTRACT_PROMPT.format(
|
||||
current_memory=old_memory or "(empty)",
|
||||
user_message=user_message,
|
||||
user_name=first_name,
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "memory-extraction"]},
|
||||
)
|
||||
text = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
).strip()
|
||||
|
||||
if text == "NO_UPDATE" or not text:
|
||||
logger.debug("Memory extraction: no update needed (user %s)", uid)
|
||||
return
|
||||
|
||||
save_result = await _save_memory(
|
||||
updated_memory=text,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
logger.info(
|
||||
"Background memory extraction for user %s: %s",
|
||||
uid,
|
||||
save_result.get("status"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Background user memory extraction failed")
|
||||
|
||||
|
||||
async def extract_and_save_team_memory(
|
||||
*,
|
||||
user_message: str,
|
||||
search_space_id: int | None,
|
||||
llm: Any,
|
||||
author_display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Background task: extract team-level memory and persist it.
|
||||
|
||||
Runs only for shared threads. Designed to be fire-and-forget and catches
|
||||
exceptions internally.
|
||||
"""
|
||||
if not search_space_id:
|
||||
return
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
prompt = _TEAM_MEMORY_EXTRACT_PROMPT.format(
|
||||
current_memory=old_memory or "(empty)",
|
||||
author=author_display_name or "Unknown team member",
|
||||
user_message=user_message,
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "team-memory-extraction"]},
|
||||
)
|
||||
text = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
).strip()
|
||||
|
||||
if text == "NO_UPDATE" or not text:
|
||||
logger.debug(
|
||||
"Team memory extraction: no update needed (space %s)",
|
||||
search_space_id,
|
||||
)
|
||||
return
|
||||
|
||||
save_result = await _save_memory(
|
||||
updated_memory=text,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
logger.info(
|
||||
"Background team memory extraction for space %s: %s",
|
||||
search_space_id,
|
||||
save_result.get("status"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Background team memory extraction failed")
|
||||
|
|
@ -9,9 +9,13 @@ from app.agents.new_chat.middleware.filesystem import (
|
|||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.memory_injection import (
|
||||
MemoryInjectionMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DedupHITLToolCallsMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"MemoryInjectionMiddleware",
|
||||
"SurfSenseFilesystemMiddleware",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,138 @@
|
|||
"""Memory injection middleware for the SurfSense agent.
|
||||
|
||||
Injects memory markdown into the system prompt on every turn:
|
||||
- Private threads: only personal memory (<user_memory>)
|
||||
- Shared threads: only team memory (<team_memory>)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
|
||||
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Injects memory markdown into the conversation on every turn."""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
user_id: str | UUID | None,
|
||||
search_space_id: int,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> None:
|
||||
self.user_id = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
self.search_space_id = search_space_id
|
||||
self.visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
if not isinstance(last_message, HumanMessage):
|
||||
return None
|
||||
|
||||
memory_blocks: list[str] = []
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||
team_memory = await self._load_team_memory(session)
|
||||
if team_memory:
|
||||
chars = len(team_memory)
|
||||
memory_blocks.append(
|
||||
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{team_memory}\n"
|
||||
f"</team_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Team memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
elif self.user_id is not None:
|
||||
user_memory, display_name = await self._load_user_memory(session)
|
||||
if display_name and display_name.strip():
|
||||
first_name = display_name.strip().split()[0]
|
||||
memory_blocks.append(f"<user_name>{first_name}</user_name>")
|
||||
if user_memory:
|
||||
chars = len(user_memory)
|
||||
memory_blocks.append(
|
||||
f'<user_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{user_memory}\n"
|
||||
f"</user_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Your personal memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
|
||||
if not memory_blocks:
|
||||
return None
|
||||
|
||||
memory_text = "\n\n".join(memory_blocks)
|
||||
memory_msg = SystemMessage(content=memory_text)
|
||||
|
||||
new_messages = list(messages)
|
||||
insert_idx = 1 if len(new_messages) > 1 else 0
|
||||
new_messages.insert(insert_idx, memory_msg)
|
||||
|
||||
return {"messages": new_messages}
|
||||
|
||||
async def _load_user_memory(
|
||||
self, session: AsyncSession
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Return (memory_content, display_name)."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(User.memory_md, User.display_name).where(User.id == self.user_id)
|
||||
)
|
||||
row = result.one_or_none()
|
||||
if row is None:
|
||||
return None, None
|
||||
return row.memory_md or None, row.display_name
|
||||
except Exception:
|
||||
logger.exception("Failed to load user memory")
|
||||
return None, None
|
||||
|
||||
async def _load_team_memory(self, session: AsyncSession) -> str | None:
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace.shared_memory_md).where(
|
||||
SearchSpace.id == self.search_space_id
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return row if row else None
|
||||
except Exception:
|
||||
logger.exception("Failed to load team memory")
|
||||
return None
|
||||
|
|
@ -40,6 +40,13 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
|||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
</knowledge_base_only_policy>
|
||||
|
||||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the user (role, interests, preferences, projects,
|
||||
background, or standing instructions)? If yes, you MUST call update_memory
|
||||
alongside your normal response — do not defer this to a later turn.
|
||||
</memory_protocol>
|
||||
|
||||
</system_instruction>
|
||||
"""
|
||||
|
||||
|
|
@ -71,6 +78,13 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
|||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
</knowledge_base_only_policy>
|
||||
|
||||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
||||
or key facts)? If yes, you MUST call update_memory alongside your normal response —
|
||||
do not defer this to a later turn.
|
||||
</memory_protocol>
|
||||
|
||||
</system_instruction>
|
||||
"""
|
||||
|
||||
|
|
@ -248,115 +262,97 @@ _TOOL_INSTRUCTIONS["web_search"] = """
|
|||
"""
|
||||
|
||||
# Memory tool instructions have private and shared variants.
|
||||
# We store them keyed as "save_memory" / "recall_memory" with sub-keys.
|
||||
# We store them keyed as "update_memory" with sub-keys.
|
||||
_MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
||||
"save_memory": {
|
||||
"update_memory": {
|
||||
"private": """
|
||||
- save_memory: Save facts, preferences, or context for personalized responses.
|
||||
- Use this when the user explicitly or implicitly shares information worth remembering.
|
||||
- Trigger scenarios:
|
||||
* User says "remember this", "keep this in mind", "note that", or similar
|
||||
* User shares personal preferences (e.g., "I prefer Python over JavaScript")
|
||||
* User shares facts about themselves (e.g., "I'm a senior developer at Company X")
|
||||
* User gives standing instructions (e.g., "always respond in bullet points")
|
||||
* User shares project context (e.g., "I'm working on migrating our codebase to TypeScript")
|
||||
- update_memory: Update your personal memory document about the user.
|
||||
- Your current memory is already in <user_memory> in your context. The `chars` and
|
||||
`limit` attributes show your current usage and the maximum allowed size.
|
||||
- This is your curated long-term memory — the distilled essence of what you know about
|
||||
the user, not raw conversation logs.
|
||||
- Call update_memory when:
|
||||
* The user explicitly asks to remember or forget something
|
||||
* The user shares durable facts or preferences that will matter in future conversations
|
||||
- The user's first name is provided in <user_name>. Use it in memory entries
|
||||
instead of "the user" (e.g. "{name} works at..." not "The user works at...").
|
||||
Do not store the name itself as a separate memory entry.
|
||||
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
||||
session logistics, or things that only matter for the current task.
|
||||
- Args:
|
||||
- content: The fact/preference to remember. Phrase it clearly:
|
||||
* "User prefers dark mode for all interfaces"
|
||||
* "User is a senior Python developer"
|
||||
* "User wants responses in bullet point format"
|
||||
* "User is working on project called ProjectX"
|
||||
- category: Type of memory:
|
||||
* "preference": User preferences (coding style, tools, formats)
|
||||
* "fact": Facts about the user (role, expertise, background)
|
||||
* "instruction": Standing instructions (response format, communication style)
|
||||
* "context": Current context (ongoing projects, goals, challenges)
|
||||
- Returns: Confirmation of saved memory
|
||||
- IMPORTANT: Only save information that would be genuinely useful for future conversations.
|
||||
Don't save trivial or temporary information.
|
||||
- updated_memory: The FULL updated markdown document (not a diff).
|
||||
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
||||
Treat every update as a curation pass — consolidate, don't just append.
|
||||
- Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text
|
||||
Markers:
|
||||
[fact] — durable facts (role, background, projects, tools, expertise)
|
||||
[pref] — preferences (response style, languages, formats, tools)
|
||||
[instr] — standing instructions (always/never do, response rules)
|
||||
- Keep it concise and well under the character limit shown in <user_memory>.
|
||||
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
||||
natural. Do NOT include the user's name in headings. Organize by context — e.g.
|
||||
who they are, what they're focused on, how they prefer things. Create, split, or
|
||||
merge headings freely as the memory grows.
|
||||
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- During consolidation, prioritize keeping: [instr] > [pref] > [fact].
|
||||
""",
|
||||
"shared": """
|
||||
- save_memory: Save a fact, preference, or context to the team's shared memory for future reference.
|
||||
- Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat.
|
||||
- Use when the team agrees on something to remember (e.g., decisions, conventions).
|
||||
- Someone shares a preference or fact that should be visible to the whole team.
|
||||
- The saved information will be available in future shared conversations in this space.
|
||||
- update_memory: Update the team's shared memory document for this search space.
|
||||
- Your current team memory is already in <team_memory> in your context. The `chars`
|
||||
and `limit` attributes show current usage and the maximum allowed size.
|
||||
- This is the team's curated long-term memory — decisions, conventions, key facts.
|
||||
- NEVER store personal memory in team memory (e.g. personal bio, individual
|
||||
preferences, or user-only standing instructions).
|
||||
- Call update_memory when:
|
||||
* A team member explicitly asks to remember or forget something
|
||||
* The conversation surfaces durable team decisions, conventions, or facts
|
||||
that will matter in future conversations
|
||||
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
||||
session logistics, or things that only matter for the current task.
|
||||
- Args:
|
||||
- content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays"
|
||||
- category: Type of memory. One of:
|
||||
* "preference": Team or workspace preferences
|
||||
* "fact": Facts the team agreed on (e.g., processes, locations)
|
||||
* "instruction": Standing instructions for the team
|
||||
* "context": Current context (e.g., ongoing projects, goals)
|
||||
- Returns: Confirmation of saved memory; returned context may include who added it (added_by).
|
||||
- IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space.
|
||||
""",
|
||||
},
|
||||
"recall_memory": {
|
||||
"private": """
|
||||
- recall_memory: Retrieve relevant memories about the user for personalized responses.
|
||||
- Use this to access stored information about the user.
|
||||
- Trigger scenarios:
|
||||
* You need user context to give a better, more personalized answer
|
||||
* User references something they mentioned before
|
||||
* User asks "what do you know about me?" or similar
|
||||
* Personalization would significantly improve response quality
|
||||
* Before making recommendations that should consider user preferences
|
||||
- Args:
|
||||
- query: Optional search query to find specific memories (e.g., "programming preferences")
|
||||
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
||||
- top_k: Number of memories to retrieve (default: 5)
|
||||
- Returns: Relevant memories formatted as context
|
||||
- IMPORTANT: Use the recalled memories naturally in your response without explicitly
|
||||
stating "Based on your memory..." - integrate the context seamlessly.
|
||||
""",
|
||||
"shared": """
|
||||
- recall_memory: Recall relevant team memories for this space to provide contextual responses.
|
||||
- Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?").
|
||||
- Use when someone asks about something the team agreed to remember.
|
||||
- Use when team preferences or conventions would improve the response.
|
||||
- Args:
|
||||
- query: Optional search query to find specific memories. If not provided, returns the most recent memories.
|
||||
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
||||
- top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
- Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...".
|
||||
- updated_memory: The FULL updated markdown document (not a diff).
|
||||
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
||||
Treat every update as a curation pass — consolidate, don't just append.
|
||||
- Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text
|
||||
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory.
|
||||
- Keep it concise and well under the character limit shown in <team_memory>.
|
||||
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
||||
natural. Organize by context — e.g. what the team decided, current architecture,
|
||||
active processes. Create, split, or merge headings freely as the memory grows.
|
||||
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
|
||||
""",
|
||||
},
|
||||
}
|
||||
|
||||
_MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = {
|
||||
"save_memory": {
|
||||
"update_memory": {
|
||||
"private": """
|
||||
- User: "Remember that I prefer TypeScript over JavaScript"
|
||||
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
||||
- User: "I'm a data scientist working on ML pipelines"
|
||||
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
||||
- User: "Always give me code examples in Python"
|
||||
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
||||
- <user_name>Alex</user_name>, <user_memory> is empty. User: "I'm a space enthusiast, explain astrophage to me"
|
||||
- The user casually shared a durable fact. Use their first name in the entry, short neutral heading:
|
||||
update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n")
|
||||
- User: "Remember that I prefer concise answers over detailed explanations"
|
||||
- Durable preference. Merge with existing memory, add a new heading:
|
||||
update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n\\n## Response style\\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\\n")
|
||||
- User: "I actually moved to Tokyo last month"
|
||||
- Updated fact, date prefix reflects when recorded:
|
||||
update_memory(updated_memory="## Interests & background\\n...\\n\\n## Personal context\\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\\n...")
|
||||
- User: "I'm a freelance photographer working on a nature documentary"
|
||||
- Durable background info under a fitting heading:
|
||||
update_memory(updated_memory="...\\n\\n## Current focus\\n- (2025-03-15) [fact] Alex is a freelance photographer\\n- (2025-03-15) [fact] Alex is working on a nature documentary\\n")
|
||||
- User: "Always respond in bullet points"
|
||||
- Standing instruction:
|
||||
update_memory(updated_memory="...\\n\\n## Response style\\n- (2025-03-15) [instr] Always respond to Alex in bullet points\\n")
|
||||
""",
|
||||
"shared": """
|
||||
- User: "Remember that API keys are stored in Vault"
|
||||
- Call: `save_memory(content="API keys are stored in Vault", category="fact")`
|
||||
- User: "Let's remember that the team prefers weekly demos on Fridays"
|
||||
- Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")`
|
||||
""",
|
||||
},
|
||||
"recall_memory": {
|
||||
"private": """
|
||||
- User: "What programming language should I use for this project?"
|
||||
- First recall: `recall_memory(query="programming language preferences")`
|
||||
- Then provide a personalized recommendation based on their preferences
|
||||
- User: "What do you know about me?"
|
||||
- Call: `recall_memory(top_k=10)`
|
||||
- Then summarize the stored memories
|
||||
""",
|
||||
"shared": """
|
||||
- User: "What did we decide about the release date?"
|
||||
- First recall: `recall_memory(query="release date decision")`
|
||||
- Then answer based on the team memories
|
||||
- User: "Where do we document onboarding?"
|
||||
- Call: `recall_memory(query="onboarding documentation")`
|
||||
- Then answer using the recalled team context
|
||||
- User: "Let's remember that we decided to do weekly standup meetings on Mondays"
|
||||
- Durable team decision:
|
||||
update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\\n...")
|
||||
- User: "Our office is in downtown Seattle, 5th floor"
|
||||
- Durable team fact:
|
||||
update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\\n...")
|
||||
""",
|
||||
},
|
||||
}
|
||||
|
|
@ -456,8 +452,7 @@ _ALL_TOOL_NAMES_ORDERED = [
|
|||
"generate_report",
|
||||
"generate_image",
|
||||
"scrape_webpage",
|
||||
"save_memory",
|
||||
"recall_memory",
|
||||
"update_memory",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ Available tools:
|
|||
- generate_video_presentation: Generate video presentations with slides and narration
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- scrape_webpage: Extract content from webpages
|
||||
- save_memory: Store facts/preferences about the user
|
||||
- recall_memory: Retrieve relevant user memories
|
||||
- update_memory: Update the user's / team's memory document
|
||||
"""
|
||||
|
||||
# Registry exports
|
||||
|
|
@ -33,7 +32,7 @@ from .registry import (
|
|||
)
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -47,10 +46,10 @@ __all__ = [
|
|||
"create_generate_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_generate_video_presentation_tool",
|
||||
"create_recall_memory_tool",
|
||||
"create_save_memory_tool",
|
||||
"create_scrape_webpage_tool",
|
||||
"create_search_surfsense_docs_tool",
|
||||
"create_update_memory_tool",
|
||||
"create_update_team_memory_tool",
|
||||
"format_documents_for_context",
|
||||
"get_all_tool_names",
|
||||
"get_default_enabled_tools",
|
||||
|
|
|
|||
|
|
@ -94,11 +94,7 @@ from .podcast import create_generate_podcast_tool
|
|||
from .report import create_generate_report_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .shared_memory import (
|
||||
create_recall_shared_memory_tool,
|
||||
create_save_shared_memory_tool,
|
||||
)
|
||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
from .web_search import create_web_search_tool
|
||||
|
||||
|
|
@ -214,42 +210,31 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["db_session"],
|
||||
),
|
||||
# =========================================================================
|
||||
# USER MEMORY TOOLS - private or team store by thread_visibility
|
||||
# MEMORY TOOL - single update_memory, private or team by thread_visibility
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="save_memory",
|
||||
description="Save facts, preferences, or context for personalized or team responses",
|
||||
name="update_memory",
|
||||
description="Save important long-term facts, preferences, and instructions to the (personal or team) memory",
|
||||
factory=lambda deps: (
|
||||
create_save_shared_memory_tool(
|
||||
create_update_team_memory_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
created_by_id=deps["user_id"],
|
||||
db_session=deps["db_session"],
|
||||
llm=deps.get("llm"),
|
||||
)
|
||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||
else create_save_memory_tool(
|
||||
else create_update_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
llm=deps.get("llm"),
|
||||
)
|
||||
),
|
||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="recall_memory",
|
||||
description="Recall relevant memories (personal or team) for context",
|
||||
factory=lambda deps: (
|
||||
create_recall_shared_memory_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||
else create_recall_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
),
|
||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
requires=[
|
||||
"user_id",
|
||||
"search_space_id",
|
||||
"db_session",
|
||||
"thread_visibility",
|
||||
"llm",
|
||||
],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS - create, update, delete issues
|
||||
|
|
|
|||
|
|
@ -1,281 +0,0 @@
|
|||
"""Shared (team) memory backend for search-space-scoped AI context."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import MemoryCategory, SharedMemory, User
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RECALL_TOP_K = 5
|
||||
MAX_MEMORIES_PER_SEARCH_SPACE = 250
|
||||
|
||||
|
||||
async def get_shared_memory_count(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> int:
|
||||
result = await db_session.execute(
|
||||
select(SharedMemory).where(SharedMemory.search_space_id == search_space_id)
|
||||
)
|
||||
return len(result.scalars().all())
|
||||
|
||||
|
||||
async def delete_oldest_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> None:
|
||||
result = await db_session.execute(
|
||||
select(SharedMemory)
|
||||
.where(SharedMemory.search_space_id == search_space_id)
|
||||
.order_by(SharedMemory.updated_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
oldest = result.scalars().first()
|
||||
if oldest:
|
||||
await db_session.delete(oldest)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
def _to_uuid(value: str | UUID) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
return UUID(value)
|
||||
|
||||
|
||||
async def save_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
created_by_id: str | UUID,
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
category = category.lower() if category else "fact"
|
||||
valid = ["preference", "fact", "instruction", "context"]
|
||||
if category not in valid:
|
||||
category = "fact"
|
||||
try:
|
||||
count = await get_shared_memory_count(db_session, search_space_id)
|
||||
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
|
||||
await delete_oldest_shared_memory(db_session, search_space_id)
|
||||
embedding = await asyncio.to_thread(embed_text, content)
|
||||
row = SharedMemory(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=_to_uuid(created_by_id),
|
||||
memory_text=content,
|
||||
category=MemoryCategory(category),
|
||||
embedding=embedding,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(row)
|
||||
return {
|
||||
"status": "saved",
|
||||
"memory_id": row.id,
|
||||
"memory_text": content,
|
||||
"category": category,
|
||||
"message": f"I'll remember: {content}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save shared memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": "Failed to save memory. Please try again.",
|
||||
}
|
||||
|
||||
|
||||
async def recall_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
top_k = min(max(top_k, 1), 20)
|
||||
try:
|
||||
valid_categories = ["preference", "fact", "instruction", "context"]
|
||||
stmt = select(SharedMemory).where(
|
||||
SharedMemory.search_space_id == search_space_id
|
||||
)
|
||||
if category and category in valid_categories:
|
||||
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
|
||||
if query:
|
||||
query_embedding = await asyncio.to_thread(embed_text, query)
|
||||
stmt = stmt.order_by(
|
||||
SharedMemory.embedding.op("<=>")(query_embedding)
|
||||
).limit(top_k)
|
||||
else:
|
||||
stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k)
|
||||
result = await db_session.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
memory_list = [
|
||||
{
|
||||
"id": m.id,
|
||||
"memory_text": m.memory_text,
|
||||
"category": m.category.value if m.category else "unknown",
|
||||
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
||||
"created_by_id": str(m.created_by_id) if m.created_by_id else None,
|
||||
}
|
||||
for m in rows
|
||||
]
|
||||
created_by_ids = list(
|
||||
{m["created_by_id"] for m in memory_list if m["created_by_id"]}
|
||||
)
|
||||
created_by_map: dict[str, str] = {}
|
||||
if created_by_ids:
|
||||
uuids = [UUID(uid) for uid in created_by_ids]
|
||||
users_result = await db_session.execute(
|
||||
select(User).where(User.id.in_(uuids))
|
||||
)
|
||||
for u in users_result.scalars().all():
|
||||
created_by_map[str(u.id)] = u.display_name or "A team member"
|
||||
formatted_context = format_shared_memories_for_context(
|
||||
memory_list, created_by_map
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(memory_list),
|
||||
"memories": memory_list,
|
||||
"formatted_context": formatted_context,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to recall shared memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"memories": [],
|
||||
"formatted_context": "Failed to recall memories.",
|
||||
}
|
||||
|
||||
|
||||
def format_shared_memories_for_context(
|
||||
memories: list[dict[str, Any]],
|
||||
created_by_map: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
if not memories:
|
||||
return "No relevant team memories found."
|
||||
created_by_map = created_by_map or {}
|
||||
parts = ["<team_memories>"]
|
||||
for memory in memories:
|
||||
category = memory.get("category", "unknown")
|
||||
text = memory.get("memory_text", "")
|
||||
updated = memory.get("updated_at", "")
|
||||
created_by_id = memory.get("created_by_id")
|
||||
added_by = (
|
||||
created_by_map.get(str(created_by_id), "A team member")
|
||||
if created_by_id is not None
|
||||
else "A team member"
|
||||
)
|
||||
parts.append(
|
||||
f" <memory category='{category}' updated='{updated}' added_by='{added_by}'>{text}</memory>"
|
||||
)
|
||||
parts.append("</team_memories>")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def create_save_shared_memory_tool(
|
||||
search_space_id: int,
|
||||
created_by_id: str | UUID,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the save_memory tool for shared (team) chats.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
created_by_id: The user ID of the person adding the memory
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for saving team memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def save_memory(
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Save a fact, preference, or context to the team's shared memory for future reference.
|
||||
|
||||
Use this tool when:
|
||||
- User or a team member says "remember this", "keep this in mind", or similar in this shared chat
|
||||
- The team agrees on something to remember (e.g., decisions, conventions, where things live)
|
||||
- Someone shares a preference or fact that should be visible to the whole team
|
||||
|
||||
The saved information will be available in future shared conversations in this space.
|
||||
|
||||
Args:
|
||||
content: The fact/preference/context to remember.
|
||||
Phrase it clearly, e.g., "API keys are stored in Vault",
|
||||
"The team prefers weekly demos on Fridays"
|
||||
category: Type of memory. One of:
|
||||
- "preference": Team or workspace preferences
|
||||
- "fact": Facts the team agreed on (e.g., processes, locations)
|
||||
- "instruction": Standing instructions for the team
|
||||
- "context": Current context (e.g., ongoing projects, goals)
|
||||
|
||||
Returns:
|
||||
A dictionary with the save status and memory details
|
||||
"""
|
||||
return await save_shared_memory(
|
||||
db_session, search_space_id, created_by_id, content, category
|
||||
)
|
||||
|
||||
return save_memory
|
||||
|
||||
|
||||
def create_recall_shared_memory_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the recall_memory tool for shared (team) chats.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for recalling team memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def recall_memory(
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Recall relevant team memories for this space to provide contextual responses.
|
||||
|
||||
Use this tool when:
|
||||
- You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?")
|
||||
- Someone asks about something the team agreed to remember
|
||||
- Team preferences or conventions would improve the response
|
||||
|
||||
Args:
|
||||
query: Optional search query to find specific memories.
|
||||
If not provided, returns the most recent memories.
|
||||
category: Optional category filter. One of:
|
||||
"preference", "fact", "instruction", "context"
|
||||
top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
|
||||
Returns:
|
||||
A dictionary containing relevant memories and formatted context
|
||||
"""
|
||||
return await recall_shared_memory(
|
||||
db_session, search_space_id, query, category, top_k
|
||||
)
|
||||
|
||||
return recall_memory
|
||||
389
surfsense_backend/app/agents/new_chat/tools/update_memory.py
Normal file
389
surfsense_backend/app/agents/new_chat/tools/update_memory.py
Normal file
|
|
@ -0,0 +1,389 @@
|
|||
"""Markdown-document memory tool for the SurfSense agent.
|
||||
|
||||
Replaces the old row-per-fact save_memory / recall_memory tools with a single
|
||||
update_memory tool that overwrites a freeform markdown TEXT column. The LLM
|
||||
always sees the current memory in <user_memory> / <team_memory> tags injected
|
||||
by MemoryInjectionMiddleware, so it passes the FULL updated document each time.
|
||||
|
||||
Overflow handling:
|
||||
- Soft limit (18K chars): a warning is returned telling the agent to
|
||||
consolidate on the next update.
|
||||
- Hard limit (25K chars): a forced LLM-driven rewrite compresses the document.
|
||||
If it still exceeds the limit after rewriting, the save is rejected.
|
||||
- Diff validation: warns when entire ``##`` sections are dropped or when the
|
||||
document shrinks by more than 60%.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSpace, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_SOFT_LIMIT = 18_000
|
||||
MEMORY_HARD_LIMIT = 25_000
|
||||
|
||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
||||
|
||||
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
|
||||
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
|
||||
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diff validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_headings(memory: str) -> set[str]:
|
||||
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
|
||||
return set(_SECTION_HEADING_RE.findall(memory))
|
||||
|
||||
|
||||
def _normalize_heading(heading: str) -> str:
|
||||
"""Normalize heading text for robust scope checks."""
|
||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
||||
|
||||
|
||||
def _validate_memory_scope(
|
||||
content: str, scope: Literal["user", "team"]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Reject personal-only markers ([pref], [instr]) in team memory."""
|
||||
if scope != "team":
|
||||
return None
|
||||
|
||||
markers = set(_MARKER_RE.findall(content))
|
||||
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
|
||||
if leaked:
|
||||
tags = ", ".join(f"[{m}]" for m in leaked)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Team memory cannot include personal markers: {tags}. "
|
||||
"Use [fact] only in team memory."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _validate_bullet_format(content: str) -> list[str]:
|
||||
"""Return warnings for bullet lines that don't match the required format.
|
||||
|
||||
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
|
||||
"""
|
||||
warnings: list[str] = []
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped.startswith("- "):
|
||||
continue
|
||||
if not _BULLET_FORMAT_RE.match(stripped):
|
||||
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
||||
warnings.append(f"Malformed bullet: {short}")
|
||||
return warnings
|
||||
|
||||
|
||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||
"""Return a list of warning strings about suspicious changes."""
|
||||
if not old_memory:
|
||||
return []
|
||||
|
||||
warnings: list[str] = []
|
||||
old_headings = _extract_headings(old_memory)
|
||||
new_headings = _extract_headings(new_memory)
|
||||
dropped = old_headings - new_headings
|
||||
if dropped:
|
||||
names = ", ".join(sorted(dropped))
|
||||
warnings.append(
|
||||
f"Sections removed: {names}. "
|
||||
"If unintentional, the user can restore from the settings page."
|
||||
)
|
||||
|
||||
old_len = len(old_memory)
|
||||
new_len = len(new_memory)
|
||||
if old_len > 0 and new_len < old_len * 0.4:
|
||||
warnings.append(
|
||||
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
|
||||
"Possible data loss."
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Size validation & soft warning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_memory_size(content: str) -> dict[str, Any] | None:
|
||||
"""Return an error/warning dict if *content* is too large, else None."""
|
||||
length = len(content)
|
||||
if length > MEMORY_HARD_LIMIT:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
||||
f"({length:,} chars). Consolidate by merging related items, "
|
||||
"removing outdated entries, and shortening descriptions. "
|
||||
"Then call update_memory again."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _soft_warning(content: str) -> str | None:
|
||||
"""Return a warning string if content exceeds the soft limit."""
|
||||
length = len(content)
|
||||
if length > MEMORY_SOFT_LIMIT:
|
||||
return (
|
||||
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
||||
"Consolidate by merging related items and removing less important "
|
||||
"entries on your next update."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forced rewrite when memory exceeds the hard limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FORCED_REWRITE_PROMPT = """\
|
||||
You are a memory curator. The following memory document exceeds the character \
|
||||
limit and must be shortened.
|
||||
|
||||
RULES:
|
||||
1. Rewrite the document to be under {target} characters.
|
||||
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
|
||||
or rename headings to consolidate, but keep names personal and descriptive.
|
||||
3. Priority for keeping content: [instr] > [pref] > [fact].
|
||||
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
||||
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||
6. Preserve the user's first name in entries — do not replace it with "the user".
|
||||
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||
|
||||
<memory_document>
|
||||
{content}
|
||||
</memory_document>"""
|
||||
|
||||
|
||||
async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
||||
"""Use a focused LLM call to compress *content* under the hard limit.
|
||||
|
||||
Returns the rewritten string, or ``None`` if the call fails.
|
||||
"""
|
||||
try:
|
||||
prompt = _FORCED_REWRITE_PROMPT.format(
|
||||
target=MEMORY_HARD_LIMIT, content=content
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
text = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
)
|
||||
return text.strip()
|
||||
except Exception:
|
||||
logger.exception("Forced rewrite LLM call failed")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared save-and-respond logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _save_memory(
|
||||
*,
|
||||
updated_memory: str,
|
||||
old_memory: str | None,
|
||||
llm: Any | None,
|
||||
apply_fn,
|
||||
commit_fn,
|
||||
rollback_fn,
|
||||
label: str,
|
||||
scope: Literal["user", "team"],
|
||||
) -> dict[str, Any]:
|
||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||
return a response dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
updated_memory : str
|
||||
The new document the agent submitted.
|
||||
old_memory : str | None
|
||||
The previously persisted document (for diff checks).
|
||||
llm : Any | None
|
||||
LLM instance for forced rewrite (may be ``None``).
|
||||
apply_fn : callable(str) -> None
|
||||
Callback that sets the new memory on the ORM object.
|
||||
commit_fn : coroutine
|
||||
``session.commit``.
|
||||
rollback_fn : coroutine
|
||||
``session.rollback``.
|
||||
label : str
|
||||
Human label for log messages (e.g. "user memory", "team memory").
|
||||
"""
|
||||
content = updated_memory
|
||||
|
||||
# --- forced rewrite if over the hard limit ---
|
||||
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
|
||||
rewritten = await _forced_rewrite(content, llm)
|
||||
if rewritten is not None and len(rewritten) < len(content):
|
||||
content = rewritten
|
||||
|
||||
# --- hard-limit gate (reject if still too large after rewrite) ---
|
||||
size_err = _validate_memory_size(content)
|
||||
if size_err:
|
||||
return size_err
|
||||
|
||||
scope_err = _validate_memory_scope(content, scope)
|
||||
if scope_err:
|
||||
return scope_err
|
||||
|
||||
# --- persist ---
|
||||
try:
|
||||
apply_fn(content)
|
||||
await commit_fn()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update %s: %s", label, e)
|
||||
await rollback_fn()
|
||||
return {"status": "error", "message": f"Failed to update {label}: {e}"}
|
||||
|
||||
# --- build response ---
|
||||
resp: dict[str, Any] = {
|
||||
"status": "saved",
|
||||
"message": f"{label.capitalize()} updated.",
|
||||
}
|
||||
|
||||
if content is not updated_memory:
|
||||
resp["notice"] = "Memory was automatically rewritten to fit within limits."
|
||||
|
||||
diff_warnings = _validate_diff(old_memory, content)
|
||||
if diff_warnings:
|
||||
resp["diff_warnings"] = diff_warnings
|
||||
|
||||
format_warnings = _validate_bullet_format(content)
|
||||
if format_warnings:
|
||||
resp["format_warnings"] = format_warnings
|
||||
|
||||
warning = _soft_warning(content)
|
||||
if warning:
|
||||
resp["warning"] = warning
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_update_memory_tool(
|
||||
user_id: str | UUID,
|
||||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
@tool
|
||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the user's personal memory document.
|
||||
|
||||
Your current memory is shown in <user_memory> in the system prompt.
|
||||
When the user shares important long-term information (preferences,
|
||||
facts, instructions, context), rewrite the memory document to include
|
||||
the new information. Merge new facts with existing ones, update
|
||||
contradictions, remove outdated entries, and keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return {"status": "error", "message": "User not found."}
|
||||
|
||||
old_memory = user.memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update user memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update memory: {e}",
|
||||
}
|
||||
|
||||
return update_memory
|
||||
|
||||
|
||||
def create_update_team_memory_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the team's shared memory document for this search space.
|
||||
|
||||
Your current team memory is shown in <team_memory> in the system
|
||||
prompt. When the team shares important long-term information
|
||||
(decisions, conventions, key facts, priorities), rewrite the memory
|
||||
document to include the new information. Merge new facts with
|
||||
existing ones, update contradictions, remove outdated entries, and
|
||||
keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return {"status": "error", "message": "Search space not found."}
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update team memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update team memory: {e}",
|
||||
}
|
||||
|
||||
return update_memory
|
||||
|
|
@ -1,351 +0,0 @@
|
|||
"""
|
||||
User memory tools for the SurfSense agent.
|
||||
|
||||
This module provides tools for storing and retrieving user memories,
|
||||
enabling personalized AI responses similar to Claude's memory feature.
|
||||
|
||||
Features:
|
||||
- save_memory: Store facts, preferences, and context about the user
|
||||
- recall_memory: Retrieve relevant memories using semantic search
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import MemoryCategory, UserMemory
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
# Default number of memories to retrieve
|
||||
DEFAULT_RECALL_TOP_K = 5
|
||||
|
||||
# Maximum number of memories per user (to prevent unbounded growth)
|
||||
MAX_MEMORIES_PER_USER = 100
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _to_uuid(user_id: str) -> UUID:
|
||||
"""Convert a string user_id to a UUID object."""
|
||||
if isinstance(user_id, UUID):
|
||||
return user_id
|
||||
return UUID(user_id)
|
||||
|
||||
|
||||
async def get_user_memory_count(
|
||||
db_session: AsyncSession,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> int:
|
||||
"""Get the count of memories for a user."""
|
||||
uuid_user_id = _to_uuid(user_id)
|
||||
query = select(UserMemory).where(UserMemory.user_id == uuid_user_id)
|
||||
if search_space_id is not None:
|
||||
query = query.where(
|
||||
(UserMemory.search_space_id == search_space_id)
|
||||
| (UserMemory.search_space_id.is_(None))
|
||||
)
|
||||
result = await db_session.execute(query)
|
||||
return len(result.scalars().all())
|
||||
|
||||
|
||||
async def delete_oldest_memory(
|
||||
db_session: AsyncSession,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> None:
|
||||
"""Delete the oldest memory for a user to make room for new ones."""
|
||||
uuid_user_id = _to_uuid(user_id)
|
||||
query = (
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == uuid_user_id)
|
||||
.order_by(UserMemory.updated_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
if search_space_id is not None:
|
||||
query = query.where(
|
||||
(UserMemory.search_space_id == search_space_id)
|
||||
| (UserMemory.search_space_id.is_(None))
|
||||
)
|
||||
result = await db_session.execute(query)
|
||||
oldest_memory = result.scalars().first()
|
||||
if oldest_memory:
|
||||
await db_session.delete(oldest_memory)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
def format_memories_for_context(memories: list[dict[str, Any]]) -> str:
|
||||
"""Format retrieved memories into a readable context string for the LLM."""
|
||||
if not memories:
|
||||
return "No relevant memories found for this user."
|
||||
|
||||
parts = ["<user_memories>"]
|
||||
for memory in memories:
|
||||
category = memory.get("category", "unknown")
|
||||
text = memory.get("memory_text", "")
|
||||
updated = memory.get("updated_at", "")
|
||||
parts.append(
|
||||
f" <memory category='{category}' updated='{updated}'>{text}</memory>"
|
||||
)
|
||||
parts.append("</user_memories>")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tool Factory Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_save_memory_tool(
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the save_memory tool.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID
|
||||
search_space_id: The search space ID (for space-specific memories)
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for saving user memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def save_memory(
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Save a fact, preference, or context about the user for future reference.
|
||||
|
||||
Use this tool when:
|
||||
- User explicitly says "remember this", "keep this in mind", or similar
|
||||
- User shares personal preferences (e.g., "I prefer Python over JavaScript")
|
||||
- User shares important facts about themselves (name, role, interests, projects)
|
||||
- User gives standing instructions (e.g., "always respond in bullet points")
|
||||
- User shares relevant context (e.g., "I'm working on project X")
|
||||
|
||||
The saved information will be available in future conversations to provide
|
||||
more personalized and contextual responses.
|
||||
|
||||
Args:
|
||||
content: The fact/preference/context to remember.
|
||||
Phrase it clearly, e.g., "User prefers dark mode",
|
||||
"User is a senior Python developer", "User is working on an AI project"
|
||||
category: Type of memory. One of:
|
||||
- "preference": User preferences (e.g., coding style, tools, formats)
|
||||
- "fact": Facts about the user (e.g., name, role, expertise)
|
||||
- "instruction": Standing instructions (e.g., response format preferences)
|
||||
- "context": Current context (e.g., ongoing projects, goals)
|
||||
|
||||
Returns:
|
||||
A dictionary with the save status and memory details
|
||||
"""
|
||||
# Normalize and validate category (LLMs may send uppercase)
|
||||
category = category.lower() if category else "fact"
|
||||
valid_categories = ["preference", "fact", "instruction", "context"]
|
||||
if category not in valid_categories:
|
||||
category = "fact"
|
||||
|
||||
try:
|
||||
# Convert user_id to UUID
|
||||
uuid_user_id = _to_uuid(user_id)
|
||||
|
||||
# Check if we've hit the memory limit
|
||||
memory_count = await get_user_memory_count(
|
||||
db_session, user_id, search_space_id
|
||||
)
|
||||
if memory_count >= MAX_MEMORIES_PER_USER:
|
||||
# Delete oldest memory to make room
|
||||
await delete_oldest_memory(db_session, user_id, search_space_id)
|
||||
|
||||
embedding = await asyncio.to_thread(embed_text, content)
|
||||
|
||||
# Create new memory using ORM
|
||||
# The pgvector Vector column type handles embedding conversion automatically
|
||||
new_memory = UserMemory(
|
||||
user_id=uuid_user_id,
|
||||
search_space_id=search_space_id,
|
||||
memory_text=content,
|
||||
category=MemoryCategory(category), # Convert string to enum
|
||||
embedding=embedding, # Pass embedding directly (list or numpy array)
|
||||
)
|
||||
|
||||
db_session.add(new_memory)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(new_memory)
|
||||
|
||||
return {
|
||||
"status": "saved",
|
||||
"memory_id": new_memory.id,
|
||||
"memory_text": content,
|
||||
"category": category,
|
||||
"message": f"I'll remember: {content}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to save memory for user {user_id}: {e}")
|
||||
# Rollback the session to clear any failed transaction state
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": "Failed to save memory. Please try again.",
|
||||
}
|
||||
|
||||
return save_memory
|
||||
|
||||
|
||||
def create_recall_memory_tool(
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the recall_memory tool.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID
|
||||
search_space_id: The search space ID
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for recalling user memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def recall_memory(
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Recall relevant memories about the user to provide personalized responses.
|
||||
|
||||
Use this tool when:
|
||||
- You need user context to give a better, more personalized answer
|
||||
- User asks about their preferences or past information they shared
|
||||
- User references something they told you before
|
||||
- Personalization would significantly improve the response quality
|
||||
- User asks "what do you know about me?" or similar
|
||||
|
||||
Args:
|
||||
query: Optional search query to find specific memories.
|
||||
If not provided, returns the most recent memories.
|
||||
Example: "programming preferences", "current projects"
|
||||
category: Optional category filter. One of:
|
||||
"preference", "fact", "instruction", "context"
|
||||
If not provided, searches all categories.
|
||||
top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
|
||||
Returns:
|
||||
A dictionary containing relevant memories and formatted context
|
||||
"""
|
||||
top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20
|
||||
|
||||
try:
|
||||
# Convert user_id to UUID
|
||||
uuid_user_id = _to_uuid(user_id)
|
||||
|
||||
if query:
|
||||
query_embedding = await asyncio.to_thread(embed_text, query)
|
||||
|
||||
# Build query with vector similarity
|
||||
stmt = (
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == uuid_user_id)
|
||||
.where(
|
||||
(UserMemory.search_space_id == search_space_id)
|
||||
| (UserMemory.search_space_id.is_(None))
|
||||
)
|
||||
)
|
||||
|
||||
# Add category filter if specified
|
||||
if category and category in [
|
||||
"preference",
|
||||
"fact",
|
||||
"instruction",
|
||||
"context",
|
||||
]:
|
||||
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
|
||||
|
||||
# Order by vector similarity
|
||||
stmt = stmt.order_by(
|
||||
UserMemory.embedding.op("<=>")(query_embedding)
|
||||
).limit(top_k)
|
||||
|
||||
else:
|
||||
# No query - return most recent memories
|
||||
stmt = (
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == uuid_user_id)
|
||||
.where(
|
||||
(UserMemory.search_space_id == search_space_id)
|
||||
| (UserMemory.search_space_id.is_(None))
|
||||
)
|
||||
)
|
||||
|
||||
# Add category filter if specified
|
||||
if category and category in [
|
||||
"preference",
|
||||
"fact",
|
||||
"instruction",
|
||||
"context",
|
||||
]:
|
||||
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
|
||||
|
||||
stmt = stmt.order_by(UserMemory.updated_at.desc()).limit(top_k)
|
||||
|
||||
result = await db_session.execute(stmt)
|
||||
memories = result.scalars().all()
|
||||
|
||||
# Format memories for response
|
||||
memory_list = [
|
||||
{
|
||||
"id": m.id,
|
||||
"memory_text": m.memory_text,
|
||||
"category": m.category.value if m.category else "unknown",
|
||||
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
||||
}
|
||||
for m in memories
|
||||
]
|
||||
|
||||
formatted_context = format_memories_for_context(memory_list)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(memory_list),
|
||||
"memories": memory_list,
|
||||
"formatted_context": formatted_context,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to recall memories for user {user_id}: {e}")
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"memories": [],
|
||||
"formatted_context": "Failed to recall memories.",
|
||||
}
|
||||
|
||||
return recall_memory
|
||||
|
|
@ -861,99 +861,6 @@ class ChatSessionState(BaseModel):
|
|||
ai_responding_to_user = relationship("User")
|
||||
|
||||
|
||||
class MemoryCategory(StrEnum):
|
||||
"""Categories for user memories."""
|
||||
|
||||
# Using lowercase keys to match PostgreSQL enum values
|
||||
preference = "preference" # User preferences (e.g., "prefers dark mode")
|
||||
fact = "fact" # Facts about the user (e.g., "is a Python developer")
|
||||
instruction = (
|
||||
"instruction" # Standing instructions (e.g., "always respond in bullet points")
|
||||
)
|
||||
context = "context" # Contextual information (e.g., "working on project X")
|
||||
|
||||
|
||||
class UserMemory(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Private memory: facts, preferences, context per user per search space.
|
||||
Used only for private chats (not shared/team chats).
|
||||
"""
|
||||
|
||||
__tablename__ = "user_memories"
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
# Optional association with a search space (if memory is space-specific)
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# The actual memory content
|
||||
memory_text = Column(Text, nullable=False)
|
||||
# Category for organization and filtering
|
||||
category = Column(
|
||||
SQLAlchemyEnum(MemoryCategory),
|
||||
nullable=False,
|
||||
default=MemoryCategory.fact,
|
||||
)
|
||||
# Vector embedding for semantic search
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
|
||||
# Track when memory was last updated
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="memories")
|
||||
search_space = relationship("SearchSpace", back_populates="user_memories")
|
||||
|
||||
|
||||
class SharedMemory(BaseModel, TimestampMixin):
|
||||
__tablename__ = "shared_memories"
|
||||
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
created_by_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
memory_text = Column(Text, nullable=False)
|
||||
category = Column(
|
||||
SQLAlchemyEnum(MemoryCategory),
|
||||
nullable=False,
|
||||
default=MemoryCategory.fact,
|
||||
)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
search_space = relationship("SearchSpace", back_populates="shared_memories")
|
||||
created_by = relationship("User")
|
||||
|
||||
|
||||
class Folder(BaseModel, TimestampMixin):
|
||||
__tablename__ = "folders"
|
||||
|
||||
|
|
@ -1394,6 +1301,8 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
Text, nullable=True, default=""
|
||||
) # User's custom instructions
|
||||
|
||||
shared_memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Search space-level LLM preferences (shared by all members)
|
||||
# Note: ID values:
|
||||
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
||||
|
|
@ -1516,20 +1425,6 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# User memories associated with this search space
|
||||
user_memories = relationship(
|
||||
"UserMemory",
|
||||
back_populates="search_space",
|
||||
order_by="UserMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
shared_memories = relationship(
|
||||
"SharedMemory",
|
||||
back_populates="search_space",
|
||||
order_by="SharedMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||
__tablename__ = "search_source_connectors"
|
||||
|
|
@ -2037,14 +1932,6 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
back_populates="user",
|
||||
order_by="UserMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Incentive tasks completed by this user
|
||||
incentive_tasks = relationship(
|
||||
"UserIncentiveTask",
|
||||
|
|
@ -2072,6 +1959,8 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
|
||||
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Refresh tokens for this user
|
||||
refresh_tokens = relationship(
|
||||
"RefreshToken",
|
||||
|
|
@ -2157,14 +2046,6 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
back_populates="user",
|
||||
order_by="UserMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Incentive tasks completed by this user
|
||||
incentive_tasks = relationship(
|
||||
"UserIncentiveTask",
|
||||
|
|
@ -2192,6 +2073,8 @@ else:
|
|||
|
||||
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Refresh tokens for this user
|
||||
refresh_tokens = relationship(
|
||||
"RefreshToken",
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ async def parse_with_azure_doc_intelligence(file_path: str) -> str:
|
|||
async with client:
|
||||
with open(file_path, "rb") as f:
|
||||
poller = await client.begin_analyze_document(
|
||||
"prebuilt-read",
|
||||
"prebuilt-layout",
|
||||
body=f,
|
||||
output_content_format=DocumentContentFormat.MARKDOWN,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
|
|||
from .linear_add_connector_route import router as linear_add_connector_router
|
||||
from .logs_routes import router as logs_router
|
||||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .memory_routes import router as memory_router
|
||||
from .model_list_routes import router as model_list_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
from .new_llm_config_routes import router as new_llm_config_router
|
||||
|
|
@ -100,4 +101,5 @@ router.include_router(incentive_tasks_router) # Incentive tasks for earning fre
|
|||
router.include_router(stripe_router) # Stripe checkout for additional page packs
|
||||
router.include_router(youtube_router) # YouTube playlist resolution
|
||||
router.include_router(prompts_router)
|
||||
router.include_router(memory_router) # User personal memory (memory.md style)
|
||||
router.include_router(autocomplete_router) # Lightweight autocomplete with KB context
|
||||
|
|
|
|||
153
surfsense_backend/app/routes/memory_routes.py
Normal file
153
surfsense_backend/app/routes/memory_routes.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""Routes for user memory management (personal memory.md)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.llm_config import (
|
||||
create_chat_litellm_from_agent_config,
|
||||
load_agent_llm_config_for_search_space,
|
||||
)
|
||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
||||
from app.db import User, get_async_session
|
||||
from app.users import current_active_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MemoryRead(BaseModel):
|
||||
memory_md: str
|
||||
|
||||
|
||||
class MemoryUpdate(BaseModel):
|
||||
memory_md: str
|
||||
|
||||
|
||||
class MemoryEditRequest(BaseModel):
|
||||
query: str
|
||||
search_space_id: int
|
||||
|
||||
|
||||
_MEMORY_EDIT_PROMPT = """\
|
||||
You are a memory editor. The user wants to modify their memory document. \
|
||||
Apply the user's instruction to the existing memory document and output the \
|
||||
FULL updated document.
|
||||
|
||||
RULES:
|
||||
1. If the instruction asks to add something, add it with format: \
|
||||
- (YYYY-MM-DD) [fact|pref|instr] text, under an existing or new ## heading. \
|
||||
Heading names should be personal and descriptive, not generic categories.
|
||||
2. If the instruction asks to remove something, remove the matching entry.
|
||||
3. If the instruction asks to change something, update the matching entry.
|
||||
4. Preserve existing ## headings and all other entries.
|
||||
5. Every bullet must include a marker: [fact], [pref], or [instr].
|
||||
6. Use the user's first name (from <user_name>) in entries instead of "the user".
|
||||
7. Output ONLY the updated markdown — no explanations, no wrapping.
|
||||
|
||||
<user_name>{user_name}</user_name>
|
||||
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
<user_instruction>
|
||||
{instruction}
|
||||
</user_instruction>"""
|
||||
|
||||
|
||||
@router.get("/users/me/memory", response_model=MemoryRead)
|
||||
async def get_user_memory(
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
await session.refresh(user, ["memory_md"])
|
||||
return MemoryRead(memory_md=user.memory_md or "")
|
||||
|
||||
|
||||
@router.put("/users/me/memory", response_model=MemoryRead)
|
||||
async def update_user_memory(
|
||||
body: MemoryUpdate,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if len(body.memory_md) > MEMORY_HARD_LIMIT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit ({len(body.memory_md):,} chars).",
|
||||
)
|
||||
user.memory_md = body.memory_md
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user, ["memory_md"])
|
||||
return MemoryRead(memory_md=user.memory_md or "")
|
||||
|
||||
|
||||
@router.post("/users/me/memory/edit", response_model=MemoryRead)
|
||||
async def edit_user_memory(
|
||||
body: MemoryEditRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Apply a natural language edit to the user's personal memory via LLM."""
|
||||
agent_config = await load_agent_llm_config_for_search_space(
|
||||
session, body.search_space_id
|
||||
)
|
||||
if not agent_config:
|
||||
raise HTTPException(status_code=500, detail="No LLM configuration available.")
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
if not llm:
|
||||
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
|
||||
|
||||
await session.refresh(user, ["memory_md", "display_name"])
|
||||
current_memory = user.memory_md or ""
|
||||
first_name = (
|
||||
user.display_name.strip().split()[0]
|
||||
if user.display_name and user.display_name.strip()
|
||||
else "The user"
|
||||
)
|
||||
|
||||
prompt = _MEMORY_EDIT_PROMPT.format(
|
||||
current_memory=current_memory or "(empty)",
|
||||
instruction=body.query,
|
||||
user_name=first_name,
|
||||
)
|
||||
try:
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "memory-edit"]},
|
||||
)
|
||||
updated = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
).strip()
|
||||
except Exception as e:
|
||||
logger.exception("Memory edit LLM call failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Memory edit failed.") from e
|
||||
|
||||
if not updated:
|
||||
raise HTTPException(status_code=400, detail="LLM returned empty result.")
|
||||
|
||||
result = await _save_memory(
|
||||
updated_memory=updated,
|
||||
old_memory=current_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
|
||||
if result.get("status") == "error":
|
||||
raise HTTPException(status_code=400, detail=result["message"])
|
||||
|
||||
await session.refresh(user, ["memory_md"])
|
||||
return MemoryRead(memory_md=user.memory_md or "")
|
||||
|
|
@ -1,10 +1,17 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.llm_config import (
|
||||
create_chat_litellm_from_agent_config,
|
||||
load_agent_llm_config_for_search_space,
|
||||
)
|
||||
from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_memory
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGenerationConfig,
|
||||
|
|
@ -34,6 +41,34 @@ logger = logging.getLogger(__name__)
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
class _TeamMemoryEditRequest(PydanticBaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
_TEAM_MEMORY_EDIT_PROMPT = """\
|
||||
You are a memory editor for a team workspace. The user wants to modify the \
|
||||
team's shared memory document. Apply the user's instruction to the existing \
|
||||
memory document and output the FULL updated document.
|
||||
|
||||
RULES:
|
||||
1. If the instruction asks to add something, add it with format: \
|
||||
- (YYYY-MM-DD) [fact] text, under an existing or new ## heading. \
|
||||
Heading names should be descriptive, not generic categories.
|
||||
2. If the instruction asks to remove something, remove the matching entry.
|
||||
3. If the instruction asks to change something, update the matching entry.
|
||||
4. Preserve existing ## headings and all other entries.
|
||||
5. NEVER use [pref] or [instr] markers. Team memory uses [fact] only.
|
||||
6. Output ONLY the updated markdown — no explanations, no wrapping.
|
||||
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
<user_instruction>
|
||||
{instruction}
|
||||
</user_instruction>"""
|
||||
|
||||
|
||||
async def create_default_roles_and_membership(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
|
|
@ -255,6 +290,16 @@ async def update_search_space(
|
|||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
update_data = search_space_update.model_dump(exclude_unset=True)
|
||||
|
||||
if (
|
||||
"shared_memory_md" in update_data
|
||||
and len(update_data["shared_memory_md"] or "") > MEMORY_HARD_LIMIT
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Team memory exceeds {MEMORY_HARD_LIMIT:,} character limit.",
|
||||
)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(db_search_space, key, value)
|
||||
await session.commit()
|
||||
|
|
@ -269,6 +314,76 @@ async def update_search_space(
|
|||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/searchspaces/{search_space_id}/memory/edit",
|
||||
response_model=SearchSpaceRead,
|
||||
)
|
||||
async def edit_team_memory(
|
||||
search_space_id: int,
|
||||
body: _TeamMemoryEditRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Apply a natural language edit to the team memory via LLM."""
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
|
||||
agent_config = await load_agent_llm_config_for_search_space(
|
||||
session, search_space_id
|
||||
)
|
||||
if not agent_config:
|
||||
raise HTTPException(status_code=500, detail="No LLM configuration available.")
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
if not llm:
|
||||
raise HTTPException(status_code=500, detail="Failed to create LLM instance.")
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
db_search_space = result.scalars().first()
|
||||
if not db_search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
current_memory = db_search_space.shared_memory_md or ""
|
||||
|
||||
prompt = _TEAM_MEMORY_EDIT_PROMPT.format(
|
||||
current_memory=current_memory or "(empty)",
|
||||
instruction=body.query,
|
||||
)
|
||||
try:
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal", "memory-edit"]},
|
||||
)
|
||||
updated = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
).strip()
|
||||
except Exception as e:
|
||||
logger.exception("Team memory edit LLM call failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Team memory edit failed.") from e
|
||||
|
||||
if not updated:
|
||||
raise HTTPException(status_code=400, detail="LLM returned empty result.")
|
||||
|
||||
save_result = await _save_memory(
|
||||
updated_memory=updated,
|
||||
old_memory=current_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(db_search_space, "shared_memory_md", content),
|
||||
commit_fn=session.commit,
|
||||
rollback_fn=session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
|
||||
if save_result.get("status") == "error":
|
||||
raise HTTPException(status_code=400, detail=save_result["message"])
|
||||
|
||||
await session.refresh(db_search_space)
|
||||
return db_search_space
|
||||
|
||||
|
||||
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
||||
async def delete_search_space(
|
||||
search_space_id: int,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class SearchSpaceUpdate(BaseModel):
|
|||
description: str | None = None
|
||||
citations_enabled: bool | None = None
|
||||
qna_custom_instructions: str | None = None
|
||||
shared_memory_md: str | None = None
|
||||
|
||||
|
||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||
|
|
@ -29,6 +30,7 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
|||
user_id: uuid.UUID
|
||||
citations_enabled: bool
|
||||
qna_custom_instructions: str | None = None
|
||||
shared_memory_md: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,10 @@ from app.agents.new_chat.llm_config import (
|
|||
load_agent_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.agents.new_chat.memory_extraction import (
|
||||
extract_and_save_memory,
|
||||
extract_and_save_team_memory,
|
||||
)
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
|
|
@ -59,8 +63,6 @@ from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_hea
|
|||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def format_mentioned_surfsense_docs_as_context(
|
||||
documents: list[SurfsenseDocsDocument],
|
||||
|
|
@ -141,6 +143,7 @@ class StreamResult:
|
|||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
|
||||
agent_called_update_memory: bool = False
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -183,6 +186,7 @@ async def _stream_agent_events(
|
|||
last_active_step_items: list[str] = initial_step_items or []
|
||||
just_finished_tool: bool = False
|
||||
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
|
||||
called_update_memory: bool = False
|
||||
|
||||
def next_thinking_step_id() -> str:
|
||||
nonlocal thinking_step_counter
|
||||
|
|
@ -490,6 +494,9 @@ async def _stream_agent_events(
|
|||
tool_name = event.get("name", "unknown_tool")
|
||||
raw_output = event.get("data", {}).get("output", "")
|
||||
|
||||
if tool_name == "update_memory":
|
||||
called_update_memory = True
|
||||
|
||||
if hasattr(raw_output, "content"):
|
||||
content = raw_output.content
|
||||
if isinstance(content, str):
|
||||
|
|
@ -1111,6 +1118,7 @@ async def _stream_agent_events(
|
|||
yield completion_event
|
||||
|
||||
result.accumulated_text = accumulated_text
|
||||
result.agent_called_update_memory = called_update_memory
|
||||
|
||||
state = await agent.aget_state(config)
|
||||
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
||||
|
|
@ -1540,6 +1548,27 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Fire background memory extraction if the agent didn't handle it.
|
||||
# Shared threads write to team memory; private threads write to user memory.
|
||||
if not stream_result.agent_called_update_memory:
|
||||
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||
asyncio.create_task(
|
||||
extract_and_save_team_memory(
|
||||
user_message=user_query,
|
||||
search_space_id=search_space_id,
|
||||
llm=llm,
|
||||
author_display_name=current_user_display_name,
|
||||
)
|
||||
)
|
||||
elif user_id:
|
||||
asyncio.create_task(
|
||||
extract_and_save_memory(
|
||||
user_message=user_query,
|
||||
user_id=user_id,
|
||||
llm=llm,
|
||||
)
|
||||
)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
"""Unit tests for memory scope validation and bullet format validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.tools.update_memory import (
|
||||
_save_memory,
|
||||
_validate_bullet_format,
|
||||
_validate_memory_scope,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _Recorder:
|
||||
def __init__(self) -> None:
|
||||
self.applied_content: str | None = None
|
||||
self.commit_calls = 0
|
||||
self.rollback_calls = 0
|
||||
|
||||
def apply(self, content: str) -> None:
|
||||
self.applied_content = content
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commit_calls += 1
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollback_calls += 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_memory_scope — marker-based
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_memory_scope_rejects_pref_marker_in_team_scope() -> None:
|
||||
content = "- (2026-04-10) [pref] Prefers dark mode\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is not None
|
||||
assert result["status"] == "error"
|
||||
assert "[pref]" in result["message"]
|
||||
|
||||
|
||||
def test_validate_memory_scope_rejects_instr_marker_in_team_scope() -> None:
|
||||
content = "- (2026-04-10) [instr] Always respond in Spanish\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is not None
|
||||
assert result["status"] == "error"
|
||||
assert "[instr]" in result["message"]
|
||||
|
||||
|
||||
def test_validate_memory_scope_rejects_both_personal_markers_in_team() -> None:
|
||||
content = (
|
||||
"- (2026-04-10) [pref] Prefers dark mode\n"
|
||||
"- (2026-04-10) [instr] Always respond in Spanish\n"
|
||||
)
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is not None
|
||||
assert result["status"] == "error"
|
||||
assert "[instr]" in result["message"]
|
||||
assert "[pref]" in result["message"]
|
||||
|
||||
|
||||
def test_validate_memory_scope_allows_fact_in_team_scope() -> None:
|
||||
content = "- (2026-04-10) [fact] Office is in downtown Seattle\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_memory_scope_allows_all_markers_in_user_scope() -> None:
|
||||
content = (
|
||||
"- (2026-04-10) [fact] Python developer\n"
|
||||
"- (2026-04-10) [pref] Prefers concise answers\n"
|
||||
"- (2026-04-10) [instr] Always use bullet points\n"
|
||||
)
|
||||
result = _validate_memory_scope(content, "user")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_memory_scope_allows_any_heading_in_team() -> None:
|
||||
content = "## Architecture\n- (2026-04-10) [fact] Uses PostgreSQL for persistence\n"
|
||||
result = _validate_memory_scope(content, "team")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_memory_scope_allows_any_heading_in_user() -> None:
|
||||
content = "## My Projects\n- (2026-04-10) [fact] Working on SurfSense\n"
|
||||
result = _validate_memory_scope(content, "user")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_bullet_format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_bullet_format_passes_valid_bullets() -> None:
|
||||
content = (
|
||||
"## Work\n"
|
||||
"- (2026-04-10) [fact] Senior Python developer\n"
|
||||
"- (2026-04-10) [pref] Prefers dark mode\n"
|
||||
"- (2026-04-10) [instr] Always respond in bullet points\n"
|
||||
)
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert warnings == []
|
||||
|
||||
|
||||
def test_validate_bullet_format_warns_on_missing_marker() -> None:
|
||||
content = "- (2026-04-10) Senior Python developer\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert len(warnings) == 1
|
||||
assert "Malformed bullet" in warnings[0]
|
||||
|
||||
|
||||
def test_validate_bullet_format_warns_on_missing_date() -> None:
|
||||
content = "- [fact] Senior Python developer\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert len(warnings) == 1
|
||||
assert "Malformed bullet" in warnings[0]
|
||||
|
||||
|
||||
def test_validate_bullet_format_warns_on_unknown_marker() -> None:
|
||||
content = "- (2026-04-10) [context] Working on project X\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert len(warnings) == 1
|
||||
assert "Malformed bullet" in warnings[0]
|
||||
|
||||
|
||||
def test_validate_bullet_format_ignores_non_bullet_lines() -> None:
|
||||
content = "## Some Heading\nSome paragraph text\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert warnings == []
|
||||
|
||||
|
||||
def test_validate_bullet_format_warns_on_old_format_without_marker() -> None:
|
||||
content = "## About the user\n- (2026-04-10) Likes cats\n"
|
||||
warnings = _validate_bullet_format(content)
|
||||
assert len(warnings) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _save_memory — end-to-end with marker scope check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_blocks_pref_in_team_before_commit() -> None:
|
||||
recorder = _Recorder()
|
||||
result = await _save_memory(
|
||||
updated_memory="- (2026-04-10) [pref] Prefers dark mode\n",
|
||||
old_memory=None,
|
||||
llm=None,
|
||||
apply_fn=recorder.apply,
|
||||
commit_fn=recorder.commit,
|
||||
rollback_fn=recorder.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert recorder.commit_calls == 0
|
||||
assert recorder.applied_content is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_allows_fact_in_team_and_commits() -> None:
|
||||
recorder = _Recorder()
|
||||
content = "- (2026-04-10) [fact] Weekly standup on Mondays\n"
|
||||
result = await _save_memory(
|
||||
updated_memory=content,
|
||||
old_memory=None,
|
||||
llm=None,
|
||||
apply_fn=recorder.apply,
|
||||
commit_fn=recorder.commit,
|
||||
rollback_fn=recorder.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
assert result["status"] == "saved"
|
||||
assert recorder.commit_calls == 1
|
||||
assert recorder.applied_content == content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memory_includes_format_warnings() -> None:
|
||||
recorder = _Recorder()
|
||||
content = "- (2026-04-10) Missing marker text\n"
|
||||
result = await _save_memory(
|
||||
updated_memory=content,
|
||||
old_memory=None,
|
||||
llm=None,
|
||||
apply_fn=recorder.apply,
|
||||
commit_fn=recorder.commit,
|
||||
rollback_fn=recorder.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
assert result["status"] == "saved"
|
||||
assert "format_warnings" in result
|
||||
assert len(result["format_warnings"]) == 1
|
||||
|
|
@ -0,0 +1,291 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { z } from "zod";
|
||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
|
||||
const MEMORY_HARD_LIMIT = 25_000;
|
||||
|
||||
const MemoryReadSchema = z.object({
|
||||
memory_md: z.string(),
|
||||
});
|
||||
|
||||
export function MemoryContent() {
|
||||
const activeSearchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||
const [memory, setMemory] = useState("");
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [editQuery, setEditQuery] = useState("");
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [showInput, setShowInput] = useState(false);
|
||||
const textareaRef = useRef<HTMLInputElement>(null);
|
||||
const inputContainerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const fetchMemory = useCallback(async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const data = await baseApiService.get("/api/v1/users/me/memory", MemoryReadSchema);
|
||||
setMemory(data.memory_md);
|
||||
} catch {
|
||||
toast.error("Failed to load memory");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
fetchMemory();
|
||||
}, [fetchMemory]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!showInput) return;
|
||||
|
||||
const handlePointerDownOutside = (event: MouseEvent | TouchEvent) => {
|
||||
const target = event.target;
|
||||
if (!(target instanceof Node)) return;
|
||||
if (inputContainerRef.current?.contains(target)) return;
|
||||
|
||||
setShowInput(false);
|
||||
};
|
||||
|
||||
document.addEventListener("mousedown", handlePointerDownOutside);
|
||||
document.addEventListener("touchstart", handlePointerDownOutside, { passive: true });
|
||||
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handlePointerDownOutside);
|
||||
document.removeEventListener("touchstart", handlePointerDownOutside);
|
||||
};
|
||||
}, [showInput]);
|
||||
|
||||
const handleClear = async () => {
|
||||
try {
|
||||
setSaving(true);
|
||||
const data = await baseApiService.put("/api/v1/users/me/memory", MemoryReadSchema, {
|
||||
body: { memory_md: "" },
|
||||
});
|
||||
setMemory(data.memory_md);
|
||||
toast.success("Memory cleared");
|
||||
} catch {
|
||||
toast.error("Failed to clear memory");
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleEdit = async () => {
|
||||
const query = editQuery.trim();
|
||||
if (!query) return;
|
||||
|
||||
try {
|
||||
setEditing(true);
|
||||
const data = await baseApiService.post("/api/v1/users/me/memory/edit", MemoryReadSchema, {
|
||||
body: { query, search_space_id: Number(activeSearchSpaceId) },
|
||||
});
|
||||
setMemory(data.memory_md);
|
||||
setEditQuery("");
|
||||
setShowInput(false);
|
||||
toast.success("Memory updated");
|
||||
} catch {
|
||||
toast.error("Failed to edit memory");
|
||||
} finally {
|
||||
setEditing(false);
|
||||
}
|
||||
};
|
||||
|
||||
const openInput = () => {
|
||||
setShowInput(true);
|
||||
requestAnimationFrame(() => textareaRef.current?.focus());
|
||||
};
|
||||
|
||||
const handleDownload = () => {
|
||||
if (!memory) return;
|
||||
try {
|
||||
const blob = new Blob([memory], { type: "text/markdown;charset=utf-8" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = "personal-memory.md";
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
} catch {
|
||||
toast.error("Failed to download memory");
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyMarkdown = async () => {
|
||||
if (!memory) return;
|
||||
try {
|
||||
await navigator.clipboard.writeText(memory);
|
||||
toast.success("Copied to clipboard");
|
||||
} catch {
|
||||
toast.error("Failed to copy memory");
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleEdit();
|
||||
}
|
||||
};
|
||||
|
||||
const displayMemory = memory.replace(/\(\d{4}-\d{2}-\d{2}\)\s*\[(fact|pref|instr)\]\s*/g, "");
|
||||
const charCount = memory.length;
|
||||
|
||||
const getCounterColor = () => {
|
||||
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
|
||||
if (charCount > 15_000) return "text-orange-500";
|
||||
if (charCount > 10_000) return "text-yellow-500";
|
||||
return "text-muted-foreground";
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spinner size="md" className="text-muted-foreground" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!memory) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-16 text-center">
|
||||
<h3 className="text-base font-medium text-foreground">What does SurfSense remember?</h3>
|
||||
<p className="mt-2 max-w-sm text-sm text-muted-foreground">
|
||||
Nothing yet. SurfSense picks up on your preferences and context as you chat.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<Alert className="bg-muted/50 py-3 md:py-4">
|
||||
<Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
|
||||
<AlertDescription className="text-xs md:text-sm">
|
||||
<p>
|
||||
SurfSense uses this personal memory to personalize your responses across all
|
||||
conversations.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className="relative h-[380px] rounded-lg border bg-background">
|
||||
<div className="h-full overflow-y-auto scrollbar-thin">
|
||||
<PlateEditor
|
||||
markdown={displayMemory}
|
||||
readOnly
|
||||
preset="readonly"
|
||||
variant="default"
|
||||
editorVariant="none"
|
||||
className="px-5 py-4 text-sm min-h-full"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{showInput ? (
|
||||
<div className="absolute bottom-3 inset-x-3 z-10">
|
||||
<div
|
||||
ref={inputContainerRef}
|
||||
className="relative flex h-[54px] items-center gap-2 rounded-[9999px] border bg-muted/60 backdrop-blur-sm pl-4 pr-1 shadow-sm"
|
||||
>
|
||||
<input
|
||||
ref={textareaRef}
|
||||
type="text"
|
||||
value={editQuery}
|
||||
onChange={(e) => setEditQuery(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Tell SurfSense what to remember or forget"
|
||||
disabled={editing}
|
||||
className="flex-1 bg-transparent text-sm outline-none placeholder:text-muted-foreground/70"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
onClick={handleEdit}
|
||||
disabled={editing || !editQuery.trim()}
|
||||
className={`h-11 w-11 shrink-0 rounded-full ${
|
||||
editing ? "" : "bg-muted-foreground/15 hover:bg-muted-foreground/20"
|
||||
}`}
|
||||
>
|
||||
{editing ? (
|
||||
<Spinner size="sm" />
|
||||
) : (
|
||||
<ArrowUp className="!h-5 !w-5 text-foreground" strokeWidth={2.25} />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
size="icon"
|
||||
variant="secondary"
|
||||
onClick={openInput}
|
||||
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
|
||||
>
|
||||
<Pen className="!h-5 !w-5" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className={`text-xs shrink-0 ${getCounterColor()}`}>
|
||||
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()}
|
||||
<span className="hidden sm:inline"> characters</span>
|
||||
<span className="sm:hidden"> chars</span>
|
||||
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
|
||||
{charCount > MEMORY_HARD_LIMIT && " - Exceeds limit"}
|
||||
</span>
|
||||
<div className="flex items-center gap-1.5 sm:gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
className="text-xs sm:text-sm"
|
||||
onClick={handleClear}
|
||||
disabled={saving || editing || !memory}
|
||||
>
|
||||
<span className="hidden sm:inline">Reset Memory</span>
|
||||
<span className="sm:hidden">Reset</span>
|
||||
</Button>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button type="button" variant="secondary" size="sm" disabled={!memory}>
|
||||
Export
|
||||
<ChevronDown className="h-3 w-3 opacity-60" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem onClick={handleCopyMarkdown}>
|
||||
<ClipboardCopy className="h-4 w-4 mr-2" />
|
||||
Copy as Markdown
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handleDownload}>
|
||||
<Download className="h-4 w-4 mr-2" />
|
||||
Download as Markdown
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { Receipt } from "lucide-react";
|
||||
import { ReceiptText } from "lucide-react";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import {
|
||||
|
|
@ -65,7 +65,7 @@ export function PurchaseHistoryContent() {
|
|||
if (purchases.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center gap-2 py-16 text-center">
|
||||
<Receipt className="h-8 w-8 text-muted-foreground" />
|
||||
<ReceiptText className="h-8 w-8 text-muted-foreground" />
|
||||
<p className="text-sm font-medium">No purchases yet</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Your page-pack purchases will appear here after checkout.
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ export const viewport: Viewport = {
|
|||
export const metadata: Metadata = {
|
||||
title: "SurfSense - Open Source NotebookLM Alternative for Teams",
|
||||
description:
|
||||
"Connect any LLM to your internal knowledge sources and chat with it in real time alongside your team. SurfSense is an open source alternative to NotebookLM, built for enterprise AI search and knowledge management.",
|
||||
"An open source, privacy focused alternative to NotebookLM for teams with no data limits, built for enterprise AI search and knowledge management.",
|
||||
keywords: [
|
||||
"enterprise ai",
|
||||
"enterprise search",
|
||||
|
|
@ -85,7 +85,7 @@ export const metadata: Metadata = {
|
|||
openGraph: {
|
||||
title: "SurfSense - Open Source NotebookLM Alternative for Teams",
|
||||
description:
|
||||
"Connect any LLM to your internal knowledge sources and chat with it in real time alongside your team. Open source enterprise AI search and knowledge management.",
|
||||
"An open source, privacy focused alternative to NotebookLM for teams with no data limits. Open source enterprise AI search and knowledge management.",
|
||||
url: "https://surfsense.com",
|
||||
siteName: "SurfSense",
|
||||
type: "website",
|
||||
|
|
@ -103,7 +103,7 @@ export const metadata: Metadata = {
|
|||
card: "summary_large_image",
|
||||
title: "SurfSense - Open Source NotebookLM Alternative for Teams",
|
||||
description:
|
||||
"Connect any LLM to your internal knowledge sources and chat with it in real time alongside your team. Open source enterprise AI search and knowledge management.",
|
||||
"An open source, privacy focused alternative to NotebookLM for teams with no data limits. Open source enterprise AI search and knowledge management.",
|
||||
creator: "https://surfsense.com",
|
||||
site: "https://surfsense.com",
|
||||
images: [
|
||||
|
|
|
|||
|
|
@ -21,5 +21,3 @@ export const userSettingsDialogAtom = atom<UserSettingsDialogState>({
|
|||
});
|
||||
|
||||
export const teamDialogAtom = atom<boolean>(false);
|
||||
|
||||
export const morePagesDialogAtom = atom<boolean>(false);
|
||||
|
|
|
|||
|
|
@ -76,12 +76,8 @@ const GenerateImageToolUI = dynamic(
|
|||
import("@/components/tool-ui/generate-image").then((m) => ({ default: m.GenerateImageToolUI })),
|
||||
{ ssr: false }
|
||||
);
|
||||
const SaveMemoryToolUI = dynamic(
|
||||
() => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.SaveMemoryToolUI })),
|
||||
{ ssr: false }
|
||||
);
|
||||
const RecallMemoryToolUI = dynamic(
|
||||
() => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.RecallMemoryToolUI })),
|
||||
const UpdateMemoryToolUI = dynamic(
|
||||
() => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.UpdateMemoryToolUI })),
|
||||
{ ssr: false }
|
||||
);
|
||||
const SandboxExecuteToolUI = dynamic(
|
||||
|
|
@ -386,8 +382,7 @@ const AssistantMessageInner: FC = () => {
|
|||
generate_video_presentation: GenerateVideoPresentationToolUI,
|
||||
display_image: GenerateImageToolUI,
|
||||
generate_image: GenerateImageToolUI,
|
||||
save_memory: SaveMemoryToolUI,
|
||||
recall_memory: RecallMemoryToolUI,
|
||||
update_memory: UpdateMemoryToolUI,
|
||||
execute: SandboxExecuteToolUI,
|
||||
create_notion_page: CreateNotionPageToolUI,
|
||||
update_notion_page: UpdateNotionPageToolUI,
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ import {
|
|||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
|
|
@ -92,7 +93,7 @@ import { useMediaQuery } from "@/hooks/use-media-query";
|
|||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const COMPOSER_PLACEHOLDER = "Ask anything · Type / for prompts · Type @ to mention docs";
|
||||
const COMPOSER_PLACEHOLDER = "Ask anything, type / for prompts, type @ to mention docs";
|
||||
|
||||
export const Thread: FC = () => {
|
||||
return <ThreadContent />;
|
||||
|
|
@ -804,7 +805,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
const isDesktop = useMediaQuery("(min-width: 640px)");
|
||||
const { openDialog: openUploadDialog } = useDocumentUploadDialog();
|
||||
const [toolsScrollPos, setToolsScrollPos] = useState<"top" | "middle" | "bottom">("top");
|
||||
const toolsRafRef = useRef<number>();
|
||||
const toolsRafRef = useRef<number | undefined>(undefined);
|
||||
const handleToolsScroll = useCallback((e: React.UIEvent<HTMLDivElement>) => {
|
||||
const el = e.currentTarget;
|
||||
if (toolsRafRef.current) return;
|
||||
|
|
@ -1021,8 +1022,23 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
</div>
|
||||
)}
|
||||
{!filteredTools?.length && (
|
||||
<div className="px-4 py-6 text-center text-sm text-muted-foreground">
|
||||
Loading tools...
|
||||
<div className="px-4 pt-3 pb-2">
|
||||
<Skeleton className="h-3 w-16 mb-2" />
|
||||
{["t1", "t2", "t3", "t4"].map((k) => (
|
||||
<div key={k} className="flex items-center gap-3 py-2">
|
||||
<Skeleton className="size-4 rounded shrink-0" />
|
||||
<Skeleton className="h-3.5 flex-1" />
|
||||
<Skeleton className="h-5 w-9 rounded-full shrink-0" />
|
||||
</div>
|
||||
))}
|
||||
<Skeleton className="h-3 w-24 mt-3 mb-2" />
|
||||
{["c1", "c2", "c3"].map((k) => (
|
||||
<div key={k} className="flex items-center gap-3 py-2">
|
||||
<Skeleton className="size-4 rounded shrink-0" />
|
||||
<Skeleton className="h-3.5 flex-1" />
|
||||
<Skeleton className="h-5 w-9 rounded-full shrink-0" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
|
@ -1058,12 +1074,12 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
side="bottom"
|
||||
align="start"
|
||||
sideOffset={12}
|
||||
className="w-[calc(100vw-2rem)] max-w-56 sm:max-w-72 sm:w-72 p-0 select-none"
|
||||
className="w-[calc(100vw-2rem)] max-w-48 sm:max-w-56 sm:w-56 p-0 select-none"
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<div className="sr-only">Manage Tools</div>
|
||||
<div
|
||||
className="max-h-48 sm:max-h-64 overflow-y-auto overscroll-none py-0.5 sm:py-1"
|
||||
className="max-h-44 sm:max-h-56 overflow-y-auto overscroll-none py-0.5"
|
||||
onScroll={handleToolsScroll}
|
||||
style={{
|
||||
maskImage: `linear-gradient(to bottom, ${toolsScrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${toolsScrollPos === "bottom" ? "black" : "transparent"})`,
|
||||
|
|
@ -1074,22 +1090,22 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
.filter((g) => !g.connectorIcon)
|
||||
.map((group) => (
|
||||
<div key={group.label}>
|
||||
<div className="px-2.5 sm:px-3 pt-2 pb-0.5 text-[10px] sm:text-xs text-muted-foreground/80 font-normal select-none">
|
||||
<div className="px-2 sm:px-2.5 pt-1.5 pb-0.5 text-[9px] sm:text-[10px] text-muted-foreground/80 font-normal select-none">
|
||||
{group.label}
|
||||
</div>
|
||||
{group.tools.map((tool) => {
|
||||
const isDisabled = disabledToolsSet.has(tool.name);
|
||||
const ToolIcon = getToolIcon(tool.name);
|
||||
const row = (
|
||||
<div className="flex w-full items-center gap-2 sm:gap-3 px-2.5 sm:px-3 py-1 sm:py-1.5 hover:bg-muted-foreground/10 transition-colors">
|
||||
<ToolIcon className="size-3.5 sm:size-4 shrink-0 text-muted-foreground" />
|
||||
<span className="flex-1 min-w-0 text-xs sm:text-sm font-medium truncate">
|
||||
<div className="flex w-full items-center gap-1.5 sm:gap-2 px-2 sm:px-2.5 py-0.5 sm:py-1 hover:bg-muted-foreground/10 transition-colors">
|
||||
<ToolIcon className="size-3 sm:size-3.5 shrink-0 text-muted-foreground" />
|
||||
<span className="flex-1 min-w-0 text-[11px] sm:text-xs font-medium truncate">
|
||||
{formatToolName(tool.name)}
|
||||
</span>
|
||||
<Switch
|
||||
checked={!isDisabled}
|
||||
onCheckedChange={() => toggleTool(tool.name)}
|
||||
className="shrink-0 scale-[0.6] sm:scale-75"
|
||||
className="shrink-0 scale-50 sm:scale-[0.6]"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -1106,7 +1122,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
))}
|
||||
{groupedTools.some((g) => g.connectorIcon) && (
|
||||
<div>
|
||||
<div className="px-2.5 sm:px-3 pt-2 pb-0.5 text-[10px] sm:text-xs text-muted-foreground/80 font-normal select-none">
|
||||
<div className="px-2 sm:px-2.5 pt-1.5 pb-0.5 text-[9px] sm:text-[10px] text-muted-foreground/80 font-normal select-none">
|
||||
Connector Actions
|
||||
</div>
|
||||
{groupedTools
|
||||
|
|
@ -1118,26 +1134,26 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
const allDisabled = toolNames.every((n) => disabledToolsSet.has(n));
|
||||
const groupDef = TOOL_GROUPS.find((g) => g.label === group.label);
|
||||
const row = (
|
||||
<div className="flex w-full items-center gap-2 sm:gap-3 px-2.5 sm:px-3 py-1 sm:py-1.5 hover:bg-muted-foreground/10 transition-colors">
|
||||
<div className="flex w-full items-center gap-1.5 sm:gap-2 px-2 sm:px-2.5 py-0.5 sm:py-1 hover:bg-muted-foreground/10 transition-colors">
|
||||
{iconInfo ? (
|
||||
<Image
|
||||
src={iconInfo.src}
|
||||
alt={iconInfo.alt}
|
||||
width={16}
|
||||
height={16}
|
||||
className="size-3.5 sm:size-4 shrink-0 select-none pointer-events-none"
|
||||
width={14}
|
||||
height={14}
|
||||
className="size-3 sm:size-3.5 shrink-0 select-none pointer-events-none"
|
||||
draggable={false}
|
||||
/>
|
||||
) : (
|
||||
<Wrench className="size-3.5 sm:size-4 shrink-0 text-muted-foreground" />
|
||||
<Wrench className="size-3 sm:size-3.5 shrink-0 text-muted-foreground" />
|
||||
)}
|
||||
<span className="flex-1 min-w-0 text-xs sm:text-sm font-medium truncate">
|
||||
<span className="flex-1 min-w-0 text-[11px] sm:text-xs font-medium truncate">
|
||||
{group.label}
|
||||
</span>
|
||||
<Switch
|
||||
checked={!allDisabled}
|
||||
onCheckedChange={() => toggleToolGroup(toolNames)}
|
||||
className="shrink-0 scale-[0.6] sm:scale-75"
|
||||
className="shrink-0 scale-50 sm:scale-[0.6]"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -1158,8 +1174,23 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
</div>
|
||||
)}
|
||||
{!filteredTools?.length && (
|
||||
<div className="px-3 py-4 text-center text-xs text-muted-foreground">
|
||||
Loading tools...
|
||||
<div className="px-2 sm:px-2.5 pt-1.5 pb-1">
|
||||
<Skeleton className="h-2 w-12 mb-1.5" />
|
||||
{["dt1", "dt2", "dt3", "dt4"].map((k) => (
|
||||
<div key={k} className="flex items-center gap-1.5 sm:gap-2 py-0.5 sm:py-1">
|
||||
<Skeleton className="size-3 sm:size-3.5 rounded shrink-0" />
|
||||
<Skeleton className="h-2.5 sm:h-3 flex-1" />
|
||||
<Skeleton className="h-3.5 sm:h-4 w-7 sm:w-8 rounded-full shrink-0" />
|
||||
</div>
|
||||
))}
|
||||
<Skeleton className="h-2 w-20 mt-2 mb-1.5" />
|
||||
{["dc1", "dc2", "dc3"].map((k) => (
|
||||
<div key={k} className="flex items-center gap-1.5 sm:gap-2 py-0.5 sm:py-1">
|
||||
<Skeleton className="size-3 sm:size-3.5 rounded shrink-0" />
|
||||
<Skeleton className="h-2.5 sm:h-3 flex-1" />
|
||||
<Skeleton className="h-3.5 sm:h-4 w-7 sm:w-8 rounded-full shrink-0" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
|
@ -1297,7 +1328,7 @@ const TOOL_GROUPS: ToolGroup[] = [
|
|||
},
|
||||
{
|
||||
label: "Memory",
|
||||
tools: ["save_memory", "recall_memory"],
|
||||
tools: ["update_memory"],
|
||||
},
|
||||
{
|
||||
label: "Gmail",
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import { MarkdownPlugin, remarkMdx } from "@platejs/markdown";
|
||||
import { slateToHtml } from "@slate-serializers/html";
|
||||
import type { AnyPluginConfig, Descendant, Value } from "platejs";
|
||||
import { createPlatePlugin, Key, Plate, usePlateEditor } from "platejs/react";
|
||||
import { createPlatePlugin, Key, Plate, useEditorReadOnly, usePlateEditor } from "platejs/react";
|
||||
import { useEffect, useMemo, useRef } from "react";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
|
|
@ -60,6 +60,24 @@ export interface PlateEditorProps {
|
|||
extraPlugins?: AnyPluginConfig[];
|
||||
}
|
||||
|
||||
function PlateEditorContent({
|
||||
editorVariant,
|
||||
placeholder,
|
||||
}: {
|
||||
editorVariant: PlateEditorProps["editorVariant"];
|
||||
placeholder?: string;
|
||||
}) {
|
||||
const isReadOnly = useEditorReadOnly();
|
||||
|
||||
return (
|
||||
<Editor
|
||||
variant={editorVariant}
|
||||
placeholder={isReadOnly ? undefined : placeholder}
|
||||
className="min-h-full"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export function PlateEditor({
|
||||
markdown,
|
||||
html,
|
||||
|
|
@ -188,7 +206,7 @@ export function PlateEditor({
|
|||
}}
|
||||
>
|
||||
<EditorContainer variant={variant} className={className}>
|
||||
<Editor variant={editorVariant} placeholder={placeholder} />
|
||||
<PlateEditorContent editorVariant={editorVariant} placeholder={placeholder} />
|
||||
</EditorContainer>
|
||||
</Plate>
|
||||
</EditorSaveContext.Provider>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import type { AnyPluginConfig } from "platejs";
|
||||
import { TrailingBlockPlugin } from "platejs";
|
||||
|
||||
import { AutoformatKit } from "@/components/editor/plugins/autoformat-kit";
|
||||
import { BasicNodesKit } from "@/components/editor/plugins/basic-nodes-kit";
|
||||
|
|
@ -36,6 +37,7 @@ export const fullPreset: AnyPluginConfig[] = [
|
|||
...FloatingToolbarKit,
|
||||
...AutoformatKit,
|
||||
...DndKit,
|
||||
TrailingBlockPlugin,
|
||||
];
|
||||
|
||||
/**
|
||||
|
|
@ -48,8 +50,8 @@ export const minimalPreset: AnyPluginConfig[] = [
|
|||
...ListKit,
|
||||
...CodeBlockKit,
|
||||
...LinkKit,
|
||||
...FloatingToolbarKit,
|
||||
...AutoformatKit,
|
||||
TrailingBlockPlugin,
|
||||
];
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"use client";
|
||||
|
||||
import { IconBrandGithub } from "@tabler/icons-react";
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { motion, useMotionValue, useSpring } from "motion/react";
|
||||
import * as React from "react";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-digit scrolling wheel
|
||||
|
|
@ -244,28 +246,21 @@ function NavbarGitHubStars({
|
|||
href = "https://github.com/MODSetter/SurfSense",
|
||||
className,
|
||||
}: NavbarGitHubStarsProps) {
|
||||
const [stars, setStars] = React.useState(0);
|
||||
const [isLoading, setIsLoading] = React.useState(true);
|
||||
|
||||
React.useEffect(() => {
|
||||
const abortController = new AbortController();
|
||||
fetch(`https://api.github.com/repos/${username}/${repo}`, {
|
||||
signal: abortController.signal,
|
||||
})
|
||||
.then((res) => res.json())
|
||||
.then((data) => {
|
||||
if (data && typeof data.stargazers_count === "number") {
|
||||
setStars(data.stargazers_count);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
if (err instanceof Error && err.name !== "AbortError") {
|
||||
console.error("Error fetching stars:", err);
|
||||
}
|
||||
})
|
||||
.finally(() => setIsLoading(false));
|
||||
return () => abortController.abort();
|
||||
}, [username, repo]);
|
||||
const { data: stars = 0, isLoading } = useQuery({
|
||||
queryKey: cacheKeys.github.repoStars(username, repo),
|
||||
queryFn: async ({ signal }) => {
|
||||
const res = await fetch(
|
||||
`https://api.github.com/repos/${username}/${repo}`,
|
||||
{ signal },
|
||||
);
|
||||
const data = await res.json();
|
||||
if (data && typeof data.stargazers_count === "number") {
|
||||
return data.stargazers_count as number;
|
||||
}
|
||||
return 0;
|
||||
},
|
||||
staleTime: 5 * 60 * 1000,
|
||||
});
|
||||
|
||||
return (
|
||||
<a
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import { ChevronDown, Download, Monitor } from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import Link from "next/link";
|
||||
import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import React, { memo, useCallback, useEffect, useRef, useState } from "react";
|
||||
import Balancer from "react-wrap-balancer";
|
||||
import {
|
||||
DropdownMenu,
|
||||
|
|
@ -12,6 +12,11 @@ import {
|
|||
} from "@/components/ui/dropdown-menu";
|
||||
import { ExpandedMediaOverlay, useExpandedMedia } from "@/components/ui/expanded-gif-overlay";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import {
|
||||
GITHUB_RELEASES_URL,
|
||||
getAssetLabel,
|
||||
usePrimaryDownload,
|
||||
} from "@/lib/desktop-download-utils";
|
||||
import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config";
|
||||
import { trackLoginAttempt } from "@/lib/posthog/events";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
|
@ -200,107 +205,8 @@ function GetStartedButton() {
|
|||
);
|
||||
}
|
||||
|
||||
type OSInfo = {
|
||||
os: "macOS" | "Windows" | "Linux";
|
||||
arch: "arm64" | "x64";
|
||||
};
|
||||
|
||||
function useUserOS(): OSInfo {
|
||||
const [info, setInfo] = useState<OSInfo>({ os: "macOS", arch: "arm64" });
|
||||
useEffect(() => {
|
||||
const ua = navigator.userAgent;
|
||||
let os: OSInfo["os"] = "macOS";
|
||||
let arch: OSInfo["arch"] = "x64";
|
||||
|
||||
if (/Windows/i.test(ua)) {
|
||||
os = "Windows";
|
||||
arch = "x64";
|
||||
} else if (/Linux/i.test(ua)) {
|
||||
os = "Linux";
|
||||
arch = "x64";
|
||||
} else {
|
||||
os = "macOS";
|
||||
arch = /Mac/.test(ua) && !/Intel/.test(ua) ? "arm64" : "arm64";
|
||||
}
|
||||
|
||||
const uaData = (navigator as Navigator & { userAgentData?: { architecture?: string } })
|
||||
.userAgentData;
|
||||
if (uaData?.architecture === "arm") arch = "arm64";
|
||||
else if (uaData?.architecture === "x86") arch = "x64";
|
||||
|
||||
setInfo({ os, arch });
|
||||
}, []);
|
||||
return info;
|
||||
}
|
||||
|
||||
interface ReleaseAsset {
|
||||
name: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
function useLatestRelease() {
|
||||
const [assets, setAssets] = useState<ReleaseAsset[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
fetch("https://api.github.com/repos/MODSetter/SurfSense/releases/latest", {
|
||||
signal: controller.signal,
|
||||
})
|
||||
.then((r) => r.json())
|
||||
.then((data) => {
|
||||
if (data?.assets) {
|
||||
setAssets(
|
||||
data.assets
|
||||
.filter((a: { name: string }) => /\.(exe|dmg|AppImage|deb)$/.test(a.name))
|
||||
.map((a: { name: string; browser_download_url: string }) => ({
|
||||
name: a.name,
|
||||
url: a.browser_download_url,
|
||||
}))
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch(() => {});
|
||||
return () => controller.abort();
|
||||
}, []);
|
||||
|
||||
return assets;
|
||||
}
|
||||
|
||||
const ASSET_LABELS: Record<string, string> = {
|
||||
".exe": "Windows (exe)",
|
||||
"-arm64.dmg": "macOS Apple Silicon (dmg)",
|
||||
"-x64.dmg": "macOS Intel (dmg)",
|
||||
"-arm64.zip": "macOS Apple Silicon (zip)",
|
||||
"-x64.zip": "macOS Intel (zip)",
|
||||
".AppImage": "Linux (AppImage)",
|
||||
".deb": "Linux (deb)",
|
||||
};
|
||||
|
||||
function getAssetLabel(name: string): string {
|
||||
for (const [suffix, label] of Object.entries(ASSET_LABELS)) {
|
||||
if (name.endsWith(suffix)) return label;
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
function DownloadButton() {
|
||||
const { os, arch } = useUserOS();
|
||||
const assets = useLatestRelease();
|
||||
|
||||
const { primary, alternatives } = useMemo(() => {
|
||||
if (assets.length === 0) return { primary: null, alternatives: [] };
|
||||
|
||||
const matchers: Record<string, (n: string) => boolean> = {
|
||||
Windows: (n) => n.endsWith(".exe"),
|
||||
macOS: (n) => n.endsWith(`-${arch}.dmg`),
|
||||
Linux: (n) => n.endsWith(".AppImage"),
|
||||
};
|
||||
|
||||
const match = matchers[os];
|
||||
const primary = assets.find((a) => match(a.name)) ?? null;
|
||||
const alternatives = assets.filter((a) => a !== primary);
|
||||
return { primary, alternatives };
|
||||
}, [assets, os, arch]);
|
||||
const { os, primary, alternatives } = usePrimaryDownload();
|
||||
|
||||
const fallbackUrl = GITHUB_RELEASES_URL;
|
||||
|
||||
|
|
@ -504,5 +410,3 @@ const TabVideo = memo(function TabVideo({ src }: { src: string }) {
|
|||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
const GITHUB_RELEASES_URL = "https://github.com/MODSetter/SurfSense/releases/latest";
|
||||
|
|
|
|||
|
|
@ -9,11 +9,28 @@ import { Logo } from "@/components/Logo";
|
|||
import { ThemeTogglerComponent } from "@/components/theme/theme-toggle";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface NavItem {
|
||||
name: string;
|
||||
link: string;
|
||||
}
|
||||
|
||||
interface NavbarProps {
|
||||
/** Override the scrolled-state background classes (desktop & mobile). */
|
||||
scrolledBgClassName?: string;
|
||||
}
|
||||
|
||||
interface DesktopNavProps {
|
||||
navItems: NavItem[];
|
||||
isScrolled: boolean;
|
||||
scrolledBgClassName?: string;
|
||||
}
|
||||
|
||||
interface MobileNavProps {
|
||||
navItems: NavItem[];
|
||||
isScrolled: boolean;
|
||||
scrolledBgClassName?: string;
|
||||
}
|
||||
|
||||
export const Navbar = ({ scrolledBgClassName }: NavbarProps = {}) => {
|
||||
const [isScrolled, setIsScrolled] = useState(false);
|
||||
|
||||
|
|
@ -52,7 +69,7 @@ export const Navbar = ({ scrolledBgClassName }: NavbarProps = {}) => {
|
|||
);
|
||||
};
|
||||
|
||||
const DesktopNav = ({ navItems, isScrolled, scrolledBgClassName }: any) => {
|
||||
const DesktopNav = ({ navItems, isScrolled, scrolledBgClassName }: DesktopNavProps) => {
|
||||
const [hovered, setHovered] = useState<number | null>(null);
|
||||
return (
|
||||
<motion.div
|
||||
|
|
@ -75,7 +92,7 @@ const DesktopNav = ({ navItems, isScrolled, scrolledBgClassName }: any) => {
|
|||
<span className="dark:text-white/90 text-gray-800 text-lg font-bold">SurfSense</span>
|
||||
</Link>
|
||||
<div className="hidden flex-1 flex-row items-center justify-center space-x-2 text-sm font-medium text-zinc-600 transition duration-200 hover:text-zinc-800 lg:flex lg:space-x-2">
|
||||
{navItems.map((navItem: any, idx: number) => (
|
||||
{navItems.map((navItem: NavItem, idx: number) => (
|
||||
<Link
|
||||
onMouseEnter={() => setHovered(idx)}
|
||||
onMouseLeave={() => setHovered(null)}
|
||||
|
|
@ -118,7 +135,7 @@ const DesktopNav = ({ navItems, isScrolled, scrolledBgClassName }: any) => {
|
|||
);
|
||||
};
|
||||
|
||||
const MobileNav = ({ navItems, isScrolled, scrolledBgClassName }: any) => {
|
||||
const MobileNav = ({ navItems, isScrolled, scrolledBgClassName }: MobileNavProps) => {
|
||||
const [open, setOpen] = useState(false);
|
||||
const navRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
|
|
@ -183,7 +200,7 @@ const MobileNav = ({ navItems, isScrolled, scrolledBgClassName }: any) => {
|
|||
transition={{ duration: 0.2, ease: "easeOut" }}
|
||||
className="absolute inset-x-0 top-full mt-1 z-20 flex w-full flex-col items-start justify-start gap-4 rounded-xl bg-white/90 backdrop-blur-xl border border-white/20 shadow-2xl px-4 py-6 dark:bg-neutral-950/90 dark:border-neutral-800/50"
|
||||
>
|
||||
{navItems.map((navItem: any, idx: number) => (
|
||||
{navItems.map((navItem: NavItem, idx: number) => (
|
||||
<Link
|
||||
key={`link=${idx}`}
|
||||
href={navItem.link}
|
||||
|
|
|
|||
|
|
@ -91,12 +91,12 @@ export function SidebarSlideOutPanel({
|
|||
|
||||
{/* Panel extending from sidebar's right edge, flush with the wrapper border */}
|
||||
<motion.div
|
||||
style={{ width, left: "100%", top: -1, bottom: -1 }}
|
||||
initial={{ x: -width }}
|
||||
animate={{ x: 0 }}
|
||||
exit={{ x: -width }}
|
||||
initial={{ width: 0 }}
|
||||
animate={{ width }}
|
||||
exit={{ width: 0 }}
|
||||
transition={{ type: "tween", duration: 0.2, ease: [0.4, 0, 0.2, 1] }}
|
||||
className="absolute z-20 overflow-hidden"
|
||||
style={{ left: "100%", top: -1, bottom: -1 }}
|
||||
>
|
||||
<div
|
||||
style={{ width }}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import {
|
||||
Check,
|
||||
ChevronUp,
|
||||
Download,
|
||||
ExternalLink,
|
||||
Info,
|
||||
Languages,
|
||||
|
|
@ -29,6 +30,8 @@ import {
|
|||
} from "@/components/ui/dropdown-menu";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { useLocaleContext } from "@/contexts/LocaleContext";
|
||||
import { usePlatform } from "@/hooks/use-platform";
|
||||
import { GITHUB_RELEASES_URL, usePrimaryDownload } from "@/lib/desktop-download-utils";
|
||||
import { APP_VERSION } from "@/lib/env-config";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { User } from "../../types/layout.types";
|
||||
|
|
@ -149,10 +152,13 @@ export function SidebarUserProfile({
|
|||
}: SidebarUserProfileProps) {
|
||||
const t = useTranslations("sidebar");
|
||||
const { locale, setLocale } = useLocaleContext();
|
||||
const { isDesktop } = usePlatform();
|
||||
const { os, primary } = usePrimaryDownload();
|
||||
const [isLoggingOut, setIsLoggingOut] = useState(false);
|
||||
const bgColor = stringToColor(user.email);
|
||||
const initials = getInitials(user.email);
|
||||
const displayName = user.name || user.email.split("@")[0];
|
||||
const downloadUrl = primary?.url ?? GITHUB_RELEASES_URL;
|
||||
|
||||
const handleLanguageChange = (newLocale: "en" | "es" | "pt" | "hi" | "zh") => {
|
||||
setLocale(newLocale);
|
||||
|
|
@ -294,6 +300,15 @@ export function SidebarUserProfile({
|
|||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
|
||||
{!isDesktop && (
|
||||
<DropdownMenuItem asChild className="font-medium">
|
||||
<a href={downloadUrl} target="_blank" rel="noopener noreferrer">
|
||||
<Download className="h-4 w-4" strokeWidth={2.5} />
|
||||
{t("download_for_os", { os })}
|
||||
</a>
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
|
||||
<DropdownMenuSeparator className="dark:bg-neutral-700" />
|
||||
|
||||
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
||||
|
|
@ -439,6 +454,15 @@ export function SidebarUserProfile({
|
|||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
|
||||
{!isDesktop && (
|
||||
<DropdownMenuItem asChild className="font-medium">
|
||||
<a href={downloadUrl} target="_blank" rel="noopener noreferrer">
|
||||
<Download className="h-4 w-4" strokeWidth={2.5} />
|
||||
{t("download_for_os", { os })}
|
||||
</a>
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
|
||||
<DropdownMenuSeparator className="dark:bg-neutral-700" />
|
||||
|
||||
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
"use client";
|
||||
|
||||
import { Pricing } from "@/components/pricing";
|
||||
|
||||
const demoPlans = [
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ export function PublicChatSnapshotsManager({
|
|||
});
|
||||
} catch (error) {
|
||||
console.error("Failed to delete snapshot:", error);
|
||||
toast.error("Failed to delete snapshot");
|
||||
} finally {
|
||||
setDeletingId(undefined);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,14 +18,39 @@ import {
|
|||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
|
||||
function ReportPanelSkeleton() {
|
||||
return (
|
||||
<div className="space-y-6 p-6">
|
||||
<div className="h-6 w-3/4 rounded-md bg-muted/60 animate-pulse" />
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse" />
|
||||
<div className="h-3 w-[95%] rounded-md bg-muted/60 animate-pulse [animation-delay:100ms]" />
|
||||
<div className="h-3 w-[88%] rounded-md bg-muted/60 animate-pulse [animation-delay:200ms]" />
|
||||
<div className="h-3 w-[60%] rounded-md bg-muted/60 animate-pulse [animation-delay:300ms]" />
|
||||
</div>
|
||||
<div className="h-5 w-2/5 rounded-md bg-muted/60 animate-pulse [animation-delay:400ms]" />
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:500ms]" />
|
||||
<div className="h-3 w-[92%] rounded-md bg-muted/60 animate-pulse [animation-delay:600ms]" />
|
||||
<div className="h-3 w-[97%] rounded-md bg-muted/60 animate-pulse [animation-delay:700ms]" />
|
||||
</div>
|
||||
<div className="h-5 w-1/3 rounded-md bg-muted/60 animate-pulse [animation-delay:800ms]" />
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-[90%] rounded-md bg-muted/60 animate-pulse [animation-delay:900ms]" />
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:1000ms]" />
|
||||
<div className="h-3 w-[75%] rounded-md bg-muted/60 animate-pulse [animation-delay:1100ms]" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const PlateEditor = dynamic(
|
||||
() => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })),
|
||||
{ ssr: false, loading: () => <Skeleton className="h-64 w-full" /> }
|
||||
{ ssr: false, loading: () => <ReportPanelSkeleton /> }
|
||||
);
|
||||
|
||||
/**
|
||||
|
|
@ -59,46 +84,6 @@ const ReportContentResponseSchema = z.object({
|
|||
type ReportContentResponse = z.infer<typeof ReportContentResponseSchema>;
|
||||
type VersionInfo = z.infer<typeof VersionInfoSchema>;
|
||||
|
||||
/**
|
||||
* Shimmer loading skeleton for report panel
|
||||
*/
|
||||
function ReportPanelSkeleton() {
|
||||
return (
|
||||
<div className="space-y-6 p-6">
|
||||
{/* Title skeleton */}
|
||||
<div className="h-6 w-3/4 rounded-md bg-muted/60 animate-pulse" />
|
||||
|
||||
{/* Paragraph 1 */}
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse" />
|
||||
<div className="h-3 w-[95%] rounded-md bg-muted/60 animate-pulse [animation-delay:100ms]" />
|
||||
<div className="h-3 w-[88%] rounded-md bg-muted/60 animate-pulse [animation-delay:200ms]" />
|
||||
<div className="h-3 w-[60%] rounded-md bg-muted/60 animate-pulse [animation-delay:300ms]" />
|
||||
</div>
|
||||
|
||||
{/* Heading */}
|
||||
<div className="h-5 w-2/5 rounded-md bg-muted/60 animate-pulse [animation-delay:400ms]" />
|
||||
|
||||
{/* Paragraph 2 */}
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:500ms]" />
|
||||
<div className="h-3 w-[92%] rounded-md bg-muted/60 animate-pulse [animation-delay:600ms]" />
|
||||
<div className="h-3 w-[97%] rounded-md bg-muted/60 animate-pulse [animation-delay:700ms]" />
|
||||
</div>
|
||||
|
||||
{/* Heading */}
|
||||
<div className="h-5 w-1/3 rounded-md bg-muted/60 animate-pulse [animation-delay:800ms]" />
|
||||
|
||||
{/* Paragraph 3 */}
|
||||
<div className="space-y-2.5">
|
||||
<div className="h-3 w-[90%] rounded-md bg-muted/60 animate-pulse [animation-delay:900ms]" />
|
||||
<div className="h-3 w-full rounded-md bg-muted/60 animate-pulse [animation-delay:1000ms]" />
|
||||
<div className="h-3 w-[75%] rounded-md bg-muted/60 animate-pulse [animation-delay:1100ms]" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Inner content component used by desktop panel, mobile drawer, and the layout right panel
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
? "model"
|
||||
: "models"}
|
||||
</span>{" "}
|
||||
available from your administrator. Use the model selector to view and select them.
|
||||
available from your administrator.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import {
|
|||
FileText,
|
||||
ImageIcon,
|
||||
RefreshCw,
|
||||
Shuffle,
|
||||
} from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
|
@ -44,7 +43,6 @@ import {
|
|||
} from "@/components/ui/select";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { getProviderIcon } from "@/lib/provider-icons";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const ROLE_DESCRIPTIONS = {
|
||||
|
|
@ -79,8 +77,8 @@ const ROLE_DESCRIPTIONS = {
|
|||
icon: Eye,
|
||||
title: "Vision LLM",
|
||||
description: "Vision-capable model for screenshot analysis and context extraction",
|
||||
color: "text-amber-600 dark:text-amber-400",
|
||||
bgColor: "bg-amber-500/10",
|
||||
color: "text-muted-foreground",
|
||||
bgColor: "bg-muted",
|
||||
prefKey: "vision_llm_config_id" as const,
|
||||
configType: "vision" as const,
|
||||
},
|
||||
|
|
@ -205,11 +203,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
),
|
||||
];
|
||||
|
||||
const isAssignmentComplete =
|
||||
allLLMConfigs.some((c) => c.id === assignments.agent_llm_id) &&
|
||||
allLLMConfigs.some((c) => c.id === assignments.document_summary_llm_id) &&
|
||||
allImageConfigs.some((c) => c.id === assignments.image_generation_config_id);
|
||||
|
||||
const isLoading =
|
||||
configsLoading ||
|
||||
preferencesLoading ||
|
||||
|
|
@ -231,7 +224,7 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
return (
|
||||
<div className="space-y-5 md:space-y-6">
|
||||
{/* Header actions */}
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center justify-start">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
|
|
@ -239,15 +232,9 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
disabled={isLoading}
|
||||
className="gap-2"
|
||||
>
|
||||
<RefreshCw className="h-3.5 w-3.5" />
|
||||
<RefreshCw className={cn("h-3.5 w-3.5", isLoading && "animate-spin")} />
|
||||
Refresh
|
||||
</Button>
|
||||
{isAssignmentComplete && !isLoading && !hasError && (
|
||||
<Badge variant="outline" className="text-xs gap-1.5 text-muted-foreground">
|
||||
<CircleCheck className="h-3 w-3" />
|
||||
All roles assigned
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Error Alert */}
|
||||
|
|
@ -343,8 +330,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
|
||||
const assignedConfig = roleAllConfigs.find((config) => config.id === currentAssignment);
|
||||
const isAssigned = !!assignedConfig;
|
||||
const isAutoMode =
|
||||
assignedConfig && "is_auto_mode" in assignedConfig && assignedConfig.is_auto_mode;
|
||||
|
||||
return (
|
||||
<div key={key}>
|
||||
|
|
@ -389,7 +374,7 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
<SelectTrigger className="w-full h-9 md:h-10 text-xs md:text-sm">
|
||||
<SelectValue placeholder="Select a configuration" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="max-w-[calc(100vw-2rem)]">
|
||||
<SelectContent className="max-w-[calc(100vw-2rem)] select-none">
|
||||
<SelectItem
|
||||
value="unassigned"
|
||||
className="text-xs md:text-sm py-1.5 md:py-2"
|
||||
|
|
@ -412,21 +397,9 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
className="text-xs md:text-sm py-1.5 md:py-2"
|
||||
>
|
||||
<div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0">
|
||||
{isAuto ? (
|
||||
<Shuffle className="size-3 md:size-3.5 shrink-0 text-muted-foreground" />
|
||||
) : (
|
||||
getProviderIcon(config.provider, {
|
||||
className: "size-3 md:size-3.5 shrink-0",
|
||||
})
|
||||
)}
|
||||
<span className="truncate text-xs md:text-sm">
|
||||
{config.name}
|
||||
</span>
|
||||
{!isAuto && (
|
||||
<span className="text-muted-foreground text-[10px] md:text-[11px] truncate">
|
||||
({config.model_name})
|
||||
</span>
|
||||
)}
|
||||
{isAuto && (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
|
|
@ -455,15 +428,9 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
className="text-xs md:text-sm py-1.5 md:py-2"
|
||||
>
|
||||
<div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0">
|
||||
{getProviderIcon(config.provider, {
|
||||
className: "size-3 md:size-3.5 shrink-0",
|
||||
})}
|
||||
<span className="truncate text-xs md:text-sm">
|
||||
{config.name}
|
||||
</span>
|
||||
<span className="text-muted-foreground text-[10px] md:text-[11px] truncate">
|
||||
({config.model_name})
|
||||
</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
|
|
@ -472,63 +439,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{/* Assigned Config Summary */}
|
||||
{assignedConfig && (
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-lg p-3 border",
|
||||
isAutoMode
|
||||
? "bg-violet-50 dark:bg-violet-900/10 border-violet-200/50 dark:border-violet-800/30"
|
||||
: "bg-muted/40 border-border/50"
|
||||
)}
|
||||
>
|
||||
{isAutoMode ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Shuffle
|
||||
className={cn(
|
||||
"w-3.5 h-3.5 shrink-0 text-violet-600 dark:text-violet-400"
|
||||
)}
|
||||
/>
|
||||
<div className="min-w-0">
|
||||
<p className="text-xs font-medium text-violet-700 dark:text-violet-300">
|
||||
Auto Mode
|
||||
</p>
|
||||
<p className="text-[10px] text-violet-600/70 dark:text-violet-400/70 mt-0.5">
|
||||
Routes across all available providers
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-start gap-2">
|
||||
<IconComponent className="w-3.5 h-3.5 shrink-0 mt-0.5 text-muted-foreground" />
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-1.5 flex-wrap">
|
||||
<span className="text-xs font-medium">{assignedConfig.name}</span>
|
||||
{"is_global" in assignedConfig && assignedConfig.is_global && (
|
||||
<Badge variant="secondary" className="text-[9px] px-1.5 py-0">
|
||||
🌐 Global
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-1.5 mt-1">
|
||||
{getProviderIcon(assignedConfig.provider, {
|
||||
className: "size-3 shrink-0",
|
||||
})}
|
||||
<code className="text-[10px] text-muted-foreground font-mono truncate">
|
||||
{assignedConfig.model_name}
|
||||
</code>
|
||||
</div>
|
||||
{assignedConfig.api_base && (
|
||||
<p className="text-[10px] text-muted-foreground/60 mt-1 truncate">
|
||||
{assignedConfig.api_base}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
<span className="font-medium">
|
||||
{globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"}
|
||||
</span>{" "}
|
||||
available from your administrator. Use the model selector to view and select them.
|
||||
available from your administrator.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
|
|
|||
|
|
@ -113,11 +113,11 @@ export function MorePagesContent() {
|
|||
{isLoading ? (
|
||||
<Card>
|
||||
<CardContent className="flex items-center gap-3 p-3">
|
||||
<Skeleton className="h-8 w-8 rounded-full bg-muted" />
|
||||
<Skeleton className="h-8 w-8 rounded-full" />
|
||||
<div className="flex-1 space-y-2">
|
||||
<Skeleton className="h-4 w-3/4 bg-muted" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
</div>
|
||||
<Skeleton className="h-8 w-16 bg-muted" />
|
||||
<Skeleton className="h-8 w-16" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
) : (
|
||||
|
|
|
|||
|
|
@ -1,24 +0,0 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom } from "jotai";
|
||||
import { morePagesDialogAtom } from "@/atoms/settings/settings-dialog.atoms";
|
||||
import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog";
|
||||
import { MorePagesContent } from "./more-pages-content";
|
||||
|
||||
export function MorePagesDialog() {
|
||||
const [open, setOpen] = useAtom(morePagesDialogAtom);
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={setOpen}>
|
||||
<DialogContent
|
||||
className="select-none max-w-md w-[95vw] max-h-[90vh] flex flex-col p-0 gap-0 overflow-hidden"
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<DialogTitle className="sr-only">Get More Pages</DialogTitle>
|
||||
<div className="flex-1 overflow-y-auto p-6">
|
||||
<MorePagesContent />
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,7 +1,17 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom } from "jotai";
|
||||
import { Bot, Brain, Eye, FileText, Globe, ImageIcon, MessageSquare, Shield } from "lucide-react";
|
||||
import {
|
||||
BookText,
|
||||
Bot,
|
||||
Brain,
|
||||
CircleUser,
|
||||
Earth,
|
||||
Eye,
|
||||
ImageIcon,
|
||||
ListChecks,
|
||||
UserKey,
|
||||
} from "lucide-react";
|
||||
import dynamic from "next/dynamic";
|
||||
import { useTranslations } from "next-intl";
|
||||
import type React from "react";
|
||||
|
|
@ -59,6 +69,13 @@ const PublicChatSnapshotsManager = dynamic(
|
|||
})),
|
||||
{ ssr: false }
|
||||
);
|
||||
const TeamMemoryManager = dynamic(
|
||||
() =>
|
||||
import("@/components/settings/team-memory-manager").then((m) => ({
|
||||
default: m.TeamMemoryManager,
|
||||
})),
|
||||
{ ssr: false }
|
||||
);
|
||||
|
||||
interface SearchSpaceSettingsDialogProps {
|
||||
searchSpaceId: number;
|
||||
|
|
@ -69,9 +86,9 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
|||
const [state, setState] = useAtom(searchSpaceSettingsDialogAtom);
|
||||
|
||||
const navItems = [
|
||||
{ value: "general", label: t("nav_general"), icon: <FileText className="h-4 w-4" /> },
|
||||
{ value: "general", label: t("nav_general"), icon: <CircleUser className="h-4 w-4" /> },
|
||||
{ value: "roles", label: t("nav_role_assignments"), icon: <ListChecks className="h-4 w-4" /> },
|
||||
{ value: "models", label: t("nav_agent_configs"), icon: <Bot className="h-4 w-4" /> },
|
||||
{ value: "roles", label: t("nav_role_assignments"), icon: <Brain className="h-4 w-4" /> },
|
||||
{
|
||||
value: "image-models",
|
||||
label: t("nav_image_models"),
|
||||
|
|
@ -82,13 +99,18 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
|||
label: t("nav_vision_models"),
|
||||
icon: <Eye className="h-4 w-4" />,
|
||||
},
|
||||
{ value: "team-roles", label: t("nav_team_roles"), icon: <Shield className="h-4 w-4" /> },
|
||||
{ value: "team-roles", label: t("nav_team_roles"), icon: <UserKey className="h-4 w-4" /> },
|
||||
{
|
||||
value: "prompts",
|
||||
label: t("nav_system_instructions"),
|
||||
icon: <MessageSquare className="h-4 w-4" />,
|
||||
icon: <BookText className="h-4 w-4" />,
|
||||
},
|
||||
{ value: "public-links", label: t("nav_public_links"), icon: <Globe className="h-4 w-4" /> },
|
||||
{
|
||||
value: "team-memory",
|
||||
label: "Team Memory",
|
||||
icon: <Brain className="h-4 w-4" />,
|
||||
},
|
||||
{ value: "public-links", label: t("nav_public_links"), icon: <Earth className="h-4 w-4" /> },
|
||||
];
|
||||
|
||||
const content: Record<string, React.ReactNode> = {
|
||||
|
|
@ -99,6 +121,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
|||
"vision-models": <VisionModelManager searchSpaceId={searchSpaceId} />,
|
||||
"team-roles": <RolesManager searchSpaceId={searchSpaceId} />,
|
||||
prompts: <PromptConfigManager searchSpaceId={searchSpaceId} />,
|
||||
"team-memory": <TeamMemoryManager searchSpaceId={searchSpaceId} />,
|
||||
"public-links": <PublicChatSnapshotsManager searchSpaceId={searchSpaceId} />,
|
||||
};
|
||||
|
||||
|
|
|
|||
297
surfsense_web/components/settings/team-memory-manager.tsx
Normal file
297
surfsense_web/components/settings/team-memory-manager.tsx
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { z } from "zod";
|
||||
import { updateSearchSpaceMutationAtom } from "@/atoms/search-spaces/search-space-mutation.atoms";
|
||||
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
|
||||
const MEMORY_HARD_LIMIT = 25_000;
|
||||
|
||||
const SearchSpaceSchema = z
|
||||
.object({
|
||||
shared_memory_md: z.string().optional().default(""),
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
interface TeamMemoryManagerProps {
|
||||
searchSpaceId: number;
|
||||
}
|
||||
|
||||
export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
|
||||
const queryClient = useQueryClient();
|
||||
const { data: searchSpace, isLoading: loading } = useQuery({
|
||||
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
|
||||
queryFn: () => searchSpacesApiService.getSearchSpace({ id: searchSpaceId }),
|
||||
enabled: !!searchSpaceId,
|
||||
});
|
||||
|
||||
const { mutateAsync: updateSearchSpace } = useAtomValue(updateSearchSpaceMutationAtom);
|
||||
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [editQuery, setEditQuery] = useState("");
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [showInput, setShowInput] = useState(false);
|
||||
const textareaRef = useRef<HTMLInputElement>(null);
|
||||
const inputContainerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const memory = searchSpace?.shared_memory_md || "";
|
||||
|
||||
useEffect(() => {
|
||||
if (!showInput) return;
|
||||
|
||||
const handlePointerDownOutside = (event: MouseEvent | TouchEvent) => {
|
||||
const target = event.target;
|
||||
if (!(target instanceof Node)) return;
|
||||
if (inputContainerRef.current?.contains(target)) return;
|
||||
|
||||
setShowInput(false);
|
||||
};
|
||||
|
||||
document.addEventListener("mousedown", handlePointerDownOutside);
|
||||
document.addEventListener("touchstart", handlePointerDownOutside, { passive: true });
|
||||
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handlePointerDownOutside);
|
||||
document.removeEventListener("touchstart", handlePointerDownOutside);
|
||||
};
|
||||
}, [showInput]);
|
||||
|
||||
const handleClear = async () => {
|
||||
try {
|
||||
setSaving(true);
|
||||
await updateSearchSpace({
|
||||
id: searchSpaceId,
|
||||
data: { shared_memory_md: "" },
|
||||
});
|
||||
toast.success("Team memory cleared");
|
||||
} catch {
|
||||
toast.error("Failed to clear team memory");
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleEdit = async () => {
|
||||
const query = editQuery.trim();
|
||||
if (!query) return;
|
||||
|
||||
try {
|
||||
setEditing(true);
|
||||
await baseApiService.post(
|
||||
`/api/v1/searchspaces/${searchSpaceId}/memory/edit`,
|
||||
SearchSpaceSchema,
|
||||
{ body: { query } }
|
||||
);
|
||||
setEditQuery("");
|
||||
setShowInput(false);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
|
||||
});
|
||||
toast.success("Team memory updated");
|
||||
} catch {
|
||||
toast.error("Failed to edit team memory");
|
||||
} finally {
|
||||
setEditing(false);
|
||||
}
|
||||
};
|
||||
|
||||
const openInput = () => {
|
||||
setShowInput(true);
|
||||
requestAnimationFrame(() => textareaRef.current?.focus());
|
||||
};
|
||||
|
||||
const handleDownload = () => {
|
||||
if (!memory) return;
|
||||
try {
|
||||
const blob = new Blob([memory], { type: "text/markdown;charset=utf-8" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = "team-memory.md";
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
} catch {
|
||||
toast.error("Failed to download team memory");
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyMarkdown = async () => {
|
||||
if (!memory) return;
|
||||
try {
|
||||
await navigator.clipboard.writeText(memory);
|
||||
toast.success("Copied to clipboard");
|
||||
} catch {
|
||||
toast.error("Failed to copy team memory");
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleEdit();
|
||||
}
|
||||
};
|
||||
|
||||
const displayMemory = memory.replace(/\(\d{4}-\d{2}-\d{2}\)\s*\[(fact|pref|instr)\]\s*/g, "");
|
||||
const charCount = memory.length;
|
||||
|
||||
const getCounterColor = () => {
|
||||
if (charCount > MEMORY_HARD_LIMIT) return "text-red-500";
|
||||
if (charCount > 15_000) return "text-orange-500";
|
||||
if (charCount > 10_000) return "text-yellow-500";
|
||||
return "text-muted-foreground";
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spinner size="md" className="text-muted-foreground" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!memory) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-16 text-center">
|
||||
<h3 className="text-base font-medium text-foreground">
|
||||
What does SurfSense remember about your team?
|
||||
</h3>
|
||||
<p className="mt-2 max-w-sm text-sm text-muted-foreground">
|
||||
Nothing yet. SurfSense picks up on team decisions and conventions as your team chats.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<Alert className="bg-muted/50 py-3 md:py-4">
|
||||
<Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
|
||||
<AlertDescription className="text-xs md:text-sm">
|
||||
<p>
|
||||
SurfSense uses this shared memory to provide team-wide context across all conversations
|
||||
in this search space.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className="relative h-[380px] rounded-lg border bg-background">
|
||||
<div className="h-full overflow-y-auto scrollbar-thin">
|
||||
<PlateEditor
|
||||
markdown={displayMemory}
|
||||
readOnly
|
||||
preset="readonly"
|
||||
variant="default"
|
||||
editorVariant="none"
|
||||
className="px-5 py-4 text-sm min-h-full"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{showInput ? (
|
||||
<div className="absolute bottom-3 inset-x-3 z-10">
|
||||
<div
|
||||
ref={inputContainerRef}
|
||||
className="relative flex h-[54px] items-center gap-2 rounded-[9999px] border bg-muted/60 backdrop-blur-sm pl-4 pr-1 shadow-sm"
|
||||
>
|
||||
<input
|
||||
ref={textareaRef}
|
||||
type="text"
|
||||
value={editQuery}
|
||||
onChange={(e) => setEditQuery(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Tell SurfSense what to remember or forget about your team"
|
||||
disabled={editing}
|
||||
className="flex-1 bg-transparent text-sm outline-none placeholder:text-muted-foreground/70"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
onClick={handleEdit}
|
||||
disabled={editing || !editQuery.trim()}
|
||||
className={`h-11 w-11 shrink-0 rounded-full ${
|
||||
editing ? "" : "bg-muted-foreground/15 hover:bg-muted-foreground/20"
|
||||
}`}
|
||||
>
|
||||
{editing ? (
|
||||
<Spinner size="sm" />
|
||||
) : (
|
||||
<ArrowUp className="!h-5 !w-5 text-foreground" strokeWidth={2.25} />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
size="icon"
|
||||
variant="secondary"
|
||||
onClick={openInput}
|
||||
className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm"
|
||||
>
|
||||
<Pen className="!h-5 !w-5" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className={`text-xs shrink-0 ${getCounterColor()}`}>
|
||||
{charCount.toLocaleString()} / {MEMORY_HARD_LIMIT.toLocaleString()}
|
||||
<span className="hidden sm:inline"> characters</span>
|
||||
<span className="sm:hidden"> chars</span>
|
||||
{charCount > 15_000 && charCount <= MEMORY_HARD_LIMIT && " - Approaching limit"}
|
||||
{charCount > MEMORY_HARD_LIMIT && " - Exceeds limit"}
|
||||
</span>
|
||||
<div className="flex items-center gap-1.5 sm:gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
className="text-xs sm:text-sm"
|
||||
onClick={handleClear}
|
||||
disabled={saving || editing || !memory}
|
||||
>
|
||||
<span className="hidden sm:inline">Reset Memory</span>
|
||||
<span className="sm:hidden">Reset</span>
|
||||
</Button>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button type="button" variant="secondary" size="sm" disabled={!memory}>
|
||||
Export
|
||||
<ChevronDown className="h-3 w-3 opacity-60" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem onClick={handleCopyMarkdown}>
|
||||
<ClipboardCopy className="h-4 w-4 mr-2" />
|
||||
Copy as Markdown
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handleDownload}>
|
||||
<Download className="h-4 w-4 mr-2" />
|
||||
Download as Markdown
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom } from "jotai";
|
||||
import { Globe, KeyRound, Monitor, Receipt, Sparkles, User } from "lucide-react";
|
||||
import { Brain, CircleUser, Globe, KeyRound, Monitor, ReceiptText, Sparkles } from "lucide-react";
|
||||
import dynamic from "next/dynamic";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useMemo } from "react";
|
||||
|
|
@ -51,6 +51,13 @@ const DesktopContent = dynamic(
|
|||
),
|
||||
{ ssr: false }
|
||||
);
|
||||
const MemoryContent = dynamic(
|
||||
() =>
|
||||
import("@/app/dashboard/[search_space_id]/user-settings/components/MemoryContent").then(
|
||||
(m) => ({ default: m.MemoryContent })
|
||||
),
|
||||
{ ssr: false }
|
||||
);
|
||||
|
||||
export function UserSettingsDialog() {
|
||||
const t = useTranslations("userSettings");
|
||||
|
|
@ -59,7 +66,7 @@ export function UserSettingsDialog() {
|
|||
|
||||
const navItems = useMemo(
|
||||
() => [
|
||||
{ value: "profile", label: t("profile_nav_label"), icon: <User className="h-4 w-4" /> },
|
||||
{ value: "profile", label: t("profile_nav_label"), icon: <CircleUser className="h-4 w-4" /> },
|
||||
{
|
||||
value: "api-key",
|
||||
label: t("api_key_nav_label"),
|
||||
|
|
@ -75,10 +82,15 @@ export function UserSettingsDialog() {
|
|||
label: "Community Prompts",
|
||||
icon: <Globe className="h-4 w-4" />,
|
||||
},
|
||||
{
|
||||
value: "memory",
|
||||
label: "Memory",
|
||||
icon: <Brain className="h-4 w-4" />,
|
||||
},
|
||||
{
|
||||
value: "purchases",
|
||||
label: "Purchase History",
|
||||
icon: <Receipt className="h-4 w-4" />,
|
||||
icon: <ReceiptText className="h-4 w-4" />,
|
||||
},
|
||||
...(isDesktop
|
||||
? [{ value: "desktop", label: "Desktop", icon: <Monitor className="h-4 w-4" /> }]
|
||||
|
|
@ -101,6 +113,7 @@ export function UserSettingsDialog() {
|
|||
{state.initialTab === "api-key" && <ApiKeyContent />}
|
||||
{state.initialTab === "prompts" && <PromptsContent />}
|
||||
{state.initialTab === "community-prompts" && <CommunityPromptsContent />}
|
||||
{state.initialTab === "memory" && <MemoryContent />}
|
||||
{state.initialTab === "purchases" && <PurchaseHistoryContent />}
|
||||
{state.initialTab === "desktop" && <DesktopContent />}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
? "model"
|
||||
: "models"}
|
||||
</span>{" "}
|
||||
available from your administrator. Use the model selector to view and select them.
|
||||
available from your administrator.
|
||||
</p>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
|
|
|||
|
|
@ -51,17 +51,11 @@ export {
|
|||
SandboxExecuteToolUI,
|
||||
} from "./sandbox-execute";
|
||||
export {
|
||||
type MemoryItem,
|
||||
type RecallMemoryArgs,
|
||||
RecallMemoryArgsSchema,
|
||||
type RecallMemoryResult,
|
||||
RecallMemoryResultSchema,
|
||||
RecallMemoryToolUI,
|
||||
type SaveMemoryArgs,
|
||||
SaveMemoryArgsSchema,
|
||||
type SaveMemoryResult,
|
||||
SaveMemoryResultSchema,
|
||||
SaveMemoryToolUI,
|
||||
type UpdateMemoryArgs,
|
||||
UpdateMemoryArgsSchema,
|
||||
type UpdateMemoryResult,
|
||||
UpdateMemoryResultSchema,
|
||||
UpdateMemoryToolUI,
|
||||
} from "./user-memory";
|
||||
export { GenerateVideoPresentationToolUI } from "./video-presentation";
|
||||
export { type WriteTodosData, WriteTodosSchema, WriteTodosToolUI } from "./write-todos";
|
||||
|
|
|
|||
|
|
@ -1,100 +1,38 @@
|
|||
"use client";
|
||||
|
||||
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
|
||||
import { BrainIcon, CheckIcon, Loader2Icon, SearchIcon, XIcon } from "lucide-react";
|
||||
import { AlertTriangleIcon, BrainIcon, CheckIcon, Loader2Icon, XIcon } from "lucide-react";
|
||||
import { z } from "zod";
|
||||
|
||||
// ============================================================================
|
||||
// Zod Schemas for save_memory tool
|
||||
// Zod Schemas for update_memory tool
|
||||
// ============================================================================
|
||||
|
||||
const SaveMemoryArgsSchema = z.object({
|
||||
content: z.string(),
|
||||
category: z.string().default("fact"),
|
||||
const UpdateMemoryArgsSchema = z.object({
|
||||
updated_memory: z.string(),
|
||||
});
|
||||
|
||||
const SaveMemoryResultSchema = z.object({
|
||||
const UpdateMemoryResultSchema = z.object({
|
||||
status: z.enum(["saved", "error"]),
|
||||
memory_id: z.number().nullish(),
|
||||
memory_text: z.string().nullish(),
|
||||
category: z.string().nullish(),
|
||||
message: z.string().nullish(),
|
||||
error: z.string().nullish(),
|
||||
warning: z.string().nullish(),
|
||||
});
|
||||
|
||||
type SaveMemoryArgs = z.infer<typeof SaveMemoryArgsSchema>;
|
||||
type SaveMemoryResult = z.infer<typeof SaveMemoryResultSchema>;
|
||||
type UpdateMemoryArgs = z.infer<typeof UpdateMemoryArgsSchema>;
|
||||
type UpdateMemoryResult = z.infer<typeof UpdateMemoryResultSchema>;
|
||||
|
||||
// ============================================================================
|
||||
// Zod Schemas for recall_memory tool
|
||||
// Update Memory Tool UI
|
||||
// ============================================================================
|
||||
|
||||
const RecallMemoryArgsSchema = z.object({
|
||||
query: z.string().nullish(),
|
||||
category: z.string().nullish(),
|
||||
top_k: z.number().default(5),
|
||||
});
|
||||
|
||||
const MemoryItemSchema = z.object({
|
||||
id: z.number(),
|
||||
memory_text: z.string(),
|
||||
category: z.string(),
|
||||
updated_at: z.string().nullish(),
|
||||
});
|
||||
|
||||
const RecallMemoryResultSchema = z.object({
|
||||
status: z.enum(["success", "error"]),
|
||||
count: z.number().nullish(),
|
||||
memories: z.array(MemoryItemSchema).nullish(),
|
||||
formatted_context: z.string().nullish(),
|
||||
error: z.string().nullish(),
|
||||
});
|
||||
|
||||
type RecallMemoryArgs = z.infer<typeof RecallMemoryArgsSchema>;
|
||||
type RecallMemoryResult = z.infer<typeof RecallMemoryResultSchema>;
|
||||
type MemoryItem = z.infer<typeof MemoryItemSchema>;
|
||||
|
||||
// ============================================================================
|
||||
// Category badge colors
|
||||
// ============================================================================
|
||||
|
||||
const categoryColors: Record<string, string> = {
|
||||
preference: "bg-blue-500/10 text-blue-600 dark:text-blue-400",
|
||||
fact: "bg-green-500/10 text-green-600 dark:text-green-400",
|
||||
instruction: "bg-purple-500/10 text-purple-600 dark:text-purple-400",
|
||||
context: "bg-orange-500/10 text-orange-600 dark:text-orange-400",
|
||||
};
|
||||
|
||||
function CategoryBadge({ category }: { category: string }) {
|
||||
const colorClass = categoryColors[category] || "bg-gray-500/10 text-gray-600 dark:text-gray-400";
|
||||
return (
|
||||
<span
|
||||
className={`inline-flex items-center rounded-full px-2 py-0.5 text-xs font-medium ${colorClass}`}
|
||||
>
|
||||
{category}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Save Memory Tool UI
|
||||
// ============================================================================
|
||||
|
||||
export const SaveMemoryToolUI = ({
|
||||
args,
|
||||
export const UpdateMemoryToolUI = ({
|
||||
result,
|
||||
status,
|
||||
}: ToolCallMessagePartProps<SaveMemoryArgs, SaveMemoryResult>) => {
|
||||
}: ToolCallMessagePartProps<UpdateMemoryArgs, UpdateMemoryResult>) => {
|
||||
const isRunning = status.type === "running" || status.type === "requires-action";
|
||||
const isComplete = status.type === "complete";
|
||||
const isError = result?.status === "error";
|
||||
|
||||
// Parse args safely
|
||||
const parsedArgs = SaveMemoryArgsSchema.safeParse(args);
|
||||
const content = parsedArgs.success ? parsedArgs.data.content : "";
|
||||
const category = parsedArgs.success ? parsedArgs.data.category : "fact";
|
||||
|
||||
// Loading state
|
||||
if (isRunning) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
|
|
@ -102,13 +40,12 @@ export const SaveMemoryToolUI = ({
|
|||
<Loader2Icon className="size-4 animate-spin text-primary" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-sm text-muted-foreground">Saving to memory...</span>
|
||||
<span className="text-sm text-muted-foreground">Updating memory...</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Error state
|
||||
if (isError) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border border-destructive/20 bg-destructive/5 px-4 py-3">
|
||||
|
|
@ -116,14 +53,13 @@ export const SaveMemoryToolUI = ({
|
|||
<XIcon className="size-4 text-destructive" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-sm text-destructive">Failed to save memory</span>
|
||||
{result?.error && <p className="mt-1 text-xs text-destructive/70">{result.error}</p>}
|
||||
<span className="text-sm text-destructive">Failed to update memory</span>
|
||||
{result?.message && <p className="mt-1 text-xs text-destructive/70">{result.message}</p>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Success state
|
||||
if (isComplete && result?.status === "saved") {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border border-primary/20 bg-primary/5 px-4 py-3">
|
||||
|
|
@ -133,138 +69,19 @@ export const SaveMemoryToolUI = ({
|
|||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<CheckIcon className="size-3 text-green-500 shrink-0" />
|
||||
<span className="text-sm font-medium text-foreground">Memory saved</span>
|
||||
<CategoryBadge category={category} />
|
||||
<span className="text-sm font-medium text-foreground">Memory updated</span>
|
||||
</div>
|
||||
<p className="mt-1 truncate text-sm text-muted-foreground">{content}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Default/incomplete state - show what's being saved
|
||||
if (content) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
||||
<BrainIcon className="size-4 text-muted-foreground" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-muted-foreground">Saving memory</span>
|
||||
<CategoryBadge category={category} />
|
||||
</div>
|
||||
<p className="mt-1 truncate text-sm text-muted-foreground">{content}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Recall Memory Tool UI
|
||||
// ============================================================================
|
||||
|
||||
export const RecallMemoryToolUI = ({
|
||||
args,
|
||||
result,
|
||||
status,
|
||||
}: ToolCallMessagePartProps<RecallMemoryArgs, RecallMemoryResult>) => {
|
||||
const isRunning = status.type === "running" || status.type === "requires-action";
|
||||
const isComplete = status.type === "complete";
|
||||
const isError = result?.status === "error";
|
||||
|
||||
// Parse args safely
|
||||
const parsedArgs = RecallMemoryArgsSchema.safeParse(args);
|
||||
const query = parsedArgs.success ? parsedArgs.data.query : null;
|
||||
|
||||
// Loading state
|
||||
if (isRunning) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-primary/10">
|
||||
<Loader2Icon className="size-4 animate-spin text-primary" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{query ? `Searching memories for "${query}"...` : "Recalling memories..."}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Error state
|
||||
if (isError) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border border-destructive/20 bg-destructive/5 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-destructive/10">
|
||||
<XIcon className="size-4 text-destructive" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-sm text-destructive">Failed to recall memories</span>
|
||||
{result?.error && <p className="mt-1 text-xs text-destructive/70">{result.error}</p>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Success state with memories
|
||||
if (isComplete && result?.status === "success") {
|
||||
const memories = result.memories || [];
|
||||
const count = result.count || 0;
|
||||
|
||||
if (count === 0) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
||||
<SearchIcon className="size-4 text-muted-foreground" />
|
||||
</div>
|
||||
<span className="text-sm text-muted-foreground">No memories found</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="my-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<BrainIcon className="size-4 text-primary" />
|
||||
<span className="text-sm font-medium text-foreground">
|
||||
Recalled {count} {count === 1 ? "memory" : "memories"}
|
||||
</span>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
{memories.slice(0, 5).map((memory: MemoryItem) => (
|
||||
<div
|
||||
key={memory.id}
|
||||
className="flex items-start gap-2 rounded-md bg-muted/50 px-3 py-2"
|
||||
>
|
||||
<CategoryBadge category={memory.category} />
|
||||
<span className="text-sm text-muted-foreground flex-1">{memory.memory_text}</span>
|
||||
{result.warning && (
|
||||
<div className="mt-1.5 flex items-start gap-1.5">
|
||||
<AlertTriangleIcon className="size-3 text-yellow-500 shrink-0 mt-0.5" />
|
||||
<p className="text-xs text-yellow-600 dark:text-yellow-400">{result.warning}</p>
|
||||
</div>
|
||||
))}
|
||||
{memories.length > 5 && (
|
||||
<p className="text-xs text-muted-foreground">...and {memories.length - 5} more</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Default/incomplete state
|
||||
if (query) {
|
||||
return (
|
||||
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
||||
<SearchIcon className="size-4 text-muted-foreground" />
|
||||
</div>
|
||||
<span className="text-sm text-muted-foreground">Searching memories for "{query}"</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
|
|
@ -273,13 +90,8 @@ export const RecallMemoryToolUI = ({
|
|||
// ============================================================================
|
||||
|
||||
export {
|
||||
SaveMemoryArgsSchema,
|
||||
SaveMemoryResultSchema,
|
||||
RecallMemoryArgsSchema,
|
||||
RecallMemoryResultSchema,
|
||||
type SaveMemoryArgs,
|
||||
type SaveMemoryResult,
|
||||
type RecallMemoryArgs,
|
||||
type RecallMemoryResult,
|
||||
type MemoryItem,
|
||||
UpdateMemoryArgsSchema,
|
||||
UpdateMemoryResultSchema,
|
||||
type UpdateMemoryArgs,
|
||||
type UpdateMemoryResult,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ function AlertDialogContent({
|
|||
<AlertDialogPrimitive.Content
|
||||
data-slot="alert-dialog-content"
|
||||
className={cn(
|
||||
"bg-background dark:bg-neutral-900 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-xl p-6 shadow-2xl duration-200 sm:max-w-lg",
|
||||
"bg-background dark:bg-neutral-900 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-xl p-6 shadow-2xl duration-200 sm:max-w-lg select-none",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ const DialogContent = React.forwardRef<
|
|||
<DialogPrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 bg-background dark:bg-neutral-900 p-6 shadow-2xl duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 rounded-xl focus:outline-none focus:ring-0 focus-visible:outline-none focus-visible:ring-0",
|
||||
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 bg-background dark:bg-neutral-900 p-6 shadow-2xl duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 rounded-xl focus:outline-none focus:ring-0 focus-visible:outline-none focus-visible:ring-0 select-none",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
|
|
|||
|
|
@ -54,8 +54,8 @@ const editorVariants = cva(
|
|||
cn(
|
||||
"group/editor",
|
||||
"relative w-full cursor-text select-text overflow-x-hidden whitespace-pre-wrap break-words",
|
||||
"rounded-md ring-offset-background focus-visible:outline-none",
|
||||
"**:data-slate-placeholder:!top-1/2 **:data-slate-placeholder:-translate-y-1/2 placeholder:text-muted-foreground/80 **:data-slate-placeholder:text-muted-foreground/80 **:data-slate-placeholder:opacity-100!",
|
||||
"rounded-none ring-offset-background focus-visible:outline-none",
|
||||
"placeholder:text-muted-foreground/80 **:data-slate-placeholder:text-muted-foreground/80 **:data-slate-placeholder:py-1",
|
||||
"[&_strong]:font-bold"
|
||||
),
|
||||
{
|
||||
|
|
|
|||
|
|
@ -65,8 +65,9 @@ export function FloatingToolbar({
|
|||
{...rootProps}
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"scrollbar-hide absolute z-50 overflow-x-auto whitespace-nowrap rounded-md border bg-popover p-1 opacity-100 shadow-md print:hidden",
|
||||
"scrollbar-hide absolute z-50 overflow-x-auto whitespace-nowrap rounded-md border dark:border-neutral-700 bg-muted p-1 opacity-100 shadow-md print:hidden",
|
||||
"max-w-[80vw]",
|
||||
"[&_button:hover]:bg-neutral-200 dark:[&_button:hover]:bg-neutral-700",
|
||||
className
|
||||
)}
|
||||
>
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import type { PlateElementProps } from "platejs/react";
|
|||
import { PlateElement } from "platejs/react";
|
||||
import * as React from "react";
|
||||
|
||||
const headingVariants = cva("relative mb-1", {
|
||||
const headingVariants = cva("relative mb-1 first:mt-0", {
|
||||
variants: {
|
||||
variant: {
|
||||
h1: "mt-[1.6em] pb-1 font-bold font-heading text-4xl",
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ const TOOL_ICONS: Record<string, LucideIcon> = {
|
|||
scrape_webpage: ScanLine,
|
||||
web_search: Globe,
|
||||
search_surfsense_docs: BookOpen,
|
||||
save_memory: Brain,
|
||||
recall_memory: Brain,
|
||||
update_memory: Brain,
|
||||
};
|
||||
|
||||
export function getToolIcon(name: string): LucideIcon {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ export const searchSpace = z.object({
|
|||
user_id: z.string(),
|
||||
citations_enabled: z.boolean(),
|
||||
qna_custom_instructions: z.string().nullable(),
|
||||
shared_memory_md: z.string().nullable().optional(),
|
||||
member_count: z.number(),
|
||||
is_owner: z.boolean(),
|
||||
});
|
||||
|
|
@ -54,6 +55,7 @@ export const updateSearchSpaceRequest = z.object({
|
|||
description: true,
|
||||
citations_enabled: true,
|
||||
qna_custom_instructions: true,
|
||||
shared_memory_md: true,
|
||||
})
|
||||
.partial(),
|
||||
});
|
||||
|
|
|
|||
108
surfsense_web/lib/desktop-download-utils.ts
Normal file
108
surfsense_web/lib/desktop-download-utils.ts
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
export type OSInfo = {
|
||||
os: "macOS" | "Windows" | "Linux";
|
||||
arch: "arm64" | "x64";
|
||||
};
|
||||
|
||||
export function useUserOS(): OSInfo {
|
||||
const [info, setInfo] = useState<OSInfo>({ os: "macOS", arch: "arm64" });
|
||||
useEffect(() => {
|
||||
const ua = navigator.userAgent;
|
||||
let os: OSInfo["os"] = "macOS";
|
||||
let arch: OSInfo["arch"] = "x64";
|
||||
|
||||
if (/Windows/i.test(ua)) {
|
||||
os = "Windows";
|
||||
arch = "x64";
|
||||
} else if (/Linux/i.test(ua)) {
|
||||
os = "Linux";
|
||||
arch = "x64";
|
||||
} else {
|
||||
os = "macOS";
|
||||
arch = /Mac/.test(ua) && !/Intel/.test(ua) ? "arm64" : "arm64";
|
||||
}
|
||||
|
||||
const uaData = (navigator as Navigator & { userAgentData?: { architecture?: string } })
|
||||
.userAgentData;
|
||||
if (uaData?.architecture === "arm") arch = "arm64";
|
||||
else if (uaData?.architecture === "x86") arch = "x64";
|
||||
|
||||
setInfo({ os, arch });
|
||||
}, []);
|
||||
return info;
|
||||
}
|
||||
|
||||
export interface ReleaseAsset {
|
||||
name: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export function useLatestRelease() {
|
||||
const [assets, setAssets] = useState<ReleaseAsset[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
fetch("https://api.github.com/repos/MODSetter/SurfSense/releases/latest", {
|
||||
signal: controller.signal,
|
||||
})
|
||||
.then((r) => r.json())
|
||||
.then((data) => {
|
||||
if (data?.assets) {
|
||||
setAssets(
|
||||
data.assets
|
||||
.filter((a: { name: string }) => /\.(exe|dmg|AppImage|deb)$/.test(a.name))
|
||||
.map((a: { name: string; browser_download_url: string }) => ({
|
||||
name: a.name,
|
||||
url: a.browser_download_url,
|
||||
}))
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch(() => {});
|
||||
return () => controller.abort();
|
||||
}, []);
|
||||
|
||||
return assets;
|
||||
}
|
||||
|
||||
export const ASSET_LABELS: Record<string, string> = {
|
||||
".exe": "Windows (exe)",
|
||||
"-arm64.dmg": "macOS Apple Silicon (dmg)",
|
||||
"-x64.dmg": "macOS Intel (dmg)",
|
||||
"-arm64.zip": "macOS Apple Silicon (zip)",
|
||||
"-x64.zip": "macOS Intel (zip)",
|
||||
".AppImage": "Linux (AppImage)",
|
||||
".deb": "Linux (deb)",
|
||||
};
|
||||
|
||||
export function getAssetLabel(name: string): string {
|
||||
for (const [suffix, label] of Object.entries(ASSET_LABELS)) {
|
||||
if (name.endsWith(suffix)) return label;
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
export const GITHUB_RELEASES_URL = "https://github.com/MODSetter/SurfSense/releases/latest";
|
||||
|
||||
export function usePrimaryDownload() {
|
||||
const { os, arch } = useUserOS();
|
||||
const assets = useLatestRelease();
|
||||
|
||||
const { primary, alternatives } = useMemo(() => {
|
||||
if (assets.length === 0) return { primary: null, alternatives: [] };
|
||||
|
||||
const matchers: Record<string, (n: string) => boolean> = {
|
||||
Windows: (n) => n.endsWith(".exe"),
|
||||
macOS: (n) => n.endsWith(`-${arch}.dmg`),
|
||||
Linux: (n) => n.endsWith(".AppImage"),
|
||||
};
|
||||
|
||||
const match = matchers[os];
|
||||
const primary = assets.find((a) => match(a.name)) ?? null;
|
||||
const alternatives = assets.filter((a) => a !== primary);
|
||||
return { primary, alternatives };
|
||||
}, [assets, os, arch]);
|
||||
|
||||
return { os, arch, assets, primary, alternatives };
|
||||
}
|
||||
|
|
@ -92,6 +92,10 @@ export const cacheKeys = {
|
|||
publicChat: {
|
||||
byToken: (shareToken: string) => ["public-chat", shareToken] as const,
|
||||
},
|
||||
github: {
|
||||
repoStars: (username: string, repo: string) =>
|
||||
["github", "repo-stars", username, repo] as const,
|
||||
},
|
||||
publicChatSnapshots: {
|
||||
all: ["public-chat-snapshots"] as const,
|
||||
bySearchSpace: (searchSpaceId: number) =>
|
||||
|
|
|
|||
|
|
@ -697,6 +697,7 @@
|
|||
"learn_more": "Learn more",
|
||||
"documentation": "Documentation",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "Download for {os}",
|
||||
"inbox": "Inbox",
|
||||
"search_inbox": "Search inbox",
|
||||
"mark_all_read": "Mark all as read",
|
||||
|
|
@ -784,7 +785,7 @@
|
|||
"homepage": {
|
||||
"hero_title_part1": "The AI Workspace",
|
||||
"hero_title_part2": "Built for Teams",
|
||||
"hero_description": "Connect any LLM to your internal knowledge sources and chat with it in real time alongside your team.",
|
||||
"hero_description": "An open source, privacy focused alternative to NotebookLM for teams with no data limits.",
|
||||
"cta_start_trial": "Start Free Trial",
|
||||
"cta_explore": "Explore",
|
||||
"integrations_title": "Integrations",
|
||||
|
|
|
|||
|
|
@ -697,6 +697,7 @@
|
|||
"learn_more": "Más información",
|
||||
"documentation": "Documentación",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "Descargar para {os}",
|
||||
"inbox": "Bandeja de entrada",
|
||||
"search_inbox": "Buscar en bandeja de entrada",
|
||||
"mark_all_read": "Marcar todo como leído",
|
||||
|
|
|
|||
|
|
@ -697,6 +697,7 @@
|
|||
"learn_more": "और जानें",
|
||||
"documentation": "दस्तावेज़ीकरण",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "{os} के लिए डाउनलोड करें",
|
||||
"inbox": "इनबॉक्स",
|
||||
"search_inbox": "इनबॉक्स में खोजें",
|
||||
"mark_all_read": "सभी पढ़ा हुआ चिह्नित करें",
|
||||
|
|
|
|||
|
|
@ -697,6 +697,7 @@
|
|||
"learn_more": "Saiba mais",
|
||||
"documentation": "Documentação",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "Baixar para {os}",
|
||||
"inbox": "Caixa de entrada",
|
||||
"search_inbox": "Pesquisar caixa de entrada",
|
||||
"mark_all_read": "Marcar tudo como lido",
|
||||
|
|
|
|||
|
|
@ -681,6 +681,7 @@
|
|||
"learn_more": "了解更多",
|
||||
"documentation": "文档",
|
||||
"github": "GitHub",
|
||||
"download_for_os": "下载 {os} 版本",
|
||||
"inbox": "收件箱",
|
||||
"search_inbox": "搜索收件箱",
|
||||
"mark_all_read": "全部标记为已读",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue