diff --git a/surfsense_backend/tests/integration/chat/test_thread_visibility.py b/surfsense_backend/tests/integration/chat/test_thread_visibility.py new file mode 100644 index 000000000..464d389db --- /dev/null +++ b/surfsense_backend/tests/integration/chat/test_thread_visibility.py @@ -0,0 +1,279 @@ +"""Integration tests for new-chat thread visibility invariants. + +These tests exercise the route handlers directly with real DB-backed +users, memberships, and permissions. The important contract is that a +thread shared with a search space stays shared across normal metadata +updates until the creator explicitly makes it private again. +""" + +from __future__ import annotations + +import uuid + +import pytest +import pytest_asyncio +from fastapi import HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + ChatVisibility, + SearchSpace, + SearchSpaceMembership, + SearchSpaceRole, + User, +) +from app.routes import new_chat_routes +from app.schemas.new_chat import ( + NewChatThreadCreate, + NewChatThreadUpdate, + NewChatThreadVisibilityUpdate, +) + +pytestmark = pytest.mark.integration + + +@pytest_asyncio.fixture +async def db_member(db_session: AsyncSession, db_search_space: SearchSpace) -> User: + member = User( + id=uuid.uuid4(), + email="member@surfsense.net", + hashed_password="hashed", + is_active=True, + is_superuser=False, + is_verified=True, + ) + db_session.add(member) + await db_session.flush() + + role = ( + ( + await db_session.execute( + select(SearchSpaceRole).where( + SearchSpaceRole.search_space_id == db_search_space.id, + SearchSpaceRole.name == "Editor", + ) + ) + ) + .scalars() + .one() + ) + db_session.add( + SearchSpaceMembership( + user_id=member.id, + search_space_id=db_search_space.id, + role_id=role.id, + is_owner=False, + ) + ) + await db_session.flush() + return member + + +async def _create_thread( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, + *, + title: str = "Visibility Invariant Chat", +): + return await new_chat_routes.create_thread( + NewChatThreadCreate( + title=title, + archived=False, + search_space_id=db_search_space.id, + visibility=ChatVisibility.PRIVATE, + ), + session=db_session, + user=db_user, + ) + + +def _active_thread_ids(response) -> set[int]: + return {thread.id for thread in response.threads} + + +def _search_thread_ids(response) -> set[int]: + return {thread.id for thread in response} + + +async def test_private_thread_is_hidden_from_other_search_space_member( + db_session: AsyncSession, + db_user: User, + db_member: User, + db_search_space: SearchSpace, +): + thread = await _create_thread(db_session, db_user, db_search_space) + + member_threads = await new_chat_routes.list_threads( + search_space_id=db_search_space.id, + session=db_session, + user=db_member, + ) + member_search = await new_chat_routes.search_threads( + search_space_id=db_search_space.id, + title="Visibility", + session=db_session, + user=db_member, + ) + + assert thread.id not in _active_thread_ids(member_threads) + assert thread.id not in _search_thread_ids(member_search) + with pytest.raises(HTTPException) as exc_info: + await new_chat_routes.get_thread_full( + thread_id=thread.id, + session=db_session, + user=db_member, + ) + assert exc_info.value.status_code == 403 + + +async def test_creator_can_share_thread_and_member_can_list_search_read_it( + db_session: AsyncSession, + db_user: User, + db_member: User, + db_search_space: SearchSpace, +): + thread = await _create_thread(db_session, db_user, db_search_space) + + updated = await new_chat_routes.update_thread_visibility( + thread_id=thread.id, + visibility_update=NewChatThreadVisibilityUpdate( + visibility=ChatVisibility.SEARCH_SPACE, + ), + session=db_session, + user=db_user, + ) + + member_threads = await new_chat_routes.list_threads( + search_space_id=db_search_space.id, + session=db_session, + user=db_member, + ) + member_search = await new_chat_routes.search_threads( + search_space_id=db_search_space.id, + title="Visibility", + session=db_session, + user=db_member, + ) + full_thread = await new_chat_routes.get_thread_full( + thread_id=thread.id, + session=db_session, + user=db_member, + ) + + assert updated.visibility == ChatVisibility.SEARCH_SPACE + assert thread.id in _active_thread_ids(member_threads) + assert thread.id in _search_thread_ids(member_search) + assert full_thread["id"] == thread.id + assert full_thread["visibility"] == ChatVisibility.SEARCH_SPACE + + +async def test_rename_and_archive_do_not_reset_shared_visibility( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + thread = await _create_thread(db_session, db_user, db_search_space) + await new_chat_routes.update_thread_visibility( + thread_id=thread.id, + visibility_update=NewChatThreadVisibilityUpdate( + visibility=ChatVisibility.SEARCH_SPACE, + ), + session=db_session, + user=db_user, + ) + + renamed = await new_chat_routes.update_thread( + thread_id=thread.id, + thread_update=NewChatThreadUpdate(title="Renamed Shared Chat"), + session=db_session, + user=db_user, + ) + archived = await new_chat_routes.update_thread( + thread_id=thread.id, + thread_update=NewChatThreadUpdate(archived=True), + session=db_session, + user=db_user, + ) + + assert renamed.visibility == ChatVisibility.SEARCH_SPACE + assert archived.visibility == ChatVisibility.SEARCH_SPACE + assert archived.archived is True + + +async def test_non_creator_cannot_change_shared_thread_back_to_private( + db_session: AsyncSession, + db_user: User, + db_member: User, + db_search_space: SearchSpace, +): + thread = await _create_thread(db_session, db_user, db_search_space) + await new_chat_routes.update_thread_visibility( + thread_id=thread.id, + visibility_update=NewChatThreadVisibilityUpdate( + visibility=ChatVisibility.SEARCH_SPACE, + ), + session=db_session, + user=db_user, + ) + + with pytest.raises(HTTPException) as exc_info: + await new_chat_routes.update_thread_visibility( + thread_id=thread.id, + visibility_update=NewChatThreadVisibilityUpdate( + visibility=ChatVisibility.PRIVATE, + ), + session=db_session, + user=db_member, + ) + + assert exc_info.value.status_code == 403 + + +async def test_creator_can_make_shared_thread_private_again( + db_session: AsyncSession, + db_user: User, + db_member: User, + db_search_space: SearchSpace, +): + thread = await _create_thread(db_session, db_user, db_search_space) + await new_chat_routes.update_thread_visibility( + thread_id=thread.id, + visibility_update=NewChatThreadVisibilityUpdate( + visibility=ChatVisibility.SEARCH_SPACE, + ), + session=db_session, + user=db_user, + ) + + private_again = await new_chat_routes.update_thread_visibility( + thread_id=thread.id, + visibility_update=NewChatThreadVisibilityUpdate( + visibility=ChatVisibility.PRIVATE, + ), + session=db_session, + user=db_user, + ) + member_threads = await new_chat_routes.list_threads( + search_space_id=db_search_space.id, + session=db_session, + user=db_member, + ) + member_search = await new_chat_routes.search_threads( + search_space_id=db_search_space.id, + title="Visibility", + session=db_session, + user=db_member, + ) + + assert private_again.visibility == ChatVisibility.PRIVATE + assert thread.id not in _active_thread_ids(member_threads) + assert thread.id not in _search_thread_ids(member_search) + with pytest.raises(HTTPException) as exc_info: + await new_chat_routes.get_thread_full( + thread_id=thread.id, + session=db_session, + user=db_member, + ) + assert exc_info.value.status_code == 403 diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index f8ca9bbc2..399cbdf99 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -18,6 +18,7 @@ import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; import { clearTargetCommentIdAtom, currentThreadAtom, + setCurrentThreadMetadataAtom, setTargetCommentIdAtom, } from "@/atoms/chat/current-thread.atom"; import { @@ -375,7 +376,8 @@ export default function NewChatPage() { const mentionedDocuments = useAtomValue(mentionedDocumentsAtom); const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom); - const setCurrentThreadState = useSetAtom(currentThreadAtom); + const currentThreadState = useAtomValue(currentThreadAtom); + const setCurrentThreadMetadata = useSetAtom(setCurrentThreadMetadataAtom); const setPremiumAlertForThread = useSetAtom(setPremiumAlertForThreadAtom); const setTargetCommentId = useSetAtom(setTargetCommentIdAtom); const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom); @@ -772,13 +774,31 @@ export default function NewChatPage() { // Sync current thread state to atom useEffect(() => { - setCurrentThreadState((prev) => ({ - ...prev, - id: currentThread?.id ?? null, - visibility: currentThread?.visibility ?? null, - hasComments: currentThread?.has_comments ?? false, - })); - }, [currentThread, setCurrentThreadState]); + if (!currentThread) { + setCurrentThreadMetadata({ + id: null, + visibility: null, + hasComments: false, + }); + return; + } + + const visibility = + currentThreadState.id === currentThread.id && currentThreadState.visibility !== null + ? currentThreadState.visibility + : currentThread.visibility; + + setCurrentThreadMetadata({ + id: currentThread.id, + visibility, + hasComments: currentThread.has_comments ?? false, + }); + }, [ + currentThread, + currentThreadState.id, + currentThreadState.visibility, + setCurrentThreadMetadata, + ]); // Cleanup on unmount - abort any in-flight requests useEffect(() => { diff --git a/surfsense_web/atoms/chat/current-thread.atom.ts b/surfsense_web/atoms/chat/current-thread.atom.ts index 131c98309..98a554af4 100644 --- a/surfsense_web/atoms/chat/current-thread.atom.ts +++ b/surfsense_web/atoms/chat/current-thread.atom.ts @@ -8,6 +8,18 @@ interface CurrentThreadState { hasComments: boolean; } +interface CurrentThreadMetadataPatch { + id: number | null; + visibility?: ChatVisibility | null; + hasComments?: boolean; +} + +interface CurrentThreadMetadataUpdate { + id: number; + visibility?: ChatVisibility | null; + hasComments?: boolean; +} + const initialState: CurrentThreadState = { id: null, visibility: null, @@ -24,6 +36,37 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat set(currentThreadAtom, { ...get(currentThreadAtom), visibility: newVisibility }); }); +export const setCurrentThreadMetadataAtom = atom( + null, + (get, set, metadata: CurrentThreadMetadataPatch) => { + const current = get(currentThreadAtom); + + set(currentThreadAtom, { + ...current, + id: metadata.id, + visibility: "visibility" in metadata ? (metadata.visibility ?? null) : current.visibility, + hasComments: + "hasComments" in metadata ? (metadata.hasComments ?? false) : current.hasComments, + }); + } +); + +export const patchCurrentThreadMetadataAtom = atom( + null, + (get, set, patch: CurrentThreadMetadataUpdate) => { + const current = get(currentThreadAtom); + if (current.id !== patch.id) { + return; + } + + set(currentThreadAtom, { + ...current, + visibility: "visibility" in patch ? (patch.visibility ?? null) : current.visibility, + hasComments: "hasComments" in patch ? (patch.hasComments ?? false) : current.hasComments, + }); + } +); + export const resetCurrentThreadAtom = atom(null, (_, set) => { set(currentThreadAtom, initialState); set(reportPanelAtom, { diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 458bdabfb..f56503969 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -899,7 +899,7 @@ const Composer: FC = () => {
diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 5fac87973..8d66566d7 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -1,6 +1,6 @@ "use client"; -import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useQuery } from "@tanstack/react-query"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; import { AlertTriangle, Inbox, LibraryBig, Workflow } from "lucide-react"; import { useParams, usePathname, useRouter } from "next/navigation"; @@ -44,10 +44,11 @@ import { Spinner } from "@/components/ui/spinner"; import { useAnnouncements } from "@/hooks/use-announcements"; import { useInbox } from "@/hooks/use-inbox"; import { useIsMobile } from "@/hooks/use-mobile"; +import { useArchiveThread, useDeleteThread, useRenameThread } from "@/hooks/use-thread-mutations"; import { notificationsApiService } from "@/lib/apis/notifications-api.service"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { getLoginPath, logout } from "@/lib/auth-utils"; -import { deleteThread, fetchThreads, updateThread } from "@/lib/chat/thread-persistence"; +import { fetchThreads } from "@/lib/chat/thread-persistence"; import { resetUser, trackLogout } from "@/lib/posthog/events"; import { cacheKeys } from "@/lib/query-client/cache-keys"; import type { ChatItem, NavItem, SearchSpace } from "../types/layout.types"; @@ -77,7 +78,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid const router = useRouter(); const params = useParams(); const pathname = usePathname(); - const queryClient = useQueryClient(); const { theme, setTheme } = useTheme(); const isMobile = useIsMobile(); @@ -96,6 +96,9 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid const resetCurrentThread = useSetAtom(resetCurrentThreadAtom); const syncChatTab = useSetAtom(syncChatTabAtom); const removeChatTab = useSetAtom(removeChatTabAtom); + const { mutateAsync: archiveThread } = useArchiveThread(searchSpaceId); + const { mutateAsync: deleteThread } = useDeleteThread(searchSpaceId); + const { mutateAsync: renameThread } = useRenameThread(searchSpaceId); // Key used to force-remount the page component (e.g. after deleting the active chat // when the router is out of sync due to replaceState) @@ -542,18 +545,14 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid : tSidebar("chat_unarchived") || "Chat restored"; try { - await updateThread(chat.id, { archived: newArchivedState }); + await archiveThread({ threadId: chat.id, archived: newArchivedState }); toast.success(successMessage); - // Invalidate queries to refresh UI (React Query will only refetch active queries) - queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] }); } catch (error) { console.error("Error archiving thread:", error); toast.error(tSidebar("error_archiving_chat") || "Failed to archive chat"); } }, - [queryClient, searchSpaceId, tSidebar] + [archiveThread, tSidebar] ); const handleSettings = useCallback(() => { @@ -591,9 +590,8 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid if (!chatToDelete) return; setIsDeletingChat(true); try { - await deleteThread(chatToDelete.id); + await deleteThread({ threadId: chatToDelete.id }); const fallbackTab = removeChatTab(chatToDelete.id); - queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] }); if (currentChatId === chatToDelete.id) { resetCurrentThread(); if (fallbackTab?.type === "chat" && fallbackTab.chatUrl) { @@ -617,7 +615,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid } }, [ chatToDelete, - queryClient, + deleteThread, searchSpaceId, resetCurrentThread, currentChatId, @@ -632,11 +630,12 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid if (!chatToRename || !newChatTitle.trim()) return; setIsRenamingChat(true); try { - await updateThread(chatToRename.id, { title: newChatTitle.trim() }); + await renameThread({ + threadId: chatToRename.id, + title: newChatTitle.trim(), + previousTitle: chatToRename.name, + }); toast.success(tSidebar("chat_renamed") || "Chat renamed"); - queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] }); } catch (error) { console.error("Error renaming thread:", error); toast.error(tSidebar("error_renaming_chat") || "Failed to rename chat"); @@ -646,7 +645,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid setChatToRename(null); setNewChatTitle(""); } - }, [chatToRename, newChatTitle, queryClient, searchSpaceId, tSidebar]); + }, [chatToRename, newChatTitle, renameThread, tSidebar]); // Detect if we're on the chat page (needs overflow-hidden for chat's own scroll) const isChatPage = pathname?.includes("/new-chat") ?? false; diff --git a/surfsense_web/components/layout/ui/header/Header.tsx b/surfsense_web/components/layout/ui/header/Header.tsx index c6ccfddc6..572f61869 100644 --- a/surfsense_web/components/layout/ui/header/Header.tsx +++ b/surfsense_web/components/layout/ui/header/Header.tsx @@ -8,7 +8,7 @@ import { activeTabAtom } from "@/atoms/tabs/tabs.atom"; import { ActionLogButton } from "@/components/agent-action-log/action-log-button"; import { ChatHeader } from "@/components/new-chat/chat-header"; import { ChatShareButton } from "@/components/new-chat/chat-share-button"; -import type { ChatVisibility, ThreadRecord } from "@/lib/chat/thread-persistence"; +import type { ThreadRecord } from "@/lib/chat/thread-persistence"; interface HeaderProps { mobileMenuTrigger?: React.ReactNode; @@ -38,12 +38,12 @@ export function Header({ mobileMenuTrigger }: HeaderProps) { } const threadForButton: ThreadRecord | null = - hasThread && currentThreadState.id !== null + hasThread && currentThreadState.id !== null && searchSpaceId ? { id: currentThreadState.id, visibility: currentThreadState.visibility ?? "PRIVATE", created_by_id: null, - search_space_id: 0, + search_space_id: Number(searchSpaceId), title: "", archived: false, created_at: "", @@ -51,8 +51,6 @@ export function Header({ mobileMenuTrigger }: HeaderProps) { } : null; - const handleVisibilityChange = (_visibility: ChatVisibility) => {}; - return (
{/* Left side - Mobile menu trigger + Model selector */} @@ -66,9 +64,7 @@ export function Header({ mobileMenuTrigger }: HeaderProps) { {/* Right side - Actions */}
{hasThread && } - {hasThread && ( - - )} + {hasThread && }
); diff --git a/surfsense_web/components/layout/ui/sidebar/AllChatsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/AllChatsSidebar.tsx index 7149f5f00..3099d40ea 100644 --- a/surfsense_web/components/layout/ui/sidebar/AllChatsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/AllChatsSidebar.tsx @@ -43,12 +43,8 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip import { useDebouncedValue } from "@/hooks/use-debounced-value"; import { useLongPress } from "@/hooks/use-long-press"; import { useIsMobile } from "@/hooks/use-mobile"; -import { - deleteThread, - fetchThreads, - searchThreads, - updateThread, -} from "@/lib/chat/thread-persistence"; +import { useArchiveThread, useDeleteThread, useRenameThread } from "@/hooks/use-thread-mutations"; +import { fetchThreads, searchThreads } from "@/lib/chat/thread-persistence"; import { formatThreadTimestamp } from "@/lib/format-date"; import { cn } from "@/lib/utils"; import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel"; @@ -74,6 +70,9 @@ export function AllChatsSidebarContent({ const queryClient = useQueryClient(); const isMobile = useIsMobile(); const removeChatTab = useSetAtom(removeChatTabAtom); + const { mutateAsync: deleteThread } = useDeleteThread(searchSpaceId); + const { mutateAsync: archiveThread } = useArchiveThread(searchSpaceId); + const { mutateAsync: renameThread } = useRenameThread(searchSpaceId); const currentChatId = Array.isArray(params.chat_id) ? Number(params.chat_id[0]) @@ -154,12 +153,9 @@ export function AllChatsSidebarContent({ async (threadId: number) => { setDeletingThreadId(threadId); try { - await deleteThread(threadId); + await deleteThread({ threadId }); const fallbackTab = removeChatTab(threadId); toast.success(t("chat_deleted") || "Chat deleted successfully"); - queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] }); if (currentChatId === threadId) { onOpenChange(false); @@ -178,22 +174,19 @@ export function AllChatsSidebarContent({ setDeletingThreadId(null); } }, - [queryClient, searchSpaceId, t, currentChatId, router, onOpenChange, removeChatTab] + [deleteThread, t, currentChatId, router, onOpenChange, removeChatTab, searchSpaceId] ); const handleToggleArchive = useCallback( async (threadId: number, currentlyArchived: boolean) => { setArchivingThreadId(threadId); try { - await updateThread(threadId, { archived: !currentlyArchived }); + await archiveThread({ threadId, archived: !currentlyArchived }); toast.success( currentlyArchived ? t("chat_unarchived") || "Chat restored" : t("chat_archived") || "Chat archived" ); - queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] }); } catch (error) { console.error("Error archiving thread:", error); toast.error(t("error_archiving_chat") || "Failed to archive chat"); @@ -201,7 +194,7 @@ export function AllChatsSidebarContent({ setArchivingThreadId(null); } }, - [queryClient, searchSpaceId, t] + [archiveThread, t] ); const handleStartRename = useCallback((threadId: number, title: string) => { @@ -214,14 +207,12 @@ export function AllChatsSidebarContent({ if (!renamingThread || !newTitle.trim()) return; setIsRenaming(true); try { - await updateThread(renamingThread.id, { title: newTitle.trim() }); - toast.success(t("chat_renamed") || "Chat renamed"); - queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] }); - queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] }); - queryClient.invalidateQueries({ - queryKey: ["threads", searchSpaceId, "detail", String(renamingThread.id)], + await renameThread({ + threadId: renamingThread.id, + title: newTitle.trim(), + previousTitle: renamingThread.title, }); + toast.success(t("chat_renamed") || "Chat renamed"); } catch (error) { console.error("Error renaming thread:", error); toast.error(t("error_renaming_chat") || "Failed to rename chat"); @@ -231,7 +222,7 @@ export function AllChatsSidebarContent({ setRenamingThread(null); setNewTitle(""); } - }, [renamingThread, newTitle, queryClient, searchSpaceId, t]); + }, [renamingThread, newTitle, renameThread, t]); const handleClearSearch = useCallback(() => { setSearchQuery(""); @@ -448,34 +439,36 @@ export function AllChatsSidebarContent({ - {!thread.archived && ( - handleStartRename(thread.id, thread.title || "New Chat")} - > - - {t("rename") || "Rename"} - - )} - handleToggleArchive(thread.id, thread.archived)} - disabled={isArchiving} - > - {thread.archived ? ( - <> - - {t("unarchive") || "Restore"} - - ) : ( - <> - - {t("archive") || "Archive"} - + {!thread.archived && ( + + handleStartRename(thread.id, thread.title || "New Chat") + } + > + + {t("rename") || "Rename"} + )} - - handleDeleteThread(thread.id)}> - - {t("delete") || "Delete"} - + handleToggleArchive(thread.id, thread.archived)} + disabled={isArchiving} + > + {thread.archived ? ( + <> + + {t("unarchive") || "Restore"} + + ) : ( + <> + + {t("archive") || "Archive"} + + )} + + handleDeleteThread(thread.id)}> + + {t("delete") || "Delete"} +
diff --git a/surfsense_web/components/new-chat/chat-share-button.tsx b/surfsense_web/components/new-chat/chat-share-button.tsx index 101f73ade..f46656de8 100644 --- a/surfsense_web/components/new-chat/chat-share-button.tsx +++ b/surfsense_web/components/new-chat/chat-share-button.tsx @@ -1,24 +1,21 @@ "use client"; import { useQuery, useQueryClient } from "@tanstack/react-query"; -import { useAtomValue, useSetAtom } from "jotai"; +import { useAtomValue } from "jotai"; import { Earth, User, Users } from "lucide-react"; import { useRouter } from "next/navigation"; import { useCallback, useState } from "react"; import { toast } from "sonner"; -import { currentThreadAtom, setThreadVisibilityAtom } from "@/atoms/chat/current-thread.atom"; +import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { myAccessAtom } from "@/atoms/members/members-query.atoms"; import { createPublicChatSnapshotMutationAtom } from "@/atoms/public-chat-snapshots/public-chat-snapshots-mutation.atoms"; import { Button } from "@/components/ui/button"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { useUpdateThreadVisibility } from "@/hooks/use-thread-mutations"; import { chatThreadsApiService } from "@/lib/apis/chat-threads-api.service"; -import { - type ChatVisibility, - type ThreadRecord, - updateThreadVisibility, -} from "@/lib/chat/thread-persistence"; +import type { ChatVisibility, ThreadRecord } from "@/lib/chat/thread-persistence"; import { cn } from "@/lib/utils"; interface ChatShareButtonProps { @@ -54,7 +51,7 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS // Use Jotai atom for visibility (single source of truth) const currentThreadState = useAtomValue(currentThreadAtom); - const setThreadVisibility = useSetAtom(setThreadVisibilityAtom); + const { mutateAsync: updateVisibility } = useUpdateThreadVisibility(thread?.search_space_id ?? 0); // Snapshot creation mutation const { mutateAsync: createSnapshot, isPending: isCreatingSnapshot } = useAtomValue( @@ -90,30 +87,23 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS return; } - // Update Jotai atom immediately for instant UI feedback - setThreadVisibility(newVisibility); - try { - await updateThreadVisibility(thread.id, newVisibility); - - // Refetch threads list to update sidebar - await queryClient.refetchQueries({ - predicate: (query) => Array.isArray(query.queryKey) && query.queryKey[0] === "threads", + const updatedThread = await updateVisibility({ + thread, + visibility: newVisibility, }); - onVisibilityChange?.(newVisibility); + onVisibilityChange?.(updatedThread.visibility); toast.success( newVisibility === "SEARCH_SPACE" ? "Chat shared with search space" : "Chat is now private" ); setOpen(false); } catch (error) { console.error("Failed to update visibility:", error); - // Revert Jotai state on error - setThreadVisibility(thread.visibility ?? "PRIVATE"); toast.error("Failed to update sharing settings"); } }, - [thread, currentVisibility, onVisibilityChange, queryClient, setThreadVisibility] + [thread, currentVisibility, onVisibilityChange, updateVisibility] ); const handleCreatePublicLink = useCallback(async () => { diff --git a/surfsense_web/hooks/use-thread-mutations.ts b/surfsense_web/hooks/use-thread-mutations.ts new file mode 100644 index 000000000..e3ae35e6b --- /dev/null +++ b/surfsense_web/hooks/use-thread-mutations.ts @@ -0,0 +1,158 @@ +"use client"; + +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { useAtomValue, useSetAtom } from "jotai"; +import { + currentThreadAtom, + patchCurrentThreadMetadataAtom, + resetCurrentThreadAtom, +} from "@/atoms/chat/current-thread.atom"; +import { + moveThreadArchiveState, + patchThreadEverywhere, + removeThreadEverywhere, + replaceThreadEverywhere, +} from "@/lib/chat/thread-cache"; +import { + type ChatVisibility, + deleteThread, + type ThreadRecord, + updateThread, + updateThreadVisibility, +} from "@/lib/chat/thread-persistence"; + +type SearchSpaceKey = number | string; + +interface VisibilityVariables { + thread: ThreadRecord; + visibility: ChatVisibility; +} + +interface RenameVariables { + threadId: number; + title: string; + previousTitle?: string; +} + +interface ArchiveVariables { + threadId: number; + archived: boolean; +} + +interface DeleteVariables { + threadId: number; +} + +interface VisibilityRollback { + threadId: number; + visibility: ChatVisibility; +} + +interface RenameRollback { + threadId: number; + title?: string; +} + +interface ArchiveRollback { + threadId: number; + archived: boolean; +} + +export function useUpdateThreadVisibility(searchSpaceId: SearchSpaceKey) { + const queryClient = useQueryClient(); + const currentThread = useAtomValue(currentThreadAtom); + const patchCurrentThreadMetadata = useSetAtom(patchCurrentThreadMetadataAtom); + + return useMutation({ + mutationFn: ({ thread, visibility }) => updateThreadVisibility(thread.id, visibility), + onMutate: ({ thread, visibility }) => { + const previousVisibility = thread.visibility ?? "PRIVATE"; + + patchThreadEverywhere(queryClient, searchSpaceId, thread.id, { visibility }); + if (currentThread.id === thread.id) { + patchCurrentThreadMetadata({ id: thread.id, visibility }); + } + + return { threadId: thread.id, visibility: previousVisibility }; + }, + onError: (_error, _variables, rollback) => { + if (!rollback) return; + patchThreadEverywhere(queryClient, searchSpaceId, rollback.threadId, { + visibility: rollback.visibility, + }); + if (currentThread.id === rollback.threadId) { + patchCurrentThreadMetadata({ + id: rollback.threadId, + visibility: rollback.visibility, + }); + } + }, + onSuccess: (thread) => { + replaceThreadEverywhere(queryClient, searchSpaceId, thread); + if (currentThread.id === thread.id) { + patchCurrentThreadMetadata({ + id: thread.id, + visibility: thread.visibility, + ...(thread.has_comments !== undefined ? { hasComments: thread.has_comments } : {}), + }); + } + }, + }); +} + +export function useRenameThread(searchSpaceId: SearchSpaceKey) { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: ({ threadId, title }) => updateThread(threadId, { title }), + onMutate: ({ threadId, title, previousTitle }) => { + patchThreadEverywhere(queryClient, searchSpaceId, threadId, { title }); + return { threadId, title: previousTitle }; + }, + onError: (_error, _variables, rollback) => { + if (!rollback || rollback.title === undefined) return; + patchThreadEverywhere(queryClient, searchSpaceId, rollback.threadId, { + title: rollback.title, + }); + }, + onSuccess: (thread) => { + replaceThreadEverywhere(queryClient, searchSpaceId, thread); + }, + }); +} + +export function useArchiveThread(searchSpaceId: SearchSpaceKey) { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: ({ threadId, archived }) => updateThread(threadId, { archived }), + onMutate: ({ threadId, archived }) => { + moveThreadArchiveState(queryClient, searchSpaceId, threadId, archived); + return { threadId, archived: !archived }; + }, + onError: (_error, _variables, rollback) => { + if (!rollback) return; + moveThreadArchiveState(queryClient, searchSpaceId, rollback.threadId, rollback.archived); + }, + onSuccess: (thread) => { + replaceThreadEverywhere(queryClient, searchSpaceId, thread); + moveThreadArchiveState(queryClient, searchSpaceId, thread.id, thread.archived); + }, + }); +} + +export function useDeleteThread(searchSpaceId: SearchSpaceKey) { + const queryClient = useQueryClient(); + const currentThread = useAtomValue(currentThreadAtom); + const resetCurrentThread = useSetAtom(resetCurrentThreadAtom); + + return useMutation({ + mutationFn: ({ threadId }) => deleteThread(threadId), + onSuccess: (_data, { threadId }) => { + removeThreadEverywhere(queryClient, searchSpaceId, threadId); + if (currentThread.id === threadId) { + resetCurrentThread(); + } + }, + }); +} diff --git a/surfsense_web/lib/chat/thread-cache.ts b/surfsense_web/lib/chat/thread-cache.ts new file mode 100644 index 000000000..789704032 --- /dev/null +++ b/surfsense_web/lib/chat/thread-cache.ts @@ -0,0 +1,249 @@ +import type { QueryClient, QueryKey } from "@tanstack/react-query"; +import type { + ThreadListItem, + ThreadListResponse, + ThreadRecord, +} from "@/lib/chat/thread-persistence"; + +type SearchSpaceKey = number | string; + +type ThreadMetadataPatch = Partial & + Partial & { + has_comments?: boolean; + }; + +function isSameSearchSpace(keyValue: unknown, searchSpaceId: SearchSpaceKey): boolean { + return String(keyValue) === String(searchSpaceId); +} + +function isThreadListResponse(value: unknown): value is ThreadListResponse { + return ( + typeof value === "object" && + value !== null && + Array.isArray((value as ThreadListResponse).threads) && + Array.isArray((value as ThreadListResponse).archived_threads) + ); +} + +function isThreadListItemArray(value: unknown): value is ThreadListItem[] { + return Array.isArray(value); +} + +function listItemPatchFromMetadata(patch: ThreadMetadataPatch): Partial { + const listPatch: Partial = {}; + + if (patch.title !== undefined) listPatch.title = patch.title; + if (patch.archived !== undefined) listPatch.archived = patch.archived; + if (patch.visibility !== undefined) listPatch.visibility = patch.visibility; + if (patch.created_by_id !== undefined) listPatch.created_by_id = patch.created_by_id; + if (patch.created_at !== undefined) listPatch.createdAt = patch.created_at; + if (patch.updated_at !== undefined) listPatch.updatedAt = patch.updated_at; + if (patch.createdAt !== undefined) listPatch.createdAt = patch.createdAt; + if (patch.updatedAt !== undefined) listPatch.updatedAt = patch.updatedAt; + + return listPatch; +} + +function patchListItem( + item: ThreadListItem, + threadId: number, + patch: ThreadMetadataPatch +): ThreadListItem { + if (item.id !== threadId) return item; + return { + ...item, + ...listItemPatchFromMetadata(patch), + }; +} + +function patchThreadListResponse( + response: ThreadListResponse, + threadId: number, + patch: ThreadMetadataPatch +): ThreadListResponse { + return { + ...response, + threads: response.threads.map((item) => patchListItem(item, threadId, patch)), + archived_threads: response.archived_threads.map((item) => patchListItem(item, threadId, patch)), + }; +} + +function patchThreadListItems( + items: ThreadListItem[], + threadId: number, + patch: ThreadMetadataPatch +): ThreadListItem[] { + return items.map((item) => patchListItem(item, threadId, patch)); +} + +function patchThreadRecord( + record: ThreadRecord, + threadId: number, + patch: ThreadMetadataPatch +): ThreadRecord { + if (record.id !== threadId) return record; + return { + ...record, + ...patch, + }; +} + +function threadListQueryFilter(searchSpaceId: SearchSpaceKey) { + return { + predicate: ({ queryKey }: { queryKey: QueryKey }) => + Array.isArray(queryKey) && + queryKey[0] === "threads" && + isSameSearchSpace(queryKey[1], searchSpaceId), + }; +} + +function allThreadsQueryFilter(searchSpaceId: SearchSpaceKey) { + return { + predicate: ({ queryKey }: { queryKey: QueryKey }) => + Array.isArray(queryKey) && + queryKey[0] === "all-threads" && + isSameSearchSpace(queryKey[1], searchSpaceId), + }; +} + +function searchThreadsQueryFilter(searchSpaceId: SearchSpaceKey) { + return { + predicate: ({ queryKey }: { queryKey: QueryKey }) => + Array.isArray(queryKey) && + queryKey[0] === "search-threads" && + isSameSearchSpace(queryKey[1], searchSpaceId), + }; +} + +function threadDetailQueryFilter(searchSpaceId: SearchSpaceKey, threadId: number) { + return { + predicate: ({ queryKey }: { queryKey: QueryKey }) => + Array.isArray(queryKey) && + ((queryKey[0] === "threads" && + queryKey[1] === "detail" && + Number(queryKey[2]) === threadId) || + (queryKey[0] === "threads" && + isSameSearchSpace(queryKey[1], searchSpaceId) && + queryKey[2] === "detail" && + Number(queryKey[3]) === threadId)), + }; +} + +function updateThreadListResponse( + queryClient: QueryClient, + filter: ReturnType, + threadId: number, + patch: ThreadMetadataPatch +): void { + queryClient.setQueriesData(filter, (old) => { + if (!isThreadListResponse(old)) return old; + return patchThreadListResponse(old, threadId, patch); + }); +} + +export function patchThreadEverywhere( + queryClient: QueryClient, + searchSpaceId: SearchSpaceKey, + threadId: number, + patch: ThreadMetadataPatch +): void { + updateThreadListResponse(queryClient, threadListQueryFilter(searchSpaceId), threadId, patch); + updateThreadListResponse(queryClient, allThreadsQueryFilter(searchSpaceId), threadId, patch); + + queryClient.setQueriesData(searchThreadsQueryFilter(searchSpaceId), (old) => { + if (!isThreadListItemArray(old)) return old; + return patchThreadListItems(old, threadId, patch); + }); + + queryClient.setQueriesData( + threadDetailQueryFilter(searchSpaceId, threadId), + (old) => { + if (!old) return old; + return patchThreadRecord(old, threadId, patch); + } + ); +} + +export function replaceThreadEverywhere( + queryClient: QueryClient, + searchSpaceId: SearchSpaceKey, + thread: ThreadRecord +): void { + patchThreadEverywhere(queryClient, searchSpaceId, thread.id, thread); +} + +export function removeThreadEverywhere( + queryClient: QueryClient, + searchSpaceId: SearchSpaceKey, + threadId: number +): void { + const removeFromListResponse = (old: ThreadListResponse | undefined) => { + if (!isThreadListResponse(old)) return old; + return { + ...old, + threads: old.threads.filter((thread) => thread.id !== threadId), + archived_threads: old.archived_threads.filter((thread) => thread.id !== threadId), + }; + }; + + queryClient.setQueriesData( + threadListQueryFilter(searchSpaceId), + removeFromListResponse + ); + queryClient.setQueriesData( + allThreadsQueryFilter(searchSpaceId), + removeFromListResponse + ); + queryClient.setQueriesData(searchThreadsQueryFilter(searchSpaceId), (old) => { + if (!isThreadListItemArray(old)) return old; + return old.filter((thread) => thread.id !== threadId); + }); + queryClient.removeQueries(threadDetailQueryFilter(searchSpaceId, threadId)); +} + +export function moveThreadArchiveState( + queryClient: QueryClient, + searchSpaceId: SearchSpaceKey, + threadId: number, + archived: boolean +): void { + const moveInListResponse = (old: ThreadListResponse | undefined) => { + if (!isThreadListResponse(old)) return old; + + const activeWithoutThread = old.threads.filter((thread) => thread.id !== threadId); + const archivedWithoutThread = old.archived_threads.filter((thread) => thread.id !== threadId); + const existing = + old.threads.find((thread) => thread.id === threadId) ?? + old.archived_threads.find((thread) => thread.id === threadId); + + if (!existing) return old; + + const updated = { ...existing, archived }; + + return { + ...old, + threads: archived ? activeWithoutThread : [updated, ...activeWithoutThread], + archived_threads: archived ? [updated, ...archivedWithoutThread] : archivedWithoutThread, + }; + }; + + queryClient.setQueriesData( + threadListQueryFilter(searchSpaceId), + moveInListResponse + ); + queryClient.setQueriesData( + allThreadsQueryFilter(searchSpaceId), + moveInListResponse + ); + queryClient.setQueriesData(searchThreadsQueryFilter(searchSpaceId), (old) => { + if (!isThreadListItemArray(old)) return old; + return old.map((thread) => (thread.id === threadId ? { ...thread, archived } : thread)); + }); + queryClient.setQueriesData( + threadDetailQueryFilter(searchSpaceId, threadId), + (old) => { + if (!old || old.id !== threadId) return old; + return { ...old, archived }; + } + ); +}