This commit is contained in:
Manoj Aggarwal 2026-01-20 15:34:01 -08:00
parent 48fb38bafc
commit 92aa3f4eab
3 changed files with 72 additions and 20 deletions

View file

@ -158,46 +158,79 @@ def create_save_memory_tool(
Returns: Returns:
A dictionary with the save status and memory details A dictionary with the save status and memory details
""" """
# Validate category # Log at the very start
logger.info(f">>> SAVE_MEMORY TOOL CALLED: content='{content}', category='{category}'")
print(f">>> SAVE_MEMORY TOOL CALLED: content='{content}', category='{category}'")
# Normalize and validate category (LLMs may send uppercase)
category = category.lower() if category else "fact"
valid_categories = ["preference", "fact", "instruction", "context"] valid_categories = ["preference", "fact", "instruction", "context"]
if category not in valid_categories: if category not in valid_categories:
category = "fact" category = "fact"
try: try:
logger.info(f"save_memory called: user_id={user_id}, search_space_id={search_space_id}, content={content[:50]}...")
# Convert user_id to UUID # Convert user_id to UUID
uuid_user_id = _to_uuid(user_id) uuid_user_id = _to_uuid(user_id)
logger.info(f"UUID conversion successful: {uuid_user_id}")
# Check if we've hit the memory limit # Check if we've hit the memory limit
memory_count = await get_user_memory_count( memory_count = await get_user_memory_count(
db_session, user_id, search_space_id db_session, user_id, search_space_id
) )
logger.info(f"Current memory count: {memory_count}")
if memory_count >= MAX_MEMORIES_PER_USER: if memory_count >= MAX_MEMORIES_PER_USER:
# Delete oldest memory to make room # Delete oldest memory to make room
await delete_oldest_memory(db_session, user_id, search_space_id) await delete_oldest_memory(db_session, user_id, search_space_id)
# Generate embedding for the memory # Generate embedding for the memory
logger.info("Generating embedding...")
embedding = config.embedding_model_instance.embed(content) embedding = config.embedding_model_instance.embed(content)
logger.info(f"Embedding generated, type: {type(embedding)}, len: {len(embedding) if hasattr(embedding, '__len__') else 'N/A'}")
# Map string category to enum # Convert numpy array to list of Python floats for PostgreSQL
category_enum = MemoryCategory(category) import numpy as np
if isinstance(embedding, np.ndarray):
embedding_list = embedding.tolist()
else:
embedding_list = list(embedding)
# Create new memory # Create new memory using ORM with proper enum handling
new_memory = UserMemory( # Use the enum's value attribute directly
user_id=uuid_user_id, from sqlalchemy import text as sql_text
search_space_id=search_space_id,
memory_text=content, now = datetime.now(UTC)
category=category_enum,
embedding=embedding, # Use raw SQL with proper parameter binding for asyncpg
updated_at=datetime.now(UTC), insert_sql = sql_text("""
INSERT INTO user_memories (user_id, search_space_id, memory_text, category, embedding, updated_at, created_at)
VALUES (:user_id, :search_space_id, :memory_text, CAST(:category AS memorycategory), :embedding, :updated_at, :created_at)
RETURNING id
""")
result = await db_session.execute(
insert_sql,
{
"user_id": uuid_user_id,
"search_space_id": search_space_id,
"memory_text": content,
"category": category, # Already lowercase string
"embedding": str(embedding_list), # Convert to string format for pgvector
"updated_at": now,
"created_at": now,
}
) )
new_memory_id = result.scalar_one()
db_session.add(new_memory) logger.info("Committing...")
await db_session.commit() await db_session.commit()
await db_session.refresh(new_memory) logger.info(f"Memory saved successfully with id: {new_memory_id}")
return { return {
"status": "saved", "status": "saved",
"memory_id": new_memory.id, "memory_id": new_memory_id,
"memory_text": content, "memory_text": content,
"category": category, "category": category,
"message": f"I'll remember: {content}", "message": f"I'll remember: {content}",
@ -205,6 +238,8 @@ def create_save_memory_tool(
except Exception as e: except Exception as e:
logger.exception(f"Failed to save memory for user {user_id}: {e}") logger.exception(f"Failed to save memory for user {user_id}: {e}")
# Rollback the session to clear any failed transaction state
await db_session.rollback()
return { return {
"status": "error", "status": "error",
"error": str(e), "error": str(e),
@ -261,9 +296,14 @@ def create_recall_memory_tool(
""" """
top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20 top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20
# Log at the very start
logger.info(f">>> RECALL_MEMORY TOOL CALLED: query='{query}', category='{category}', top_k={top_k}")
print(f">>> RECALL_MEMORY TOOL CALLED: query='{query}', category='{category}', top_k={top_k}")
try: try:
# Convert user_id to UUID # Convert user_id to UUID
uuid_user_id = _to_uuid(user_id) uuid_user_id = _to_uuid(user_id)
logger.info(f"Recall memory for user: {uuid_user_id}, search_space: {search_space_id}")
if query: if query:
# Semantic search using embeddings # Semantic search using embeddings
@ -308,6 +348,8 @@ def create_recall_memory_tool(
result = await db_session.execute(stmt) result = await db_session.execute(stmt)
memories = result.scalars().all() memories = result.scalars().all()
logger.info(f"Found {len(memories)} memories")
# Format memories for response # Format memories for response
memory_list = [ memory_list = [
{ {
@ -319,8 +361,12 @@ def create_recall_memory_tool(
for m in memories for m in memories
] ]
logger.info(f"Formatted memory list: {memory_list}")
formatted_context = format_memories_for_context(memory_list) formatted_context = format_memories_for_context(memory_list)
logger.info(f"Returning {len(memory_list)} memories")
return { return {
"status": "success", "status": "success",
"count": len(memory_list), "count": len(memory_list),
@ -329,6 +375,8 @@ def create_recall_memory_tool(
} }
except Exception as e: except Exception as e:
logger.exception(f"Failed to recall memories for user {user_id}: {e}")
await db_session.rollback()
return { return {
"status": "error", "status": "error",
"error": str(e), "error": str(e),

View file

@ -475,10 +475,11 @@ class ChatCommentMention(BaseModel, TimestampMixin):
class MemoryCategory(str, Enum): class MemoryCategory(str, Enum):
"""Categories for user memories.""" """Categories for user memories."""
PREFERENCE = "preference" # User preferences (e.g., "prefers dark mode") # Using lowercase keys to match PostgreSQL enum values
FACT = "fact" # Facts about the user (e.g., "is a Python developer") preference = "preference" # User preferences (e.g., "prefers dark mode")
INSTRUCTION = "instruction" # Standing instructions (e.g., "always respond in bullet points") fact = "fact" # Facts about the user (e.g., "is a Python developer")
CONTEXT = "context" # Contextual information (e.g., "working on project X") instruction = "instruction" # Standing instructions (e.g., "always respond in bullet points")
context = "context" # Contextual information (e.g., "working on project X")
class UserMemory(BaseModel, TimestampMixin): class UserMemory(BaseModel, TimestampMixin):
@ -510,7 +511,7 @@ class UserMemory(BaseModel, TimestampMixin):
category = Column( category = Column(
SQLAlchemyEnum(MemoryCategory), SQLAlchemyEnum(MemoryCategory),
nullable=False, nullable=False,
default=MemoryCategory.FACT, default=MemoryCategory.fact,
) )
# Vector embedding for semantic search # Vector embedding for semantic search
embedding = Column(Vector(config.embedding_model_instance.dimension)) embedding = Column(Vector(config.embedding_model_instance.dimension))

View file

@ -32,6 +32,7 @@ import { DisplayImageToolUI } from "@/components/tool-ui/display-image";
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast"; import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview"; import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview";
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage"; import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
import { SaveMemoryToolUI, RecallMemoryToolUI } from "@/components/tool-ui/user-memory";
// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos"; // import { WriteTodosToolUI } from "@/components/tool-ui/write-todos";
import { getBearerToken } from "@/lib/auth-utils"; import { getBearerToken } from "@/lib/auth-utils";
import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter"; import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter";
@ -1056,6 +1057,8 @@ export default function NewChatPage() {
<LinkPreviewToolUI /> <LinkPreviewToolUI />
<DisplayImageToolUI /> <DisplayImageToolUI />
<ScrapeWebpageToolUI /> <ScrapeWebpageToolUI />
<SaveMemoryToolUI />
<RecallMemoryToolUI />
{/* <WriteTodosToolUI /> Disabled for now */} {/* <WriteTodosToolUI /> Disabled for now */}
<div className="flex flex-col h-[calc(100vh-64px)] overflow-hidden"> <div className="flex flex-col h-[calc(100vh-64px)] overflow-hidden">
<Thread <Thread