mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-11 00:32:38 +02:00
Merge pull request #722 from manojag115/feature/user-memory
Add user memory feature to SurfSense
This commit is contained in:
commit
cc658789e4
12 changed files with 966 additions and 0 deletions
135
surfsense_backend/alembic/versions/73_add_user_memories_table.py
Normal file
135
surfsense_backend/alembic/versions/73_add_user_memories_table.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
"""Add user_memories table for AI memory feature
|
||||||
|
|
||||||
|
Revision ID: 73
|
||||||
|
Revises: 72
|
||||||
|
Create Date: 2026-01-20
|
||||||
|
|
||||||
|
This migration adds the user_memories table which enables Claude-like memory
|
||||||
|
functionality - allowing the AI to remember facts, preferences, and context
|
||||||
|
about users across conversations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "73"
|
||||||
|
down_revision: str | None = "72"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
# Get embedding dimension from config
|
||||||
|
EMBEDDING_DIM = config.embedding_model_instance.dimension
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Create user_memories table and MemoryCategory enum."""
|
||||||
|
|
||||||
|
# Create the MemoryCategory enum type
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'memorycategory') THEN
|
||||||
|
CREATE TYPE memorycategory AS ENUM (
|
||||||
|
'preference',
|
||||||
|
'fact',
|
||||||
|
'instruction',
|
||||||
|
'context'
|
||||||
|
);
|
||||||
|
END IF;
|
||||||
|
END$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create user_memories table
|
||||||
|
op.execute(
|
||||||
|
f"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT FROM information_schema.tables
|
||||||
|
WHERE table_name = 'user_memories'
|
||||||
|
) THEN
|
||||||
|
CREATE TABLE user_memories (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
search_space_id INTEGER REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||||
|
memory_text TEXT NOT NULL,
|
||||||
|
category memorycategory NOT NULL DEFAULT 'fact',
|
||||||
|
embedding vector({EMBEDDING_DIM}),
|
||||||
|
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
END IF;
|
||||||
|
END$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes for efficient querying
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
-- Index on user_id for filtering memories by user
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_indexes
|
||||||
|
WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_user_id'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX ix_user_memories_user_id ON user_memories(user_id);
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Index on search_space_id for filtering memories by search space
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_indexes
|
||||||
|
WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_search_space_id'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX ix_user_memories_search_space_id ON user_memories(search_space_id);
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Index on updated_at for ordering by recency
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_indexes
|
||||||
|
WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_updated_at'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX ix_user_memories_updated_at ON user_memories(updated_at);
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Index on category for filtering by memory type
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_indexes
|
||||||
|
WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_category'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX ix_user_memories_category ON user_memories(category);
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Composite index for common query pattern (user + search space)
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM pg_indexes
|
||||||
|
WHERE tablename = 'user_memories' AND indexname = 'ix_user_memories_user_search_space'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX ix_user_memories_user_search_space ON user_memories(user_id, search_space_id);
|
||||||
|
END IF;
|
||||||
|
END$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create vector index for semantic search
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS user_memories_vector_index
|
||||||
|
ON user_memories USING hnsw (embedding public.vector_cosine_ops);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Drop user_memories table and MemoryCategory enum."""
|
||||||
|
|
||||||
|
# Drop the table
|
||||||
|
op.execute("DROP TABLE IF EXISTS user_memories CASCADE;")
|
||||||
|
|
||||||
|
# Drop the enum type
|
||||||
|
op.execute("DROP TYPE IF EXISTS memorycategory;")
|
||||||
|
|
@ -34,6 +34,7 @@ async def create_surfsense_deep_agent(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
connector_service: ConnectorService,
|
connector_service: ConnectorService,
|
||||||
checkpointer: Checkpointer,
|
checkpointer: Checkpointer,
|
||||||
|
user_id: str | None = None,
|
||||||
agent_config: AgentConfig | None = None,
|
agent_config: AgentConfig | None = None,
|
||||||
enabled_tools: list[str] | None = None,
|
enabled_tools: list[str] | None = None,
|
||||||
disabled_tools: list[str] | None = None,
|
disabled_tools: list[str] | None = None,
|
||||||
|
|
@ -49,6 +50,8 @@ async def create_surfsense_deep_agent(
|
||||||
- link_preview: Fetch rich previews for URLs
|
- link_preview: Fetch rich previews for URLs
|
||||||
- display_image: Display images in chat
|
- display_image: Display images in chat
|
||||||
- scrape_webpage: Extract content from webpages
|
- scrape_webpage: Extract content from webpages
|
||||||
|
- save_memory: Store facts/preferences about the user
|
||||||
|
- recall_memory: Retrieve relevant user memories
|
||||||
|
|
||||||
The agent also includes TodoListMiddleware by default (via create_deep_agent) which provides:
|
The agent also includes TodoListMiddleware by default (via create_deep_agent) which provides:
|
||||||
- write_todos: Create and update planning/todo lists for complex tasks
|
- write_todos: Create and update planning/todo lists for complex tasks
|
||||||
|
|
@ -64,6 +67,7 @@ async def create_surfsense_deep_agent(
|
||||||
connector_service: Initialized connector service for knowledge base search
|
connector_service: Initialized connector service for knowledge base search
|
||||||
checkpointer: LangGraph checkpointer for conversation state persistence.
|
checkpointer: LangGraph checkpointer for conversation state persistence.
|
||||||
Use AsyncPostgresSaver for production or MemorySaver for testing.
|
Use AsyncPostgresSaver for production or MemorySaver for testing.
|
||||||
|
user_id: The current user's UUID string (required for memory tools)
|
||||||
agent_config: Optional AgentConfig from NewLLMConfig for prompt configuration.
|
agent_config: Optional AgentConfig from NewLLMConfig for prompt configuration.
|
||||||
If None, uses default system prompt with citations enabled.
|
If None, uses default system prompt with citations enabled.
|
||||||
enabled_tools: Explicit list of tool names to enable. If None, all default tools
|
enabled_tools: Explicit list of tool names to enable. If None, all default tools
|
||||||
|
|
@ -118,6 +122,7 @@ async def create_surfsense_deep_agent(
|
||||||
"db_session": db_session,
|
"db_session": db_session,
|
||||||
"connector_service": connector_service,
|
"connector_service": connector_service,
|
||||||
"firecrawl_api_key": firecrawl_api_key,
|
"firecrawl_api_key": firecrawl_api_key,
|
||||||
|
"user_id": user_id, # Required for memory tools
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build tools using the async registry (includes MCP tools)
|
# Build tools using the async registry (includes MCP tools)
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,45 @@ You have access to the following tools:
|
||||||
* This makes your response more visual and engaging.
|
* This makes your response more visual and engaging.
|
||||||
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
||||||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||||
|
|
||||||
|
6. save_memory: Save facts, preferences, or context about the user for personalized responses.
|
||||||
|
- Use this when the user explicitly or implicitly shares information worth remembering.
|
||||||
|
- Trigger scenarios:
|
||||||
|
* User says "remember this", "keep this in mind", "note that", or similar
|
||||||
|
* User shares personal preferences (e.g., "I prefer Python over JavaScript")
|
||||||
|
* User shares facts about themselves (e.g., "I'm a senior developer at Company X")
|
||||||
|
* User gives standing instructions (e.g., "always respond in bullet points")
|
||||||
|
* User shares project context (e.g., "I'm working on migrating our codebase to TypeScript")
|
||||||
|
- Args:
|
||||||
|
- content: The fact/preference to remember. Phrase it clearly:
|
||||||
|
* "User prefers dark mode for all interfaces"
|
||||||
|
* "User is a senior Python developer"
|
||||||
|
* "User wants responses in bullet point format"
|
||||||
|
* "User is working on project called ProjectX"
|
||||||
|
- category: Type of memory:
|
||||||
|
* "preference": User preferences (coding style, tools, formats)
|
||||||
|
* "fact": Facts about the user (role, expertise, background)
|
||||||
|
* "instruction": Standing instructions (response format, communication style)
|
||||||
|
* "context": Current context (ongoing projects, goals, challenges)
|
||||||
|
- Returns: Confirmation of saved memory
|
||||||
|
- IMPORTANT: Only save information that would be genuinely useful for future conversations.
|
||||||
|
Don't save trivial or temporary information.
|
||||||
|
|
||||||
|
7. recall_memory: Retrieve relevant memories about the user for personalized responses.
|
||||||
|
- Use this to access stored information about the user.
|
||||||
|
- Trigger scenarios:
|
||||||
|
* You need user context to give a better, more personalized answer
|
||||||
|
* User references something they mentioned before
|
||||||
|
* User asks "what do you know about me?" or similar
|
||||||
|
* Personalization would significantly improve response quality
|
||||||
|
* Before making recommendations that should consider user preferences
|
||||||
|
- Args:
|
||||||
|
- query: Optional search query to find specific memories (e.g., "programming preferences")
|
||||||
|
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
||||||
|
- top_k: Number of memories to retrieve (default: 5)
|
||||||
|
- Returns: Relevant memories formatted as context
|
||||||
|
- IMPORTANT: Use the recalled memories naturally in your response without explicitly
|
||||||
|
stating "Based on your memory..." - integrate the context seamlessly.
|
||||||
</tools>
|
</tools>
|
||||||
<tool_call_examples>
|
<tool_call_examples>
|
||||||
- User: "How do I install SurfSense?"
|
- User: "How do I install SurfSense?"
|
||||||
|
|
@ -136,6 +175,23 @@ You have access to the following tools:
|
||||||
- User: "What did I discuss on Slack last week about the React migration?"
|
- User: "What did I discuss on Slack last week about the React migration?"
|
||||||
- Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")`
|
- Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")`
|
||||||
|
|
||||||
|
- User: "Remember that I prefer TypeScript over JavaScript"
|
||||||
|
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
||||||
|
|
||||||
|
- User: "I'm a data scientist working on ML pipelines"
|
||||||
|
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
||||||
|
|
||||||
|
- User: "Always give me code examples in Python"
|
||||||
|
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
||||||
|
|
||||||
|
- User: "What programming language should I use for this project?"
|
||||||
|
- First recall: `recall_memory(query="programming language preferences")`
|
||||||
|
- Then provide a personalized recommendation based on their preferences
|
||||||
|
|
||||||
|
- User: "What do you know about me?"
|
||||||
|
- Call: `recall_memory(top_k=10)`
|
||||||
|
- Then summarize the stored memories
|
||||||
|
|
||||||
- User: "Give me a podcast about AI trends based on what we discussed"
|
- User: "Give me a podcast about AI trends based on what we discussed"
|
||||||
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
|
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ Available tools:
|
||||||
- link_preview: Fetch rich previews for URLs
|
- link_preview: Fetch rich previews for URLs
|
||||||
- display_image: Display images in chat
|
- display_image: Display images in chat
|
||||||
- scrape_webpage: Extract content from webpages
|
- scrape_webpage: Extract content from webpages
|
||||||
|
- save_memory: Store facts/preferences about the user
|
||||||
|
- recall_memory: Retrieve relevant user memories
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Registry exports
|
# Registry exports
|
||||||
|
|
@ -33,6 +35,7 @@ from .registry import (
|
||||||
)
|
)
|
||||||
from .scrape_webpage import create_scrape_webpage_tool
|
from .scrape_webpage import create_scrape_webpage_tool
|
||||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||||
|
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Registry
|
# Registry
|
||||||
|
|
@ -43,6 +46,8 @@ __all__ = [
|
||||||
"create_display_image_tool",
|
"create_display_image_tool",
|
||||||
"create_generate_podcast_tool",
|
"create_generate_podcast_tool",
|
||||||
"create_link_preview_tool",
|
"create_link_preview_tool",
|
||||||
|
"create_recall_memory_tool",
|
||||||
|
"create_save_memory_tool",
|
||||||
"create_scrape_webpage_tool",
|
"create_scrape_webpage_tool",
|
||||||
"create_search_knowledge_base_tool",
|
"create_search_knowledge_base_tool",
|
||||||
"create_search_surfsense_docs_tool",
|
"create_search_surfsense_docs_tool",
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ from .mcp_tool import load_mcp_tools
|
||||||
from .podcast import create_generate_podcast_tool
|
from .podcast import create_generate_podcast_tool
|
||||||
from .scrape_webpage import create_scrape_webpage_tool
|
from .scrape_webpage import create_scrape_webpage_tool
|
||||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||||
|
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Tool Definition
|
# Tool Definition
|
||||||
|
|
@ -138,6 +139,31 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
requires=["db_session"],
|
requires=["db_session"],
|
||||||
),
|
),
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
# USER MEMORY TOOLS - Claude-like memory feature
|
||||||
|
# =========================================================================
|
||||||
|
# Save memory tool - stores facts/preferences about the user
|
||||||
|
ToolDefinition(
|
||||||
|
name="save_memory",
|
||||||
|
description="Save facts, preferences, or context about the user for personalized responses",
|
||||||
|
factory=lambda deps: create_save_memory_tool(
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
),
|
||||||
|
requires=["user_id", "search_space_id", "db_session"],
|
||||||
|
),
|
||||||
|
# Recall memory tool - retrieves relevant user memories
|
||||||
|
ToolDefinition(
|
||||||
|
name="recall_memory",
|
||||||
|
description="Recall user memories for personalized and contextual responses",
|
||||||
|
factory=lambda deps: create_recall_memory_tool(
|
||||||
|
user_id=deps["user_id"],
|
||||||
|
search_space_id=deps["search_space_id"],
|
||||||
|
db_session=deps["db_session"],
|
||||||
|
),
|
||||||
|
requires=["user_id", "search_space_id", "db_session"],
|
||||||
|
),
|
||||||
|
# =========================================================================
|
||||||
# ADD YOUR CUSTOM TOOLS BELOW
|
# ADD YOUR CUSTOM TOOLS BELOW
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Example:
|
# Example:
|
||||||
|
|
|
||||||
352
surfsense_backend/app/agents/new_chat/tools/user_memory.py
Normal file
352
surfsense_backend/app/agents/new_chat/tools/user_memory.py
Normal file
|
|
@ -0,0 +1,352 @@
|
||||||
|
"""
|
||||||
|
User memory tools for the SurfSense agent.
|
||||||
|
|
||||||
|
This module provides tools for storing and retrieving user memories,
|
||||||
|
enabling personalized AI responses similar to Claude's memory feature.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- save_memory: Store facts, preferences, and context about the user
|
||||||
|
- recall_memory: Retrieve relevant memories using semantic search
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import MemoryCategory, UserMemory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Constants
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Default number of memories to retrieve
|
||||||
|
DEFAULT_RECALL_TOP_K = 5
|
||||||
|
|
||||||
|
# Maximum number of memories per user (to prevent unbounded growth)
|
||||||
|
MAX_MEMORIES_PER_USER = 100
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _to_uuid(user_id: str) -> UUID:
|
||||||
|
"""Convert a string user_id to a UUID object."""
|
||||||
|
if isinstance(user_id, UUID):
|
||||||
|
return user_id
|
||||||
|
return UUID(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_memory_count(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Get the count of memories for a user."""
|
||||||
|
uuid_user_id = _to_uuid(user_id)
|
||||||
|
query = select(UserMemory).where(UserMemory.user_id == uuid_user_id)
|
||||||
|
if search_space_id is not None:
|
||||||
|
query = query.where(
|
||||||
|
(UserMemory.search_space_id == search_space_id)
|
||||||
|
| (UserMemory.search_space_id.is_(None))
|
||||||
|
)
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
return len(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_oldest_memory(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Delete the oldest memory for a user to make room for new ones."""
|
||||||
|
uuid_user_id = _to_uuid(user_id)
|
||||||
|
query = (
|
||||||
|
select(UserMemory)
|
||||||
|
.where(UserMemory.user_id == uuid_user_id)
|
||||||
|
.order_by(UserMemory.updated_at.asc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
if search_space_id is not None:
|
||||||
|
query = query.where(
|
||||||
|
(UserMemory.search_space_id == search_space_id)
|
||||||
|
| (UserMemory.search_space_id.is_(None))
|
||||||
|
)
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
oldest_memory = result.scalars().first()
|
||||||
|
if oldest_memory:
|
||||||
|
await db_session.delete(oldest_memory)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def format_memories_for_context(memories: list[dict[str, Any]]) -> str:
|
||||||
|
"""Format retrieved memories into a readable context string for the LLM."""
|
||||||
|
if not memories:
|
||||||
|
return "No relevant memories found for this user."
|
||||||
|
|
||||||
|
parts = ["<user_memories>"]
|
||||||
|
for memory in memories:
|
||||||
|
category = memory.get("category", "unknown")
|
||||||
|
text = memory.get("memory_text", "")
|
||||||
|
updated = memory.get("updated_at", "")
|
||||||
|
parts.append(
|
||||||
|
f" <memory category='{category}' updated='{updated}'>{text}</memory>"
|
||||||
|
)
|
||||||
|
parts.append("</user_memories>")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tool Factory Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def create_save_memory_tool(
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the save_memory tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID
|
||||||
|
search_space_id: The search space ID (for space-specific memories)
|
||||||
|
db_session: Database session for executing queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured tool function for saving user memories
|
||||||
|
"""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def save_memory(
|
||||||
|
content: str,
|
||||||
|
category: str = "fact",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Save a fact, preference, or context about the user for future reference.
|
||||||
|
|
||||||
|
Use this tool when:
|
||||||
|
- User explicitly says "remember this", "keep this in mind", or similar
|
||||||
|
- User shares personal preferences (e.g., "I prefer Python over JavaScript")
|
||||||
|
- User shares important facts about themselves (name, role, interests, projects)
|
||||||
|
- User gives standing instructions (e.g., "always respond in bullet points")
|
||||||
|
- User shares relevant context (e.g., "I'm working on project X")
|
||||||
|
|
||||||
|
The saved information will be available in future conversations to provide
|
||||||
|
more personalized and contextual responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The fact/preference/context to remember.
|
||||||
|
Phrase it clearly, e.g., "User prefers dark mode",
|
||||||
|
"User is a senior Python developer", "User is working on an AI project"
|
||||||
|
category: Type of memory. One of:
|
||||||
|
- "preference": User preferences (e.g., coding style, tools, formats)
|
||||||
|
- "fact": Facts about the user (e.g., name, role, expertise)
|
||||||
|
- "instruction": Standing instructions (e.g., response format preferences)
|
||||||
|
- "context": Current context (e.g., ongoing projects, goals)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the save status and memory details
|
||||||
|
"""
|
||||||
|
# Normalize and validate category (LLMs may send uppercase)
|
||||||
|
category = category.lower() if category else "fact"
|
||||||
|
valid_categories = ["preference", "fact", "instruction", "context"]
|
||||||
|
if category not in valid_categories:
|
||||||
|
category = "fact"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert user_id to UUID
|
||||||
|
uuid_user_id = _to_uuid(user_id)
|
||||||
|
|
||||||
|
# Check if we've hit the memory limit
|
||||||
|
memory_count = await get_user_memory_count(
|
||||||
|
db_session, user_id, search_space_id
|
||||||
|
)
|
||||||
|
if memory_count >= MAX_MEMORIES_PER_USER:
|
||||||
|
# Delete oldest memory to make room
|
||||||
|
await delete_oldest_memory(db_session, user_id, search_space_id)
|
||||||
|
|
||||||
|
# Generate embedding for the memory
|
||||||
|
embedding = config.embedding_model_instance.embed(content)
|
||||||
|
|
||||||
|
# Create new memory using ORM
|
||||||
|
# The pgvector Vector column type handles embedding conversion automatically
|
||||||
|
new_memory = UserMemory(
|
||||||
|
user_id=uuid_user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
memory_text=content,
|
||||||
|
category=MemoryCategory(category), # Convert string to enum
|
||||||
|
embedding=embedding, # Pass embedding directly (list or numpy array)
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(new_memory)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(new_memory)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "saved",
|
||||||
|
"memory_id": new_memory.id,
|
||||||
|
"memory_text": content,
|
||||||
|
"category": category,
|
||||||
|
"message": f"I'll remember: {content}",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to save memory for user {user_id}: {e}")
|
||||||
|
# Rollback the session to clear any failed transaction state
|
||||||
|
await db_session.rollback()
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"message": "Failed to save memory. Please try again.",
|
||||||
|
}
|
||||||
|
|
||||||
|
return save_memory
|
||||||
|
|
||||||
|
|
||||||
|
def create_recall_memory_tool(
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the recall_memory tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID
|
||||||
|
search_space_id: The search space ID
|
||||||
|
db_session: Database session for executing queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured tool function for recalling user memories
|
||||||
|
"""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def recall_memory(
|
||||||
|
query: str | None = None,
|
||||||
|
category: str | None = None,
|
||||||
|
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Recall relevant memories about the user to provide personalized responses.
|
||||||
|
|
||||||
|
Use this tool when:
|
||||||
|
- You need user context to give a better, more personalized answer
|
||||||
|
- User asks about their preferences or past information they shared
|
||||||
|
- User references something they told you before
|
||||||
|
- Personalization would significantly improve the response quality
|
||||||
|
- User asks "what do you know about me?" or similar
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Optional search query to find specific memories.
|
||||||
|
If not provided, returns the most recent memories.
|
||||||
|
Example: "programming preferences", "current projects"
|
||||||
|
category: Optional category filter. One of:
|
||||||
|
"preference", "fact", "instruction", "context"
|
||||||
|
If not provided, searches all categories.
|
||||||
|
top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing relevant memories and formatted context
|
||||||
|
"""
|
||||||
|
top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert user_id to UUID
|
||||||
|
uuid_user_id = _to_uuid(user_id)
|
||||||
|
|
||||||
|
if query:
|
||||||
|
# Semantic search using embeddings
|
||||||
|
query_embedding = config.embedding_model_instance.embed(query)
|
||||||
|
|
||||||
|
# Build query with vector similarity
|
||||||
|
stmt = (
|
||||||
|
select(UserMemory)
|
||||||
|
.where(UserMemory.user_id == uuid_user_id)
|
||||||
|
.where(
|
||||||
|
(UserMemory.search_space_id == search_space_id)
|
||||||
|
| (UserMemory.search_space_id.is_(None))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add category filter if specified
|
||||||
|
if category and category in [
|
||||||
|
"preference",
|
||||||
|
"fact",
|
||||||
|
"instruction",
|
||||||
|
"context",
|
||||||
|
]:
|
||||||
|
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
|
||||||
|
|
||||||
|
# Order by vector similarity
|
||||||
|
stmt = stmt.order_by(
|
||||||
|
UserMemory.embedding.op("<=>")(query_embedding)
|
||||||
|
).limit(top_k)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No query - return most recent memories
|
||||||
|
stmt = (
|
||||||
|
select(UserMemory)
|
||||||
|
.where(UserMemory.user_id == uuid_user_id)
|
||||||
|
.where(
|
||||||
|
(UserMemory.search_space_id == search_space_id)
|
||||||
|
| (UserMemory.search_space_id.is_(None))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add category filter if specified
|
||||||
|
if category and category in [
|
||||||
|
"preference",
|
||||||
|
"fact",
|
||||||
|
"instruction",
|
||||||
|
"context",
|
||||||
|
]:
|
||||||
|
stmt = stmt.where(UserMemory.category == MemoryCategory(category))
|
||||||
|
|
||||||
|
stmt = stmt.order_by(UserMemory.updated_at.desc()).limit(top_k)
|
||||||
|
|
||||||
|
result = await db_session.execute(stmt)
|
||||||
|
memories = result.scalars().all()
|
||||||
|
|
||||||
|
# Format memories for response
|
||||||
|
memory_list = [
|
||||||
|
{
|
||||||
|
"id": m.id,
|
||||||
|
"memory_text": m.memory_text,
|
||||||
|
"category": m.category.value if m.category else "unknown",
|
||||||
|
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
||||||
|
}
|
||||||
|
for m in memories
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_context = format_memories_for_context(memory_list)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"count": len(memory_list),
|
||||||
|
"memories": memory_list,
|
||||||
|
"formatted_context": formatted_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to recall memories for user {user_id}: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"memories": [],
|
||||||
|
"formatted_context": "Failed to recall memories.",
|
||||||
|
}
|
||||||
|
|
||||||
|
return recall_memory
|
||||||
|
|
@ -472,6 +472,66 @@ class ChatCommentMention(BaseModel, TimestampMixin):
|
||||||
mentioned_user = relationship("User")
|
mentioned_user = relationship("User")
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCategory(str, Enum):
|
||||||
|
"""Categories for user memories."""
|
||||||
|
|
||||||
|
# Using lowercase keys to match PostgreSQL enum values
|
||||||
|
preference = "preference" # User preferences (e.g., "prefers dark mode")
|
||||||
|
fact = "fact" # Facts about the user (e.g., "is a Python developer")
|
||||||
|
instruction = (
|
||||||
|
"instruction" # Standing instructions (e.g., "always respond in bullet points")
|
||||||
|
)
|
||||||
|
context = "context" # Contextual information (e.g., "working on project X")
|
||||||
|
|
||||||
|
|
||||||
|
class UserMemory(BaseModel, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Stores facts, preferences, and context about users for personalized AI responses.
|
||||||
|
Similar to Claude's memory feature - enables the AI to remember user information
|
||||||
|
across conversations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "user_memories"
|
||||||
|
|
||||||
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("user.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
# Optional association with a search space (if memory is space-specific)
|
||||||
|
search_space_id = Column(
|
||||||
|
Integer,
|
||||||
|
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The actual memory content
|
||||||
|
memory_text = Column(Text, nullable=False)
|
||||||
|
# Category for organization and filtering
|
||||||
|
category = Column(
|
||||||
|
SQLAlchemyEnum(MemoryCategory),
|
||||||
|
nullable=False,
|
||||||
|
default=MemoryCategory.fact,
|
||||||
|
)
|
||||||
|
# Vector embedding for semantic search
|
||||||
|
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||||
|
|
||||||
|
# Track when memory was last updated
|
||||||
|
updated_at = Column(
|
||||||
|
TIMESTAMP(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
default=lambda: datetime.now(UTC),
|
||||||
|
onupdate=lambda: datetime.now(UTC),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
user = relationship("User", back_populates="memories")
|
||||||
|
search_space = relationship("SearchSpace", back_populates="user_memories")
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel, TimestampMixin):
|
class Document(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "documents"
|
__tablename__ = "documents"
|
||||||
|
|
||||||
|
|
@ -659,6 +719,14 @@ class SearchSpace(BaseModel, TimestampMixin):
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# User memories associated with this search space
|
||||||
|
user_memories = relationship(
|
||||||
|
"UserMemory",
|
||||||
|
back_populates="search_space",
|
||||||
|
order_by="UserMemory.updated_at.desc()",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "search_source_connectors"
|
__tablename__ = "search_source_connectors"
|
||||||
|
|
@ -967,6 +1035,14 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# User memories for personalized AI responses
|
||||||
|
memories = relationship(
|
||||||
|
"UserMemory",
|
||||||
|
back_populates="user",
|
||||||
|
order_by="UserMemory.updated_at.desc()",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
# Page usage tracking for ETL services
|
# Page usage tracking for ETL services
|
||||||
pages_limit = Column(
|
pages_limit = Column(
|
||||||
Integer,
|
Integer,
|
||||||
|
|
@ -1010,6 +1086,14 @@ else:
|
||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# User memories for personalized AI responses
|
||||||
|
memories = relationship(
|
||||||
|
"UserMemory",
|
||||||
|
back_populates="user",
|
||||||
|
order_by="UserMemory.updated_at.desc()",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
# Page usage tracking for ETL services
|
# Page usage tracking for ETL services
|
||||||
pages_limit = Column(
|
pages_limit = Column(
|
||||||
Integer,
|
Integer,
|
||||||
|
|
|
||||||
|
|
@ -990,6 +990,7 @@ async def handle_new_chat(
|
||||||
search_space_id=request.search_space_id,
|
search_space_id=request.search_space_id,
|
||||||
chat_id=request.chat_id,
|
chat_id=request.chat_id,
|
||||||
session=session,
|
session=session,
|
||||||
|
user_id=str(user.id), # Pass user ID for memory tools
|
||||||
llm_config_id=llm_config_id,
|
llm_config_id=llm_config_id,
|
||||||
attachments=request.attachments,
|
attachments=request.attachments,
|
||||||
mentioned_document_ids=request.mentioned_document_ids,
|
mentioned_document_ids=request.mentioned_document_ids,
|
||||||
|
|
|
||||||
|
|
@ -149,6 +149,7 @@ async def stream_new_chat(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
|
user_id: str | None = None,
|
||||||
llm_config_id: int = -1,
|
llm_config_id: int = -1,
|
||||||
attachments: list[ChatAttachment] | None = None,
|
attachments: list[ChatAttachment] | None = None,
|
||||||
mentioned_document_ids: list[int] | None = None,
|
mentioned_document_ids: list[int] | None = None,
|
||||||
|
|
@ -166,6 +167,7 @@ async def stream_new_chat(
|
||||||
search_space_id: The search space ID
|
search_space_id: The search space ID
|
||||||
chat_id: The chat ID (used as LangGraph thread_id for memory)
|
chat_id: The chat ID (used as LangGraph thread_id for memory)
|
||||||
session: The database session
|
session: The database session
|
||||||
|
user_id: The current user's UUID string (for memory tools)
|
||||||
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
||||||
messages: Optional chat history from frontend (list of ChatMessage)
|
messages: Optional chat history from frontend (list of ChatMessage)
|
||||||
attachments: Optional attachments with extracted content
|
attachments: Optional attachments with extracted content
|
||||||
|
|
@ -243,6 +245,7 @@ async def stream_new_chat(
|
||||||
db_session=session,
|
db_session=session,
|
||||||
connector_service=connector_service,
|
connector_service=connector_service,
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
|
user_id=user_id, # Pass user ID for memory tools
|
||||||
agent_config=agent_config, # Pass prompt configuration
|
agent_config=agent_config, # Pass prompt configuration
|
||||||
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured
|
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ import { DisplayImageToolUI } from "@/components/tool-ui/display-image";
|
||||||
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
|
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
|
||||||
import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview";
|
import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview";
|
||||||
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
|
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
|
||||||
|
import { SaveMemoryToolUI, RecallMemoryToolUI } from "@/components/tool-ui/user-memory";
|
||||||
// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos";
|
// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos";
|
||||||
import { getBearerToken } from "@/lib/auth-utils";
|
import { getBearerToken } from "@/lib/auth-utils";
|
||||||
import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter";
|
import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter";
|
||||||
|
|
@ -1056,6 +1057,8 @@ export default function NewChatPage() {
|
||||||
<LinkPreviewToolUI />
|
<LinkPreviewToolUI />
|
||||||
<DisplayImageToolUI />
|
<DisplayImageToolUI />
|
||||||
<ScrapeWebpageToolUI />
|
<ScrapeWebpageToolUI />
|
||||||
|
<SaveMemoryToolUI />
|
||||||
|
<RecallMemoryToolUI />
|
||||||
{/* <WriteTodosToolUI /> Disabled for now */}
|
{/* <WriteTodosToolUI /> Disabled for now */}
|
||||||
<div className="flex flex-col h-[calc(100vh-64px)] overflow-hidden">
|
<div className="flex flex-col h-[calc(100vh-64px)] overflow-hidden">
|
||||||
<Thread
|
<Thread
|
||||||
|
|
|
||||||
|
|
@ -77,4 +77,17 @@ export {
|
||||||
ScrapeWebpageResultSchema,
|
ScrapeWebpageResultSchema,
|
||||||
ScrapeWebpageToolUI,
|
ScrapeWebpageToolUI,
|
||||||
} from "./scrape-webpage";
|
} from "./scrape-webpage";
|
||||||
|
export {
|
||||||
|
type MemoryItem,
|
||||||
|
type RecallMemoryArgs,
|
||||||
|
RecallMemoryArgsSchema,
|
||||||
|
type RecallMemoryResult,
|
||||||
|
RecallMemoryResultSchema,
|
||||||
|
RecallMemoryToolUI,
|
||||||
|
type SaveMemoryArgs,
|
||||||
|
SaveMemoryArgsSchema,
|
||||||
|
type SaveMemoryResult,
|
||||||
|
SaveMemoryResultSchema,
|
||||||
|
SaveMemoryToolUI,
|
||||||
|
} from "./user-memory";
|
||||||
export { type WriteTodosData, WriteTodosSchema, WriteTodosToolUI } from "./write-todos";
|
export { type WriteTodosData, WriteTodosSchema, WriteTodosToolUI } from "./write-todos";
|
||||||
|
|
|
||||||
283
surfsense_web/components/tool-ui/user-memory.tsx
Normal file
283
surfsense_web/components/tool-ui/user-memory.tsx
Normal file
|
|
@ -0,0 +1,283 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { makeAssistantToolUI } from "@assistant-ui/react";
|
||||||
|
import { BrainIcon, CheckIcon, Loader2Icon, SearchIcon, XIcon } from "lucide-react";
|
||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Zod Schemas for save_memory tool
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
const SaveMemoryArgsSchema = z.object({
|
||||||
|
content: z.string(),
|
||||||
|
category: z.string().default("fact"),
|
||||||
|
});
|
||||||
|
|
||||||
|
const SaveMemoryResultSchema = z.object({
|
||||||
|
status: z.enum(["saved", "error"]),
|
||||||
|
memory_id: z.number().nullish(),
|
||||||
|
memory_text: z.string().nullish(),
|
||||||
|
category: z.string().nullish(),
|
||||||
|
message: z.string().nullish(),
|
||||||
|
error: z.string().nullish(),
|
||||||
|
});
|
||||||
|
|
||||||
|
type SaveMemoryArgs = z.infer<typeof SaveMemoryArgsSchema>;
|
||||||
|
type SaveMemoryResult = z.infer<typeof SaveMemoryResultSchema>;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Zod Schemas for recall_memory tool
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
const RecallMemoryArgsSchema = z.object({
|
||||||
|
query: z.string().nullish(),
|
||||||
|
category: z.string().nullish(),
|
||||||
|
top_k: z.number().default(5),
|
||||||
|
});
|
||||||
|
|
||||||
|
const MemoryItemSchema = z.object({
|
||||||
|
id: z.number(),
|
||||||
|
memory_text: z.string(),
|
||||||
|
category: z.string(),
|
||||||
|
updated_at: z.string().nullish(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const RecallMemoryResultSchema = z.object({
|
||||||
|
status: z.enum(["success", "error"]),
|
||||||
|
count: z.number().nullish(),
|
||||||
|
memories: z.array(MemoryItemSchema).nullish(),
|
||||||
|
formatted_context: z.string().nullish(),
|
||||||
|
error: z.string().nullish(),
|
||||||
|
});
|
||||||
|
|
||||||
|
type RecallMemoryArgs = z.infer<typeof RecallMemoryArgsSchema>;
|
||||||
|
type RecallMemoryResult = z.infer<typeof RecallMemoryResultSchema>;
|
||||||
|
type MemoryItem = z.infer<typeof MemoryItemSchema>;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Category badge colors
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
const categoryColors: Record<string, string> = {
|
||||||
|
preference: "bg-blue-500/10 text-blue-600 dark:text-blue-400",
|
||||||
|
fact: "bg-green-500/10 text-green-600 dark:text-green-400",
|
||||||
|
instruction: "bg-purple-500/10 text-purple-600 dark:text-purple-400",
|
||||||
|
context: "bg-orange-500/10 text-orange-600 dark:text-orange-400",
|
||||||
|
};
|
||||||
|
|
||||||
|
function CategoryBadge({ category }: { category: string }) {
|
||||||
|
const colorClass = categoryColors[category] || "bg-gray-500/10 text-gray-600 dark:text-gray-400";
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
className={`inline-flex items-center rounded-full px-2 py-0.5 text-xs font-medium ${colorClass}`}
|
||||||
|
>
|
||||||
|
{category}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Save Memory Tool UI
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
export const SaveMemoryToolUI = makeAssistantToolUI<SaveMemoryArgs, SaveMemoryResult>({
|
||||||
|
toolName: "save_memory",
|
||||||
|
render: function SaveMemoryUI({ args, result, status }) {
|
||||||
|
const isRunning = status.type === "running" || status.type === "requires-action";
|
||||||
|
const isComplete = status.type === "complete";
|
||||||
|
const isError = result?.status === "error";
|
||||||
|
|
||||||
|
// Parse args safely
|
||||||
|
const parsedArgs = SaveMemoryArgsSchema.safeParse(args);
|
||||||
|
const content = parsedArgs.success ? parsedArgs.data.content : "";
|
||||||
|
const category = parsedArgs.success ? parsedArgs.data.category : "fact";
|
||||||
|
|
||||||
|
// Loading state
|
||||||
|
if (isRunning) {
|
||||||
|
return (
|
||||||
|
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||||
|
<div className="flex size-8 items-center justify-center rounded-full bg-primary/10">
|
||||||
|
<Loader2Icon className="size-4 animate-spin text-primary" />
|
||||||
|
</div>
|
||||||
|
<div className="flex-1">
|
||||||
|
<span className="text-sm text-muted-foreground">Saving to memory...</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error state
|
||||||
|
if (isError) {
|
||||||
|
return (
|
||||||
|
<div className="my-3 flex items-center gap-3 rounded-lg border border-destructive/20 bg-destructive/5 px-4 py-3">
|
||||||
|
<div className="flex size-8 items-center justify-center rounded-full bg-destructive/10">
|
||||||
|
<XIcon className="size-4 text-destructive" />
|
||||||
|
</div>
|
||||||
|
<div className="flex-1">
|
||||||
|
<span className="text-sm text-destructive">Failed to save memory</span>
|
||||||
|
{result?.error && <p className="mt-1 text-xs text-destructive/70">{result.error}</p>}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success state
|
||||||
|
if (isComplete && result?.status === "saved") {
|
||||||
|
return (
|
||||||
|
<div className="my-3 flex items-center gap-3 rounded-lg border border-primary/20 bg-primary/5 px-4 py-3">
|
||||||
|
<div className="flex size-8 items-center justify-center rounded-full bg-primary/10">
|
||||||
|
<BrainIcon className="size-4 text-primary" />
|
||||||
|
</div>
|
||||||
|
<div className="flex-1 min-w-0">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<CheckIcon className="size-3 text-green-500 shrink-0" />
|
||||||
|
<span className="text-sm font-medium text-foreground">Memory saved</span>
|
||||||
|
<CategoryBadge category={category} />
|
||||||
|
</div>
|
||||||
|
<p className="mt-1 truncate text-sm text-muted-foreground">{content}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default/incomplete state - show what's being saved
|
||||||
|
if (content) {
|
||||||
|
return (
|
||||||
|
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||||
|
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
||||||
|
<BrainIcon className="size-4 text-muted-foreground" />
|
||||||
|
</div>
|
||||||
|
<div className="flex-1 min-w-0">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-sm text-muted-foreground">Saving memory</span>
|
||||||
|
<CategoryBadge category={category} />
|
||||||
|
</div>
|
||||||
|
<p className="mt-1 truncate text-sm text-muted-foreground">{content}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Recall Memory Tool UI
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
export const RecallMemoryToolUI = makeAssistantToolUI<RecallMemoryArgs, RecallMemoryResult>({
|
||||||
|
toolName: "recall_memory",
|
||||||
|
render: function RecallMemoryUI({ args, result, status }) {
|
||||||
|
const isRunning = status.type === "running" || status.type === "requires-action";
|
||||||
|
const isComplete = status.type === "complete";
|
||||||
|
const isError = result?.status === "error";
|
||||||
|
|
||||||
|
// Parse args safely
|
||||||
|
const parsedArgs = RecallMemoryArgsSchema.safeParse(args);
|
||||||
|
const query = parsedArgs.success ? parsedArgs.data.query : null;
|
||||||
|
|
||||||
|
// Loading state
|
||||||
|
if (isRunning) {
|
||||||
|
return (
|
||||||
|
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||||
|
<div className="flex size-8 items-center justify-center rounded-full bg-primary/10">
|
||||||
|
<Loader2Icon className="size-4 animate-spin text-primary" />
|
||||||
|
</div>
|
||||||
|
<div className="flex-1">
|
||||||
|
<span className="text-sm text-muted-foreground">
|
||||||
|
{query ? `Searching memories for "${query}"...` : "Recalling memories..."}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error state
|
||||||
|
if (isError) {
|
||||||
|
return (
|
||||||
|
<div className="my-3 flex items-center gap-3 rounded-lg border border-destructive/20 bg-destructive/5 px-4 py-3">
|
||||||
|
<div className="flex size-8 items-center justify-center rounded-full bg-destructive/10">
|
||||||
|
<XIcon className="size-4 text-destructive" />
|
||||||
|
</div>
|
||||||
|
<div className="flex-1">
|
||||||
|
<span className="text-sm text-destructive">Failed to recall memories</span>
|
||||||
|
{result?.error && <p className="mt-1 text-xs text-destructive/70">{result.error}</p>}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success state with memories
|
||||||
|
if (isComplete && result?.status === "success") {
|
||||||
|
const memories = result.memories || [];
|
||||||
|
const count = result.count || 0;
|
||||||
|
|
||||||
|
if (count === 0) {
|
||||||
|
return (
|
||||||
|
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||||
|
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
||||||
|
<SearchIcon className="size-4 text-muted-foreground" />
|
||||||
|
</div>
|
||||||
|
<span className="text-sm text-muted-foreground">No memories found</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="my-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||||
|
<div className="flex items-center gap-2 mb-2">
|
||||||
|
<BrainIcon className="size-4 text-primary" />
|
||||||
|
<span className="text-sm font-medium text-foreground">
|
||||||
|
Recalled {count} {count === 1 ? "memory" : "memories"}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-2">
|
||||||
|
{memories.slice(0, 5).map((memory: MemoryItem) => (
|
||||||
|
<div
|
||||||
|
key={memory.id}
|
||||||
|
className="flex items-start gap-2 rounded-md bg-muted/50 px-3 py-2"
|
||||||
|
>
|
||||||
|
<CategoryBadge category={memory.category} />
|
||||||
|
<span className="text-sm text-muted-foreground flex-1">{memory.memory_text}</span>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
{memories.length > 5 && (
|
||||||
|
<p className="text-xs text-muted-foreground">...and {memories.length - 5} more</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default/incomplete state
|
||||||
|
if (query) {
|
||||||
|
return (
|
||||||
|
<div className="my-3 flex items-center gap-3 rounded-lg border bg-card/60 px-4 py-3">
|
||||||
|
<div className="flex size-8 items-center justify-center rounded-full bg-muted">
|
||||||
|
<SearchIcon className="size-4 text-muted-foreground" />
|
||||||
|
</div>
|
||||||
|
<span className="text-sm text-muted-foreground">Searching memories for "{query}"</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Exports
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
export {
|
||||||
|
SaveMemoryArgsSchema,
|
||||||
|
SaveMemoryResultSchema,
|
||||||
|
RecallMemoryArgsSchema,
|
||||||
|
RecallMemoryResultSchema,
|
||||||
|
type SaveMemoryArgs,
|
||||||
|
type SaveMemoryResult,
|
||||||
|
type RecallMemoryArgs,
|
||||||
|
type RecallMemoryResult,
|
||||||
|
type MemoryItem,
|
||||||
|
};
|
||||||
Loading…
Add table
Add a link
Reference in a new issue