diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 125177084..9f4280063 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -189,3 +189,75 @@ def test_premium_classification_is_error_code_driven(): assert "RATE_LIMIT_KEYWORDS" not in source assert "normalized.includes(" not in source assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source + + +def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + assert "onPreAcceptFailure?: () => Promise;" in source + assert "if (!accepted) {" in source + assert "await onPreAcceptFailure?.();" in source + assert "await onAcceptedStreamError?.();" in source + assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source + assert "setMessageDocumentsMap((prev) => {" in source + + +def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): + user_message_path = ( + Path(__file__).resolve().parents[3] / "surfsense_web/components/assistant-ui/user-message.tsx" + ) + source = user_message_path.read_text(encoding="utf-8") + + assert "Not sent. Edit and retry." not in source + assert "failed_pre_accept" not in source + + +def test_network_send_failures_use_unified_retry_toast_message(): + classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + classifier_source = classifier_path.read_text(encoding="utf-8") + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + page_source = page_path.read_text(encoding="utf-8") + + assert '"send_failed_pre_accept"' in classifier_source + assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source + assert "if (withCode.code) return withCode.code;" in classifier_source + assert 'userMessage: "Message not sent. Please retry."' in classifier_source + assert 'userMessage: "Connection issue. Please try again."' in classifier_source + assert "tagPreAcceptSendFailure(error)" in page_source + assert 'existingCode === "THREAD_BUSY"' in page_source + assert 'existingCode === "AUTH_EXPIRED"' in page_source + assert 'existingCode === "UNAUTHORIZED"' in page_source + assert 'existingCode === "RATE_LIMITED"' in page_source + assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source + assert 'errorCode: "NETWORK_ERROR"' not in page_source + assert "Failed to start chat. Please try again." not in page_source + + +def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + # Each flow tracks accepted boundary and passes it into shared terminal handling. + assert "let newAccepted = false;" in source + assert "let resumeAccepted = false;" in source + assert "let regenerateAccepted = false;" in source + assert "accepted: newAccepted," in source + assert "accepted: resumeAccepted," in source + assert "accepted: regenerateAccepted," in source + + # Pre-accept abort in resume/regenerate exits without persistence. + assert "if (!resumeAccepted) return;" in source + assert "if (!regenerateAccepted) return;" in source + + # New flow persists only when accepted and not already persisted. + assert "if (newAccepted && !userPersisted) {" in source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 70e188612..80ee9e9cd 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -206,6 +206,26 @@ async function toHttpResponseError(response: Response): Promise Promise; + onPreAcceptFailure?: () => Promise; onAcceptedStreamError?: () => Promise; }) => { if (error instanceof Error && error.name === "AbortError") { @@ -625,12 +647,14 @@ export default function NewChatPage() { return; } - if (accepted) { + if (!accepted) { + await onPreAcceptFailure?.(); + } else { await onAcceptedStreamError?.(); } await handleChatFailure({ - error, + error: !accepted ? tagPreAcceptSendFailure(error) : error, flow, threadId, assistantMsgId: accepted ? assistantMsgId : "no-persist-assistant", @@ -863,7 +887,12 @@ export default function NewChatPage() { ); } catch (error) { console.error("[NewChatPage] Failed to create thread:", error); - toast.error("Failed to start chat. Please try again."); + await handleChatFailure({ + error: tagPreAcceptSendFailure(error), + flow: "new", + threadId: currentThreadId, + assistantMsgId: "no-persist-assistant", + }); return; } } @@ -948,27 +977,6 @@ export default function NewChatPage() { }); } - appendMessage(currentThreadId, { - role: "user", - content: persistContent, - }) - .then((savedMessage) => { - const newUserMsgId = `msg-${savedMessage.id}`; - setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) - ); - setMessageDocumentsMap((prev) => { - const docs = prev[userMsgId]; - if (!docs) return prev; - const { [userMsgId]: _, ...rest } = prev; - return { ...rest, [newUserMsgId]: docs }; - }); - if (isNewThread) { - queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); - } - }) - .catch((err) => console.error("Failed to persist user message:", err)); - // Start streaming response setIsRunning(true); const controller = new AbortController(); @@ -988,17 +996,7 @@ export default function NewChatPage() { let wasInterrupted = false; let tokenUsageData: Record | null = null; let newAccepted = false; - - // Add placeholder assistant message - setMessages((prev) => [ - ...prev, - { - id: assistantMsgId, - role: "assistant", - content: [{ type: "text", text: "" }], - createdAt: new Date(), - }, - ]); + let userPersisted = false; try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1062,6 +1060,15 @@ export default function NewChatPage() { throw await toHttpResponseError(response); } newAccepted = true; + setMessages((prev) => [ + ...prev, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]); const flushMessages = () => { setMessages((prev) => @@ -1224,6 +1231,20 @@ export default function NewChatPage() { // Skip persistence for interrupted messages -- handleResume will persist the final version const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0 && !wasInterrupted) { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + logContext: "new chat", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + await persistAssistantTurn({ threadId: currentThreadId, assistantMsgId, @@ -1251,6 +1272,20 @@ export default function NewChatPage() { assistantMsgId, accepted: newAccepted, onAbort: async () => { + if (newAccepted && !userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + logContext: "new chat (aborted)", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + // Request was cancelled by user - persist partial response if any content was received const hasContent = contentParts.some( (part) => @@ -1267,6 +1302,29 @@ export default function NewChatPage() { }); } }, + onAcceptedStreamError: async () => { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + logContext: "new chat (stream error)", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + }, + onPreAcceptFailure: async () => { + setMessages((prev) => prev.filter((m) => m.id !== userMsgId)); + setMessageDocumentsMap((prev) => { + if (!(userMsgId in prev)) return prev; + const { [userMsgId]: _removed, ...rest } = prev; + return rest; + }); + }, }); } finally { setIsRunning(false); @@ -1291,7 +1349,9 @@ export default function NewChatPage() { setPendingUserImageUrls, toolsWithUI, handleStreamTerminalError, + handleChatFailure, persistAssistantTurn, + persistUserTurn, ] ); @@ -1548,6 +1608,22 @@ export default function NewChatPage() { threadId: resumeThreadId, assistantMsgId, accepted: resumeAccepted, + onAbort: async () => { + if (!resumeAccepted) return; + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + ); + if (!hasContent) return; + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId: resumeThreadId, + assistantMsgId, + content: partialContent, + logContext: "partial resumed chat", + }); + }, }); } finally { setIsRunning(false); @@ -1882,6 +1958,33 @@ export default function NewChatPage() { threadId, assistantMsgId, accepted: regenerateAccepted, + onAbort: async () => { + if (!regenerateAccepted) return; + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + logContext: "regenerated (aborted)", + }); + userPersisted = Boolean(persistedUserMsgId); + } + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + ); + if (!hasContent) return; + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId, + assistantMsgId, + content: partialContent, + tokenUsage: tokenUsageData ?? undefined, + logContext: "partial regenerated chat", + }); + }, onAcceptedStreamError: async () => { if (!userPersisted) { const persistedUserMsgId = await persistUserTurn({ diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 4341f7dc5..57341a4c3 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -3,6 +3,7 @@ export type ChatFlow = "new" | "resume" | "regenerate"; export type ChatErrorKind = | "premium_quota_exhausted" | "thread_busy" + | "send_failed_pre_accept" | "auth_expired" | "rate_limited" | "network_offline" @@ -54,8 +55,9 @@ function getErrorMessage(error: unknown): string { function getErrorCode(error: unknown, parsedJson: Record | null): string | undefined { if (error instanceof Error) { - const withCode = error as Error & { errorCode?: string }; + const withCode = error as Error & { errorCode?: string; code?: string }; if (withCode.errorCode) return withCode.errorCode; + if (withCode.code) return withCode.code; } if (typeof error === "object" && error !== null) { @@ -161,6 +163,20 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if (errorCode === "SEND_FAILED_PRE_ACCEPT") { + return { + kind: "send_failed_pre_accept", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "Message not sent. Please retry.", + rawMessage, + errorCode: errorCode ?? "SEND_FAILED_PRE_ACCEPT", + details: { flow: input.flow }, + }; + } + if ( errorCode === "AUTH_EXPIRED" || errorCode === "UNAUTHORIZED" @@ -196,16 +212,14 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "NETWORK_ERROR" - ) { + if (errorCode === "NETWORK_ERROR") { return { kind: "network_offline", channel: "toast", severity: "warn", telemetryEvent: "chat_error", isExpected: true, - userMessage: "Connection issue detected. Check your internet and try again.", + userMessage: "Connection issue. Please try again.", rawMessage, errorCode: errorCode ?? "NETWORK_ERROR", details: { flow: input.flow },