From 6465ea181a25a8c6d003572ea4707aa9e1dcf3cc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:09:18 +0530 Subject: [PATCH 01/13] refactor(chat): streamline NewChatPage component by removing unused functions and integrating new stream handling utilities for improved performance --- .../new-chat/[[...chat_id]]/page.tsx | 625 +++++++----------- 1 file changed, 255 insertions(+), 370 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index fe625f169..d1dd14e06 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -252,6 +252,168 @@ function tagPreAcceptSendFailure(error: unknown): unknown { }); } +type SharedStreamEventContext = { + contentPartsState: ContentPartsState; + toolsWithUI: ToolUIGate; + currentThinkingSteps: Map; + scheduleFlush: () => void; + forceFlush: () => void; + onTokenUsage?: (data: TokenUsageData) => void; + onToolOutputAvailable?: ( + event: Extract, + context: { + contentPartsState: ContentPartsState; + toolCallIndices: Map; + } + ) => void; +}; + +function createStreamFlushHelpers(flushMessages: () => void): { + batcher: FrameBatchedUpdater; + scheduleFlush: () => void; + forceFlush: () => void; +} { + const batcher = new FrameBatchedUpdater(); + const scheduleFlush = () => batcher.schedule(flushMessages); + // Force-flush helper: ``batcher.flush()`` is a no-op when + // ``dirty=false`` (e.g. a tool starts before any text streamed). + // ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so + // terminal events render promptly without the throttle delay. + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; + return { batcher, scheduleFlush, forceFlush }; +} + +function hasPersistableContent(contentParts: ContentPartsState["contentParts"], toolsWithUI: ToolUIGate) { + return contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) + ); +} + +function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { + const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; + const { contentParts, toolCallIndices } = contentPartsState; + + switch (parsed.type) { + case "text-delta": + appendText(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + return true; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + return true; + + case "finish-step": + return true; + + case "tool-input-start": + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); + forceFlush(); + return true; + + case "tool-input-delta": + // High-frequency event: deltas can fire dozens of times per call, + // so use throttled scheduleFlush (NOT forceFlush) to coalesce. + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + return true; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + argsText: finalArgsText, + langchainToolCallId: parsed.langchainToolCallId, + }); + } else { + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + parsed.input || {}, + false, + parsed.langchainToolCallId + ); + // addToolCall doesn't accept argsText today; backfill via + // updateToolCall so the new card renders pretty-printed JSON. + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); + } + forceFlush(); + return true; + } + + case "tool-output-available": + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); + markInterruptsCompleted(contentParts); + context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices }); + forceFlush(); + return true; + + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); + if (didUpdate) { + scheduleFlush(); + } + } + return true; + } + + case "data-token-usage": + context.onTokenUsage?.(parsed.data as TokenUsageData); + return true; + + case "error": + throw toStreamTerminalError(parsed); + + default: + return false; + } +} + +async function consumeSseEvents( + response: Response, + onEvent: (event: SSEEvent) => void | Promise +): Promise { + for await (const parsed of readSSEStream(response)) { + await onEvent(parsed); + } +} + /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -456,7 +618,7 @@ export default function NewChatPage() { threadId: number | null; assistantMsgId: string; content: unknown; - tokenUsage?: Record; + tokenUsage?: TokenUsageData; turnId?: string | null; logContext: string; onRemapped?: (newMsgId: string) => void; @@ -1055,8 +1217,6 @@ export default function NewChatPage() { // Prepare assistant message const assistantMsgId = `msg-assistant-${Date.now()}`; const currentThinkingSteps = new Map(); - const batcher = new FrameBatchedUpdater(); - const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, @@ -1065,11 +1225,12 @@ export default function NewChatPage() { }; const { contentParts, toolCallIndices } = contentPartsState; let wasInterrupted = false; - let tokenUsageData: Record | null = null; + let tokenUsageData: TokenUsageData | null = null; let newAccepted = false; let userPersisted = false; // Captured from ``data-turn-info`` at stream start. let streamedChatTurnId: string | null = null; + let streamBatcher: FrameBatchedUpdater | null = null; try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1152,123 +1313,37 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - // Force-flush helper: ``batcher.flush()`` is a no-op when - // ``dirty=false`` (e.g. a tool starts before any text - // streamed). ``scheduleFlush(); batcher.flush()`` sets - // the dirty bit FIRST so terminal events render - // promptly without the 50ms throttle delay. - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - // High-frequency event: deltas can fire dozens - // of times per call, so use throttled - // scheduleFlush (NOT forceFlush) to coalesce. - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - // addToolCall doesn't accept argsText today; - // backfill via updateToolCall so the new card - // renders pretty-printed JSON. - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": { - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageData = data; + tokenUsageStore.set(assistantMsgId, data); + }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } } } - } - forceFlush(); - break; - } - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - + }, + }) + ) { + return; + } + switch (parsed.type) { case "data-thread-title-update": { const titleData = parsed.data as { threadId: number; title: string }; if (titleData?.title && titleData?.threadId === currentThreadId) { @@ -1374,16 +1449,8 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw toStreamTerminalError(parsed); } - } + }); batcher.flush(); @@ -1425,7 +1492,7 @@ export default function NewChatPage() { trackChatResponseReceived(searchSpaceId, currentThreadId); } } catch (error) { - batcher.dispose(); + streamBatcher?.dispose(); await handleStreamTerminalError({ error, flow: "new", @@ -1448,13 +1515,7 @@ export default function NewChatPage() { } } - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); + const hasContent = hasPersistableContent(contentParts, toolsWithUI); if (hasContent && currentThreadId) { const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); await persistAssistantTurn({ @@ -1543,7 +1604,6 @@ export default function NewChatPage() { abortControllerRef.current = controller; const currentThinkingSteps = new Map(); - const batcher = new FrameBatchedUpdater(); const contentPartsState: ContentPartsState = { contentParts: [], @@ -1552,10 +1612,11 @@ export default function NewChatPage() { toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; - let tokenUsageData: Record | null = null; + let tokenUsageData: TokenUsageData | null = null; let resumeAccepted = false; // Captured from ``data-turn-info`` at stream start. let streamedChatTurnId: string | null = null; + let streamBatcher: FrameBatchedUpdater | null = null; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1664,102 +1725,26 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageData = data; + tokenUsageStore.set(assistantMsgId, data); + }, + }) + ) { + return; + } switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - forceFlush(); - break; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - case "data-interrupt-request": { const interruptData = parsed.data as Record; const actionRequests = (interruptData.action_requests ?? []) as Array<{ @@ -1830,16 +1815,8 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw toStreamTerminalError(parsed); } - } + }); batcher.flush(); @@ -1855,7 +1832,7 @@ export default function NewChatPage() { }); } } catch (error) { - batcher.dispose(); + streamBatcher?.dispose(); await handleStreamTerminalError({ error, flow: "resume", @@ -1864,13 +1841,7 @@ export default function NewChatPage() { accepted: resumeAccepted, onAbort: async () => { if (!resumeAccepted) return; - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); + const hasContent = hasPersistableContent(contentParts, toolsWithUI); if (!hasContent) return; const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); await persistAssistantTurn({ @@ -1891,6 +1862,7 @@ export default function NewChatPage() { pendingInterrupt, messages, searchSpaceId, + queryClient, tokenUsageStore, handleStreamTerminalError, persistAssistantTurn, @@ -2045,15 +2017,15 @@ export default function NewChatPage() { currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; - const { contentParts, toolCallIndices } = contentPartsState; - const batcher = new FrameBatchedUpdater(); - let tokenUsageData: Record | null = null; + const { contentParts } = contentPartsState; + let tokenUsageData: TokenUsageData | null = null; let regenerateAccepted = false; let userPersisted = false; // Captured from ``data-turn-info`` at stream start; stamped // onto persisted messages so future edits can locate the // right LangGraph checkpoint. let streamedChatTurnId: string | null = null; + let streamBatcher: FrameBatchedUpdater | null = null; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -2155,111 +2127,37 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageData = data; + tokenUsageStore.set(assistantMsgId, data); + }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } } } - } - forceFlush(); - break; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - + }, + }) + ) { + return; + } + switch (parsed.type) { case "data-action-log": { if (threadId !== null) { applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data); @@ -2326,16 +2224,8 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw toStreamTerminalError(parsed); } - } + }); batcher.flush(); @@ -2364,7 +2254,7 @@ export default function NewChatPage() { trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { - batcher.dispose(); + streamBatcher?.dispose(); await handleStreamTerminalError({ error, flow: "regenerate", @@ -2384,13 +2274,7 @@ export default function NewChatPage() { }); userPersisted = Boolean(persistedUserMsgId); } - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); + const hasContent = hasPersistableContent(contentParts, toolsWithUI); if (!hasContent) return; const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); await persistAssistantTurn({ @@ -2428,6 +2312,7 @@ export default function NewChatPage() { disabledTools, messageDocumentsMap, setMessageDocumentsMap, + queryClient, tokenUsageStore, handleStreamTerminalError, persistAssistantTurn, From 86f6b285ce9cedbf529a7d8325f4457f602f997a Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:09:34 +0530 Subject: [PATCH 02/13] refactor(chat): introduce new stream handling utilities and restructure event processing for improved performance and maintainability --- .../new-chat/[[...chat_id]]/page.tsx | 205 +----------------- surfsense_web/lib/chat/stream-flush.ts | 19 ++ surfsense_web/lib/chat/stream-pipeline.ts | 191 ++++++++++++++++ 3 files changed, 217 insertions(+), 198 deletions(-) create mode 100644 surfsense_web/lib/chat/stream-flush.ts create mode 100644 surfsense_web/lib/chat/stream-pipeline.ts diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index d1dd14e06..82a12b6b1 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -71,23 +71,21 @@ import { setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; import { - addStepSeparator, addToolCall, - appendReasoning, - appendText, - appendToolInputDelta, buildContentForPersistence, buildContentForUI, type ContentPartsState, - endReasoning, - FrameBatchedUpdater, - readSSEStream, - type SSEEvent, + type FrameBatchedUpdater, type ThinkingStepData, type ToolUIGate, - updateThinkingSteps, updateToolCall, } from "@/lib/chat/streaming-state"; +import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; +import { + consumeSseEvents, + hasPersistableContent, + processSharedStreamEvent, +} from "@/lib/chat/stream-pipeline"; import { appendMessage, createThread, @@ -134,33 +132,6 @@ const MobileReportPanel = dynamic( { ssr: false } ); -/** - * After a tool produces output, mark any previously-decided interrupt tool - * calls as completed so the ApprovalCard can transition from shimmer to done. - */ -function markInterruptsCompleted(contentParts: Array<{ type: string; result?: unknown }>): void { - for (const part of contentParts) { - if ( - part.type === "tool-call" && - typeof part.result === "object" && - part.result !== null && - (part.result as Record).__interrupt__ === true && - (part.result as Record).__decided__ && - !(part.result as Record).__completed__ - ) { - part.result = { ...(part.result as Record), __completed__: true }; - } - } -} - -function toStreamTerminalError( - event: Extract -): Error & { errorCode?: string } { - return Object.assign(new Error(event.errorText || "Server error"), { - errorCode: event.errorCode, - }); -} - async function toHttpResponseError(response: Response): Promise { const statusDefaultCode = response.status === 409 @@ -252,168 +223,6 @@ function tagPreAcceptSendFailure(error: unknown): unknown { }); } -type SharedStreamEventContext = { - contentPartsState: ContentPartsState; - toolsWithUI: ToolUIGate; - currentThinkingSteps: Map; - scheduleFlush: () => void; - forceFlush: () => void; - onTokenUsage?: (data: TokenUsageData) => void; - onToolOutputAvailable?: ( - event: Extract, - context: { - contentPartsState: ContentPartsState; - toolCallIndices: Map; - } - ) => void; -}; - -function createStreamFlushHelpers(flushMessages: () => void): { - batcher: FrameBatchedUpdater; - scheduleFlush: () => void; - forceFlush: () => void; -} { - const batcher = new FrameBatchedUpdater(); - const scheduleFlush = () => batcher.schedule(flushMessages); - // Force-flush helper: ``batcher.flush()`` is a no-op when - // ``dirty=false`` (e.g. a tool starts before any text streamed). - // ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so - // terminal events render promptly without the throttle delay. - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; - return { batcher, scheduleFlush, forceFlush }; -} - -function hasPersistableContent(contentParts: ContentPartsState["contentParts"], toolsWithUI: ToolUIGate) { - return contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); -} - -function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { - const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; - const { contentParts, toolCallIndices } = contentPartsState; - - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - return true; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - return true; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - return true; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - return true; - - case "finish-step": - return true; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - return true; - - case "tool-input-delta": - // High-frequency event: deltas can fire dozens of times per call, - // so use throttled scheduleFlush (NOT forceFlush) to coalesce. - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - return true; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - // addToolCall doesn't accept argsText today; backfill via - // updateToolCall so the new card renders pretty-printed JSON. - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - return true; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices }); - forceFlush(); - return true; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - return true; - } - - case "data-token-usage": - context.onTokenUsage?.(parsed.data as TokenUsageData); - return true; - - case "error": - throw toStreamTerminalError(parsed); - - default: - return false; - } -} - -async function consumeSseEvents( - response: Response, - onEvent: (event: SSEEvent) => void | Promise -): Promise { - for await (const parsed of readSSEStream(response)) { - await onEvent(parsed); - } -} - /** * Zod schema for mentioned document info (for type-safe parsing) */ diff --git a/surfsense_web/lib/chat/stream-flush.ts b/surfsense_web/lib/chat/stream-flush.ts new file mode 100644 index 000000000..6d13c9237 --- /dev/null +++ b/surfsense_web/lib/chat/stream-flush.ts @@ -0,0 +1,19 @@ +import { FrameBatchedUpdater } from "@/lib/chat/streaming-state"; + +export function createStreamFlushHelpers(flushMessages: () => void): { + batcher: FrameBatchedUpdater; + scheduleFlush: () => void; + forceFlush: () => void; +} { + const batcher = new FrameBatchedUpdater(); + const scheduleFlush = () => batcher.schedule(flushMessages); + // Force-flush helper: ``batcher.flush()`` is a no-op when + // ``dirty=false`` (e.g. a tool starts before any text streamed). + // ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so + // terminal events render promptly without the throttle delay. + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; + return { batcher, scheduleFlush, forceFlush }; +} diff --git a/surfsense_web/lib/chat/stream-pipeline.ts b/surfsense_web/lib/chat/stream-pipeline.ts new file mode 100644 index 000000000..8957bdea3 --- /dev/null +++ b/surfsense_web/lib/chat/stream-pipeline.ts @@ -0,0 +1,191 @@ +import { + addStepSeparator, + addToolCall, + appendReasoning, + appendText, + appendToolInputDelta, + type ContentPartsState, + endReasoning, + readSSEStream, + type SSEEvent, + type ThinkingStepData, + type ToolUIGate, + updateThinkingSteps, + updateToolCall, +} from "@/lib/chat/streaming-state"; + +export type SharedStreamEventContext = { + contentPartsState: ContentPartsState; + toolsWithUI: ToolUIGate; + currentThinkingSteps: Map; + scheduleFlush: () => void; + forceFlush: () => void; + onTokenUsage?: (data: Extract["data"]) => void; + onToolOutputAvailable?: ( + event: Extract, + context: { + contentPartsState: ContentPartsState; + toolCallIndices: Map; + } + ) => void; +}; + +/** + * After a tool produces output, mark any previously-decided interrupt tool + * calls as completed so the ApprovalCard can transition from shimmer to done. + */ +export function markInterruptsCompleted( + contentParts: Array<{ type: string; result?: unknown }> +): void { + for (const part of contentParts) { + if ( + part.type === "tool-call" && + typeof part.result === "object" && + part.result !== null && + (part.result as Record).__interrupt__ === true && + (part.result as Record).__decided__ && + !(part.result as Record).__completed__ + ) { + part.result = { ...(part.result as Record), __completed__: true }; + } + } +} + +export function hasPersistableContent( + contentParts: ContentPartsState["contentParts"], + toolsWithUI: ToolUIGate +) { + return contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) + ); +} + +function toStreamTerminalError( + event: Extract +): Error & { errorCode?: string } { + return Object.assign(new Error(event.errorText || "Server error"), { + errorCode: event.errorCode, + }); +} + +export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { + const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; + const { contentParts, toolCallIndices } = contentPartsState; + + switch (parsed.type) { + case "text-delta": + appendText(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + return true; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + return true; + + case "finish-step": + return true; + + case "tool-input-start": + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); + forceFlush(); + return true; + + case "tool-input-delta": + // High-frequency event: deltas can fire dozens of times per call, + // so use throttled scheduleFlush (NOT forceFlush) to coalesce. + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + return true; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + argsText: finalArgsText, + langchainToolCallId: parsed.langchainToolCallId, + }); + } else { + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + parsed.input || {}, + false, + parsed.langchainToolCallId + ); + // addToolCall doesn't accept argsText today; backfill via + // updateToolCall so the new card renders pretty-printed JSON. + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); + } + forceFlush(); + return true; + } + + case "tool-output-available": + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); + markInterruptsCompleted(contentParts); + context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices }); + forceFlush(); + return true; + + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); + if (didUpdate) { + scheduleFlush(); + } + } + return true; + } + + case "data-token-usage": + context.onTokenUsage?.(parsed.data); + return true; + + case "error": + throw toStreamTerminalError(parsed); + + default: + return false; + } +} + +export async function consumeSseEvents( + response: Response, + onEvent: (event: SSEEvent) => void | Promise +): Promise { + for await (const parsed of readSSEStream(response)) { + await onEvent(parsed); + } +} From d65a3fdf76364b0705eaff0953f4d7283ecafde2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:22:34 +0530 Subject: [PATCH 03/13] refactor(chat): implement new error handling utilities and streamline interrupt request processing in NewChatPage for improved performance and maintainability --- .../new-chat/[[...chat_id]]/page.tsx | 238 +++--------------- surfsense_web/lib/chat/chat-request-errors.ts | 89 +++++++ surfsense_web/lib/chat/stream-side-effects.ts | 127 ++++++++++ 3 files changed, 246 insertions(+), 208 deletions(-) create mode 100644 surfsense_web/lib/chat/chat-request-errors.ts create mode 100644 surfsense_web/lib/chat/stream-side-effects.ts diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 82a12b6b1..02c2914be 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -64,6 +64,10 @@ import { classifyChatError, type ChatFlow, } from "@/lib/chat/chat-error-classifier"; +import { + tagPreAcceptSendFailure, + toHttpResponseError, +} from "@/lib/chat/chat-request-errors"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { isPodcastGenerating, @@ -71,14 +75,12 @@ import { setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; import { - addToolCall, buildContentForPersistence, buildContentForUI, type ContentPartsState, type FrameBatchedUpdater, type ThinkingStepData, type ToolUIGate, - updateToolCall, } from "@/lib/chat/streaming-state"; import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; import { @@ -86,6 +88,14 @@ import { hasPersistableContent, processSharedStreamEvent, } from "@/lib/chat/stream-pipeline"; +import { + applyTurnIdToAssistantMessageList, + applyInterruptRequestToContentParts, + mergeChatTurnIdIntoMessage, + mergeEditedInterruptAction, + markInterruptDecisionOnContentParts, + readStreamedChatTurnId, +} from "@/lib/chat/stream-side-effects"; import { appendMessage, createThread, @@ -132,97 +142,6 @@ const MobileReportPanel = dynamic( { ssr: false } ); -async function toHttpResponseError(response: Response): Promise { - const statusDefaultCode = - response.status === 409 - ? "THREAD_BUSY" - : response.status === 429 - ? "RATE_LIMITED" - : response.status === 401 || response.status === 403 - ? "AUTH_EXPIRED" - : "SERVER_ERROR"; - - let rawBody = ""; - try { - rawBody = await response.text(); - } catch { - // noop - } - - let parsedBody: Record | null = null; - if (rawBody) { - try { - const parsed = JSON.parse(rawBody); - if (typeof parsed === "object" && parsed !== null) { - parsedBody = parsed as Record; - } - } catch { - // noop - } - } - - const detail = parsedBody?.detail; - const detailObject = - typeof detail === "object" && detail !== null ? (detail as Record) : null; - const detailMessage = typeof detail === "string" ? detail : undefined; - const topLevelMessage = - typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined; - const detailNestedMessage = - typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined; - - const topLevelCode = - typeof parsedBody?.errorCode === "string" - ? parsedBody.errorCode - : typeof parsedBody?.error_code === "string" - ? parsedBody.error_code - : undefined; - const detailCode = - typeof detailObject?.errorCode === "string" - ? detailObject.errorCode - : typeof detailObject?.error_code === "string" - ? detailObject.error_code - : undefined; - - const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; - const message = - detailNestedMessage ?? - detailMessage ?? - topLevelMessage ?? - `Backend error: ${response.status}`; - - return Object.assign(new Error(message), { errorCode }); -} - -function tagPreAcceptSendFailure(error: unknown): unknown { - if (error instanceof Error) { - const withCode = error as Error & { errorCode?: string; code?: string }; - const existingCode = withCode.errorCode ?? withCode.code; - const passthroughCodes = new Set([ - "PREMIUM_QUOTA_EXHAUSTED", - "THREAD_BUSY", - "AUTH_EXPIRED", - "UNAUTHORIZED", - "RATE_LIMITED", - "NETWORK_ERROR", - "STREAM_PARSE_ERROR", - "TOOL_EXECUTION_ERROR", - "PERSIST_MESSAGE_FAILED", - "SERVER_ERROR", - ]); - if ( - existingCode && - passthroughCodes.has(existingCode) - ) { - return Object.assign(error, { errorCode: existingCode }); - } - return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" }); - } - - return Object.assign(new Error("Failed to send message before stream acceptance"), { - errorCode: "SEND_FAILED_PRE_ACCEPT", - }); -} - /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -264,29 +183,6 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { */ const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; -/** - * When a streamed message is persisted, the backend returns the durable - * ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it - * into the assistant-ui message metadata so the per-turn "Revert turn" - * button can scope to this turn's actions even after a full chat reload. - */ -function mergeChatTurnIdIntoMessage( - msg: ThreadMessageLike, - turnId: string | null | undefined -): ThreadMessageLike { - if (!turnId) return msg; - const existingMeta = (msg.metadata ?? {}) as { custom?: Record }; - const existingCustom = existingMeta.custom ?? {}; - if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; - return { - ...msg, - metadata: { - ...existingMeta, - custom: { ...existingCustom, chatTurnId: turnId }, - }, - }; -} - export default function NewChatPage() { const params = useParams(); const queryClient = useQueryClient(); @@ -1032,7 +928,7 @@ export default function NewChatPage() { currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; - const { contentParts, toolCallIndices } = contentPartsState; + const { contentParts } = contentPartsState; let wasInterrupted = false; let tokenUsageData: TokenUsageData | null = null; let newAccepted = false; @@ -1194,27 +1090,7 @@ export default function NewChatPage() { case "data-interrupt-request": { wasInterrupted = true; const interruptData = parsed.data as Record; - const actionRequests = (interruptData.action_requests ?? []) as Array<{ - name: string; - args: Record; - }>; - for (const action of actionRequests) { - const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { - const part = contentParts[idx]; - return part?.type === "tool-call" && part.toolName === action.name; - }); - if (existingIdx) { - updateToolCall(contentPartsState, existingIdx[0], { - result: { __interrupt__: true, ...interruptData }, - }); - } else { - const tcId = `interrupt-${action.name}`; - addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); - updateToolCall(contentPartsState, tcId, { - result: { __interrupt__: true, ...interruptData }, - }); - } - } + applyInterruptRequestToContentParts(contentPartsState, toolsWithUI, interruptData); setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -1248,12 +1124,11 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + streamedChatTurnId = turnId; + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; @@ -1469,37 +1344,12 @@ export default function NewChatPage() { } // Merge edited args if present to fix race condition - if (decisions.length > 0 && decisions[0].type === "edit" && decisions[0].edited_action) { - const editedAction = decisions[0].edited_action; - for (const part of contentParts) { - if (part.type === "tool-call" && part.toolName === editedAction.name) { - const mergedArgs = { ...part.args, ...editedAction.args }; - part.args = mergedArgs; - // Sync argsText so the rendered card shows the - // edited inputs — assistant-ui prefers caller- - // supplied argsText over JSON.stringify(args). - part.argsText = JSON.stringify(mergedArgs, null, 2); - break; - } - } + if (decisions.length > 0 && decisions[0].type === "edit") { + mergeEditedInterruptAction(contentParts, decisions[0].edited_action); } const decisionType = decisions[0]?.type as "approve" | "reject" | undefined; - if (decisionType) { - for (const part of contentParts) { - if ( - part.type === "tool-call" && - typeof part.result === "object" && - part.result !== null && - "__interrupt__" in (part.result as Record) - ) { - part.result = { - ...(part.result as Record), - __decided__: decisionType, - }; - } - } - } + markInterruptDecisionOnContentParts(contentParts, decisionType); try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1556,33 +1406,7 @@ export default function NewChatPage() { switch (parsed.type) { case "data-interrupt-request": { const interruptData = parsed.data as Record; - const actionRequests = (interruptData.action_requests ?? []) as Array<{ - name: string; - args: Record; - }>; - for (const action of actionRequests) { - const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { - const part = contentParts[idx]; - return part?.type === "tool-call" && part.toolName === action.name; - }); - if (existingIdx) { - updateToolCall(contentPartsState, existingIdx[0], { - result: { - __interrupt__: true, - ...interruptData, - }, - }); - } else { - const tcId = `interrupt-${action.name}`; - addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); - updateToolCall(contentPartsState, tcId, { - result: { - __interrupt__: true, - ...interruptData, - }, - }); - } - } + applyInterruptRequestToContentParts(contentPartsState, toolsWithUI, interruptData); setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -1614,12 +1438,11 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + streamedChatTurnId = turnId; + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; @@ -1987,12 +1810,11 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + streamedChatTurnId = turnId; + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts new file mode 100644 index 000000000..3026e8203 --- /dev/null +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -0,0 +1,89 @@ +export async function toHttpResponseError( + response: Response +): Promise { + const statusDefaultCode = + response.status === 409 + ? "THREAD_BUSY" + : response.status === 429 + ? "RATE_LIMITED" + : response.status === 401 || response.status === 403 + ? "AUTH_EXPIRED" + : "SERVER_ERROR"; + + let rawBody = ""; + try { + rawBody = await response.text(); + } catch { + // noop + } + + let parsedBody: Record | null = null; + if (rawBody) { + try { + const parsed = JSON.parse(rawBody); + if (typeof parsed === "object" && parsed !== null) { + parsedBody = parsed as Record; + } + } catch { + // noop + } + } + + const detail = parsedBody?.detail; + const detailObject = + typeof detail === "object" && detail !== null ? (detail as Record) : null; + const detailMessage = typeof detail === "string" ? detail : undefined; + const topLevelMessage = + typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined; + const detailNestedMessage = + typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined; + + const topLevelCode = + typeof parsedBody?.errorCode === "string" + ? parsedBody.errorCode + : typeof parsedBody?.error_code === "string" + ? parsedBody.error_code + : undefined; + const detailCode = + typeof detailObject?.errorCode === "string" + ? detailObject.errorCode + : typeof detailObject?.error_code === "string" + ? detailObject.error_code + : undefined; + + const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + const message = + detailNestedMessage ?? + detailMessage ?? + topLevelMessage ?? + `Backend error: ${response.status}`; + + return Object.assign(new Error(message), { errorCode }); +} + +export function tagPreAcceptSendFailure(error: unknown): unknown { + if (error instanceof Error) { + const withCode = error as Error & { errorCode?: string; code?: string }; + const existingCode = withCode.errorCode ?? withCode.code; + const passthroughCodes = new Set([ + "PREMIUM_QUOTA_EXHAUSTED", + "THREAD_BUSY", + "AUTH_EXPIRED", + "UNAUTHORIZED", + "RATE_LIMITED", + "NETWORK_ERROR", + "STREAM_PARSE_ERROR", + "TOOL_EXECUTION_ERROR", + "PERSIST_MESSAGE_FAILED", + "SERVER_ERROR", + ]); + if (existingCode && passthroughCodes.has(existingCode)) { + return Object.assign(error, { errorCode: existingCode }); + } + return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" }); + } + + return Object.assign(new Error("Failed to send message before stream acceptance"), { + errorCode: "SEND_FAILED_PRE_ACCEPT", + }); +} diff --git a/surfsense_web/lib/chat/stream-side-effects.ts b/surfsense_web/lib/chat/stream-side-effects.ts new file mode 100644 index 000000000..9cb349458 --- /dev/null +++ b/surfsense_web/lib/chat/stream-side-effects.ts @@ -0,0 +1,127 @@ +import type { ThreadMessageLike } from "@assistant-ui/react"; +import { + addToolCall, + type ContentPartsState, + type ToolUIGate, + updateToolCall, +} from "@/lib/chat/streaming-state"; + +type InterruptActionRequest = { + name: string; + args: Record; +}; + +export type EditedInterruptAction = { + name: string; + args: Record; +}; + +function readInterruptActions( + interruptData: Record +): InterruptActionRequest[] { + return (interruptData.action_requests ?? []) as InterruptActionRequest[]; +} + +/** + * Applies an interrupt request payload to tool-call parts. Existing tool cards + * are updated in-place; missing ones are upserted so approval UI always shows. + */ +export function applyInterruptRequestToContentParts( + contentPartsState: ContentPartsState, + toolsWithUI: ToolUIGate, + interruptData: Record +): void { + const { contentParts, toolCallIndices } = contentPartsState; + const actionRequests = readInterruptActions(interruptData); + for (const action of actionRequests) { + const existingEntry = Array.from(toolCallIndices.entries()).find(([, idx]) => { + const part = contentParts[idx]; + return part?.type === "tool-call" && part.toolName === action.name; + }); + + if (existingEntry) { + updateToolCall(contentPartsState, existingEntry[0], { + result: { __interrupt__: true, ...interruptData }, + }); + } else { + const toolCallId = `interrupt-${action.name}`; + addToolCall(contentPartsState, toolsWithUI, toolCallId, action.name, action.args, true); + updateToolCall(contentPartsState, toolCallId, { + result: { __interrupt__: true, ...interruptData }, + }); + } + } +} + +export function mergeEditedInterruptAction( + contentParts: ContentPartsState["contentParts"], + editedAction: EditedInterruptAction | undefined +): void { + if (!editedAction) return; + for (const part of contentParts) { + if (part.type === "tool-call" && part.toolName === editedAction.name) { + const mergedArgs = { ...part.args, ...editedAction.args }; + part.args = mergedArgs; + // assistant-ui prefers argsText over JSON.stringify(args) + part.argsText = JSON.stringify(mergedArgs, null, 2); + break; + } + } +} + +export function markInterruptDecisionOnContentParts( + contentParts: ContentPartsState["contentParts"], + decisionType: "approve" | "reject" | undefined +): void { + if (!decisionType) return; + for (const part of contentParts) { + if ( + part.type === "tool-call" && + typeof part.result === "object" && + part.result !== null && + "__interrupt__" in (part.result as Record) + ) { + part.result = { + ...(part.result as Record), + __decided__: decisionType, + }; + } + } +} + +/** + * When a streamed message is persisted, the backend returns the durable + * turn_id; merge it into assistant-ui metadata for turn-scoped actions. + */ +export function mergeChatTurnIdIntoMessage( + msg: ThreadMessageLike, + turnId: string | null | undefined +): ThreadMessageLike { + if (!turnId) return msg; + const existingMeta = (msg.metadata ?? {}) as { custom?: Record }; + const existingCustom = existingMeta.custom ?? {}; + if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; + return { + ...msg, + metadata: { + ...existingMeta, + custom: { ...existingCustom, chatTurnId: turnId }, + }, + }; +} + +export function readStreamedChatTurnId(data: unknown): string | null { + if (typeof data !== "object" || data === null) return null; + const value = (data as { chat_turn_id?: unknown }).chat_turn_id; + return typeof value === "string" && value.length > 0 ? value : null; +} + +export function applyTurnIdToAssistantMessageList( + messages: ThreadMessageLike[], + assistantMsgId: string, + turnId: string +): ThreadMessageLike[] { + return messages.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m + ); +} From 4056bd1d6947703652e612ac425dabc3ec3c67da Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 22:37:11 +0530 Subject: [PATCH 04/13] refactor(chat): update resetCurrentThreadAtom to include shareToken and contentType for enhanced report panel state management --- surfsense_web/atoms/chat/current-thread.atom.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/surfsense_web/atoms/chat/current-thread.atom.ts b/surfsense_web/atoms/chat/current-thread.atom.ts index d781df8d2..131c98309 100644 --- a/surfsense_web/atoms/chat/current-thread.atom.ts +++ b/surfsense_web/atoms/chat/current-thread.atom.ts @@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat export const resetCurrentThreadAtom = atom(null, (_, set) => { set(currentThreadAtom, initialState); - set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null }); + set(reportPanelAtom, { + isOpen: false, + reportId: null, + title: null, + wordCount: null, + shareToken: null, + contentType: "markdown", + }); }); /** Target comment ID to scroll to (from URL navigation or inbox click) */ From af66fbf106921822a895536c358f2b1a9b93b7a8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 01:47:52 +0530 Subject: [PATCH 05/13] refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling --- .../agents/new_chat/middleware/busy_mutex.py | 56 ++++- .../app/routes/new_chat_routes.py | 169 ++++++++++++++- surfsense_backend/app/schemas/new_chat.py | 18 ++ .../app/services/new_streaming_service.py | 11 +- .../app/tasks/chat/stream_new_chat.py | 75 ++++++- .../unit/agents/new_chat/test_busy_mutex.py | 30 +++ .../unit/test_stream_new_chat_contract.py | 139 ++++++++++--- .../new-chat/[[...chat_id]]/page.tsx | 194 +++++++++++++----- .../lib/chat/chat-error-classifier.ts | 18 +- surfsense_web/lib/chat/chat-request-errors.ts | 29 ++- surfsense_web/lib/chat/stream-pipeline.ts | 5 + surfsense_web/lib/chat/streaming-state.ts | 8 + 12 files changed, 671 insertions(+), 81 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index c57d85004..d61a56533 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -33,6 +33,7 @@ from __future__ import annotations import asyncio import logging +import time import weakref from typing import Any @@ -58,6 +59,8 @@ class _ThreadLockManager: weakref.WeakValueDictionary() ) self._cancel_events: dict[str, asyncio.Event] = {} + self._cancel_requested_at_ms: dict[str, int] = {} + self._cancel_attempt_count: dict[str, int] = {} def lock_for(self, thread_id: str) -> asyncio.Lock: lock = self._locks.get(thread_id) @@ -76,14 +79,45 @@ class _ThreadLockManager: def request_cancel(self, thread_id: str) -> bool: event = self._cancel_events.get(thread_id) if event is None: - return False + event = asyncio.Event() + self._cancel_events[thread_id] = event event.set() + now_ms = int(time.time() * 1000) + self._cancel_requested_at_ms[thread_id] = now_ms + self._cancel_attempt_count[thread_id] = ( + self._cancel_attempt_count.get(thread_id, 0) + 1 + ) return True + def is_cancel_requested(self, thread_id: str) -> bool: + event = self._cancel_events.get(thread_id) + return bool(event and event.is_set()) + + def cancel_state(self, thread_id: str) -> tuple[int, int] | None: + if not self.is_cancel_requested(thread_id): + return None + attempts = self._cancel_attempt_count.get(thread_id, 1) + requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0) + return attempts, requested_at_ms + def reset(self, thread_id: str) -> None: event = self._cancel_events.get(thread_id) if event is not None: event.clear() + self._cancel_requested_at_ms.pop(thread_id, None) + self._cancel_attempt_count.pop(thread_id, None) + + def end_turn(self, thread_id: str) -> None: + """Best-effort terminal cleanup for a thread turn. + + This is intentionally idempotent and safe to call from outer stream + finally-blocks where middleware teardown might be skipped due to abort + or disconnect edge-cases. + """ + lock = self._locks.get(thread_id) + if lock is not None and lock.locked(): + lock.release() + self.reset(thread_id) # Module-level singleton — process-local but reused across all agent @@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event: def request_cancel(thread_id: str) -> bool: - """Trip the cancel event for ``thread_id``. Returns True if found.""" + """Trip the cancel event for ``thread_id``. Always returns True.""" return manager.request_cancel(thread_id) +def is_cancel_requested(thread_id: str) -> bool: + """Return whether ``thread_id`` currently has a pending cancel signal.""" + return manager.is_cancel_requested(thread_id) + + +def get_cancel_state(thread_id: str) -> tuple[int, int] | None: + """Return ``(attempt_count, requested_at_ms)`` for pending cancel state.""" + return manager.cancel_state(thread_id) + + def reset_cancel(thread_id: str) -> None: """Reset the cancel event for ``thread_id`` (called between turns).""" manager.reset(thread_id) +def end_turn(thread_id: str) -> None: + """Force end-of-turn cleanup for lock + cancel state.""" + manager.end_turn(thread_id) + + class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Block concurrent prompts on the same thread. @@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo __all__ = [ "BusyMutexMiddleware", + "end_turn", "get_cancel_event", + "get_cancel_state", + "is_cancel_requested", "manager", "request_cancel", "reset_cancel", diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index e04cce1b5..28b197ca2 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -15,7 +15,7 @@ import json import logging from datetime import UTC, datetime -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.responses import StreamingResponse from sqlalchemy import func, or_ from sqlalchemy.exc import IntegrityError, OperationalError @@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import ( FilesystemSelection, LocalFilesystemMount, ) +from app.agents.new_chat.middleware.busy_mutex import ( + get_cancel_state, + is_cancel_requested, + manager, + request_cancel, +) from app.config import config from app.db import ( ChatComment, @@ -44,6 +50,7 @@ from app.db import ( ) from app.schemas.new_chat import ( AgentToolInfo, + CancelActiveTurnResponse, LocalFilesystemMountPayload, NewChatMessageRead, NewChatRequest, @@ -60,6 +67,7 @@ from app.schemas.new_chat import ( ThreadListItem, ThreadListResponse, TokenUsageSummary, + TurnStatusResponse, ) from app.services.token_tracking_service import record_token_usage from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat @@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import ( _logger = logging.getLogger(__name__) _background_tasks: set[asyncio.Task] = set() +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 router = APIRouter() @@ -137,6 +148,72 @@ def _resolve_filesystem_selection( ) +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + """Bounded exponential delay for TURN_CANCELLING retry hints.""" + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) + + +def _build_turn_status_payload(thread_id: int) -> dict[str, object]: + lock = manager.lock_for(str(thread_id)) + if not lock.locked(): + return {"status": "idle"} + + if is_cancel_requested(str(thread_id)): + cancel_state = get_cancel_state(str(thread_id)) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms + return { + "status": "cancelling", + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + } + + return {"status": "busy"} + + +def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None: + response.headers["retry-after-ms"] = str(retry_after_ms) + response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000)) + + +def _raise_if_thread_busy_for_start(thread_id: int) -> None: + status_payload = _build_turn_status_payload(thread_id) + status = status_payload["status"] + if status == "idle": + return + if status == "cancelling": + retry_after_ms = int(status_payload.get("retry_after_ms") or 0) + detail = { + "errorCode": "TURN_CANCELLING", + "message": "A previous response is still stopping. Please try again in a moment.", + "retry_after_ms": retry_after_ms if retry_after_ms > 0 else None, + "retry_after_at": status_payload.get("retry_after_at"), + } + headers = ( + { + "retry-after-ms": str(retry_after_ms), + "Retry-After": str(max(1, (retry_after_ms + 999) // 1000)), + } + if retry_after_ms > 0 + else None + ) + raise HTTPException(status_code=409, detail=detail, headers=headers) + + raise HTTPException( + status_code=409, + detail={ + "errorCode": "THREAD_BUSY", + "message": "Another response is still finishing for this thread. Please try again in a moment.", + }, + ) + + def _find_pre_turn_checkpoint_id( checkpoint_tuples: list, *, @@ -1476,6 +1553,7 @@ async def handle_new_chat( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(request.chat_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -1550,6 +1628,93 @@ async def handle_new_chat( ) from None +@router.post( + "/threads/{thread_id}/cancel-active-turn", + response_model=CancelActiveTurnResponse, +) +async def cancel_active_turn( + thread_id: int, + response: Response, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Signal cancellation for the currently running turn on ``thread_id``.""" + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_UPDATE.value, + "You don't have permission to update chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + if status_payload["status"] == "idle": + return CancelActiveTurnResponse( + status="idle", + error_code="NO_ACTIVE_TURN", + ) + + request_cancel(str(thread_id)) + response.status_code = 202 + updated_payload = _build_turn_status_payload(thread_id) + retry_after_ms = int(updated_payload.get("retry_after_ms") or 0) + retry_after_at = ( + int(updated_payload["retry_after_at"]) + if "retry_after_at" in updated_payload + else None + ) + if retry_after_ms > 0: + _set_retry_after_headers(response, retry_after_ms) + return CancelActiveTurnResponse( + status="cancelling", + error_code="TURN_CANCELLING", + retry_after_ms=retry_after_ms if retry_after_ms > 0 else None, + retry_after_at=retry_after_at, + ) + + +@router.get( + "/threads/{thread_id}/turn-status", + response_model=TurnStatusResponse, +) +async def get_turn_status( + thread_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to view chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + return TurnStatusResponse( + status=status_payload["status"], # type: ignore[arg-type] + active_turn_id=None, + retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type] + retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type] + ) + + # ============================================================================= # Chat Regeneration Endpoint (Edit/Reload) # ============================================================================= @@ -1605,6 +1770,7 @@ async def regenerate_response( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -2012,6 +2178,7 @@ async def resume_chat( ) await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index c7284e901..ec5eefc07 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -335,6 +335,24 @@ class ResumeRequest(BaseModel): local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None +class CancelActiveTurnResponse(BaseModel): + """Response for canceling an active turn on a chat thread.""" + + status: Literal["cancelling", "idle"] + error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"] + retry_after_ms: int | None = None + retry_after_at: int | None = None + + +class TurnStatusResponse(BaseModel): + """Current turn execution status for a thread.""" + + status: Literal["idle", "busy", "cancelling"] + active_turn_id: str | None = None + retry_after_ms: int | None = None + retry_after_at: int | None = None + + # ============================================================================= # Public Chat Snapshot Schemas # ============================================================================= diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 842481f1c..55129668c 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -565,7 +565,12 @@ class VercelStreamingService: # Error Part # ========================================================================= - def format_error(self, error_text: str, error_code: str | None = None) -> str: + def format_error( + self, + error_text: str, + error_code: str | None = None, + extra: dict[str, object] | None = None, + ) -> str: """ Format an error message. @@ -579,9 +584,11 @@ class VercelStreamingService: Example output: data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"} """ - payload: dict[str, str] = {"type": "error", "errorText": error_text} + payload: dict[str, object] = {"type": "error", "errorText": error_text} if error_code: payload["errorCode"] = error_code + if extra: + payload.update(extra) return self._format_sse(payload) # ========================================================================= diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 2afa851b5..63c149771 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import ( extract_and_save_memory, extract_and_save_team_memory, ) +from app.agents.new_chat.middleware.busy_mutex import ( + end_turn, + get_cancel_state, + is_cancel_requested, +) from app.agents.new_chat.middleware.kb_persistence import ( commit_staged_filesystem_state, ) @@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 + + +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: @@ -401,15 +418,35 @@ def _classify_stream_exception( exc: Exception, *, flow_label: str, -) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +]: raw = str(exc) if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None + if busy_thread_id and is_cancel_requested(busy_thread_id): + cancel_state = get_cancel_state(busy_thread_id) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(time.time() * 1000) + retry_after_ms + return ( + "thread_busy", + "TURN_CANCELLING", + "info", + True, + "A previous response is still stopping. Please try again in a moment.", + { + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + }, + ) return ( "thread_busy", "THREAD_BUSY", "warn", True, "Another response is still finishing for this thread. Please try again in a moment.", + None, ) parsed = _parse_error_payload(raw) @@ -431,6 +468,7 @@ def _classify_stream_exception( "warn", True, "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, ) return ( @@ -439,6 +477,7 @@ def _classify_stream_exception( "error", False, f"Error during {flow_label}: {raw}", + None, ) @@ -470,7 +509,7 @@ def _emit_stream_terminal_error( message=message, extra=extra, ) - return streaming_service.format_error(message, error_code=error_code) + return streaming_service.format_error(message, error_code=error_code, extra=extra) def _legacy_match_lc_id( @@ -2497,6 +2536,7 @@ async def stream_new_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) # Initial thinking step - analyzing the request if mentioned_surfsense_docs: @@ -2805,6 +2845,7 @@ async def stream_new_chat( task.add_done_callback(_background_tasks.discard) # Finish the step and message + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2819,11 +2860,19 @@ async def stream_new_chat( severity, is_expected, user_message, + error_extra, ) = _classify_stream_exception(e, flow_label="chat") error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) yield _emit_stream_error( message=user_message, @@ -2831,7 +2880,9 @@ async def stream_new_chat( error_code=error_code, severity=severity, is_expected=is_expected, + extra=error_extra, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2847,6 +2898,10 @@ async def stream_new_chat( # (CancelledError is a BaseException), and the rest of the # finally block — including session.close() — would never run. with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized if _premium_request_id and _premium_reserved > 0 and user_id: try: @@ -3206,6 +3261,7 @@ async def stream_resume_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) _t_stream_start = time.perf_counter() _first_event_logged = False @@ -3305,6 +3361,7 @@ async def stream_resume_chat( }, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -3318,23 +3375,37 @@ async def stream_resume_chat( severity, is_expected, user_message, + error_extra, ) = _classify_stream_exception(e, flow_label="resume") error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) yield _emit_stream_error( message=user_message, error_kind=error_kind, error_code=error_code, severity=severity, is_expected=is_expected, + extra=error_extra, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() finally: with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: try: diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index 0c7bf17f6..c923dc499 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -7,7 +7,9 @@ import pytest from app.agents.new_chat.errors import BusyError from app.agents.new_chat.middleware.busy_mutex import ( BusyMutexMiddleware, + end_turn, get_cancel_event, + is_cancel_requested, manager, request_cancel, reset_cancel, @@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None: def test_reset_cancel_idempotent() -> None: # Should not raise even if event was never created reset_cancel("never-seen") + + +def test_request_cancel_creates_event_for_unseen_thread() -> None: + thread_id = "never-seen-cancel" + reset_cancel(thread_id) + + assert request_cancel(thread_id) is True + assert get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is True + + +@pytest.mark.asyncio +async def test_end_turn_force_clears_lock_and_cancel_state() -> None: + thread_id = "forced-end-turn" + mw = BusyMutexMiddleware() + runtime = _Runtime(thread_id) + + await mw.abefore_agent({}, runtime) + assert manager.lock_for(thread_id).locked() + + request_cancel(thread_id) + assert is_cancel_requested(thread_id) is True + + end_turn(thread_id) + + assert not manager.lock_for(thread_id).locked() + assert not get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is False diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 86ea7edd1..a1345c15c 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -8,6 +8,7 @@ import pytest import app.tasks.chat.stream_new_chat as stream_new_chat_module from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, @@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited(): exc = Exception( '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' ) - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "rate_limited" @@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited(): assert severity == "warn" assert is_expected is True assert "temporarily rate-limited" in user_message + assert extra is None def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "thread_busy" @@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy(): assert severity == "warn" assert is_expected is True assert "still finishing for this thread" in user_message + assert extra is None def test_stream_exception_classifies_thread_busy_from_message(): exc = Exception("Thread is busy with another request") - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "thread_busy" @@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message(): assert severity == "warn" assert is_expected is True assert "still finishing for this thread" in user_message + assert extra is None + + +def test_stream_exception_classifies_turn_cancelling_when_cancel_requested(): + thread_id = "thread-cancelling-1" + reset_cancel(thread_id) + request_cancel(thread_id) + exc = BusyError(request_id=thread_id) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "TURN_CANCELLING" + assert severity == "info" + assert is_expected is True + assert "stopping" in user_message + assert isinstance(extra, dict) + assert "retry_after_ms" in extra def test_premium_classification_is_error_code_driven(): @@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): def test_network_send_failures_use_unified_retry_toast_message(): classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" classifier_source = classifier_path.read_text(encoding="utf-8") - page_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + request_errors_path = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts" ) - page_source = page_path.read_text(encoding="utf-8") + request_errors_source = request_errors_path.read_text(encoding="utf-8") assert '"send_failed_pre_accept"' in classifier_source assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source + assert 'errorCode === "TURN_CANCELLING"' in classifier_source assert "if (withCode.code) return withCode.code;" in classifier_source assert 'userMessage: "Message not sent. Please retry."' in classifier_source assert 'userMessage: "Connection issue. Please try again."' in classifier_source - assert "tagPreAcceptSendFailure(error)" in page_source - assert "const passthroughCodes = new Set([" in page_source - assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source - assert '"THREAD_BUSY"' in page_source - assert '"AUTH_EXPIRED"' in page_source - assert '"UNAUTHORIZED"' in page_source - assert '"RATE_LIMITED"' in page_source - assert '"NETWORK_ERROR"' in page_source - assert '"STREAM_PARSE_ERROR"' in page_source - assert '"TOOL_EXECUTION_ERROR"' in page_source - assert '"PERSIST_MESSAGE_FAILED"' in page_source - assert '"SERVER_ERROR"' in page_source - assert "passthroughCodes.has(existingCode)" in page_source - assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source - assert 'errorCode: "NETWORK_ERROR"' not in page_source - assert "Failed to start chat. Please try again." not in page_source + assert "const passthroughCodes = new Set([" in request_errors_source + assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source + assert '"THREAD_BUSY"' in request_errors_source + assert '"TURN_CANCELLING"' in request_errors_source + assert '"AUTH_EXPIRED"' in request_errors_source + assert '"UNAUTHORIZED"' in request_errors_source + assert '"RATE_LIMITED"' in request_errors_source + assert '"NETWORK_ERROR"' in request_errors_source + assert '"STREAM_PARSE_ERROR"' in request_errors_source + assert '"TOOL_EXECUTION_ERROR"' in request_errors_source + assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source + assert '"SERVER_ERROR"' in request_errors_source + assert "passthroughCodes.has(existingCode)" in request_errors_source + assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source + assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source + assert "Failed to start chat. Please try again." not in classifier_source def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): @@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows() # New flow persists only when accepted and not already persisted. assert "if (newAccepted && !userPersisted) {" in source + assert "const fetchWithTurnCancellingRetry = useCallback(" in source + assert "computeFallbackTurnCancellingRetryDelay" in source + assert 'withMeta.errorCode === "TURN_CANCELLING"' in source + assert 'withMeta.errorCode === "THREAD_BUSY"' in source + assert "await fetchWithTurnCancellingRetry(() =>" in source + + +def test_cancel_active_turn_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source + assert "response_model=CancelActiveTurnResponse" in source + assert 'status="cancelling",' in source + assert 'error_code="TURN_CANCELLING",' in source + assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source + assert "retry_after_at=" in source + assert 'status="idle",' in source + assert 'error_code="NO_ACTIVE_TURN",' in source + + +def test_turn_status_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source + assert "response_model=TurnStatusResponse" in source + assert "_build_turn_status_payload(thread_id)" in source + assert "Permission.CHATS_READ.value" in source + assert "_raise_if_thread_busy_for_start(" in source + + +def test_turn_cancelling_retry_policy_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source + assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source + assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source + assert "def _compute_turn_cancelling_retry_delay(" in source + assert "retry-after-ms" in source + assert '"Retry-After"' in source + assert '"errorCode": "TURN_CANCELLING"' in source + + +def test_turn_status_sse_contract_exists(): + stream_source = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/tasks/chat/stream_new_chat.py" + ).read_text(encoding="utf-8") + state_source = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts" + ).read_text(encoding="utf-8") + pipeline_source = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts" + ).read_text(encoding="utf-8") + + assert '"turn-status"' in stream_source + assert '"status": "busy"' in stream_source + assert '"status": "idle"' in stream_source + assert "type: \"data-turn-status\"" in state_source + assert "case \"data-turn-status\":" in pipeline_source + assert "end_turn(str(chat_id))" in stream_source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 02c2914be..1b25ca431 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -182,6 +182,20 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { * ``stream_new_chat.py``) keep the JSON from ballooning. */ const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; +const TURN_CANCELLING_INITIAL_DELAY_MS = 200; +const TURN_CANCELLING_BACKOFF_FACTOR = 2; +const TURN_CANCELLING_MAX_DELAY_MS = 1500; +const RECENT_CANCEL_WINDOW_MS = 5_000; + +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +function computeFallbackTurnCancellingRetryDelay(attempt: number): number { + const safeAttempt = Math.max(1, attempt); + const raw = TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1); + return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS); +} export default function NewChatPage() { const params = useParams(); @@ -193,6 +207,7 @@ export default function NewChatPage() { const [isRunning, setIsRunning] = useState(false); const [tokenUsageStore] = useState(() => createTokenUsageStore()); const abortControllerRef = useRef(null); + const recentCancelRequestedAtRef = useRef(0); const [pendingInterrupt, setPendingInterrupt] = useState<{ threadId: number; assistantMsgId: string; @@ -598,6 +613,36 @@ export default function NewChatPage() { [handleChatFailure] ); + const fetchWithTurnCancellingRetry = useCallback( + async (runFetch: () => Promise) => { + const maxAttempts = 4; + for (let attempt = 1; attempt <= maxAttempts; attempt += 1) { + const response = await runFetch(); + if (response.ok) { + return response; + } + const error = await toHttpResponseError(response); + const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number }; + const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING"; + const isRecentThreadBusyAfterCancel = + withMeta.errorCode === "THREAD_BUSY" && + Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS; + if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) { + const waitMs = + withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt); + await sleep(waitMs); + continue; + } + throw error; + } + + throw Object.assign(new Error("Turn cancellation retry limit exceeded"), { + errorCode: "TURN_CANCELLING", + }); + }, + [] + ); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -767,12 +812,39 @@ export default function NewChatPage() { // Cancel ongoing request const cancelRun = useCallback(async () => { + if (threadId) { + const token = getBearerToken(); + if (token) { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + try { + const response = await fetch( + `${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`, + { + method: "POST", + headers: { + Authorization: `Bearer ${token}`, + }, + } + ); + if (response.ok) { + const payload = (await response.json()) as { + error_code?: string; + }; + if (payload.error_code === "TURN_CANCELLING") { + recentCancelRequestedAtRef.current = Date.now(); + } + } + } catch (error) { + console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error); + } + } + } if (abortControllerRef.current) { abortControllerRef.current.abort(); abortControllerRef.current = null; } setIsRunning(false); - }, []); + }, [threadId]); // Handle new message from user const onNew = useCallback( @@ -971,29 +1043,33 @@ export default function NewChatPage() { setMentionedDocuments([]); } - const response = await fetch(`${backendUrl}/api/v1/new_chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - chat_id: currentThreadId, - user_query: userQuery.trim(), - search_space_id: searchSpaceId, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - messages: messageHistory, - mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined, - mentioned_surfsense_doc_ids: hasSurfsenseDocIds - ? mentionedDocumentIds.surfsense_doc_ids - : undefined, - disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, - ...(userImages.length > 0 ? { user_images: userImages } : {}), - }), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/new_chat`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + chat_id: currentThreadId, + user_query: userQuery.trim(), + search_space_id: searchSpaceId, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + messages: messageHistory, + mentioned_document_ids: hasDocumentIds + ? mentionedDocumentIds.document_ids + : undefined, + mentioned_surfsense_doc_ids: hasSurfsenseDocIds + ? mentionedDocumentIds.surfsense_doc_ids + : undefined, + disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, + ...(userImages.length > 0 ? { user_images: userImages } : {}), + }), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1033,6 +1109,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, onToolOutputAvailable: (event, sharedCtx) => { if (event.output?.status === "pending" && event.output?.podcast_id) { const idx = sharedCtx.toolCallIndices.get(event.toolCallId); @@ -1257,6 +1338,7 @@ export default function NewChatPage() { tokenUsageStore, pendingUserImageUrls, setPendingUserImageUrls, + fetchWithTurnCancellingRetry, handleStreamTerminalError, handleChatFailure, persistAssistantTurn, @@ -1354,21 +1436,23 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; const selection = await getAgentFilesystemSelection(searchSpaceId); - const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - search_space_id: searchSpaceId, - decisions, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - }), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + search_space_id: searchSpaceId, + decisions, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + }), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1399,6 +1483,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, }) ) { return; @@ -1496,6 +1585,7 @@ export default function NewChatPage() { searchSpaceId, queryClient, tokenUsageStore, + fetchWithTurnCancellingRetry, handleStreamTerminalError, persistAssistantTurn, ] @@ -1700,15 +1790,17 @@ export default function NewChatPage() { requestBody.revert_actions = true; } } - const response = await fetch(getRegenerateUrl(threadId), { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify(requestBody), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(getRegenerateUrl(threadId), { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify(requestBody), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1774,6 +1866,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, onToolOutputAvailable: (event, sharedCtx) => { if (event.output?.status === "pending" && event.output?.podcast_id) { const idx = sharedCtx.toolCallIndices.get(event.toolCallId); @@ -1945,6 +2042,7 @@ export default function NewChatPage() { setMessageDocumentsMap, queryClient, tokenUsageStore, + fetchWithTurnCancellingRetry, handleStreamTerminalError, persistAssistantTurn, persistUserTurn, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 57341a4c3..7dfbfc1a1 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if ( + errorCode === "TURN_CANCELLING" + ) { + return { + kind: "thread_busy", + channel: "toast", + severity: "info", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "A previous response is still stopping. Please try again in a moment.", + rawMessage, + errorCode: errorCode ?? "TURN_CANCELLING", + details: { flow: input.flow }, + }; + } + if ( errorCode === "THREAD_BUSY" ) { @@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError severity: "warn", telemetryEvent: "chat_blocked", isExpected: true, - userMessage: "A previous response is still stopping. Please try again in a moment.", + userMessage: "Another response is still finishing for this thread. Please try again in a moment.", rawMessage, errorCode: errorCode ?? "THREAD_BUSY", details: { flow: input.flow }, diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts index 3026e8203..708831354 100644 --- a/surfsense_web/lib/chat/chat-request-errors.ts +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -1,6 +1,6 @@ export async function toHttpResponseError( response: Response -): Promise { +): Promise { const statusDefaultCode = response.status === 409 ? "THREAD_BUSY" @@ -52,13 +52,37 @@ export async function toHttpResponseError( : undefined; const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + + const detailRetryAfterMs = + typeof detailObject?.retry_after_ms === "number" + ? detailObject.retry_after_ms + : typeof detailObject?.retryAfterMs === "number" + ? detailObject.retryAfterMs + : undefined; + const topRetryAfterMs = + typeof parsedBody?.retry_after_ms === "number" + ? parsedBody.retry_after_ms + : typeof parsedBody?.retryAfterMs === "number" + ? parsedBody.retryAfterMs + : undefined; + const headerRetryAfterMsRaw = response.headers.get("retry-after-ms"); + const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN; + const retryAfterHeader = response.headers.get("retry-after"); + const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN; + const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs) + ? Math.max(0, Math.round(headerRetryAfterMs)) + : Number.isFinite(retryAfterSeconds) + ? Math.max(0, Math.round(retryAfterSeconds * 1000)) + : undefined; + const retryAfterMs = + detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined; const message = detailNestedMessage ?? detailMessage ?? topLevelMessage ?? `Backend error: ${response.status}`; - return Object.assign(new Error(message), { errorCode }); + return Object.assign(new Error(message), { errorCode, retryAfterMs }); } export function tagPreAcceptSendFailure(error: unknown): unknown { @@ -68,6 +92,7 @@ export function tagPreAcceptSendFailure(error: unknown): unknown { const passthroughCodes = new Set([ "PREMIUM_QUOTA_EXHAUSTED", "THREAD_BUSY", + "TURN_CANCELLING", "AUTH_EXPIRED", "UNAUTHORIZED", "RATE_LIMITED", diff --git a/surfsense_web/lib/chat/stream-pipeline.ts b/surfsense_web/lib/chat/stream-pipeline.ts index 8957bdea3..c9118f949 100644 --- a/surfsense_web/lib/chat/stream-pipeline.ts +++ b/surfsense_web/lib/chat/stream-pipeline.ts @@ -21,6 +21,7 @@ export type SharedStreamEventContext = { scheduleFlush: () => void; forceFlush: () => void; onTokenUsage?: (data: Extract["data"]) => void; + onTurnStatus?: (data: Extract["data"]) => void; onToolOutputAvailable?: ( event: Extract, context: { @@ -173,6 +174,10 @@ export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStream context.onTokenUsage?.(parsed.data); return true; + case "data-turn-status": + context.onTurnStatus?.(parsed.data); + return true; + case "error": throw toStreamTerminalError(parsed); diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 445bbe83d..80e7bffbe 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -528,6 +528,14 @@ export type SSEEvent = }>; }; } + | { + type: "data-turn-status"; + data: { + status: "idle" | "busy" | "cancelling"; + retry_after_ms?: number; + retry_after_at?: number; + }; + } | { type: "data-token-usage"; data: { From a66c1576b965acc50ae89d8a0f71ed3db1b64077 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:09:53 +0530 Subject: [PATCH 06/13] refactor(chat): introduce ChatViewport and NestedScroll components for improved chat UI structure and functionality --- .../components/assistant-ui/chat-viewport.tsx | 44 +++++++ .../components/assistant-ui/nested-scroll.tsx | 24 ++++ .../assistant-ui/thread-scroll-to-bottom.tsx | 18 --- .../components/assistant-ui/thread.tsx | 108 +++--------------- .../components/assistant-ui/tool-fallback.tsx | 9 +- .../components/free-chat/free-thread.tsx | 43 ++----- .../components/public-chat/public-thread.tsx | 9 +- 7 files changed, 99 insertions(+), 156 deletions(-) create mode 100644 surfsense_web/components/assistant-ui/chat-viewport.tsx create mode 100644 surfsense_web/components/assistant-ui/nested-scroll.tsx delete mode 100644 surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx new file mode 100644 index 000000000..f91a8916a --- /dev/null +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -0,0 +1,44 @@ +"use client"; + +import { ThreadPrimitive } from "@assistant-ui/react"; +import { ArrowDownIcon } from "lucide-react"; +import type { FC, ReactNode } from "react"; +import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; + +const ChatScrollToBottom: FC = () => ( + + + + + +); + +export interface ChatViewportProps { + children: ReactNode; + footer?: ReactNode; +} + +export const ChatViewport: FC = ({ children, footer }) => ( + + {children} + {footer ? ( + + + {footer} + + ) : null} + +); diff --git a/surfsense_web/components/assistant-ui/nested-scroll.tsx b/surfsense_web/components/assistant-ui/nested-scroll.tsx new file mode 100644 index 000000000..5a4f8d36e --- /dev/null +++ b/surfsense_web/components/assistant-ui/nested-scroll.tsx @@ -0,0 +1,24 @@ +"use client"; + +import { forwardRef, type ComponentPropsWithoutRef, type WheelEvent } from "react"; + +export type NestedScrollProps = ComponentPropsWithoutRef<"div">; + +export const NestedScroll = forwardRef( + ({ onWheel, ...props }, ref) => { + const handleWheel = (event: WheelEvent) => { + const el = event.currentTarget; + const canScrollUp = el.scrollTop > 0; + const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1; + const goingUp = event.deltaY < 0; + const goingDown = event.deltaY > 0; + if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) { + event.stopPropagation(); + } + onWheel?.(event); + }; + return
; + } +); + +NestedScroll.displayName = "NestedScroll"; diff --git a/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx b/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx deleted file mode 100644 index 394ba5d79..000000000 --- a/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx +++ /dev/null @@ -1,18 +0,0 @@ -import { ThreadPrimitive } from "@assistant-ui/react"; -import { ArrowDownIcon } from "lucide-react"; -import type { FC } from "react"; -import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; - -export const ThreadScrollToBottom: FC = () => { - return ( - - - - - - ); -}; diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 3e27e7adb..1d24a2a39 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -5,12 +5,10 @@ import { ThreadPrimitive, useAui, useAuiState, - useThreadViewportStore, } from "@assistant-ui/react"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; import { AlertCircle, - ArrowDownIcon, ArrowUpIcon, Camera, ChevronDown, @@ -55,6 +53,7 @@ import { import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status"; +import { ChatViewport } from "@/components/assistant-ui/chat-viewport"; import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup"; import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup"; import { @@ -112,10 +111,17 @@ const ThreadContent: FC = () => { ["--thread-max-width" as string]: "44rem", }} > - + !thread.isEmpty}> + + + !thread.isEmpty}> + + + + } > thread.isEmpty}> @@ -128,24 +134,7 @@ const ThreadContent: FC = () => { AssistantMessage, }} /> - - !thread.isEmpty}> -
- - - - - !thread.isEmpty}> - - - !thread.isEmpty}> - - - - + ); }; @@ -181,20 +170,6 @@ const PremiumQuotaPinnedAlert: FC = () => { ); }; -const ThreadScrollToBottom: FC = () => { - return ( - - - - - - ); -}; - const getTimeBasedGreeting = (user?: { display_name?: string | null; email?: string }): string => { const hour = new Date().getHours(); @@ -411,23 +386,9 @@ const Composer: FC = () => { >(new Map()); const documentPickerRef = useRef(null); const promptPickerRef = useRef(null); - const viewportRef = useRef(null); const { search_space_id, chat_id } = useParams(); const aui = useAui(); - const threadViewportStore = useThreadViewportStore(); const hasAutoFocusedRef = useRef(false); - const submitCleanupRef = useRef<(() => void) | null>(null); - - useEffect(() => { - return () => { - submitCleanupRef.current?.(); - }; - }, []); - - // Store viewport element reference on mount - useEffect(() => { - viewportRef.current = document.querySelector(".aui-thread-viewport"); - }, []); const electronAPI = useElectronAPI(); const [clipboardInitialText, setClipboardInitialText] = useState(); @@ -626,7 +587,6 @@ const Composer: FC = () => { [showDocumentPopover, showPromptPicker] ); - // Submit message (blocked during streaming, document picker open, or AI responding to another user) const handleSubmit = useCallback(() => { if (isThreadRunning || isBlockedByOtherUser) return; if (showDocumentPopover || showPromptPicker) return; @@ -638,50 +598,9 @@ const Composer: FC = () => { setClipboardInitialText(undefined); } - const viewportEl = viewportRef.current; - const heightBefore = viewportEl?.scrollHeight ?? 0; - aui.composer().send(); editorRef.current?.clear(); setMentionedDocuments([]); - - // With turnAnchor="top", ViewportSlack adds min-height to the last - // assistant message so that scrolling-to-bottom actually positions the - // user message at the TOP of the viewport. That slack height is - // calculated asynchronously (ResizeObserver → style → layout). - // Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes. - const scrollToBottom = () => - threadViewportStore.getState().scrollToBottom({ behavior: "instant" }); - - let lastHeight = heightBefore; - let frames = 0; - let cancelled = false; - const POLL_FRAMES = 30; - - const pollAndScroll = () => { - if (cancelled) return; - const el = viewportRef.current; - if (el) { - const h = el.scrollHeight; - if (h !== lastHeight) { - lastHeight = h; - scrollToBottom(); - } - } - if (++frames < POLL_FRAMES) { - requestAnimationFrame(pollAndScroll); - } - }; - requestAnimationFrame(pollAndScroll); - - const t1 = setTimeout(scrollToBottom, 100); - const t2 = setTimeout(scrollToBottom, 300); - - submitCleanupRef.current = () => { - cancelled = true; - clearTimeout(t1); - clearTimeout(t2); - }; }, [ showDocumentPopover, showPromptPicker, @@ -690,7 +609,6 @@ const Composer: FC = () => { clipboardInitialText, aui, setMentionedDocuments, - threadViewportStore, ]); const handleDocumentRemove = useCallback( diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index 66e2ebd4a..cf42cf398 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -13,6 +13,7 @@ import { isDoomLoopInterrupt, } from "@/components/tool-ui/doom-loop-approval"; import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; +import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; import { AlertDialog, AlertDialogAction, @@ -475,7 +476,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { {(argsText || isRunning) && (

Inputs

-
+ {argsText ? (
 											{argsText}
@@ -489,7 +490,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
 											Waiting for input…
 										

)} -
+
)} {!isCancelled && result !== undefined && ( @@ -497,11 +498,11 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {

Result

-
+
 											{typeof result === "string" ? result : serializedResult}
 										
-
+
)} diff --git a/surfsense_web/components/free-chat/free-thread.tsx b/surfsense_web/components/free-chat/free-thread.tsx index bd237004a..933847b2b 100644 --- a/surfsense_web/components/free-chat/free-thread.tsx +++ b/surfsense_web/components/free-chat/free-thread.tsx @@ -1,11 +1,10 @@ "use client"; import { AuiIf, ThreadPrimitive } from "@assistant-ui/react"; -import { ArrowDownIcon } from "lucide-react"; import type { FC } from "react"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; +import { ChatViewport } from "@/components/assistant-ui/chat-viewport"; import { EditComposer } from "@/components/assistant-ui/edit-composer"; -import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { UserMessage } from "@/components/assistant-ui/user-message"; import { FreeComposer } from "./free-composer"; @@ -24,20 +23,6 @@ const FreeThreadWelcome: FC = () => { ); }; -const ThreadScrollToBottom: FC = () => { - return ( - - - - - - ); -}; - export const FreeThread: FC = () => { return ( { ["--thread-max-width" as string]: "44rem", }} > - !thread.isEmpty}> + + + } > thread.isEmpty}> @@ -62,21 +49,7 @@ export const FreeThread: FC = () => { AssistantMessage, }} /> - - !thread.isEmpty}> -
- - - - - !thread.isEmpty}> - - - - + ); }; diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index 22e914988..de91b4451 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -45,16 +45,17 @@ export const PublicThread: FC = ({ footer }) => { ["--thread-max-width" as string]: "44rem", }} > - + - - {/* Spacer to ensure footer doesn't overlap last message */} -
{footer && ( From 833b4dd441d0e8053bd2399076fedcf067917617 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:10:21 +0530 Subject: [PATCH 07/13] refactor(chat): simplify ChatViewport and footer structure for improved readability and maintainability --- .../components/assistant-ui/chat-viewport.tsx | 26 ++++++++++--------- .../components/assistant-ui/thread.tsx | 12 +++------ .../components/public-chat/public-thread.tsx | 2 +- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index f91a8916a..a1534df01 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -23,22 +23,24 @@ export interface ChatViewportProps { } export const ChatViewport: FC = ({ children, footer }) => ( - - {children} + <> + + {children} + {footer ? ( - {footer} - +
) : null} -
+ ); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 1d24a2a39..6c02a1efa 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -113,14 +113,10 @@ const ThreadContent: FC = () => { > - !thread.isEmpty}> - - - !thread.isEmpty}> - - - + !thread.isEmpty}> + + + } > thread.isEmpty}> diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index de91b4451..750b7410e 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -59,7 +59,7 @@ export const PublicThread: FC = ({ footer }) => {
{footer && ( -
+
{footer}
)} From 7b549f84445ef158b97d0270143a88c623d89ab7 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:38:21 +0530 Subject: [PATCH 08/13] refactor(chat): enhance ChatViewport with auto-scroll and top fade effect for improved user experience --- .../components/assistant-ui/chat-viewport.tsx | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index a1534df01..d3d664ace 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -23,24 +23,30 @@ export interface ChatViewportProps { } export const ChatViewport: FC = ({ children, footer }) => ( - <> - - {children} - + +
+ {children} {footer ? ( -
- - {footer} -
+
+ + {footer} +
+ ) : null} - + ); From 511f4fde6440378a111fb7bdc3f84cbf4b9c85c1 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:40:14 +0530 Subject: [PATCH 09/13] refactor(chat): update ChatViewport className for improved scroll behavior consistency --- surfsense_web/components/assistant-ui/chat-viewport.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index d3d664ace..f7f1ac188 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -29,7 +29,7 @@ export const ChatViewport: FC = ({ children, footer }) => ( scrollToBottomOnRunStart scrollToBottomOnInitialize scrollToBottomOnThreadSwitch - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 [scroll-behavior:smooth]" + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth" style={{ scrollbarGutter: "stable" }} >
Date: Fri, 1 May 2026 04:02:24 +0530 Subject: [PATCH 10/13] refactor(chat): enhance UserMessage component with mention parsing and segment rendering for improved message display --- .../components/assistant-ui/chat-viewport.tsx | 2 +- .../components/assistant-ui/user-message.tsx | 121 ++++++------------ .../lib/chat/parse-mention-segments.ts | 54 ++++++++ 3 files changed, 97 insertions(+), 80 deletions(-) create mode 100644 surfsense_web/lib/chat/parse-mention-segments.ts diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index f7f1ac188..c0684407e 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -39,7 +39,7 @@ export const ChatViewport: FC = ({ children, footer }) => ( {children} {footer ? (
diff --git a/surfsense_web/components/assistant-ui/user-message.tsx b/surfsense_web/components/assistant-ui/user-message.tsx index fb7212119..145ac2d7e 100644 --- a/surfsense_web/components/assistant-ui/user-message.tsx +++ b/surfsense_web/components/assistant-ui/user-message.tsx @@ -1,4 +1,10 @@ -import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react"; +import { + ActionBarPrimitive, + AuiIf, + MessagePrimitive, + useAuiState, + useMessagePartText, +} from "@assistant-ui/react"; import { useAtomValue } from "jotai"; import { CheckIcon, CopyIcon, Pencil } from "lucide-react"; import Image from "next/image"; @@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; +import { parseMentionSegments } from "@/lib/chat/parse-mention-segments"; interface AuthorMetadata { displayName: string | null; @@ -47,23 +55,40 @@ const UserAvatar: FC = ({ displayName, avatarUrl }) => { ); }; -export const UserMessage: FC = () => { +const UserTextPart: FC = () => { const messageId = useAuiState(({ message }) => message?.id); - const messageText = useAuiState(({ message }) => - (message?.content ?? []) - .map((part) => - typeof part === "object" && - part !== null && - "type" in part && - (part as { type?: string }).type === "text" && - "text" in part - ? String((part as { text?: string }).text ?? "") - : "" - ) - .join("") - ); + const part = useMessagePartText(); + const text = (part as { text?: string }).text ?? ""; const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); - const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; + const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? []; + + const segments = parseMentionSegments(text, mentionedDocs); + + return ( +

+ {segments.map((segment) => + segment.type === "text" ? ( + {segment.value} + ) : ( + + + {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} + + {segment.doc.title} + + ) + )} +

+ ); +}; + +const userMessageParts = { Text: UserTextPart }; + +export const UserMessage: FC = () => { const metadata = useAuiState(({ message }) => message?.metadata); const author = metadata?.custom?.author as AuthorMetadata | undefined; const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE"; @@ -78,11 +103,7 @@ export const UserMessage: FC = () => {
- {mentionedDocs && mentionedDocs.length > 0 ? ( - - ) : ( - - )} +
@@ -99,64 +120,6 @@ export const UserMessage: FC = () => { ); }; -const UserMessageWithMentionChips: FC<{ - text: string; - mentionedDocs: { id: number; title: string; document_type: string }[]; -}> = ({ text, mentionedDocs }) => { - type Segment = - | { type: "text"; value: string; start: number } - | { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number }; - - const tokens = mentionedDocs - .map((doc) => ({ doc, token: `@${doc.title}` })) - .sort((a, b) => b.token.length - a.token.length); - - const segments: Segment[] = []; - let i = 0; - let buffer = ""; - let bufferStart = 0; - while (i < text.length) { - const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i)); - if (tokenMatch) { - if (buffer) { - segments.push({ type: "text", value: buffer, start: bufferStart }); - buffer = ""; - } - segments.push({ type: "mention", doc: tokenMatch.doc, start: i }); - i += tokenMatch.token.length; - bufferStart = i; - continue; - } - if (!buffer) bufferStart = i; - buffer += text[i]; - i += 1; - } - if (buffer) { - segments.push({ type: "text", value: buffer, start: bufferStart }); - } - - return ( - - {segments.map((segment) => - segment.type === "text" ? ( - {segment.value} - ) : ( - - - {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} - - {segment.doc.title} - - ) - )} - - ); -}; - const UserActionBar: FC = () => { const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); diff --git a/surfsense_web/lib/chat/parse-mention-segments.ts b/surfsense_web/lib/chat/parse-mention-segments.ts new file mode 100644 index 000000000..b9cf59792 --- /dev/null +++ b/surfsense_web/lib/chat/parse-mention-segments.ts @@ -0,0 +1,54 @@ +import type { MentionedDocumentInfo } from "@/atoms/chat/mentioned-documents.atom"; + +export type MentionSegment = + | { type: "text"; value: string; start: number } + | { type: "mention"; doc: MentionedDocumentInfo; start: number }; + +/** + * Tokenizes a user message into text and `@mention` segments. + * + * Pure: no React, no DOM, no side effects. Safe to unit-test and reuse. + * + * Mentions are matched greedily by longest title first so that a longer title + * (e.g. `@Project Roadmap`) is never shadowed by a shorter prefix + * (e.g. `@Project`). + */ +export function parseMentionSegments( + text: string, + docs: ReadonlyArray +): MentionSegment[] { + if (text.length === 0) return []; + if (docs.length === 0) return [{ type: "text", value: text, start: 0 }]; + + const tokens = docs + .map((doc) => ({ doc, token: `@${doc.title}` })) + .sort((a, b) => b.token.length - a.token.length); + + const segments: MentionSegment[] = []; + let i = 0; + let buffer = ""; + let bufferStart = 0; + + while (i < text.length) { + const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i)); + if (tokenMatch) { + if (buffer) { + segments.push({ type: "text", value: buffer, start: bufferStart }); + buffer = ""; + } + segments.push({ type: "mention", doc: tokenMatch.doc, start: i }); + i += tokenMatch.token.length; + bufferStart = i; + continue; + } + if (!buffer) bufferStart = i; + buffer += text[i]; + i += 1; + } + + if (buffer) { + segments.push({ type: "text", value: buffer, start: bufferStart }); + } + + return segments; +} From 0883ac88fb54653223ff2477724d372531fa1301 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:23:59 +0530 Subject: [PATCH 11/13] refactor(chat): enhance InlineMentionEditor with improved mention handling and text processing for better user interaction --- .../assistant-ui/inline-mention-editor.tsx | 1078 ++++++----------- 1 file changed, 391 insertions(+), 687 deletions(-) diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index 05277f508..d92348080 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -1,26 +1,13 @@ "use client"; -import { X } from "lucide-react"; -import type { ReactElement } from "react"; -import { - createElement, - forwardRef, - useCallback, - useEffect, - useImperativeHandle, - useRef, - useState, -} from "react"; -import { renderToStaticMarkup } from "react-dom/server"; +import { type FC, forwardRef, useCallback, useImperativeHandle, useMemo, useRef } from "react"; +import { Plate, PlateContent, ParagraphPlugin, createPlatePlugin, usePlateEditor } from "platejs/react"; +import type { PlateElementProps } from "platejs/react"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { Document } from "@/contracts/types/document.types"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { cn } from "@/lib/utils"; -function renderElementToHTML(element: ReactElement): string { - return renderToStaticMarkup(element); -} - export interface MentionedDocument { id: number; title: string; @@ -61,38 +48,174 @@ interface InlineMentionEditorProps { initialText?: string; } -// Unique data attribute to identify chip elements -const CHIP_DATA_ATTR = "data-mention-chip"; -const CHIP_ID_ATTR = "data-mention-id"; -const CHIP_DOCTYPE_ATTR = "data-mention-doctype"; -const CHIP_STATUS_ATTR = "data-mention-status"; +type MentionStatusKind = "pending" | "processing" | "ready" | "failed"; +type ComposerTextNode = { text: string }; +type MentionElementNode = { + type: "mention"; + id: number; + title: string; + document_type?: string; + statusLabel?: string | null; + statusKind?: MentionStatusKind; + children: [{ text: "" }]; +}; +type ComposerNode = ComposerTextNode | MentionElementNode; +type ComposerParagraph = { type: "p"; children: ComposerNode[] }; +type ComposerValue = ComposerParagraph[]; + +const MENTION_TYPE = "mention"; +const MENTION_CHIP_CLASSNAME = + "inline-flex h-5 items-center gap-1 mx-0.5 rounded bg-primary/10 px-1 text-xs font-bold text-primary/60 select-none align-middle leading-none"; +const MENTION_CHIP_ICON_CLASSNAME = "flex items-center text-muted-foreground leading-none"; +const MENTION_CHIP_TITLE_CLASSNAME = "max-w-[120px] truncate leading-none"; +const COMPOSER_TEXT_METRICS_CLASSNAME = "text-sm leading-6"; + +const EMPTY_VALUE: ComposerValue = [{ type: "p", children: [{ text: "" }] }]; + +const MentionElement: FC> = ({ attributes, children, element }) => { + const statusClass = + element.statusKind === "failed" + ? "text-destructive" + : element.statusKind === "ready" + ? "text-emerald-700" + : "text-amber-700"; -/** - * Type guard to check if a node is a chip element - */ -function isChipElement(node: Node | null): node is HTMLSpanElement { return ( - node !== null && - node.nodeType === Node.ELEMENT_NODE && - (node as Element).hasAttribute(CHIP_DATA_ATTR) + + + + {getConnectorIcon(element.document_type ?? "UNKNOWN", "h-3 w-3")} + + + {element.title} + + {element.statusLabel ? ( + + {element.statusLabel} + + ) : null} + + {children} + ); +}; + +const MentionPlugin = createPlatePlugin({ + key: MENTION_TYPE, + node: { + isElement: true, + isInline: true, + isVoid: true, + type: MENTION_TYPE, + component: MentionElement, + }, +}); + +function isMentionNode(node: ComposerNode): node is MentionElementNode { + return typeof node === "object" && "type" in node && node.type === MENTION_TYPE; } -/** - * Safely parse chip ID from element attribute - */ -function getChipId(element: Element): number | null { - const idStr = element.getAttribute(CHIP_ID_ATTR); - if (!idStr) return null; - const id = parseInt(idStr, 10); - return Number.isNaN(id) ? null : id; +function getTextNode(node: ComposerNode): ComposerTextNode | null { + if (typeof node === "object" && "text" in node && typeof node.text === "string") return node; + return null; } -/** - * Get chip document type from element attribute - */ -function getChipDocType(element: Element): string { - return element.getAttribute(CHIP_DOCTYPE_ATTR) ?? "UNKNOWN"; +function toValueFromText(text: string): ComposerValue { + const lines = text.split("\n"); + if (lines.length === 0) return EMPTY_VALUE; + return lines.map((line) => ({ type: "p", children: [{ text: line }] })) as ComposerValue; +} + +function getPlainText(value: ComposerValue): string { + const lines = value.map((block) => + block.children + .map((node) => { + if (isMentionNode(node)) return `@${node.title}`; + return getTextNode(node)?.text ?? ""; + }) + .join("") + ); + return lines.join("\n").trim(); +} + +function getMentionedDocuments(value: ComposerValue): MentionedDocument[] { + const map = new Map(); + for (const block of value) { + for (const node of block.children) { + if (!isMentionNode(node)) continue; + const doc: MentionedDocument = { + id: node.id, + title: node.title, + document_type: node.document_type, + }; + map.set(getMentionDocKey(doc), doc); + } + } + return Array.from(map.values()); +} + +type EditorSelection = { + anchor: { path: number[]; offset: number }; + focus: { path: number[]; offset: number }; +} | null; + +function getCursorTextContext(value: ComposerValue, selection: EditorSelection) { + if (!selection || !selection.anchor || !selection.focus) return null; + if ( + selection.anchor.path.length < 2 || + selection.focus.path.length < 2 || + selection.anchor.path[0] !== selection.focus.path[0] || + selection.anchor.path[1] !== selection.focus.path[1] + ) { + return null; + } + + const block = value[selection.anchor.path[0]]; + if (!block) return null; + const child = block.children[selection.anchor.path[1]]; + const textNode = getTextNode(child); + if (!textNode) return null; + + return { + blockIndex: selection.anchor.path[0], + childIndex: selection.anchor.path[1], + text: textNode.text, + cursor: selection.anchor.offset, + }; +} + +function scanActiveTrigger(text: string, cursor: number) { + let wordStart = 0; + for (let i = cursor - 1; i >= 0; i--) { + if (text[i] === " " || text[i] === "\n") { + wordStart = i + 1; + break; + } + } + + let triggerChar: "@" | "/" | null = null; + let triggerIndex = -1; + for (let i = wordStart; i < cursor; i++) { + if (text[i] === "@" || text[i] === "/") { + triggerChar = text[i] as "@" | "/"; + triggerIndex = i; + break; + } + } + if (!triggerChar || triggerIndex === -1) return null; + + const query = text.slice(triggerIndex + 1, cursor); + if (query.startsWith(" ")) return null; + if ( + triggerChar === "/" && + triggerIndex > 0 && + text[triggerIndex - 1] !== " " && + text[triggerIndex - 1] !== "\n" + ) { + return null; + } + + return { triggerChar, query }; } export const InlineMentionEditor = forwardRef( @@ -113,393 +236,159 @@ export const InlineMentionEditor = forwardRef { - const editorRef = useRef(null); - const [isEmpty, setIsEmpty] = useState(true); - const [mentionedDocs, setMentionedDocs] = useState>( - () => new Map() - ); - const isComposingRef = useRef(false); - const lastSelectionRangeRef = useRef(null); - const isRangeInsideEditor = useCallback((range: Range | null): range is Range => { - if (!range || !editorRef.current) return false; - return ( - editorRef.current.contains(range.startContainer) && - editorRef.current.contains(range.endContainer) - ); - }, []); - const isSelectionInsideEditor = useCallback( - (selection: Selection | null): selection is Selection => { - if (!selection || selection.rangeCount === 0 || !editorRef.current) return false; - const range = selection.getRangeAt(0); - return isRangeInsideEditor(range); - }, - [isRangeInsideEditor] - ); + const editableRef = useRef(null); + const editor = usePlateEditor({ + readOnly: disabled, + plugins: [ParagraphPlugin, MentionPlugin], + value: initialText ? toValueFromText(initialText) : EMPTY_VALUE, + }); - const rememberSelection = useCallback(() => { - const selection = window.getSelection(); - if (!isSelectionInsideEditor(selection)) return; - lastSelectionRangeRef.current = selection.getRangeAt(0).cloneRange(); - }, [isSelectionInsideEditor]); - - const restoreRememberedSelection = useCallback((): Selection | null => { - const selection = window.getSelection(); - if (!selection) return null; - if (!isRangeInsideEditor(lastSelectionRangeRef.current)) return null; - selection.removeAllRanges(); - selection.addRange(lastSelectionRangeRef.current.cloneRange()); - return selection; - }, [isRangeInsideEditor]); - - useEffect(() => { - const handleSelectionChange = () => { - if (document.activeElement !== editorRef.current) return; - rememberSelection(); - }; - document.addEventListener("selectionchange", handleSelectionChange); - return () => document.removeEventListener("selectionchange", handleSelectionChange); - }, [rememberSelection]); - - useEffect(() => { - if (!initialText || !editorRef.current) return; - editorRef.current.innerText = initialText; - editorRef.current.appendChild(document.createElement("br")); - editorRef.current.appendChild(document.createElement("br")); - setIsEmpty(false); - onChange?.(initialText, []); - editorRef.current.focus(); - const sel = window.getSelection(); - const range = document.createRange(); - range.selectNodeContents(editorRef.current); - range.collapse(false); - sel?.removeAllRanges(); - sel?.addRange(range); - const anchor = document.createElement("span"); - range.insertNode(anchor); - anchor.scrollIntoView({ block: "end" }); - anchor.remove(); - }, [initialText, onChange]); - - // Focus at the end of the editor const focusAtEnd = useCallback(() => { - if (!editorRef.current) return; - editorRef.current.focus(); + const el = editableRef.current; + if (!el) return; + el.focus(); const selection = window.getSelection(); const range = document.createRange(); - range.selectNodeContents(editorRef.current); + range.selectNodeContents(el); range.collapse(false); selection?.removeAllRanges(); selection?.addRange(range); }, []); - // Get plain text content with inline mention tokens for chips. - // This preserves the original query structure sent to the backend/LLM. - const getText = useCallback((): string => { - if (!editorRef.current) return ""; + const getCurrentValue = useCallback(() => (editor.children as ComposerValue) ?? EMPTY_VALUE, [editor]); - const extractText = (node: Node): string => { - if (node.nodeType === Node.TEXT_NODE) { - return node.textContent ?? ""; - } - - if (node.nodeType === Node.ELEMENT_NODE) { - const element = node as Element; - - // Preserve mention chips as inline @title tokens. - if (element.hasAttribute(CHIP_DATA_ATTR)) { - const title = element.querySelector("[data-mention-title='true']")?.textContent?.trim(); - if (title) { - return `@${title}`; - } - return ""; - } - - let result = ""; - for (const child of Array.from(element.childNodes)) { - result += extractText(child); - } - return result; - } - - return ""; - }; - - return extractText(editorRef.current).trim(); - }, []); - - // Get all mentioned documents - const getMentionedDocuments = useCallback((): MentionedDocument[] => { - return Array.from(mentionedDocs.values()); - }, [mentionedDocs]); - - const syncEditorState = useCallback( - (docsOverride?: Map) => { - const docs = docsOverride - ? Array.from(docsOverride.values()) - : Array.from(mentionedDocs.values()); - const text = getText(); - const empty = text.length === 0 && docs.length === 0; - setIsEmpty(empty); + const emitState = useCallback( + (nextValue: ComposerValue) => { + const text = getPlainText(nextValue); + const docs = getMentionedDocuments(nextValue); onChange?.(text, docs); - }, - [getText, mentionedDocs, onChange] - ); - // Create a chip element for a document - const createChipElement = useCallback( - (doc: MentionedDocument): HTMLSpanElement => { - const chip = document.createElement("span"); - chip.setAttribute(CHIP_DATA_ATTR, "true"); - chip.setAttribute(CHIP_ID_ATTR, String(doc.id)); - chip.setAttribute(CHIP_DOCTYPE_ATTR, doc.document_type ?? "UNKNOWN"); - chip.contentEditable = "false"; - chip.className = - "inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none cursor-default"; - chip.style.userSelect = "none"; - chip.style.verticalAlign = "baseline"; - - // Container that swaps between icon and remove button on hover - const iconContainer = document.createElement("span"); - iconContainer.className = "shrink-0 flex items-center size-3 relative"; - - const iconSpan = document.createElement("span"); - iconSpan.className = "flex items-center text-muted-foreground"; - iconSpan.innerHTML = renderElementToHTML( - getConnectorIcon(doc.document_type ?? "UNKNOWN", "h-3 w-3") - ); - - const removeBtn = document.createElement("button"); - removeBtn.type = "button"; - removeBtn.className = - "size-3 items-center justify-center rounded-full text-muted-foreground transition-colors"; - removeBtn.style.display = "none"; - removeBtn.innerHTML = renderElementToHTML( - createElement(X, { className: "h-3 w-3", strokeWidth: 2.5 }) - ); - removeBtn.onclick = (e) => { - e.preventDefault(); - e.stopPropagation(); - chip.remove(); - const docKey = getMentionDocKey(doc); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(docKey); - syncEditorState(next); - return next; - }); - onDocumentRemove?.(doc.id, doc.document_type); - focusAtEnd(); - }; - - const titleSpan = document.createElement("span"); - titleSpan.className = "max-w-[120px] truncate"; - titleSpan.textContent = doc.title; - titleSpan.title = doc.title; - titleSpan.setAttribute("data-mention-title", "true"); - - const statusSpan = document.createElement("span"); - statusSpan.setAttribute(CHIP_STATUS_ATTR, "true"); - statusSpan.className = "text-[10px] font-semibold opacity-80 hidden"; - - const isTouchDevice = window.matchMedia("(hover: none)").matches; - if (isTouchDevice) { - // Mobile: icon on left, title, X on right - chip.appendChild(iconSpan); - chip.appendChild(titleSpan); - chip.appendChild(statusSpan); - removeBtn.style.display = "flex"; - removeBtn.className += " ml-0.5"; - chip.appendChild(removeBtn); - } else { - // Desktop: icon/X swap on hover in the same slot - iconContainer.appendChild(iconSpan); - iconContainer.appendChild(removeBtn); - chip.addEventListener("mouseenter", () => { - iconSpan.style.display = "none"; - removeBtn.style.display = "flex"; - }); - chip.addEventListener("mouseleave", () => { - iconSpan.style.display = ""; - removeBtn.style.display = "none"; - }); - chip.appendChild(iconContainer); - chip.appendChild(titleSpan); - chip.appendChild(statusSpan); + const cursorCtx = getCursorTextContext(nextValue, editor.selection); + if (!cursorCtx) { + onMentionClose?.(); + onActionClose?.(); + return; } - return chip; + const trigger = scanActiveTrigger(cursorCtx.text, cursorCtx.cursor); + if (!trigger) { + onMentionClose?.(); + onActionClose?.(); + return; + } + + if (trigger.triggerChar === "@") { + onMentionTrigger?.(trigger.query); + onActionClose?.(); + return; + } + + onActionTrigger?.(trigger.query); + onMentionClose?.(); }, - [focusAtEnd, onDocumentRemove, syncEditorState] + [editor.selection, onActionClose, onActionTrigger, onChange, onMentionClose, onMentionTrigger] + ); + + const setValue = useCallback( + (nextValue: ComposerValue) => { + const tf = editor.tf as { setValue: (value: ComposerValue) => void }; + tf.setValue(nextValue); + emitState(nextValue); + }, + [editor, emitState] ); - // Insert a document chip at the current cursor position const insertDocumentChip = useCallback( ( doc: Pick, options?: { removeTriggerText?: boolean } ) => { - if (!editorRef.current) return; + if (typeof doc.id !== "number" || typeof doc.title !== "string") return; + const removeTriggerText = options?.removeTriggerText ?? true; - - // Validate required fields for type safety - if (typeof doc.id !== "number" || typeof doc.title !== "string") { - console.warn("[InlineMentionEditor] Invalid document passed to insertDocumentChip:", doc); - return; - } - - const mentionDoc: MentionedDocument = { + const current = getCurrentValue(); + const selection = editor.selection; + const mentionNode: MentionElementNode = { + type: MENTION_TYPE, id: doc.id, title: doc.title, document_type: doc.document_type, + children: [{ text: "" }], }; - // Add to mentioned docs map using unique key - const docKey = getMentionDocKey(doc); - setMentionedDocs((prev) => new Map(prev).set(docKey, mentionDoc)); - const nextDocs = new Map(mentionedDocs); - nextDocs.set(docKey, mentionDoc); - - // Find and remove the @query text - const selection = window.getSelection(); - const hasActiveSelection = isSelectionInsideEditor(selection); - const resolvedSelection = hasActiveSelection ? selection : restoreRememberedSelection(); - if ( - !resolvedSelection || - resolvedSelection.rangeCount === 0 || - !isSelectionInsideEditor(resolvedSelection) - ) { - // No valid in-editor selection: deterministically insert at end. - editorRef.current.focus(); - const endSelection = window.getSelection(); - if (!endSelection) return; - const endRange = document.createRange(); - endRange.selectNodeContents(editorRef.current); - endRange.collapse(false); - endSelection.removeAllRanges(); - endSelection.addRange(endRange); - - const chip = createChipElement(mentionDoc); - endRange.insertNode(chip); - endRange.setStartAfter(chip); - endRange.collapse(true); - const space = document.createTextNode(" "); - endRange.insertNode(space); - endRange.setStartAfter(space); - endRange.collapse(true); - endSelection.removeAllRanges(); - endSelection.addRange(endRange); - - syncEditorState(nextDocs); - rememberSelection(); + const cursorCtx = getCursorTextContext(current, selection); + if (!cursorCtx) { + const lastBlock = current[current.length - 1] ?? { type: "p", children: [{ text: "" }] }; + const appended: ComposerValue = [ + ...current.slice(0, -1), + { + ...lastBlock, + children: [...lastBlock.children, mentionNode, { text: " " }], + }, + ]; + setValue(appended); + requestAnimationFrame(focusAtEnd); return; } - // Find the @ symbol before the cursor and remove it along with any query text - const range = resolvedSelection.getRangeAt(0); - const textNode = range.startContainer; - - if (textNode.nodeType === Node.TEXT_NODE && removeTriggerText) { - const text = textNode.textContent || ""; - const cursorPos = range.startOffset; - - // Find the @ symbol before cursor - let atIndex = -1; - for (let i = cursorPos - 1; i >= 0; i--) { - if (text[i] === "@") { - atIndex = i; - break; - } - } - - if (atIndex !== -1) { - // Remove @query and insert chip - const beforeAt = text.slice(0, atIndex); - const afterCursor = text.slice(cursorPos); - - // Create chip - const chip = createChipElement(mentionDoc); - - // Replace text node content - const parent = textNode.parentNode; - if (parent) { - const beforeNode = document.createTextNode(beforeAt); - const afterNode = document.createTextNode(` ${afterCursor}`); - - parent.insertBefore(beforeNode, textNode); - parent.insertBefore(chip, textNode); - parent.insertBefore(afterNode, textNode); - parent.removeChild(textNode); - - // Set cursor after the chip - const newRange = document.createRange(); - newRange.setStart(afterNode, 1); - newRange.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(newRange); - rememberSelection(); - } - } else { - // No @ found, just insert at cursor - const chip = createChipElement(mentionDoc); - range.insertNode(chip); - range.setStartAfter(chip); - range.collapse(true); - - // Add space after chip - const space = document.createTextNode(" "); - range.insertNode(space); - range.setStartAfter(space); - range.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(range); - rememberSelection(); - } - } else { - // Either explicit non-trigger insertion or no @query present. - const chip = createChipElement(mentionDoc); - range.insertNode(chip); - range.setStartAfter(chip); - range.collapse(true); - const space = document.createTextNode(" "); - range.insertNode(space); - range.setStartAfter(space); - range.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(range); - rememberSelection(); + const block = current[cursorCtx.blockIndex]; + const currentChild = getTextNode(block.children[cursorCtx.childIndex]); + if (!currentChild) { + const children = [...block.children]; + children.splice(cursorCtx.childIndex + 1, 0, mentionNode, { text: " " }); + const next = [...current]; + next[cursorCtx.blockIndex] = { ...block, children }; + setValue(next as ComposerValue); + requestAnimationFrame(focusAtEnd); + return; } - syncEditorState(nextDocs); + const text = currentChild.text; + let removeStart = cursorCtx.cursor; + if (removeTriggerText) { + for (let i = cursorCtx.cursor - 1; i >= 0; i--) { + if (text[i] === "@") { + removeStart = i; + break; + } + if (text[i] === " " || text[i] === "\n") break; + } + } + + const before = text.slice(0, removeStart); + const after = text.slice(cursorCtx.cursor); + const replacement: ComposerNode[] = []; + if (before.length > 0) replacement.push({ text: before }); + replacement.push(mentionNode); + replacement.push({ text: ` ${after}` }); + + const children = [...block.children]; + children.splice(cursorCtx.childIndex, 1, ...replacement); + const next = [...current]; + next[cursorCtx.blockIndex] = { ...block, children }; + setValue(next as ComposerValue); + requestAnimationFrame(focusAtEnd); }, - [ - createChipElement, - isSelectionInsideEditor, - mentionedDocs, - rememberSelection, - restoreRememberedSelection, - syncEditorState, - ] + [editor.selection, focusAtEnd, getCurrentValue, setValue] ); - // Clear the editor - const clear = useCallback(() => { - if (editorRef.current) { - editorRef.current.innerHTML = ""; - const emptyDocs = new Map(); - setMentionedDocs(emptyDocs); - syncEditorState(emptyDocs); - } - }, [syncEditorState]); - - // Replace editor content with plain text and place cursor at end - const setText = useCallback( - (text: string) => { - if (!editorRef.current) return; - editorRef.current.innerText = text; - syncEditorState(); - focusAtEnd(); + const removeDocumentChip = useCallback( + (docId: number, docType?: string) => { + const current = getCurrentValue(); + let changed = false; + const next = current.map((block) => { + const children = block.children.filter((node) => { + if (!isMentionNode(node)) return true; + const match = node.id === docId && (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + if (match) changed = true; + return !match; + }); + return { ...block, children: children.length ? children : [{ text: "" }] }; + }); + if (!changed) return; + setValue(next as ComposerValue); }, - [focusAtEnd, syncEditorState] + [getCurrentValue, setValue] ); const setDocumentChipStatus = useCallback( @@ -507,327 +396,142 @@ export const InlineMentionEditor = forwardRef { - if (!editorRef.current) return; - - const chips = editorRef.current.querySelectorAll( - `span[${CHIP_DATA_ATTR}="true"]` - ); - for (const chip of chips) { - const chipId = getChipId(chip); - const chipType = getChipDocType(chip); - if (chipId !== docId) continue; - if ((docType ?? "UNKNOWN") !== chipType) continue; - - const statusEl = chip.querySelector(`span[${CHIP_STATUS_ATTR}="true"]`); - if (!statusEl) continue; - - if (!statusLabel) { - statusEl.textContent = ""; - statusEl.className = "text-[10px] font-semibold opacity-80 hidden"; - continue; - } - - const statusClass = - statusKind === "failed" - ? "text-destructive" - : statusKind === "processing" - ? "text-amber-700" - : statusKind === "ready" - ? "text-emerald-700" - : "text-amber-700"; - statusEl.textContent = statusLabel; - statusEl.className = `text-[10px] font-semibold opacity-80 ${statusClass}`; - } + const current = getCurrentValue(); + let changed = false; + const next = current.map((block) => ({ + ...block, + children: block.children.map((node) => { + if (!isMentionNode(node)) return node; + const sameType = (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + if (node.id !== docId || !sameType) return node; + changed = true; + return { + ...node, + statusLabel, + statusKind: statusLabel ? statusKind : undefined, + }; + }), + })); + if (!changed) return; + setValue(next as ComposerValue); }, - [] + [getCurrentValue, setValue] ); - const removeDocumentChip = useCallback( - (docId: number, docType?: string) => { - if (!editorRef.current) return; - const chipKey = getMentionDocKey({ id: docId, document_type: docType }); - const chips = editorRef.current.querySelectorAll( - `span[${CHIP_DATA_ATTR}="true"]` - ); - for (const chip of chips) { - if (getChipId(chip) === docId && getChipDocType(chip) === (docType ?? "UNKNOWN")) { - chip.remove(); - break; - } - } - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); + const clear = useCallback(() => { + setValue(EMPTY_VALUE); + }, [setValue]); + + const setText = useCallback( + (text: string) => { + setValue(toValueFromText(text)); + requestAnimationFrame(focusAtEnd); }, - [syncEditorState] + [focusAtEnd, setValue] ); - // Expose methods via ref - useImperativeHandle(ref, () => ({ - focus: () => editorRef.current?.focus(), - clear, - setText, - getText, - getMentionedDocuments, - insertDocumentChip, - removeDocumentChip, - setDocumentChipStatus, - })); + const getText = useCallback(() => getPlainText(getCurrentValue()), [getCurrentValue]); + const getMentionedDocs = useCallback( + () => getMentionedDocuments(getCurrentValue()), + [getCurrentValue] + ); - // Handle input changes - const handleInput = useCallback(() => { - if (!editorRef.current) return; + useImperativeHandle( + ref, + () => ({ + focus: () => editableRef.current?.focus(), + clear, + setText, + getText, + getMentionedDocuments: getMentionedDocs, + insertDocumentChip, + removeDocumentChip, + setDocumentChipStatus, + }), + [clear, getMentionedDocs, getText, insertDocumentChip, removeDocumentChip, setDocumentChipStatus, setText] + ); - const text = getText(); - const empty = text.length === 0 && mentionedDocs.size === 0; - setIsEmpty(empty); - - // Unified trigger scan: find the leftmost @ or / in the current word. - // Whichever trigger was typed first owns the token — the other character - // is treated as part of the query, not as a separate trigger. - const selection = window.getSelection(); - let shouldTriggerMention = false; - let mentionQuery = ""; - let shouldTriggerAction = false; - let actionQuery = ""; - - if (selection && selection.rangeCount > 0) { - const range = selection.getRangeAt(0); - const textNode = range.startContainer; - - if (textNode.nodeType === Node.TEXT_NODE) { - const textContent = textNode.textContent || ""; - const cursorPos = range.startOffset; - - let wordStart = 0; - for (let i = cursorPos - 1; i >= 0; i--) { - if (textContent[i] === " " || textContent[i] === "\n") { - wordStart = i + 1; - break; - } - } - - let triggerChar: "@" | "/" | null = null; - let triggerIndex = -1; - for (let i = wordStart; i < cursorPos; i++) { - if (textContent[i] === "@" || textContent[i] === "/") { - triggerChar = textContent[i] as "@" | "/"; - triggerIndex = i; - break; - } - } - - if (triggerChar === "@" && triggerIndex !== -1) { - const query = textContent.slice(triggerIndex + 1, cursorPos); - if (!query.startsWith(" ")) { - shouldTriggerMention = true; - mentionQuery = query; - } - } else if (triggerChar === "/" && triggerIndex !== -1) { - if ( - triggerIndex === 0 || - textContent[triggerIndex - 1] === " " || - textContent[triggerIndex - 1] === "\n" - ) { - const query = textContent.slice(triggerIndex + 1, cursorPos); - if (!query.startsWith(" ")) { - shouldTriggerAction = true; - actionQuery = query; - } - } - } - } - } - - // If no @ found before cursor, check if text contains @ at all - // If text is empty or doesn't contain @, close the mention - if (!shouldTriggerMention) { - if (text.length === 0 || !text.includes("@")) { - onMentionClose?.(); - } else { - // Text contains @ but not before cursor, close mention - onMentionClose?.(); - } - } else { - onMentionTrigger?.(mentionQuery); - } - - if (!shouldTriggerAction) { - onActionClose?.(); - } else { - onActionTrigger?.(actionQuery); - } - - // Notify parent of change - onChange?.(text, Array.from(mentionedDocs.values())); - rememberSelection(); - }, [ - getText, - mentionedDocs, - onChange, - onMentionTrigger, - onMentionClose, - onActionTrigger, - onActionClose, - rememberSelection, - ]); - - // Handle keydown const handleKeyDown = useCallback( (e: React.KeyboardEvent) => { - // Let parent handle navigation keys when mention popover is open - if (onKeyDown) { - onKeyDown(e); - if (e.defaultPrevented) return; - } + onKeyDown?.(e); + if (e.defaultPrevented) return; - // Handle Enter for submit (without shift) if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); onSubmit?.(); return; } - // Handle backspace on chips - if (e.key === "Backspace") { - const selection = window.getSelection(); - if (selection && selection.rangeCount > 0) { - const range = selection.getRangeAt(0); - if (range.collapsed) { - // Check if cursor is right after a chip - const node = range.startContainer; - const offset = range.startOffset; - - if (node.nodeType === Node.TEXT_NODE && offset === 0) { - // Check previous sibling using type guard - const prevSibling = node.previousSibling; - if (isChipElement(prevSibling)) { - e.preventDefault(); - const chipId = getChipId(prevSibling); - const chipDocType = getChipDocType(prevSibling); - if (chipId !== null) { - prevSibling.remove(); - const chipKey = getMentionDocKey({ - id: chipId, - document_type: chipDocType, - }); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); - // Notify parent that a document was removed - onDocumentRemove?.(chipId, chipDocType); - } - return; - } - // Check if we're about to delete @ at the start - const textContent = node.textContent || ""; - if (textContent.length > 0 && textContent[0] === "@") { - // Will delete @, close mention popover - setTimeout(() => { - onMentionClose?.(); - }, 0); - } - } else if (node.nodeType === Node.TEXT_NODE && offset > 0) { - // Check if we're about to delete @ - const textContent = node.textContent || ""; - if (textContent[offset - 1] === "@") { - // Will delete @, close mention popover - setTimeout(() => { - onMentionClose?.(); - }, 0); - } - } else if (node.nodeType === Node.ELEMENT_NODE && offset > 0) { - // Check if previous child is a chip using type guard - const prevChild = (node as Element).childNodes[offset - 1]; - if (isChipElement(prevChild)) { - e.preventDefault(); - const chipId = getChipId(prevChild); - const chipDocType = getChipDocType(prevChild); - if (chipId !== null) { - prevChild.remove(); - const chipKey = getMentionDocKey({ - id: chipId, - document_type: chipDocType, - }); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); - // Notify parent that a document was removed - onDocumentRemove?.(chipId, chipDocType); - } - } - } - } - } + if (e.key !== "Backspace") return; + const selection = editor.selection; + if (!selection || !selection.anchor || !selection.focus) return; + if ( + selection.anchor.path.length < 2 || + selection.focus.path.length < 2 || + selection.anchor.path[0] !== selection.focus.path[0] + ) { + return; } + if (selection.anchor.offset !== 0 || selection.focus.offset !== 0) return; + + const value = getCurrentValue(); + const block = value[selection.anchor.path[0]]; + if (!block) return; + const childIndex = selection.anchor.path[1]; + if (childIndex <= 0) return; + const prev = block.children[childIndex - 1]; + if (!isMentionNode(prev)) return; + + e.preventDefault(); + removeDocumentChip(prev.id, prev.document_type); + onDocumentRemove?.(prev.id, prev.document_type); }, - [onKeyDown, onSubmit, onDocumentRemove, onMentionClose, syncEditorState] + [ + editor.selection, + getCurrentValue, + onDocumentRemove, + onKeyDown, + onSubmit, + removeDocumentChip, + ] ); - // Handle paste - strip formatting - const handlePaste = useCallback((e: React.ClipboardEvent) => { - e.preventDefault(); - const text = e.clipboardData.getData("text/plain"); - document.execCommand("insertText", false, text); - }, []); - - // Handle composition (for IME input) - const handleCompositionStart = useCallback(() => { - isComposingRef.current = true; - }, []); - - const handleCompositionEnd = useCallback(() => { - isComposingRef.current = false; - handleInput(); - }, [handleInput]); + const editableProps = useMemo( + () => ({ + placeholder, + onPaste: (e: React.ClipboardEvent) => { + e.preventDefault(); + const text = e.clipboardData.getData("text/plain"); + const tf = editor.tf as { insertText: (value: string) => void }; + tf.insertText(text); + }, + onKeyDown: handleKeyDown, + }), + [editor, handleKeyDown, placeholder] + ); return (
- {/* biome-ignore lint/a11y/noStaticElementInteractions: contenteditable mention editor requires a div for inline chips */} -
- {/* Placeholder with fade animation on change */} - {isEmpty && ( - - )} + { + emitState(value as ComposerValue); + }} + > + +
); } From 04da62a5541d446ccb2111dc4caed69f188806cc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:28:24 +0530 Subject: [PATCH 12/13] refactor(chat): improve AssistantMessage component with fixed comment trigger slot and enhanced visibility handling --- .../assistant-ui/assistant-message.tsx | 70 +++++++++++-------- 1 file changed, 39 insertions(+), 31 deletions(-) diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index bfe0434b4..711bb2fe2 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -548,8 +548,10 @@ const AssistantMessageInner: FC = () => {
)} -
- +
+
+ +
); @@ -642,35 +644,41 @@ export const AssistantMessage: FC = () => { className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150" data-role="assistant" > - {/* Comment trigger — right-aligned, just below user query on all screen sizes */} - {showCommentTrigger && ( -
- -
- )} + {/* Fixed trigger slot prevents any vertical reflow when visibility changes */} +
+ +
{/* Desktop floating comment panel — overlays on top of chat content */} {showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && ( From 5826e5264d68595fcf7b0e67c03739109ae05e50 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:39:33 +0530 Subject: [PATCH 13/13] refactor(chat): add TruncatedNameWithTooltip component in model selector --- .../components/new-chat/model-selector.tsx | 93 ++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 9fe9dd8da..1a0f8c5ba 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -236,6 +236,93 @@ interface DisplayItem { isAutoMode: boolean; } +const TruncatedNameWithTooltip: React.FC<{ + text: string; + className?: string; + enableTooltip: boolean; +}> = ({ text, className, enableTooltip }) => { + const textRef = useRef(null); + const openTimerRef = useRef(undefined); + const [isTruncated, setIsTruncated] = useState(false); + const [open, setOpen] = useState(false); + + const recalcTruncation = useCallback(() => { + const el = textRef.current; + if (!el) return; + setIsTruncated(el.scrollWidth > el.clientWidth + 1); + }, []); + + useEffect(() => { + if (!enableTooltip) return; + const el = textRef.current; + if (!el) return; + + const raf = requestAnimationFrame(recalcTruncation); + recalcTruncation(); + + const observer = new ResizeObserver(recalcTruncation); + observer.observe(el); + if (el.parentElement) observer.observe(el.parentElement); + window.addEventListener("resize", recalcTruncation); + + return () => { + cancelAnimationFrame(raf); + observer.disconnect(); + window.removeEventListener("resize", recalcTruncation); + }; + }, [enableTooltip, recalcTruncation]); + + useEffect(() => { + // Recompute when row text changes. + void text; + requestAnimationFrame(recalcTruncation); + }, [text, recalcTruncation]); + + useEffect( + () => () => { + if (openTimerRef.current) window.clearTimeout(openTimerRef.current); + }, + [] + ); + + if (!enableTooltip) { + return ( + + {text} + + ); + } + + const handleOpenChange = (nextOpen: boolean) => { + if (openTimerRef.current) { + window.clearTimeout(openTimerRef.current); + openTimerRef.current = undefined; + } + if (!nextOpen) { + setOpen(false); + return; + } + if (!isTruncated) return; + openTimerRef.current = window.setTimeout(() => { + setOpen(true); + openTimerRef.current = undefined; + }, 220); + }; + + return ( + + + + {text} + + + + {text} + + + ); +}; + // ─── Component ────────────────────────────────────────────────────── interface ModelSelectorProps { @@ -936,7 +1023,11 @@ export function ModelSelector({ {/* Model info */}
- {config.name} + {isAutoMode && (