mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-21 18:55:16 +02:00
error
This commit is contained in:
parent
48fb38bafc
commit
92aa3f4eab
3 changed files with 72 additions and 20 deletions
|
|
@ -158,46 +158,79 @@ def create_save_memory_tool(
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary with the save status and memory details
|
A dictionary with the save status and memory details
|
||||||
"""
|
"""
|
||||||
# Validate category
|
# Log at the very start
|
||||||
|
logger.info(f">>> SAVE_MEMORY TOOL CALLED: content='{content}', category='{category}'")
|
||||||
|
print(f">>> SAVE_MEMORY TOOL CALLED: content='{content}', category='{category}'")
|
||||||
|
|
||||||
|
# Normalize and validate category (LLMs may send uppercase)
|
||||||
|
category = category.lower() if category else "fact"
|
||||||
valid_categories = ["preference", "fact", "instruction", "context"]
|
valid_categories = ["preference", "fact", "instruction", "context"]
|
||||||
if category not in valid_categories:
|
if category not in valid_categories:
|
||||||
category = "fact"
|
category = "fact"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"save_memory called: user_id={user_id}, search_space_id={search_space_id}, content={content[:50]}...")
|
||||||
|
|
||||||
# Convert user_id to UUID
|
# Convert user_id to UUID
|
||||||
uuid_user_id = _to_uuid(user_id)
|
uuid_user_id = _to_uuid(user_id)
|
||||||
|
logger.info(f"UUID conversion successful: {uuid_user_id}")
|
||||||
|
|
||||||
# Check if we've hit the memory limit
|
# Check if we've hit the memory limit
|
||||||
memory_count = await get_user_memory_count(
|
memory_count = await get_user_memory_count(
|
||||||
db_session, user_id, search_space_id
|
db_session, user_id, search_space_id
|
||||||
)
|
)
|
||||||
|
logger.info(f"Current memory count: {memory_count}")
|
||||||
|
|
||||||
if memory_count >= MAX_MEMORIES_PER_USER:
|
if memory_count >= MAX_MEMORIES_PER_USER:
|
||||||
# Delete oldest memory to make room
|
# Delete oldest memory to make room
|
||||||
await delete_oldest_memory(db_session, user_id, search_space_id)
|
await delete_oldest_memory(db_session, user_id, search_space_id)
|
||||||
|
|
||||||
# Generate embedding for the memory
|
# Generate embedding for the memory
|
||||||
|
logger.info("Generating embedding...")
|
||||||
embedding = config.embedding_model_instance.embed(content)
|
embedding = config.embedding_model_instance.embed(content)
|
||||||
|
logger.info(f"Embedding generated, type: {type(embedding)}, len: {len(embedding) if hasattr(embedding, '__len__') else 'N/A'}")
|
||||||
|
|
||||||
# Map string category to enum
|
# Convert numpy array to list of Python floats for PostgreSQL
|
||||||
category_enum = MemoryCategory(category)
|
import numpy as np
|
||||||
|
if isinstance(embedding, np.ndarray):
|
||||||
|
embedding_list = embedding.tolist()
|
||||||
|
else:
|
||||||
|
embedding_list = list(embedding)
|
||||||
|
|
||||||
# Create new memory
|
# Create new memory using ORM with proper enum handling
|
||||||
new_memory = UserMemory(
|
# Use the enum's value attribute directly
|
||||||
user_id=uuid_user_id,
|
from sqlalchemy import text as sql_text
|
||||||
search_space_id=search_space_id,
|
|
||||||
memory_text=content,
|
now = datetime.now(UTC)
|
||||||
category=category_enum,
|
|
||||||
embedding=embedding,
|
# Use raw SQL with proper parameter binding for asyncpg
|
||||||
updated_at=datetime.now(UTC),
|
insert_sql = sql_text("""
|
||||||
|
INSERT INTO user_memories (user_id, search_space_id, memory_text, category, embedding, updated_at, created_at)
|
||||||
|
VALUES (:user_id, :search_space_id, :memory_text, CAST(:category AS memorycategory), :embedding, :updated_at, :created_at)
|
||||||
|
RETURNING id
|
||||||
|
""")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
insert_sql,
|
||||||
|
{
|
||||||
|
"user_id": uuid_user_id,
|
||||||
|
"search_space_id": search_space_id,
|
||||||
|
"memory_text": content,
|
||||||
|
"category": category, # Already lowercase string
|
||||||
|
"embedding": str(embedding_list), # Convert to string format for pgvector
|
||||||
|
"updated_at": now,
|
||||||
|
"created_at": now,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
new_memory_id = result.scalar_one()
|
||||||
db_session.add(new_memory)
|
|
||||||
|
logger.info("Committing...")
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(new_memory)
|
logger.info(f"Memory saved successfully with id: {new_memory_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "saved",
|
"status": "saved",
|
||||||
"memory_id": new_memory.id,
|
"memory_id": new_memory_id,
|
||||||
"memory_text": content,
|
"memory_text": content,
|
||||||
"category": category,
|
"category": category,
|
||||||
"message": f"I'll remember: {content}",
|
"message": f"I'll remember: {content}",
|
||||||
|
|
@ -205,6 +238,8 @@ def create_save_memory_tool(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Failed to save memory for user {user_id}: {e}")
|
logger.exception(f"Failed to save memory for user {user_id}: {e}")
|
||||||
|
# Rollback the session to clear any failed transaction state
|
||||||
|
await db_session.rollback()
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -260,10 +295,15 @@ def create_recall_memory_tool(
|
||||||
A dictionary containing relevant memories and formatted context
|
A dictionary containing relevant memories and formatted context
|
||||||
"""
|
"""
|
||||||
top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20
|
top_k = min(max(top_k, 1), 20) # Clamp between 1 and 20
|
||||||
|
|
||||||
|
# Log at the very start
|
||||||
|
logger.info(f">>> RECALL_MEMORY TOOL CALLED: query='{query}', category='{category}', top_k={top_k}")
|
||||||
|
print(f">>> RECALL_MEMORY TOOL CALLED: query='{query}', category='{category}', top_k={top_k}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert user_id to UUID
|
# Convert user_id to UUID
|
||||||
uuid_user_id = _to_uuid(user_id)
|
uuid_user_id = _to_uuid(user_id)
|
||||||
|
logger.info(f"Recall memory for user: {uuid_user_id}, search_space: {search_space_id}")
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
# Semantic search using embeddings
|
# Semantic search using embeddings
|
||||||
|
|
@ -307,6 +347,8 @@ def create_recall_memory_tool(
|
||||||
|
|
||||||
result = await db_session.execute(stmt)
|
result = await db_session.execute(stmt)
|
||||||
memories = result.scalars().all()
|
memories = result.scalars().all()
|
||||||
|
|
||||||
|
logger.info(f"Found {len(memories)} memories")
|
||||||
|
|
||||||
# Format memories for response
|
# Format memories for response
|
||||||
memory_list = [
|
memory_list = [
|
||||||
|
|
@ -318,8 +360,12 @@ def create_recall_memory_tool(
|
||||||
}
|
}
|
||||||
for m in memories
|
for m in memories
|
||||||
]
|
]
|
||||||
|
|
||||||
|
logger.info(f"Formatted memory list: {memory_list}")
|
||||||
|
|
||||||
formatted_context = format_memories_for_context(memory_list)
|
formatted_context = format_memories_for_context(memory_list)
|
||||||
|
|
||||||
|
logger.info(f"Returning {len(memory_list)} memories")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
|
|
@ -329,6 +375,8 @@ def create_recall_memory_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to recall memories for user {user_id}: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
|
||||||
|
|
@ -475,10 +475,11 @@ class ChatCommentMention(BaseModel, TimestampMixin):
|
||||||
class MemoryCategory(str, Enum):
|
class MemoryCategory(str, Enum):
|
||||||
"""Categories for user memories."""
|
"""Categories for user memories."""
|
||||||
|
|
||||||
PREFERENCE = "preference" # User preferences (e.g., "prefers dark mode")
|
# Using lowercase keys to match PostgreSQL enum values
|
||||||
FACT = "fact" # Facts about the user (e.g., "is a Python developer")
|
preference = "preference" # User preferences (e.g., "prefers dark mode")
|
||||||
INSTRUCTION = "instruction" # Standing instructions (e.g., "always respond in bullet points")
|
fact = "fact" # Facts about the user (e.g., "is a Python developer")
|
||||||
CONTEXT = "context" # Contextual information (e.g., "working on project X")
|
instruction = "instruction" # Standing instructions (e.g., "always respond in bullet points")
|
||||||
|
context = "context" # Contextual information (e.g., "working on project X")
|
||||||
|
|
||||||
|
|
||||||
class UserMemory(BaseModel, TimestampMixin):
|
class UserMemory(BaseModel, TimestampMixin):
|
||||||
|
|
@ -510,7 +511,7 @@ class UserMemory(BaseModel, TimestampMixin):
|
||||||
category = Column(
|
category = Column(
|
||||||
SQLAlchemyEnum(MemoryCategory),
|
SQLAlchemyEnum(MemoryCategory),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=MemoryCategory.FACT,
|
default=MemoryCategory.fact,
|
||||||
)
|
)
|
||||||
# Vector embedding for semantic search
|
# Vector embedding for semantic search
|
||||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ import { DisplayImageToolUI } from "@/components/tool-ui/display-image";
|
||||||
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
|
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
|
||||||
import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview";
|
import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview";
|
||||||
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
|
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
|
||||||
|
import { SaveMemoryToolUI, RecallMemoryToolUI } from "@/components/tool-ui/user-memory";
|
||||||
// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos";
|
// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos";
|
||||||
import { getBearerToken } from "@/lib/auth-utils";
|
import { getBearerToken } from "@/lib/auth-utils";
|
||||||
import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter";
|
import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter";
|
||||||
|
|
@ -1056,6 +1057,8 @@ export default function NewChatPage() {
|
||||||
<LinkPreviewToolUI />
|
<LinkPreviewToolUI />
|
||||||
<DisplayImageToolUI />
|
<DisplayImageToolUI />
|
||||||
<ScrapeWebpageToolUI />
|
<ScrapeWebpageToolUI />
|
||||||
|
<SaveMemoryToolUI />
|
||||||
|
<RecallMemoryToolUI />
|
||||||
{/* <WriteTodosToolUI /> Disabled for now */}
|
{/* <WriteTodosToolUI /> Disabled for now */}
|
||||||
<div className="flex flex-col h-[calc(100vh-64px)] overflow-hidden">
|
<div className="flex flex-col h-[calc(100vh-64px)] overflow-hidden">
|
||||||
<Thread
|
<Thread
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue