chore: ran linting

This commit is contained in:
Anish Sarkar 2026-04-09 18:10:34 +05:30
parent b8091114b5
commit f38ea77940
14 changed files with 137 additions and 111 deletions

View file

@ -60,11 +60,21 @@ def downgrade() -> None:
); );
""" """
) )
op.execute("CREATE INDEX IF NOT EXISTS ix_user_memories_user_id ON user_memories(user_id);") op.execute(
op.execute("CREATE INDEX IF NOT EXISTS ix_user_memories_search_space_id ON user_memories(search_space_id);") "CREATE INDEX IF NOT EXISTS ix_user_memories_user_id ON user_memories(user_id);"
op.execute("CREATE INDEX IF NOT EXISTS ix_user_memories_updated_at ON user_memories(updated_at);") )
op.execute("CREATE INDEX IF NOT EXISTS ix_user_memories_category ON user_memories(category);") op.execute(
op.execute("CREATE INDEX IF NOT EXISTS ix_user_memories_user_search_space ON user_memories(user_id, search_space_id);") "CREATE INDEX IF NOT EXISTS ix_user_memories_search_space_id ON user_memories(search_space_id);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_user_memories_updated_at ON user_memories(updated_at);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_user_memories_category ON user_memories(category);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_user_memories_user_search_space ON user_memories(user_id, search_space_id);"
)
op.execute( op.execute(
"CREATE INDEX IF NOT EXISTS user_memories_vector_index ON user_memories USING hnsw (embedding public.vector_cosine_ops);" "CREATE INDEX IF NOT EXISTS user_memories_vector_index ON user_memories USING hnsw (embedding public.vector_cosine_ops);"
) )
@ -83,9 +93,15 @@ def downgrade() -> None:
); );
""" """
) )
op.execute("CREATE INDEX IF NOT EXISTS ix_shared_memories_search_space_id ON shared_memories(search_space_id);") op.execute(
op.execute("CREATE INDEX IF NOT EXISTS ix_shared_memories_updated_at ON shared_memories(updated_at);") "CREATE INDEX IF NOT EXISTS ix_shared_memories_search_space_id ON shared_memories(search_space_id);"
op.execute("CREATE INDEX IF NOT EXISTS ix_shared_memories_created_by_id ON shared_memories(created_by_id);") )
op.execute(
"CREATE INDEX IF NOT EXISTS ix_shared_memories_updated_at ON shared_memories(updated_at);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_shared_memories_created_by_id ON shared_memories(created_by_id);"
)
op.execute( op.execute(
"CREATE INDEX IF NOT EXISTS shared_memories_vector_index ON shared_memories USING hnsw (embedding public.vector_cosine_ops);" "CREATE INDEX IF NOT EXISTS shared_memories_vector_index ON shared_memories USING hnsw (embedding public.vector_cosine_ops);"
) )

View file

@ -106,9 +106,7 @@ async def _call_extraction_llm(
config={"tags": ["surfsense:internal", "memory-extraction"]}, config={"tags": ["surfsense:internal", "memory-extraction"]},
) )
text = ( text = (
response.content response.content if isinstance(response.content, str) else str(response.content)
if isinstance(response.content, str)
else str(response.content)
).strip() ).strip()
if text == "NO_UPDATE" or not text: if text == "NO_UPDATE" or not text:
@ -155,9 +153,7 @@ async def _extract_user_memory(
uid = UUID(user_id) if isinstance(user_id, str) else user_id uid = UUID(user_id) if isinstance(user_id, str) else user_id
async with shielded_async_session() as session: async with shielded_async_session() as session:
result = await session.execute( result = await session.execute(select(User).where(User.id == uid))
select(User).where(User.id == uid)
)
user = result.scalars().first() user = result.scalars().first()
if not user: if not user:
return return

View file

@ -91,9 +91,7 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
return {"messages": new_messages} return {"messages": new_messages}
async def _load_user_memory( async def _load_user_memory(self, session: AsyncSession) -> tuple[str | None, bool]:
self, session: AsyncSession
) -> tuple[str | None, bool]:
"""Return (memory_content, is_persisted). """Return (memory_content, is_persisted).
When the user has no saved memory but has a display name, a seed When the user has no saved memory but has a display name, a seed
@ -102,9 +100,7 @@ class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
""" """
try: try:
result = await session.execute( result = await session.execute(
select(User.memory_md, User.display_name).where( select(User.memory_md, User.display_name).where(User.id == self.user_id)
User.id == self.user_id
)
) )
row = result.one_or_none() row = result.one_or_none()
if row is None: if row is None:

View file

@ -228,7 +228,13 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
llm=deps.get("llm"), llm=deps.get("llm"),
) )
), ),
requires=["user_id", "search_space_id", "db_session", "thread_visibility", "llm"], requires=[
"user_id",
"search_space_id",
"db_session",
"thread_visibility",
"llm",
],
), ),
# ========================================================================= # =========================================================================
# LINEAR TOOLS - create, update, delete issues # LINEAR TOOLS - create, update, delete issues

View file

@ -42,6 +42,7 @@ _SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
# Pinned-section helpers # Pinned-section helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _extract_pinned_headings(memory: str) -> set[str]: def _extract_pinned_headings(memory: str) -> set[str]:
"""Return the set of ``## …`` headings that contain ``(pinned)``.""" """Return the set of ``## …`` headings that contain ``(pinned)``."""
return set(_PINNED_RE.findall(memory)) return set(_PINNED_RE.findall(memory))
@ -59,9 +60,7 @@ def _extract_section_map(memory: str) -> dict[str, str]:
return sections return sections
def _validate_pinned_preserved( def _validate_pinned_preserved(old_memory: str | None, new_memory: str) -> str | None:
old_memory: str | None, new_memory: str
) -> str | None:
"""Return an error message if pinned headings from *old_memory* are missing """Return an error message if pinned headings from *old_memory* are missing
in *new_memory*, else ``None``.""" in *new_memory*, else ``None``."""
if not old_memory: if not old_memory:
@ -81,9 +80,7 @@ def _validate_pinned_preserved(
return None return None
def _restore_missing_pinned( def _restore_missing_pinned(old_memory: str, consolidated: str) -> str:
old_memory: str, consolidated: str
) -> str:
"""Prepend any pinned sections from *old_memory* that are absent in """Prepend any pinned sections from *old_memory* that are absent in
*consolidated*.""" *consolidated*."""
old_pinned = _extract_pinned_headings(old_memory) old_pinned = _extract_pinned_headings(old_memory)
@ -109,14 +106,13 @@ def _restore_missing_pinned(
# Diff validation # Diff validation
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _extract_headings(memory: str) -> set[str]: def _extract_headings(memory: str) -> set[str]:
"""Return all ``## …`` heading texts (without the ``## `` prefix).""" """Return all ``## …`` heading texts (without the ``## `` prefix)."""
return set(_SECTION_HEADING_RE.findall(memory)) return set(_SECTION_HEADING_RE.findall(memory))
def _validate_diff( def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
old_memory: str | None, new_memory: str
) -> list[str]:
"""Return a list of warning strings about suspicious changes.""" """Return a list of warning strings about suspicious changes."""
if not old_memory: if not old_memory:
return [] return []
@ -146,6 +142,7 @@ def _validate_diff(
# Size validation & soft warning # Size validation & soft warning
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _validate_memory_size(content: str) -> dict[str, Any] | None: def _validate_memory_size(content: str) -> dict[str, Any] | None:
"""Return an error/warning dict if *content* is too large, else None.""" """Return an error/warning dict if *content* is too large, else None."""
length = len(content) length = len(content)
@ -199,17 +196,13 @@ RULES:
</memory_document>""" </memory_document>"""
async def _auto_consolidate( async def _auto_consolidate(content: str, llm: Any) -> str | None:
content: str, llm: Any
) -> str | None:
"""Use a focused LLM call to consolidate *content* under the soft limit. """Use a focused LLM call to consolidate *content* under the soft limit.
Returns the consolidated string, or ``None`` if consolidation fails. Returns the consolidated string, or ``None`` if consolidation fails.
""" """
try: try:
prompt = _CONSOLIDATION_PROMPT.format( prompt = _CONSOLIDATION_PROMPT.format(target=MEMORY_SOFT_LIMIT, content=content)
target=MEMORY_SOFT_LIMIT, content=content
)
response = await llm.ainvoke( response = await llm.ainvoke(
[HumanMessage(content=prompt)], [HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]}, config={"tags": ["surfsense:internal"]},
@ -229,6 +222,7 @@ async def _auto_consolidate(
# Shared save-and-respond logic # Shared save-and-respond logic
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _save_memory( async def _save_memory(
*, *,
updated_memory: str, updated_memory: str,
@ -295,12 +289,13 @@ async def _save_memory(
return {"status": "error", "message": f"Failed to update {label}: {e}"} return {"status": "error", "message": f"Failed to update {label}: {e}"}
# --- build response --- # --- build response ---
resp: dict[str, Any] = {"status": "saved", "message": f"{label.capitalize()} updated."} resp: dict[str, Any] = {
"status": "saved",
"message": f"{label.capitalize()} updated.",
}
if content is not updated_memory: if content is not updated_memory:
resp["notice"] = ( resp["notice"] = "Memory was automatically consolidated to fit within limits."
"Memory was automatically consolidated to fit within limits."
)
diff_warnings = _validate_diff(old_memory, content) diff_warnings = _validate_diff(old_memory, content)
if diff_warnings: if diff_warnings:
@ -317,6 +312,7 @@ async def _save_memory(
# Tool factories # Tool factories
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def create_update_memory_tool( def create_update_memory_tool(
user_id: str | UUID, user_id: str | UUID,
db_session: AsyncSession, db_session: AsyncSession,
@ -338,9 +334,7 @@ def create_update_memory_tool(
updated_memory: The FULL updated markdown document (not a diff). updated_memory: The FULL updated markdown document (not a diff).
""" """
try: try:
result = await db_session.execute( result = await db_session.execute(select(User).where(User.id == uid))
select(User).where(User.id == uid)
)
user = result.scalars().first() user = result.scalars().first()
if not user: if not user:
return {"status": "error", "message": "User not found."} return {"status": "error", "message": "User not found."}

View file

@ -257,7 +257,10 @@ async def update_search_space(
update_data = search_space_update.model_dump(exclude_unset=True) update_data = search_space_update.model_dump(exclude_unset=True)
if "shared_memory_md" in update_data and len(update_data["shared_memory_md"] or "") > MEMORY_HARD_LIMIT: if (
"shared_memory_md" in update_data
and len(update_data["shared_memory_md"] or "") > MEMORY_HARD_LIMIT
):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Team memory exceeds {MEMORY_HARD_LIMIT:,} character limit.", detail=f"Team memory exceeds {MEMORY_HARD_LIMIT:,} character limit.",

View file

@ -29,7 +29,6 @@ from sqlalchemy.future import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.memory_extraction import extract_and_save_memory
from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.llm_config import ( from app.agents.new_chat.llm_config import (
AgentConfig, AgentConfig,
@ -38,6 +37,7 @@ from app.agents.new_chat.llm_config import (
load_agent_config, load_agent_config,
load_llm_config_from_yaml, load_llm_config_from_yaml,
) )
from app.agents.new_chat.memory_extraction import extract_and_save_memory
from app.db import ( from app.db import (
ChatVisibility, ChatVisibility,
NewChatMessage, NewChatMessage,

View file

@ -4,9 +4,9 @@ import { Info } from "lucide-react";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import { z } from "zod"; import { z } from "zod";
import { PlateEditor } from "@/components/editor/plate-editor";
import { Alert, AlertDescription } from "@/components/ui/alert"; import { Alert, AlertDescription } from "@/components/ui/alert";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { PlateEditor } from "@/components/editor/plate-editor";
import { Spinner } from "@/components/ui/spinner"; import { Spinner } from "@/components/ui/spinner";
import { baseApiService } from "@/lib/apis/base-api.service"; import { baseApiService } from "@/lib/apis/base-api.service";
@ -99,7 +99,10 @@ export function MemoryContent() {
<Alert className="bg-muted/50 py-3 md:py-4"> <Alert className="bg-muted/50 py-3 md:py-4">
<Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" /> <Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
<AlertDescription className="text-xs md:text-sm"> <AlertDescription className="text-xs md:text-sm">
<p>SurfSense uses this personal memory to personalize your responses across all conversations. Supports <span className="font-medium">Markdown</span> formatting.</p> <p>
SurfSense uses this personal memory to personalize your responses across all
conversations. Supports <span className="font-medium">Markdown</span> formatting.
</p>
</AlertDescription> </AlertDescription>
</Alert> </Alert>

View file

@ -1,7 +1,17 @@
"use client"; "use client";
import { useAtom } from "jotai"; import { useAtom } from "jotai";
import { Bot, BookMarked, Brain, Eye, FileText, Globe, ImageIcon, MessageSquare, Shield } from "lucide-react"; import {
BookMarked,
Bot,
Brain,
Eye,
FileText,
Globe,
ImageIcon,
MessageSquare,
Shield,
} from "lucide-react";
import dynamic from "next/dynamic"; import dynamic from "next/dynamic";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import type React from "react"; import type React from "react";
@ -60,7 +70,10 @@ const PublicChatSnapshotsManager = dynamic(
{ ssr: false } { ssr: false }
); );
const TeamMemoryManager = dynamic( const TeamMemoryManager = dynamic(
() => import("@/components/settings/team-memory-manager").then(m => ({ default: m.TeamMemoryManager })), () =>
import("@/components/settings/team-memory-manager").then((m) => ({
default: m.TeamMemoryManager,
})),
{ ssr: false } { ssr: false }
); );

View file

@ -6,9 +6,9 @@ import { Info } from "lucide-react";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import { updateSearchSpaceMutationAtom } from "@/atoms/search-spaces/search-space-mutation.atoms"; import { updateSearchSpaceMutationAtom } from "@/atoms/search-spaces/search-space-mutation.atoms";
import { PlateEditor } from "@/components/editor/plate-editor";
import { Alert, AlertDescription } from "@/components/ui/alert"; import { Alert, AlertDescription } from "@/components/ui/alert";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { PlateEditor } from "@/components/editor/plate-editor";
import { Spinner } from "@/components/ui/spinner"; import { Spinner } from "@/components/ui/spinner";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys"; import { cacheKeys } from "@/lib/query-client/cache-keys";
@ -20,10 +20,7 @@ interface TeamMemoryManagerProps {
} }
export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) { export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
const { const { data: searchSpace, isLoading: loading } = useQuery({
data: searchSpace,
isLoading: loading,
} = useQuery({
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()), queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
queryFn: () => searchSpacesApiService.getSearchSpace({ id: searchSpaceId }), queryFn: () => searchSpacesApiService.getSearchSpace({ id: searchSpaceId }),
enabled: !!searchSpaceId, enabled: !!searchSpaceId,
@ -45,9 +42,7 @@ export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
setMemory(trimmed); setMemory(trimmed);
}, []); }, []);
const hasChanges = const hasChanges = !!searchSpace && (searchSpace.shared_memory_md || "") !== memory;
!!searchSpace &&
(searchSpace.shared_memory_md || "") !== memory;
const handleSave = async () => { const handleSave = async () => {
try { try {
@ -103,7 +98,10 @@ export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
<Alert className="bg-muted/50 py-3 md:py-4"> <Alert className="bg-muted/50 py-3 md:py-4">
<Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" /> <Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
<AlertDescription className="text-xs md:text-sm"> <AlertDescription className="text-xs md:text-sm">
<p>SurfSense uses this shared memory to provide team-wide context across all conversations in this search space. Supports <span className="font-medium">Markdown</span> formatting.</p> <p>
SurfSense uses this shared memory to provide team-wide context across all conversations
in this search space. Supports <span className="font-medium">Markdown</span> formatting.
</p>
</AlertDescription> </AlertDescription>
</Alert> </Alert>
@ -134,7 +132,7 @@ export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) {
variant="destructive" variant="destructive"
size="sm" size="sm"
onClick={handleClear} onClick={handleClear}
disabled={saving || !(searchSpace?.shared_memory_md)} disabled={saving || !searchSpace?.shared_memory_md}
> >
Clear All Clear All
</Button> </Button>

View file

@ -52,7 +52,10 @@ const DesktopContent = dynamic(
{ ssr: false } { ssr: false }
); );
const MemoryContent = dynamic( const MemoryContent = dynamic(
() => import("@/app/dashboard/[search_space_id]/user-settings/components/MemoryContent").then(m => ({ default: m.MemoryContent })), () =>
import("@/app/dashboard/[search_space_id]/user-settings/components/MemoryContent").then(
(m) => ({ default: m.MemoryContent })
),
{ ssr: false } { ssr: false }
); );

View file

@ -54,9 +54,7 @@ export const UpdateMemoryToolUI = ({
</div> </div>
<div className="flex-1"> <div className="flex-1">
<span className="text-sm text-destructive">Failed to update memory</span> <span className="text-sm text-destructive">Failed to update memory</span>
{result?.message && ( {result?.message && <p className="mt-1 text-xs text-destructive/70">{result.message}</p>}
<p className="mt-1 text-xs text-destructive/70">{result.message}</p>
)}
</div> </div>
</div> </div>
); );