From 8b704b2fef539d70ae55e5ba43551ec1cc94793d Mon Sep 17 00:00:00 2001
From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com>
Date: Thu, 4 Jun 2026 14:15:48 +0530
Subject: [PATCH] feat(chat): Introduce centralized thread metadata management
and update chat visibility handling with new hooks for thread mutations
---
.../chat/test_thread_visibility.py | 279 ++++++++++++++++++
.../new-chat/[[...chat_id]]/page.tsx | 36 ++-
.../atoms/chat/current-thread.atom.ts | 43 +++
.../components/assistant-ui/thread.tsx | 2 +-
.../layout/providers/LayoutDataProvider.tsx | 33 +--
.../components/layout/ui/header/Header.tsx | 12 +-
.../layout/ui/sidebar/AllChatsSidebar.tsx | 95 +++---
.../components/new-chat/chat-share-button.tsx | 30 +-
surfsense_web/hooks/use-thread-mutations.ts | 158 ++++++++++
surfsense_web/lib/chat/thread-cache.ts | 249 ++++++++++++++++
10 files changed, 832 insertions(+), 105 deletions(-)
create mode 100644 surfsense_backend/tests/integration/chat/test_thread_visibility.py
create mode 100644 surfsense_web/hooks/use-thread-mutations.ts
create mode 100644 surfsense_web/lib/chat/thread-cache.ts
diff --git a/surfsense_backend/tests/integration/chat/test_thread_visibility.py b/surfsense_backend/tests/integration/chat/test_thread_visibility.py
new file mode 100644
index 000000000..464d389db
--- /dev/null
+++ b/surfsense_backend/tests/integration/chat/test_thread_visibility.py
@@ -0,0 +1,279 @@
+"""Integration tests for new-chat thread visibility invariants.
+
+These tests exercise the route handlers directly with real DB-backed
+users, memberships, and permissions. The important contract is that a
+thread shared with a search space stays shared across normal metadata
+updates until the creator explicitly makes it private again.
+"""
+
+from __future__ import annotations
+
+import uuid
+
+import pytest
+import pytest_asyncio
+from fastapi import HTTPException
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db import (
+ ChatVisibility,
+ SearchSpace,
+ SearchSpaceMembership,
+ SearchSpaceRole,
+ User,
+)
+from app.routes import new_chat_routes
+from app.schemas.new_chat import (
+ NewChatThreadCreate,
+ NewChatThreadUpdate,
+ NewChatThreadVisibilityUpdate,
+)
+
+pytestmark = pytest.mark.integration
+
+
+@pytest_asyncio.fixture
+async def db_member(db_session: AsyncSession, db_search_space: SearchSpace) -> User:
+ member = User(
+ id=uuid.uuid4(),
+ email="member@surfsense.net",
+ hashed_password="hashed",
+ is_active=True,
+ is_superuser=False,
+ is_verified=True,
+ )
+ db_session.add(member)
+ await db_session.flush()
+
+ role = (
+ (
+ await db_session.execute(
+ select(SearchSpaceRole).where(
+ SearchSpaceRole.search_space_id == db_search_space.id,
+ SearchSpaceRole.name == "Editor",
+ )
+ )
+ )
+ .scalars()
+ .one()
+ )
+ db_session.add(
+ SearchSpaceMembership(
+ user_id=member.id,
+ search_space_id=db_search_space.id,
+ role_id=role.id,
+ is_owner=False,
+ )
+ )
+ await db_session.flush()
+ return member
+
+
+async def _create_thread(
+ db_session: AsyncSession,
+ db_user: User,
+ db_search_space: SearchSpace,
+ *,
+ title: str = "Visibility Invariant Chat",
+):
+ return await new_chat_routes.create_thread(
+ NewChatThreadCreate(
+ title=title,
+ archived=False,
+ search_space_id=db_search_space.id,
+ visibility=ChatVisibility.PRIVATE,
+ ),
+ session=db_session,
+ user=db_user,
+ )
+
+
+def _active_thread_ids(response) -> set[int]:
+ return {thread.id for thread in response.threads}
+
+
+def _search_thread_ids(response) -> set[int]:
+ return {thread.id for thread in response}
+
+
+async def test_private_thread_is_hidden_from_other_search_space_member(
+ db_session: AsyncSession,
+ db_user: User,
+ db_member: User,
+ db_search_space: SearchSpace,
+):
+ thread = await _create_thread(db_session, db_user, db_search_space)
+
+ member_threads = await new_chat_routes.list_threads(
+ search_space_id=db_search_space.id,
+ session=db_session,
+ user=db_member,
+ )
+ member_search = await new_chat_routes.search_threads(
+ search_space_id=db_search_space.id,
+ title="Visibility",
+ session=db_session,
+ user=db_member,
+ )
+
+ assert thread.id not in _active_thread_ids(member_threads)
+ assert thread.id not in _search_thread_ids(member_search)
+ with pytest.raises(HTTPException) as exc_info:
+ await new_chat_routes.get_thread_full(
+ thread_id=thread.id,
+ session=db_session,
+ user=db_member,
+ )
+ assert exc_info.value.status_code == 403
+
+
+async def test_creator_can_share_thread_and_member_can_list_search_read_it(
+ db_session: AsyncSession,
+ db_user: User,
+ db_member: User,
+ db_search_space: SearchSpace,
+):
+ thread = await _create_thread(db_session, db_user, db_search_space)
+
+ updated = await new_chat_routes.update_thread_visibility(
+ thread_id=thread.id,
+ visibility_update=NewChatThreadVisibilityUpdate(
+ visibility=ChatVisibility.SEARCH_SPACE,
+ ),
+ session=db_session,
+ user=db_user,
+ )
+
+ member_threads = await new_chat_routes.list_threads(
+ search_space_id=db_search_space.id,
+ session=db_session,
+ user=db_member,
+ )
+ member_search = await new_chat_routes.search_threads(
+ search_space_id=db_search_space.id,
+ title="Visibility",
+ session=db_session,
+ user=db_member,
+ )
+ full_thread = await new_chat_routes.get_thread_full(
+ thread_id=thread.id,
+ session=db_session,
+ user=db_member,
+ )
+
+ assert updated.visibility == ChatVisibility.SEARCH_SPACE
+ assert thread.id in _active_thread_ids(member_threads)
+ assert thread.id in _search_thread_ids(member_search)
+ assert full_thread["id"] == thread.id
+ assert full_thread["visibility"] == ChatVisibility.SEARCH_SPACE
+
+
+async def test_rename_and_archive_do_not_reset_shared_visibility(
+ db_session: AsyncSession,
+ db_user: User,
+ db_search_space: SearchSpace,
+):
+ thread = await _create_thread(db_session, db_user, db_search_space)
+ await new_chat_routes.update_thread_visibility(
+ thread_id=thread.id,
+ visibility_update=NewChatThreadVisibilityUpdate(
+ visibility=ChatVisibility.SEARCH_SPACE,
+ ),
+ session=db_session,
+ user=db_user,
+ )
+
+ renamed = await new_chat_routes.update_thread(
+ thread_id=thread.id,
+ thread_update=NewChatThreadUpdate(title="Renamed Shared Chat"),
+ session=db_session,
+ user=db_user,
+ )
+ archived = await new_chat_routes.update_thread(
+ thread_id=thread.id,
+ thread_update=NewChatThreadUpdate(archived=True),
+ session=db_session,
+ user=db_user,
+ )
+
+ assert renamed.visibility == ChatVisibility.SEARCH_SPACE
+ assert archived.visibility == ChatVisibility.SEARCH_SPACE
+ assert archived.archived is True
+
+
+async def test_non_creator_cannot_change_shared_thread_back_to_private(
+ db_session: AsyncSession,
+ db_user: User,
+ db_member: User,
+ db_search_space: SearchSpace,
+):
+ thread = await _create_thread(db_session, db_user, db_search_space)
+ await new_chat_routes.update_thread_visibility(
+ thread_id=thread.id,
+ visibility_update=NewChatThreadVisibilityUpdate(
+ visibility=ChatVisibility.SEARCH_SPACE,
+ ),
+ session=db_session,
+ user=db_user,
+ )
+
+ with pytest.raises(HTTPException) as exc_info:
+ await new_chat_routes.update_thread_visibility(
+ thread_id=thread.id,
+ visibility_update=NewChatThreadVisibilityUpdate(
+ visibility=ChatVisibility.PRIVATE,
+ ),
+ session=db_session,
+ user=db_member,
+ )
+
+ assert exc_info.value.status_code == 403
+
+
+async def test_creator_can_make_shared_thread_private_again(
+ db_session: AsyncSession,
+ db_user: User,
+ db_member: User,
+ db_search_space: SearchSpace,
+):
+ thread = await _create_thread(db_session, db_user, db_search_space)
+ await new_chat_routes.update_thread_visibility(
+ thread_id=thread.id,
+ visibility_update=NewChatThreadVisibilityUpdate(
+ visibility=ChatVisibility.SEARCH_SPACE,
+ ),
+ session=db_session,
+ user=db_user,
+ )
+
+ private_again = await new_chat_routes.update_thread_visibility(
+ thread_id=thread.id,
+ visibility_update=NewChatThreadVisibilityUpdate(
+ visibility=ChatVisibility.PRIVATE,
+ ),
+ session=db_session,
+ user=db_user,
+ )
+ member_threads = await new_chat_routes.list_threads(
+ search_space_id=db_search_space.id,
+ session=db_session,
+ user=db_member,
+ )
+ member_search = await new_chat_routes.search_threads(
+ search_space_id=db_search_space.id,
+ title="Visibility",
+ session=db_session,
+ user=db_member,
+ )
+
+ assert private_again.visibility == ChatVisibility.PRIVATE
+ assert thread.id not in _active_thread_ids(member_threads)
+ assert thread.id not in _search_thread_ids(member_search)
+ with pytest.raises(HTTPException) as exc_info:
+ await new_chat_routes.get_thread_full(
+ thread_id=thread.id,
+ session=db_session,
+ user=db_member,
+ )
+ assert exc_info.value.status_code == 403
diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
index f8ca9bbc2..399cbdf99 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx
@@ -18,6 +18,7 @@ import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms";
import {
clearTargetCommentIdAtom,
currentThreadAtom,
+ setCurrentThreadMetadataAtom,
setTargetCommentIdAtom,
} from "@/atoms/chat/current-thread.atom";
import {
@@ -375,7 +376,8 @@ export default function NewChatPage() {
const mentionedDocuments = useAtomValue(mentionedDocumentsAtom);
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom);
- const setCurrentThreadState = useSetAtom(currentThreadAtom);
+ const currentThreadState = useAtomValue(currentThreadAtom);
+ const setCurrentThreadMetadata = useSetAtom(setCurrentThreadMetadataAtom);
const setPremiumAlertForThread = useSetAtom(setPremiumAlertForThreadAtom);
const setTargetCommentId = useSetAtom(setTargetCommentIdAtom);
const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom);
@@ -772,13 +774,31 @@ export default function NewChatPage() {
// Sync current thread state to atom
useEffect(() => {
- setCurrentThreadState((prev) => ({
- ...prev,
- id: currentThread?.id ?? null,
- visibility: currentThread?.visibility ?? null,
- hasComments: currentThread?.has_comments ?? false,
- }));
- }, [currentThread, setCurrentThreadState]);
+ if (!currentThread) {
+ setCurrentThreadMetadata({
+ id: null,
+ visibility: null,
+ hasComments: false,
+ });
+ return;
+ }
+
+ const visibility =
+ currentThreadState.id === currentThread.id && currentThreadState.visibility !== null
+ ? currentThreadState.visibility
+ : currentThread.visibility;
+
+ setCurrentThreadMetadata({
+ id: currentThread.id,
+ visibility,
+ hasComments: currentThread.has_comments ?? false,
+ });
+ }, [
+ currentThread,
+ currentThreadState.id,
+ currentThreadState.visibility,
+ setCurrentThreadMetadata,
+ ]);
// Cleanup on unmount - abort any in-flight requests
useEffect(() => {
diff --git a/surfsense_web/atoms/chat/current-thread.atom.ts b/surfsense_web/atoms/chat/current-thread.atom.ts
index 131c98309..98a554af4 100644
--- a/surfsense_web/atoms/chat/current-thread.atom.ts
+++ b/surfsense_web/atoms/chat/current-thread.atom.ts
@@ -8,6 +8,18 @@ interface CurrentThreadState {
hasComments: boolean;
}
+interface CurrentThreadMetadataPatch {
+ id: number | null;
+ visibility?: ChatVisibility | null;
+ hasComments?: boolean;
+}
+
+interface CurrentThreadMetadataUpdate {
+ id: number;
+ visibility?: ChatVisibility | null;
+ hasComments?: boolean;
+}
+
const initialState: CurrentThreadState = {
id: null,
visibility: null,
@@ -24,6 +36,37 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat
set(currentThreadAtom, { ...get(currentThreadAtom), visibility: newVisibility });
});
+export const setCurrentThreadMetadataAtom = atom(
+ null,
+ (get, set, metadata: CurrentThreadMetadataPatch) => {
+ const current = get(currentThreadAtom);
+
+ set(currentThreadAtom, {
+ ...current,
+ id: metadata.id,
+ visibility: "visibility" in metadata ? (metadata.visibility ?? null) : current.visibility,
+ hasComments:
+ "hasComments" in metadata ? (metadata.hasComments ?? false) : current.hasComments,
+ });
+ }
+);
+
+export const patchCurrentThreadMetadataAtom = atom(
+ null,
+ (get, set, patch: CurrentThreadMetadataUpdate) => {
+ const current = get(currentThreadAtom);
+ if (current.id !== patch.id) {
+ return;
+ }
+
+ set(currentThreadAtom, {
+ ...current,
+ visibility: "visibility" in patch ? (patch.visibility ?? null) : current.visibility,
+ hasComments: "hasComments" in patch ? (patch.hasComments ?? false) : current.hasComments,
+ });
+ }
+);
+
export const resetCurrentThreadAtom = atom(null, (_, set) => {
set(currentThreadAtom, initialState);
set(reportPanelAtom, {
diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx
index 458bdabfb..f56503969 100644
--- a/surfsense_web/components/assistant-ui/thread.tsx
+++ b/surfsense_web/components/assistant-ui/thread.tsx
@@ -899,7 +899,7 @@ const Composer: FC = () => {
diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx
index 5fac87973..8d66566d7 100644
--- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx
+++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx
@@ -1,6 +1,6 @@
"use client";
-import { useQuery, useQueryClient } from "@tanstack/react-query";
+import { useQuery } from "@tanstack/react-query";
import { useAtom, useAtomValue, useSetAtom } from "jotai";
import { AlertTriangle, Inbox, LibraryBig, Workflow } from "lucide-react";
import { useParams, usePathname, useRouter } from "next/navigation";
@@ -44,10 +44,11 @@ import { Spinner } from "@/components/ui/spinner";
import { useAnnouncements } from "@/hooks/use-announcements";
import { useInbox } from "@/hooks/use-inbox";
import { useIsMobile } from "@/hooks/use-mobile";
+import { useArchiveThread, useDeleteThread, useRenameThread } from "@/hooks/use-thread-mutations";
import { notificationsApiService } from "@/lib/apis/notifications-api.service";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { getLoginPath, logout } from "@/lib/auth-utils";
-import { deleteThread, fetchThreads, updateThread } from "@/lib/chat/thread-persistence";
+import { fetchThreads } from "@/lib/chat/thread-persistence";
import { resetUser, trackLogout } from "@/lib/posthog/events";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import type { ChatItem, NavItem, SearchSpace } from "../types/layout.types";
@@ -77,7 +78,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
const router = useRouter();
const params = useParams();
const pathname = usePathname();
- const queryClient = useQueryClient();
const { theme, setTheme } = useTheme();
const isMobile = useIsMobile();
@@ -96,6 +96,9 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
const resetCurrentThread = useSetAtom(resetCurrentThreadAtom);
const syncChatTab = useSetAtom(syncChatTabAtom);
const removeChatTab = useSetAtom(removeChatTabAtom);
+ const { mutateAsync: archiveThread } = useArchiveThread(searchSpaceId);
+ const { mutateAsync: deleteThread } = useDeleteThread(searchSpaceId);
+ const { mutateAsync: renameThread } = useRenameThread(searchSpaceId);
// Key used to force-remount the page component (e.g. after deleting the active chat
// when the router is out of sync due to replaceState)
@@ -542,18 +545,14 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
: tSidebar("chat_unarchived") || "Chat restored";
try {
- await updateThread(chat.id, { archived: newArchivedState });
+ await archiveThread({ threadId: chat.id, archived: newArchivedState });
toast.success(successMessage);
- // Invalidate queries to refresh UI (React Query will only refetch active queries)
- queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] });
} catch (error) {
console.error("Error archiving thread:", error);
toast.error(tSidebar("error_archiving_chat") || "Failed to archive chat");
}
},
- [queryClient, searchSpaceId, tSidebar]
+ [archiveThread, tSidebar]
);
const handleSettings = useCallback(() => {
@@ -591,9 +590,8 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
if (!chatToDelete) return;
setIsDeletingChat(true);
try {
- await deleteThread(chatToDelete.id);
+ await deleteThread({ threadId: chatToDelete.id });
const fallbackTab = removeChatTab(chatToDelete.id);
- queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
if (currentChatId === chatToDelete.id) {
resetCurrentThread();
if (fallbackTab?.type === "chat" && fallbackTab.chatUrl) {
@@ -617,7 +615,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
}
}, [
chatToDelete,
- queryClient,
+ deleteThread,
searchSpaceId,
resetCurrentThread,
currentChatId,
@@ -632,11 +630,12 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
if (!chatToRename || !newChatTitle.trim()) return;
setIsRenamingChat(true);
try {
- await updateThread(chatToRename.id, { title: newChatTitle.trim() });
+ await renameThread({
+ threadId: chatToRename.id,
+ title: newChatTitle.trim(),
+ previousTitle: chatToRename.name,
+ });
toast.success(tSidebar("chat_renamed") || "Chat renamed");
- queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] });
} catch (error) {
console.error("Error renaming thread:", error);
toast.error(tSidebar("error_renaming_chat") || "Failed to rename chat");
@@ -646,7 +645,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
setChatToRename(null);
setNewChatTitle("");
}
- }, [chatToRename, newChatTitle, queryClient, searchSpaceId, tSidebar]);
+ }, [chatToRename, newChatTitle, renameThread, tSidebar]);
// Detect if we're on the chat page (needs overflow-hidden for chat's own scroll)
const isChatPage = pathname?.includes("/new-chat") ?? false;
diff --git a/surfsense_web/components/layout/ui/header/Header.tsx b/surfsense_web/components/layout/ui/header/Header.tsx
index c6ccfddc6..572f61869 100644
--- a/surfsense_web/components/layout/ui/header/Header.tsx
+++ b/surfsense_web/components/layout/ui/header/Header.tsx
@@ -8,7 +8,7 @@ import { activeTabAtom } from "@/atoms/tabs/tabs.atom";
import { ActionLogButton } from "@/components/agent-action-log/action-log-button";
import { ChatHeader } from "@/components/new-chat/chat-header";
import { ChatShareButton } from "@/components/new-chat/chat-share-button";
-import type { ChatVisibility, ThreadRecord } from "@/lib/chat/thread-persistence";
+import type { ThreadRecord } from "@/lib/chat/thread-persistence";
interface HeaderProps {
mobileMenuTrigger?: React.ReactNode;
@@ -38,12 +38,12 @@ export function Header({ mobileMenuTrigger }: HeaderProps) {
}
const threadForButton: ThreadRecord | null =
- hasThread && currentThreadState.id !== null
+ hasThread && currentThreadState.id !== null && searchSpaceId
? {
id: currentThreadState.id,
visibility: currentThreadState.visibility ?? "PRIVATE",
created_by_id: null,
- search_space_id: 0,
+ search_space_id: Number(searchSpaceId),
title: "",
archived: false,
created_at: "",
@@ -51,8 +51,6 @@ export function Header({ mobileMenuTrigger }: HeaderProps) {
}
: null;
- const handleVisibilityChange = (_visibility: ChatVisibility) => {};
-
return (
{/* Left side - Mobile menu trigger + Model selector */}
@@ -66,9 +64,7 @@ export function Header({ mobileMenuTrigger }: HeaderProps) {
{/* Right side - Actions */}
{hasThread &&
}
- {hasThread && (
-
- )}
+ {hasThread &&
}
);
diff --git a/surfsense_web/components/layout/ui/sidebar/AllChatsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/AllChatsSidebar.tsx
index 7149f5f00..3099d40ea 100644
--- a/surfsense_web/components/layout/ui/sidebar/AllChatsSidebar.tsx
+++ b/surfsense_web/components/layout/ui/sidebar/AllChatsSidebar.tsx
@@ -43,12 +43,8 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip
import { useDebouncedValue } from "@/hooks/use-debounced-value";
import { useLongPress } from "@/hooks/use-long-press";
import { useIsMobile } from "@/hooks/use-mobile";
-import {
- deleteThread,
- fetchThreads,
- searchThreads,
- updateThread,
-} from "@/lib/chat/thread-persistence";
+import { useArchiveThread, useDeleteThread, useRenameThread } from "@/hooks/use-thread-mutations";
+import { fetchThreads, searchThreads } from "@/lib/chat/thread-persistence";
import { formatThreadTimestamp } from "@/lib/format-date";
import { cn } from "@/lib/utils";
import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel";
@@ -74,6 +70,9 @@ export function AllChatsSidebarContent({
const queryClient = useQueryClient();
const isMobile = useIsMobile();
const removeChatTab = useSetAtom(removeChatTabAtom);
+ const { mutateAsync: deleteThread } = useDeleteThread(searchSpaceId);
+ const { mutateAsync: archiveThread } = useArchiveThread(searchSpaceId);
+ const { mutateAsync: renameThread } = useRenameThread(searchSpaceId);
const currentChatId = Array.isArray(params.chat_id)
? Number(params.chat_id[0])
@@ -154,12 +153,9 @@ export function AllChatsSidebarContent({
async (threadId: number) => {
setDeletingThreadId(threadId);
try {
- await deleteThread(threadId);
+ await deleteThread({ threadId });
const fallbackTab = removeChatTab(threadId);
toast.success(t("chat_deleted") || "Chat deleted successfully");
- queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
if (currentChatId === threadId) {
onOpenChange(false);
@@ -178,22 +174,19 @@ export function AllChatsSidebarContent({
setDeletingThreadId(null);
}
},
- [queryClient, searchSpaceId, t, currentChatId, router, onOpenChange, removeChatTab]
+ [deleteThread, t, currentChatId, router, onOpenChange, removeChatTab, searchSpaceId]
);
const handleToggleArchive = useCallback(
async (threadId: number, currentlyArchived: boolean) => {
setArchivingThreadId(threadId);
try {
- await updateThread(threadId, { archived: !currentlyArchived });
+ await archiveThread({ threadId, archived: !currentlyArchived });
toast.success(
currentlyArchived
? t("chat_unarchived") || "Chat restored"
: t("chat_archived") || "Chat archived"
);
- queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
} catch (error) {
console.error("Error archiving thread:", error);
toast.error(t("error_archiving_chat") || "Failed to archive chat");
@@ -201,7 +194,7 @@ export function AllChatsSidebarContent({
setArchivingThreadId(null);
}
},
- [queryClient, searchSpaceId, t]
+ [archiveThread, t]
);
const handleStartRename = useCallback((threadId: number, title: string) => {
@@ -214,14 +207,12 @@ export function AllChatsSidebarContent({
if (!renamingThread || !newTitle.trim()) return;
setIsRenaming(true);
try {
- await updateThread(renamingThread.id, { title: newTitle.trim() });
- toast.success(t("chat_renamed") || "Chat renamed");
- queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] });
- queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
- queryClient.invalidateQueries({
- queryKey: ["threads", searchSpaceId, "detail", String(renamingThread.id)],
+ await renameThread({
+ threadId: renamingThread.id,
+ title: newTitle.trim(),
+ previousTitle: renamingThread.title,
});
+ toast.success(t("chat_renamed") || "Chat renamed");
} catch (error) {
console.error("Error renaming thread:", error);
toast.error(t("error_renaming_chat") || "Failed to rename chat");
@@ -231,7 +222,7 @@ export function AllChatsSidebarContent({
setRenamingThread(null);
setNewTitle("");
}
- }, [renamingThread, newTitle, queryClient, searchSpaceId, t]);
+ }, [renamingThread, newTitle, renameThread, t]);
const handleClearSearch = useCallback(() => {
setSearchQuery("");
@@ -448,34 +439,36 @@ export function AllChatsSidebarContent({
- {!thread.archived && (
- handleStartRename(thread.id, thread.title || "New Chat")}
- >
-
- {t("rename") || "Rename"}
-
- )}
- handleToggleArchive(thread.id, thread.archived)}
- disabled={isArchiving}
- >
- {thread.archived ? (
- <>
-
- {t("unarchive") || "Restore"}
- >
- ) : (
- <>
-
- {t("archive") || "Archive"}
- >
+ {!thread.archived && (
+
+ handleStartRename(thread.id, thread.title || "New Chat")
+ }
+ >
+
+ {t("rename") || "Rename"}
+
)}
-
- handleDeleteThread(thread.id)}>
-
- {t("delete") || "Delete"}
-
+ handleToggleArchive(thread.id, thread.archived)}
+ disabled={isArchiving}
+ >
+ {thread.archived ? (
+ <>
+
+ {t("unarchive") || "Restore"}
+ >
+ ) : (
+ <>
+
+ {t("archive") || "Archive"}
+ >
+ )}
+
+ handleDeleteThread(thread.id)}>
+
+ {t("delete") || "Delete"}
+
diff --git a/surfsense_web/components/new-chat/chat-share-button.tsx b/surfsense_web/components/new-chat/chat-share-button.tsx
index 101f73ade..f46656de8 100644
--- a/surfsense_web/components/new-chat/chat-share-button.tsx
+++ b/surfsense_web/components/new-chat/chat-share-button.tsx
@@ -1,24 +1,21 @@
"use client";
import { useQuery, useQueryClient } from "@tanstack/react-query";
-import { useAtomValue, useSetAtom } from "jotai";
+import { useAtomValue } from "jotai";
import { Earth, User, Users } from "lucide-react";
import { useRouter } from "next/navigation";
import { useCallback, useState } from "react";
import { toast } from "sonner";
-import { currentThreadAtom, setThreadVisibilityAtom } from "@/atoms/chat/current-thread.atom";
+import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
import { myAccessAtom } from "@/atoms/members/members-query.atoms";
import { createPublicChatSnapshotMutationAtom } from "@/atoms/public-chat-snapshots/public-chat-snapshots-mutation.atoms";
import { Button } from "@/components/ui/button";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
+import { useUpdateThreadVisibility } from "@/hooks/use-thread-mutations";
import { chatThreadsApiService } from "@/lib/apis/chat-threads-api.service";
-import {
- type ChatVisibility,
- type ThreadRecord,
- updateThreadVisibility,
-} from "@/lib/chat/thread-persistence";
+import type { ChatVisibility, ThreadRecord } from "@/lib/chat/thread-persistence";
import { cn } from "@/lib/utils";
interface ChatShareButtonProps {
@@ -54,7 +51,7 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS
// Use Jotai atom for visibility (single source of truth)
const currentThreadState = useAtomValue(currentThreadAtom);
- const setThreadVisibility = useSetAtom(setThreadVisibilityAtom);
+ const { mutateAsync: updateVisibility } = useUpdateThreadVisibility(thread?.search_space_id ?? 0);
// Snapshot creation mutation
const { mutateAsync: createSnapshot, isPending: isCreatingSnapshot } = useAtomValue(
@@ -90,30 +87,23 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS
return;
}
- // Update Jotai atom immediately for instant UI feedback
- setThreadVisibility(newVisibility);
-
try {
- await updateThreadVisibility(thread.id, newVisibility);
-
- // Refetch threads list to update sidebar
- await queryClient.refetchQueries({
- predicate: (query) => Array.isArray(query.queryKey) && query.queryKey[0] === "threads",
+ const updatedThread = await updateVisibility({
+ thread,
+ visibility: newVisibility,
});
- onVisibilityChange?.(newVisibility);
+ onVisibilityChange?.(updatedThread.visibility);
toast.success(
newVisibility === "SEARCH_SPACE" ? "Chat shared with search space" : "Chat is now private"
);
setOpen(false);
} catch (error) {
console.error("Failed to update visibility:", error);
- // Revert Jotai state on error
- setThreadVisibility(thread.visibility ?? "PRIVATE");
toast.error("Failed to update sharing settings");
}
},
- [thread, currentVisibility, onVisibilityChange, queryClient, setThreadVisibility]
+ [thread, currentVisibility, onVisibilityChange, updateVisibility]
);
const handleCreatePublicLink = useCallback(async () => {
diff --git a/surfsense_web/hooks/use-thread-mutations.ts b/surfsense_web/hooks/use-thread-mutations.ts
new file mode 100644
index 000000000..e3ae35e6b
--- /dev/null
+++ b/surfsense_web/hooks/use-thread-mutations.ts
@@ -0,0 +1,158 @@
+"use client";
+
+import { useMutation, useQueryClient } from "@tanstack/react-query";
+import { useAtomValue, useSetAtom } from "jotai";
+import {
+ currentThreadAtom,
+ patchCurrentThreadMetadataAtom,
+ resetCurrentThreadAtom,
+} from "@/atoms/chat/current-thread.atom";
+import {
+ moveThreadArchiveState,
+ patchThreadEverywhere,
+ removeThreadEverywhere,
+ replaceThreadEverywhere,
+} from "@/lib/chat/thread-cache";
+import {
+ type ChatVisibility,
+ deleteThread,
+ type ThreadRecord,
+ updateThread,
+ updateThreadVisibility,
+} from "@/lib/chat/thread-persistence";
+
+type SearchSpaceKey = number | string;
+
+interface VisibilityVariables {
+ thread: ThreadRecord;
+ visibility: ChatVisibility;
+}
+
+interface RenameVariables {
+ threadId: number;
+ title: string;
+ previousTitle?: string;
+}
+
+interface ArchiveVariables {
+ threadId: number;
+ archived: boolean;
+}
+
+interface DeleteVariables {
+ threadId: number;
+}
+
+interface VisibilityRollback {
+ threadId: number;
+ visibility: ChatVisibility;
+}
+
+interface RenameRollback {
+ threadId: number;
+ title?: string;
+}
+
+interface ArchiveRollback {
+ threadId: number;
+ archived: boolean;
+}
+
+export function useUpdateThreadVisibility(searchSpaceId: SearchSpaceKey) {
+ const queryClient = useQueryClient();
+ const currentThread = useAtomValue(currentThreadAtom);
+ const patchCurrentThreadMetadata = useSetAtom(patchCurrentThreadMetadataAtom);
+
+ return useMutation
({
+ mutationFn: ({ thread, visibility }) => updateThreadVisibility(thread.id, visibility),
+ onMutate: ({ thread, visibility }) => {
+ const previousVisibility = thread.visibility ?? "PRIVATE";
+
+ patchThreadEverywhere(queryClient, searchSpaceId, thread.id, { visibility });
+ if (currentThread.id === thread.id) {
+ patchCurrentThreadMetadata({ id: thread.id, visibility });
+ }
+
+ return { threadId: thread.id, visibility: previousVisibility };
+ },
+ onError: (_error, _variables, rollback) => {
+ if (!rollback) return;
+ patchThreadEverywhere(queryClient, searchSpaceId, rollback.threadId, {
+ visibility: rollback.visibility,
+ });
+ if (currentThread.id === rollback.threadId) {
+ patchCurrentThreadMetadata({
+ id: rollback.threadId,
+ visibility: rollback.visibility,
+ });
+ }
+ },
+ onSuccess: (thread) => {
+ replaceThreadEverywhere(queryClient, searchSpaceId, thread);
+ if (currentThread.id === thread.id) {
+ patchCurrentThreadMetadata({
+ id: thread.id,
+ visibility: thread.visibility,
+ ...(thread.has_comments !== undefined ? { hasComments: thread.has_comments } : {}),
+ });
+ }
+ },
+ });
+}
+
+export function useRenameThread(searchSpaceId: SearchSpaceKey) {
+ const queryClient = useQueryClient();
+
+ return useMutation({
+ mutationFn: ({ threadId, title }) => updateThread(threadId, { title }),
+ onMutate: ({ threadId, title, previousTitle }) => {
+ patchThreadEverywhere(queryClient, searchSpaceId, threadId, { title });
+ return { threadId, title: previousTitle };
+ },
+ onError: (_error, _variables, rollback) => {
+ if (!rollback || rollback.title === undefined) return;
+ patchThreadEverywhere(queryClient, searchSpaceId, rollback.threadId, {
+ title: rollback.title,
+ });
+ },
+ onSuccess: (thread) => {
+ replaceThreadEverywhere(queryClient, searchSpaceId, thread);
+ },
+ });
+}
+
+export function useArchiveThread(searchSpaceId: SearchSpaceKey) {
+ const queryClient = useQueryClient();
+
+ return useMutation({
+ mutationFn: ({ threadId, archived }) => updateThread(threadId, { archived }),
+ onMutate: ({ threadId, archived }) => {
+ moveThreadArchiveState(queryClient, searchSpaceId, threadId, archived);
+ return { threadId, archived: !archived };
+ },
+ onError: (_error, _variables, rollback) => {
+ if (!rollback) return;
+ moveThreadArchiveState(queryClient, searchSpaceId, rollback.threadId, rollback.archived);
+ },
+ onSuccess: (thread) => {
+ replaceThreadEverywhere(queryClient, searchSpaceId, thread);
+ moveThreadArchiveState(queryClient, searchSpaceId, thread.id, thread.archived);
+ },
+ });
+}
+
+export function useDeleteThread(searchSpaceId: SearchSpaceKey) {
+ const queryClient = useQueryClient();
+ const currentThread = useAtomValue(currentThreadAtom);
+ const resetCurrentThread = useSetAtom(resetCurrentThreadAtom);
+
+ return useMutation({
+ mutationFn: ({ threadId }) => deleteThread(threadId),
+ onSuccess: (_data, { threadId }) => {
+ removeThreadEverywhere(queryClient, searchSpaceId, threadId);
+ if (currentThread.id === threadId) {
+ resetCurrentThread();
+ }
+ },
+ });
+}
diff --git a/surfsense_web/lib/chat/thread-cache.ts b/surfsense_web/lib/chat/thread-cache.ts
new file mode 100644
index 000000000..789704032
--- /dev/null
+++ b/surfsense_web/lib/chat/thread-cache.ts
@@ -0,0 +1,249 @@
+import type { QueryClient, QueryKey } from "@tanstack/react-query";
+import type {
+ ThreadListItem,
+ ThreadListResponse,
+ ThreadRecord,
+} from "@/lib/chat/thread-persistence";
+
+type SearchSpaceKey = number | string;
+
+type ThreadMetadataPatch = Partial &
+ Partial & {
+ has_comments?: boolean;
+ };
+
+function isSameSearchSpace(keyValue: unknown, searchSpaceId: SearchSpaceKey): boolean {
+ return String(keyValue) === String(searchSpaceId);
+}
+
+function isThreadListResponse(value: unknown): value is ThreadListResponse {
+ return (
+ typeof value === "object" &&
+ value !== null &&
+ Array.isArray((value as ThreadListResponse).threads) &&
+ Array.isArray((value as ThreadListResponse).archived_threads)
+ );
+}
+
+function isThreadListItemArray(value: unknown): value is ThreadListItem[] {
+ return Array.isArray(value);
+}
+
+function listItemPatchFromMetadata(patch: ThreadMetadataPatch): Partial {
+ const listPatch: Partial = {};
+
+ if (patch.title !== undefined) listPatch.title = patch.title;
+ if (patch.archived !== undefined) listPatch.archived = patch.archived;
+ if (patch.visibility !== undefined) listPatch.visibility = patch.visibility;
+ if (patch.created_by_id !== undefined) listPatch.created_by_id = patch.created_by_id;
+ if (patch.created_at !== undefined) listPatch.createdAt = patch.created_at;
+ if (patch.updated_at !== undefined) listPatch.updatedAt = patch.updated_at;
+ if (patch.createdAt !== undefined) listPatch.createdAt = patch.createdAt;
+ if (patch.updatedAt !== undefined) listPatch.updatedAt = patch.updatedAt;
+
+ return listPatch;
+}
+
+function patchListItem(
+ item: ThreadListItem,
+ threadId: number,
+ patch: ThreadMetadataPatch
+): ThreadListItem {
+ if (item.id !== threadId) return item;
+ return {
+ ...item,
+ ...listItemPatchFromMetadata(patch),
+ };
+}
+
+function patchThreadListResponse(
+ response: ThreadListResponse,
+ threadId: number,
+ patch: ThreadMetadataPatch
+): ThreadListResponse {
+ return {
+ ...response,
+ threads: response.threads.map((item) => patchListItem(item, threadId, patch)),
+ archived_threads: response.archived_threads.map((item) => patchListItem(item, threadId, patch)),
+ };
+}
+
+function patchThreadListItems(
+ items: ThreadListItem[],
+ threadId: number,
+ patch: ThreadMetadataPatch
+): ThreadListItem[] {
+ return items.map((item) => patchListItem(item, threadId, patch));
+}
+
+function patchThreadRecord(
+ record: ThreadRecord,
+ threadId: number,
+ patch: ThreadMetadataPatch
+): ThreadRecord {
+ if (record.id !== threadId) return record;
+ return {
+ ...record,
+ ...patch,
+ };
+}
+
+function threadListQueryFilter(searchSpaceId: SearchSpaceKey) {
+ return {
+ predicate: ({ queryKey }: { queryKey: QueryKey }) =>
+ Array.isArray(queryKey) &&
+ queryKey[0] === "threads" &&
+ isSameSearchSpace(queryKey[1], searchSpaceId),
+ };
+}
+
+function allThreadsQueryFilter(searchSpaceId: SearchSpaceKey) {
+ return {
+ predicate: ({ queryKey }: { queryKey: QueryKey }) =>
+ Array.isArray(queryKey) &&
+ queryKey[0] === "all-threads" &&
+ isSameSearchSpace(queryKey[1], searchSpaceId),
+ };
+}
+
+function searchThreadsQueryFilter(searchSpaceId: SearchSpaceKey) {
+ return {
+ predicate: ({ queryKey }: { queryKey: QueryKey }) =>
+ Array.isArray(queryKey) &&
+ queryKey[0] === "search-threads" &&
+ isSameSearchSpace(queryKey[1], searchSpaceId),
+ };
+}
+
+function threadDetailQueryFilter(searchSpaceId: SearchSpaceKey, threadId: number) {
+ return {
+ predicate: ({ queryKey }: { queryKey: QueryKey }) =>
+ Array.isArray(queryKey) &&
+ ((queryKey[0] === "threads" &&
+ queryKey[1] === "detail" &&
+ Number(queryKey[2]) === threadId) ||
+ (queryKey[0] === "threads" &&
+ isSameSearchSpace(queryKey[1], searchSpaceId) &&
+ queryKey[2] === "detail" &&
+ Number(queryKey[3]) === threadId)),
+ };
+}
+
+function updateThreadListResponse(
+ queryClient: QueryClient,
+ filter: ReturnType,
+ threadId: number,
+ patch: ThreadMetadataPatch
+): void {
+ queryClient.setQueriesData(filter, (old) => {
+ if (!isThreadListResponse(old)) return old;
+ return patchThreadListResponse(old, threadId, patch);
+ });
+}
+
+export function patchThreadEverywhere(
+ queryClient: QueryClient,
+ searchSpaceId: SearchSpaceKey,
+ threadId: number,
+ patch: ThreadMetadataPatch
+): void {
+ updateThreadListResponse(queryClient, threadListQueryFilter(searchSpaceId), threadId, patch);
+ updateThreadListResponse(queryClient, allThreadsQueryFilter(searchSpaceId), threadId, patch);
+
+ queryClient.setQueriesData(searchThreadsQueryFilter(searchSpaceId), (old) => {
+ if (!isThreadListItemArray(old)) return old;
+ return patchThreadListItems(old, threadId, patch);
+ });
+
+ queryClient.setQueriesData(
+ threadDetailQueryFilter(searchSpaceId, threadId),
+ (old) => {
+ if (!old) return old;
+ return patchThreadRecord(old, threadId, patch);
+ }
+ );
+}
+
+export function replaceThreadEverywhere(
+ queryClient: QueryClient,
+ searchSpaceId: SearchSpaceKey,
+ thread: ThreadRecord
+): void {
+ patchThreadEverywhere(queryClient, searchSpaceId, thread.id, thread);
+}
+
+export function removeThreadEverywhere(
+ queryClient: QueryClient,
+ searchSpaceId: SearchSpaceKey,
+ threadId: number
+): void {
+ const removeFromListResponse = (old: ThreadListResponse | undefined) => {
+ if (!isThreadListResponse(old)) return old;
+ return {
+ ...old,
+ threads: old.threads.filter((thread) => thread.id !== threadId),
+ archived_threads: old.archived_threads.filter((thread) => thread.id !== threadId),
+ };
+ };
+
+ queryClient.setQueriesData(
+ threadListQueryFilter(searchSpaceId),
+ removeFromListResponse
+ );
+ queryClient.setQueriesData(
+ allThreadsQueryFilter(searchSpaceId),
+ removeFromListResponse
+ );
+ queryClient.setQueriesData(searchThreadsQueryFilter(searchSpaceId), (old) => {
+ if (!isThreadListItemArray(old)) return old;
+ return old.filter((thread) => thread.id !== threadId);
+ });
+ queryClient.removeQueries(threadDetailQueryFilter(searchSpaceId, threadId));
+}
+
+export function moveThreadArchiveState(
+ queryClient: QueryClient,
+ searchSpaceId: SearchSpaceKey,
+ threadId: number,
+ archived: boolean
+): void {
+ const moveInListResponse = (old: ThreadListResponse | undefined) => {
+ if (!isThreadListResponse(old)) return old;
+
+ const activeWithoutThread = old.threads.filter((thread) => thread.id !== threadId);
+ const archivedWithoutThread = old.archived_threads.filter((thread) => thread.id !== threadId);
+ const existing =
+ old.threads.find((thread) => thread.id === threadId) ??
+ old.archived_threads.find((thread) => thread.id === threadId);
+
+ if (!existing) return old;
+
+ const updated = { ...existing, archived };
+
+ return {
+ ...old,
+ threads: archived ? activeWithoutThread : [updated, ...activeWithoutThread],
+ archived_threads: archived ? [updated, ...archivedWithoutThread] : archivedWithoutThread,
+ };
+ };
+
+ queryClient.setQueriesData(
+ threadListQueryFilter(searchSpaceId),
+ moveInListResponse
+ );
+ queryClient.setQueriesData(
+ allThreadsQueryFilter(searchSpaceId),
+ moveInListResponse
+ );
+ queryClient.setQueriesData(searchThreadsQueryFilter(searchSpaceId), (old) => {
+ if (!isThreadListItemArray(old)) return old;
+ return old.map((thread) => (thread.id === threadId ? { ...thread, archived } : thread));
+ });
+ queryClient.setQueriesData(
+ threadDetailQueryFilter(searchSpaceId, threadId),
+ (old) => {
+ if (!old || old.id !== threadId) return old;
+ return { ...old, archived };
+ }
+ );
+}