mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
Merge pull request #1327 from AnishSarkar22/feat/chat-state-unification
refactor(chat): unify streaming state flow and improve chat viewport + mention UX
This commit is contained in:
commit
d335e96ec2
27 changed files with 1953 additions and 1647 deletions
File diff suppressed because it is too large
Load diff
|
|
@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat
|
|||
|
||||
export const resetCurrentThreadAtom = atom(null, (_, set) => {
|
||||
set(currentThreadAtom, initialState);
|
||||
set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null });
|
||||
set(reportPanelAtom, {
|
||||
isOpen: false,
|
||||
reportId: null,
|
||||
title: null,
|
||||
wordCount: null,
|
||||
shareToken: null,
|
||||
contentType: "markdown",
|
||||
});
|
||||
});
|
||||
|
||||
/** Target comment ID to scroll to (from URL navigation or inbox click) */
|
||||
|
|
|
|||
|
|
@ -548,8 +548,10 @@ const AssistantMessageInner: FC = () => {
|
|||
</div>
|
||||
)}
|
||||
|
||||
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 flex items-center gap-2">
|
||||
<AssistantActionBar />
|
||||
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 h-6">
|
||||
<div className="h-full opacity-100 transition-opacity">
|
||||
<AssistantActionBar />
|
||||
</div>
|
||||
</div>
|
||||
</CitationMetadataProvider>
|
||||
);
|
||||
|
|
@ -642,35 +644,41 @@ export const AssistantMessage: FC = () => {
|
|||
className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
|
||||
data-role="assistant"
|
||||
>
|
||||
{/* Comment trigger — right-aligned, just below user query on all screen sizes */}
|
||||
{showCommentTrigger && (
|
||||
<div className="mr-2 mb-1 flex justify-end">
|
||||
<button
|
||||
ref={isDesktop ? commentTriggerRef : undefined}
|
||||
type="button"
|
||||
onClick={
|
||||
isDesktop ? () => setIsInlineOpen((prev) => !prev) : () => setIsSheetOpen(true)
|
||||
}
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
|
||||
isDesktop && isInlineOpen
|
||||
? "bg-primary/10 text-primary"
|
||||
: hasComments
|
||||
? "text-primary hover:bg-primary/10"
|
||||
: "text-muted-foreground hover:text-foreground hover:bg-muted"
|
||||
)}
|
||||
>
|
||||
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
|
||||
{hasComments ? (
|
||||
<span>
|
||||
{commentCount} {commentCount === 1 ? "comment" : "comments"}
|
||||
</span>
|
||||
) : (
|
||||
<span>Add comment</span>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* Fixed trigger slot prevents any vertical reflow when visibility changes */}
|
||||
<div className="mr-2 mb-1 flex h-7 justify-end">
|
||||
<button
|
||||
ref={isDesktop ? commentTriggerRef : undefined}
|
||||
type="button"
|
||||
onClick={
|
||||
showCommentTrigger
|
||||
? isDesktop
|
||||
? () => setIsInlineOpen((prev) => !prev)
|
||||
: () => setIsSheetOpen(true)
|
||||
: undefined
|
||||
}
|
||||
aria-hidden={!showCommentTrigger}
|
||||
tabIndex={showCommentTrigger ? 0 : -1}
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
|
||||
"opacity-0 pointer-events-none",
|
||||
showCommentTrigger && "opacity-100 pointer-events-auto",
|
||||
isDesktop && isInlineOpen
|
||||
? "bg-primary/10 text-primary"
|
||||
: hasComments
|
||||
? "text-primary hover:bg-primary/10"
|
||||
: "text-muted-foreground hover:text-foreground hover:bg-muted"
|
||||
)}
|
||||
>
|
||||
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
|
||||
{hasComments ? (
|
||||
<span>
|
||||
{commentCount} {commentCount === 1 ? "comment" : "comments"}
|
||||
</span>
|
||||
) : (
|
||||
<span>Add comment</span>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Desktop floating comment panel — overlays on top of chat content */}
|
||||
{showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && (
|
||||
|
|
|
|||
52
surfsense_web/components/assistant-ui/chat-viewport.tsx
Normal file
52
surfsense_web/components/assistant-ui/chat-viewport.tsx
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"use client";
|
||||
|
||||
import { ThreadPrimitive } from "@assistant-ui/react";
|
||||
import { ArrowDownIcon } from "lucide-react";
|
||||
import type { FC, ReactNode } from "react";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
|
||||
const ChatScrollToBottom: FC = () => (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
<TooltipIconButton
|
||||
tooltip="Scroll to bottom"
|
||||
variant="outline"
|
||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||
>
|
||||
<ArrowDownIcon />
|
||||
</TooltipIconButton>
|
||||
</ThreadPrimitive.ScrollToBottom>
|
||||
);
|
||||
|
||||
export interface ChatViewportProps {
|
||||
children: ReactNode;
|
||||
footer?: ReactNode;
|
||||
}
|
||||
|
||||
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
|
||||
<ThreadPrimitive.Viewport
|
||||
turnAnchor="top"
|
||||
autoScroll
|
||||
scrollToBottomOnRunStart
|
||||
scrollToBottomOnInitialize
|
||||
scrollToBottomOnThreadSwitch
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
|
||||
style={{ scrollbarGutter: "stable" }}
|
||||
>
|
||||
<div
|
||||
aria-hidden
|
||||
className="aui-chat-viewport-top-fade pointer-events-none sticky top-0 z-10 -mx-4 h-2 shrink-0 bg-gradient-to-b from-main-panel from-20% to-transparent"
|
||||
/>
|
||||
{children}
|
||||
{footer ? (
|
||||
<ThreadPrimitive.ViewportFooter
|
||||
className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 mt-auto flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6"
|
||||
style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }}
|
||||
>
|
||||
<div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible">
|
||||
<ChatScrollToBottom />
|
||||
{footer}
|
||||
</div>
|
||||
</ThreadPrimitive.ViewportFooter>
|
||||
) : null}
|
||||
</ThreadPrimitive.Viewport>
|
||||
);
|
||||
File diff suppressed because it is too large
Load diff
24
surfsense_web/components/assistant-ui/nested-scroll.tsx
Normal file
24
surfsense_web/components/assistant-ui/nested-scroll.tsx
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"use client";
|
||||
|
||||
import { forwardRef, type ComponentPropsWithoutRef, type WheelEvent } from "react";
|
||||
|
||||
export type NestedScrollProps = ComponentPropsWithoutRef<"div">;
|
||||
|
||||
export const NestedScroll = forwardRef<HTMLDivElement, NestedScrollProps>(
|
||||
({ onWheel, ...props }, ref) => {
|
||||
const handleWheel = (event: WheelEvent<HTMLDivElement>) => {
|
||||
const el = event.currentTarget;
|
||||
const canScrollUp = el.scrollTop > 0;
|
||||
const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1;
|
||||
const goingUp = event.deltaY < 0;
|
||||
const goingDown = event.deltaY > 0;
|
||||
if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) {
|
||||
event.stopPropagation();
|
||||
}
|
||||
onWheel?.(event);
|
||||
};
|
||||
return <div ref={ref} onWheel={handleWheel} {...props} />;
|
||||
}
|
||||
);
|
||||
|
||||
NestedScroll.displayName = "NestedScroll";
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
import { ThreadPrimitive } from "@assistant-ui/react";
|
||||
import { ArrowDownIcon } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
|
||||
export const ThreadScrollToBottom: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
<TooltipIconButton
|
||||
tooltip="Scroll to bottom"
|
||||
variant="outline"
|
||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||
>
|
||||
<ArrowDownIcon />
|
||||
</TooltipIconButton>
|
||||
</ThreadPrimitive.ScrollToBottom>
|
||||
);
|
||||
};
|
||||
|
|
@ -5,12 +5,10 @@ import {
|
|||
ThreadPrimitive,
|
||||
useAui,
|
||||
useAuiState,
|
||||
useThreadViewportStore,
|
||||
} from "@assistant-ui/react";
|
||||
import { useAtom, useAtomValue, useSetAtom } from "jotai";
|
||||
import {
|
||||
AlertCircle,
|
||||
ArrowDownIcon,
|
||||
ArrowUpIcon,
|
||||
Camera,
|
||||
ChevronDown,
|
||||
|
|
@ -55,6 +53,7 @@ import {
|
|||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
||||
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
|
||||
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
|
||||
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
|
||||
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
|
||||
import {
|
||||
|
|
@ -112,10 +111,13 @@ const ThreadContent: FC = () => {
|
|||
["--thread-max-width" as string]: "44rem",
|
||||
}}
|
||||
>
|
||||
<ThreadPrimitive.Viewport
|
||||
turnAnchor="top"
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
|
||||
style={{ scrollbarGutter: "stable" }}
|
||||
<ChatViewport
|
||||
footer={
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<PremiumQuotaPinnedAlert />
|
||||
<Composer />
|
||||
</AuiIf>
|
||||
}
|
||||
>
|
||||
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
||||
<ThreadWelcome />
|
||||
|
|
@ -128,24 +130,7 @@ const ThreadContent: FC = () => {
|
|||
AssistantMessage,
|
||||
}}
|
||||
/>
|
||||
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<div className="grow" />
|
||||
</AuiIf>
|
||||
|
||||
<ThreadPrimitive.ViewportFooter
|
||||
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6"
|
||||
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
|
||||
>
|
||||
<ThreadScrollToBottom />
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<PremiumQuotaPinnedAlert />
|
||||
</AuiIf>
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<Composer />
|
||||
</AuiIf>
|
||||
</ThreadPrimitive.ViewportFooter>
|
||||
</ThreadPrimitive.Viewport>
|
||||
</ChatViewport>
|
||||
</ThreadPrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
|
@ -181,20 +166,6 @@ const PremiumQuotaPinnedAlert: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const ThreadScrollToBottom: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
<TooltipIconButton
|
||||
tooltip="Scroll to bottom"
|
||||
variant="outline"
|
||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||
>
|
||||
<ArrowDownIcon />
|
||||
</TooltipIconButton>
|
||||
</ThreadPrimitive.ScrollToBottom>
|
||||
);
|
||||
};
|
||||
|
||||
const getTimeBasedGreeting = (user?: { display_name?: string | null; email?: string }): string => {
|
||||
const hour = new Date().getHours();
|
||||
|
||||
|
|
@ -411,23 +382,9 @@ const Composer: FC = () => {
|
|||
>(new Map());
|
||||
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
|
||||
const promptPickerRef = useRef<PromptPickerRef>(null);
|
||||
const viewportRef = useRef<Element | null>(null);
|
||||
const { search_space_id, chat_id } = useParams();
|
||||
const aui = useAui();
|
||||
const threadViewportStore = useThreadViewportStore();
|
||||
const hasAutoFocusedRef = useRef(false);
|
||||
const submitCleanupRef = useRef<(() => void) | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
submitCleanupRef.current?.();
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Store viewport element reference on mount
|
||||
useEffect(() => {
|
||||
viewportRef.current = document.querySelector(".aui-thread-viewport");
|
||||
}, []);
|
||||
|
||||
const electronAPI = useElectronAPI();
|
||||
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
|
||||
|
|
@ -626,7 +583,6 @@ const Composer: FC = () => {
|
|||
[showDocumentPopover, showPromptPicker]
|
||||
);
|
||||
|
||||
// Submit message (blocked during streaming, document picker open, or AI responding to another user)
|
||||
const handleSubmit = useCallback(() => {
|
||||
if (isThreadRunning || isBlockedByOtherUser) return;
|
||||
if (showDocumentPopover || showPromptPicker) return;
|
||||
|
|
@ -638,50 +594,9 @@ const Composer: FC = () => {
|
|||
setClipboardInitialText(undefined);
|
||||
}
|
||||
|
||||
const viewportEl = viewportRef.current;
|
||||
const heightBefore = viewportEl?.scrollHeight ?? 0;
|
||||
|
||||
aui.composer().send();
|
||||
editorRef.current?.clear();
|
||||
setMentionedDocuments([]);
|
||||
|
||||
// With turnAnchor="top", ViewportSlack adds min-height to the last
|
||||
// assistant message so that scrolling-to-bottom actually positions the
|
||||
// user message at the TOP of the viewport. That slack height is
|
||||
// calculated asynchronously (ResizeObserver → style → layout).
|
||||
// Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes.
|
||||
const scrollToBottom = () =>
|
||||
threadViewportStore.getState().scrollToBottom({ behavior: "instant" });
|
||||
|
||||
let lastHeight = heightBefore;
|
||||
let frames = 0;
|
||||
let cancelled = false;
|
||||
const POLL_FRAMES = 30;
|
||||
|
||||
const pollAndScroll = () => {
|
||||
if (cancelled) return;
|
||||
const el = viewportRef.current;
|
||||
if (el) {
|
||||
const h = el.scrollHeight;
|
||||
if (h !== lastHeight) {
|
||||
lastHeight = h;
|
||||
scrollToBottom();
|
||||
}
|
||||
}
|
||||
if (++frames < POLL_FRAMES) {
|
||||
requestAnimationFrame(pollAndScroll);
|
||||
}
|
||||
};
|
||||
requestAnimationFrame(pollAndScroll);
|
||||
|
||||
const t1 = setTimeout(scrollToBottom, 100);
|
||||
const t2 = setTimeout(scrollToBottom, 300);
|
||||
|
||||
submitCleanupRef.current = () => {
|
||||
cancelled = true;
|
||||
clearTimeout(t1);
|
||||
clearTimeout(t2);
|
||||
};
|
||||
}, [
|
||||
showDocumentPopover,
|
||||
showPromptPicker,
|
||||
|
|
@ -690,7 +605,6 @@ const Composer: FC = () => {
|
|||
clipboardInitialText,
|
||||
aui,
|
||||
setMentionedDocuments,
|
||||
threadViewportStore,
|
||||
]);
|
||||
|
||||
const handleDocumentRemove = useCallback(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import {
|
|||
isDoomLoopInterrupt,
|
||||
} from "@/components/tool-ui/doom-loop-approval";
|
||||
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
||||
import { NestedScroll } from "@/components/assistant-ui/nested-scroll";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
|
|
@ -475,7 +476,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
|||
{(argsText || isRunning) && (
|
||||
<div className="flex flex-col gap-1 min-w-0">
|
||||
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
|
||||
<div className="max-h-48 overflow-auto rounded-md bg-muted/40">
|
||||
<NestedScroll className="max-h-48 overflow-auto rounded-md bg-muted/40">
|
||||
{argsText ? (
|
||||
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||
{argsText}
|
||||
|
|
@ -489,7 +490,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
|||
Waiting for input…
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</NestedScroll>
|
||||
</div>
|
||||
)}
|
||||
{!isCancelled && result !== undefined && (
|
||||
|
|
@ -497,11 +498,11 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
|||
<Separator />
|
||||
<div className="flex flex-col gap-1 min-w-0">
|
||||
<p className="text-xs font-medium text-muted-foreground">Result</p>
|
||||
<div className="max-h-64 overflow-auto rounded-md bg-muted/40">
|
||||
<NestedScroll className="max-h-64 overflow-auto rounded-md bg-muted/40">
|
||||
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||
{typeof result === "string" ? result : serializedResult}
|
||||
</pre>
|
||||
</div>
|
||||
</NestedScroll>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react";
|
||||
import {
|
||||
ActionBarPrimitive,
|
||||
AuiIf,
|
||||
MessagePrimitive,
|
||||
useAuiState,
|
||||
useMessagePartText,
|
||||
} from "@assistant-ui/react";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
|
||||
import Image from "next/image";
|
||||
|
|
@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
|
|||
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||
import { parseMentionSegments } from "@/lib/chat/parse-mention-segments";
|
||||
|
||||
interface AuthorMetadata {
|
||||
displayName: string | null;
|
||||
|
|
@ -47,23 +55,40 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
|
|||
);
|
||||
};
|
||||
|
||||
export const UserMessage: FC = () => {
|
||||
const UserTextPart: FC = () => {
|
||||
const messageId = useAuiState(({ message }) => message?.id);
|
||||
const messageText = useAuiState(({ message }) =>
|
||||
(message?.content ?? [])
|
||||
.map((part) =>
|
||||
typeof part === "object" &&
|
||||
part !== null &&
|
||||
"type" in part &&
|
||||
(part as { type?: string }).type === "text" &&
|
||||
"text" in part
|
||||
? String((part as { text?: string }).text ?? "")
|
||||
: ""
|
||||
)
|
||||
.join("")
|
||||
);
|
||||
const part = useMessagePartText();
|
||||
const text = (part as { text?: string }).text ?? "";
|
||||
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
|
||||
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
|
||||
const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? [];
|
||||
|
||||
const segments = parseMentionSegments(text, mentionedDocs);
|
||||
|
||||
return (
|
||||
<p style={{ whiteSpace: "pre-line" }} className="break-words">
|
||||
{segments.map((segment) =>
|
||||
segment.type === "text" ? (
|
||||
<span key={`txt-${segment.start}`}>{segment.value}</span>
|
||||
) : (
|
||||
<span
|
||||
key={`mention-${getMentionDocKey(segment.doc)}-${segment.start}`}
|
||||
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-middle leading-none"
|
||||
title={segment.doc.title}
|
||||
>
|
||||
<span className="flex items-center text-muted-foreground">
|
||||
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
|
||||
</span>
|
||||
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
|
||||
</span>
|
||||
)
|
||||
)}
|
||||
</p>
|
||||
);
|
||||
};
|
||||
|
||||
const userMessageParts = { Text: UserTextPart };
|
||||
|
||||
export const UserMessage: FC = () => {
|
||||
const metadata = useAuiState(({ message }) => message?.metadata);
|
||||
const author = metadata?.custom?.author as AuthorMetadata | undefined;
|
||||
const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE";
|
||||
|
|
@ -78,11 +103,7 @@ export const UserMessage: FC = () => {
|
|||
<div className="aui-user-message-content-wrapper flex items-end gap-2">
|
||||
<div className="relative flex-1 min-w-0">
|
||||
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
|
||||
{mentionedDocs && mentionedDocs.length > 0 ? (
|
||||
<UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} />
|
||||
) : (
|
||||
<MessagePrimitive.Parts />
|
||||
)}
|
||||
<MessagePrimitive.Parts components={userMessageParts} />
|
||||
</div>
|
||||
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
|
||||
<UserActionBar />
|
||||
|
|
@ -99,64 +120,6 @@ export const UserMessage: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const UserMessageWithMentionChips: FC<{
|
||||
text: string;
|
||||
mentionedDocs: { id: number; title: string; document_type: string }[];
|
||||
}> = ({ text, mentionedDocs }) => {
|
||||
type Segment =
|
||||
| { type: "text"; value: string; start: number }
|
||||
| { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number };
|
||||
|
||||
const tokens = mentionedDocs
|
||||
.map((doc) => ({ doc, token: `@${doc.title}` }))
|
||||
.sort((a, b) => b.token.length - a.token.length);
|
||||
|
||||
const segments: Segment[] = [];
|
||||
let i = 0;
|
||||
let buffer = "";
|
||||
let bufferStart = 0;
|
||||
while (i < text.length) {
|
||||
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
|
||||
if (tokenMatch) {
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
buffer = "";
|
||||
}
|
||||
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
|
||||
i += tokenMatch.token.length;
|
||||
bufferStart = i;
|
||||
continue;
|
||||
}
|
||||
if (!buffer) bufferStart = i;
|
||||
buffer += text[i];
|
||||
i += 1;
|
||||
}
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
}
|
||||
|
||||
return (
|
||||
<span className="whitespace-pre-wrap break-words">
|
||||
{segments.map((segment) =>
|
||||
segment.type === "text" ? (
|
||||
<span key={`txt-${segment.start}`}>{segment.value}</span>
|
||||
) : (
|
||||
<span
|
||||
key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`}
|
||||
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline"
|
||||
title={segment.doc.title}
|
||||
>
|
||||
<span className="flex items-center text-muted-foreground">
|
||||
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
|
||||
</span>
|
||||
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
|
||||
</span>
|
||||
)
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
};
|
||||
|
||||
const UserActionBar: FC = () => {
|
||||
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
"use client";
|
||||
|
||||
import { AuiIf, ThreadPrimitive } from "@assistant-ui/react";
|
||||
import { ArrowDownIcon } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
||||
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
|
||||
import { EditComposer } from "@/components/assistant-ui/edit-composer";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { UserMessage } from "@/components/assistant-ui/user-message";
|
||||
import { FreeComposer } from "./free-composer";
|
||||
|
||||
|
|
@ -24,20 +23,6 @@ const FreeThreadWelcome: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const ThreadScrollToBottom: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
<TooltipIconButton
|
||||
tooltip="Scroll to bottom"
|
||||
variant="outline"
|
||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||
>
|
||||
<ArrowDownIcon />
|
||||
</TooltipIconButton>
|
||||
</ThreadPrimitive.ScrollToBottom>
|
||||
);
|
||||
};
|
||||
|
||||
export const FreeThread: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.Root
|
||||
|
|
@ -46,10 +31,12 @@ export const FreeThread: FC = () => {
|
|||
["--thread-max-width" as string]: "44rem",
|
||||
}}
|
||||
>
|
||||
<ThreadPrimitive.Viewport
|
||||
turnAnchor="top"
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
|
||||
style={{ scrollbarGutter: "stable" }}
|
||||
<ChatViewport
|
||||
footer={
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<FreeComposer />
|
||||
</AuiIf>
|
||||
}
|
||||
>
|
||||
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
||||
<FreeThreadWelcome />
|
||||
|
|
@ -62,21 +49,7 @@ export const FreeThread: FC = () => {
|
|||
AssistantMessage,
|
||||
}}
|
||||
/>
|
||||
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<div className="grow" />
|
||||
</AuiIf>
|
||||
|
||||
<ThreadPrimitive.ViewportFooter
|
||||
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6"
|
||||
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
|
||||
>
|
||||
<ThreadScrollToBottom />
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<FreeComposer />
|
||||
</AuiIf>
|
||||
</ThreadPrimitive.ViewportFooter>
|
||||
</ThreadPrimitive.Viewport>
|
||||
</ChatViewport>
|
||||
</ThreadPrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -236,6 +236,93 @@ interface DisplayItem {
|
|||
isAutoMode: boolean;
|
||||
}
|
||||
|
||||
const TruncatedNameWithTooltip: React.FC<{
|
||||
text: string;
|
||||
className?: string;
|
||||
enableTooltip: boolean;
|
||||
}> = ({ text, className, enableTooltip }) => {
|
||||
const textRef = useRef<HTMLSpanElement>(null);
|
||||
const openTimerRef = useRef<number | undefined>(undefined);
|
||||
const [isTruncated, setIsTruncated] = useState(false);
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
const recalcTruncation = useCallback(() => {
|
||||
const el = textRef.current;
|
||||
if (!el) return;
|
||||
setIsTruncated(el.scrollWidth > el.clientWidth + 1);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!enableTooltip) return;
|
||||
const el = textRef.current;
|
||||
if (!el) return;
|
||||
|
||||
const raf = requestAnimationFrame(recalcTruncation);
|
||||
recalcTruncation();
|
||||
|
||||
const observer = new ResizeObserver(recalcTruncation);
|
||||
observer.observe(el);
|
||||
if (el.parentElement) observer.observe(el.parentElement);
|
||||
window.addEventListener("resize", recalcTruncation);
|
||||
|
||||
return () => {
|
||||
cancelAnimationFrame(raf);
|
||||
observer.disconnect();
|
||||
window.removeEventListener("resize", recalcTruncation);
|
||||
};
|
||||
}, [enableTooltip, recalcTruncation]);
|
||||
|
||||
useEffect(() => {
|
||||
// Recompute when row text changes.
|
||||
void text;
|
||||
requestAnimationFrame(recalcTruncation);
|
||||
}, [text, recalcTruncation]);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
if (openTimerRef.current) window.clearTimeout(openTimerRef.current);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
if (!enableTooltip) {
|
||||
return (
|
||||
<span ref={textRef} className={cn("block max-w-full", className)}>
|
||||
{text}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
const handleOpenChange = (nextOpen: boolean) => {
|
||||
if (openTimerRef.current) {
|
||||
window.clearTimeout(openTimerRef.current);
|
||||
openTimerRef.current = undefined;
|
||||
}
|
||||
if (!nextOpen) {
|
||||
setOpen(false);
|
||||
return;
|
||||
}
|
||||
if (!isTruncated) return;
|
||||
openTimerRef.current = window.setTimeout(() => {
|
||||
setOpen(true);
|
||||
openTimerRef.current = undefined;
|
||||
}, 220);
|
||||
};
|
||||
|
||||
return (
|
||||
<Tooltip open={open} onOpenChange={handleOpenChange}>
|
||||
<TooltipTrigger asChild>
|
||||
<span ref={textRef} className={cn("block max-w-full", className)}>
|
||||
{text}
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" align="start">
|
||||
{text}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
// ─── Component ──────────────────────────────────────────────────────
|
||||
|
||||
interface ModelSelectorProps {
|
||||
|
|
@ -936,7 +1023,11 @@ export function ModelSelector({
|
|||
{/* Model info */}
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<span className="font-medium text-sm truncate">{config.name}</span>
|
||||
<TruncatedNameWithTooltip
|
||||
text={config.name}
|
||||
enableTooltip={!isMobile}
|
||||
className="font-medium text-sm truncate"
|
||||
/>
|
||||
{isAutoMode && (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
|
|
|
|||
|
|
@ -45,20 +45,21 @@ export const PublicThread: FC<PublicThreadProps> = ({ footer }) => {
|
|||
["--thread-max-width" as string]: "44rem",
|
||||
}}
|
||||
>
|
||||
<ThreadPrimitive.Viewport className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4">
|
||||
<ThreadPrimitive.Viewport
|
||||
scrollToBottomOnInitialize
|
||||
scrollToBottomOnThreadSwitch
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4 pb-6"
|
||||
>
|
||||
<ThreadPrimitive.Messages
|
||||
components={{
|
||||
UserMessage: PublicUserMessage,
|
||||
AssistantMessage: PublicAssistantMessage,
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* Spacer to ensure footer doesn't overlap last message */}
|
||||
<div className="h-24" />
|
||||
</ThreadPrimitive.Viewport>
|
||||
|
||||
{footer && (
|
||||
<div className="sticky bottom-0 z-20 border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60">
|
||||
<div className="border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60">
|
||||
{footer}
|
||||
</div>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
|
|||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "TURN_CANCELLING"
|
||||
) {
|
||||
return {
|
||||
kind: "thread_busy",
|
||||
channel: "toast",
|
||||
severity: "info",
|
||||
telemetryEvent: "chat_blocked",
|
||||
isExpected: true,
|
||||
userMessage: "A previous response is still stopping. Please try again in a moment.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "TURN_CANCELLING",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "THREAD_BUSY"
|
||||
) {
|
||||
|
|
@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
|
|||
severity: "warn",
|
||||
telemetryEvent: "chat_blocked",
|
||||
isExpected: true,
|
||||
userMessage: "A previous response is still stopping. Please try again in a moment.",
|
||||
userMessage: "Another response is still finishing for this thread. Please try again in a moment.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "THREAD_BUSY",
|
||||
details: { flow: input.flow },
|
||||
|
|
|
|||
114
surfsense_web/lib/chat/chat-request-errors.ts
Normal file
114
surfsense_web/lib/chat/chat-request-errors.ts
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
export async function toHttpResponseError(
|
||||
response: Response
|
||||
): Promise<Error & { errorCode?: string; retryAfterMs?: number }> {
|
||||
const statusDefaultCode =
|
||||
response.status === 409
|
||||
? "THREAD_BUSY"
|
||||
: response.status === 429
|
||||
? "RATE_LIMITED"
|
||||
: response.status === 401 || response.status === 403
|
||||
? "AUTH_EXPIRED"
|
||||
: "SERVER_ERROR";
|
||||
|
||||
let rawBody = "";
|
||||
try {
|
||||
rawBody = await response.text();
|
||||
} catch {
|
||||
// noop
|
||||
}
|
||||
|
||||
let parsedBody: Record<string, unknown> | null = null;
|
||||
if (rawBody) {
|
||||
try {
|
||||
const parsed = JSON.parse(rawBody);
|
||||
if (typeof parsed === "object" && parsed !== null) {
|
||||
parsedBody = parsed as Record<string, unknown>;
|
||||
}
|
||||
} catch {
|
||||
// noop
|
||||
}
|
||||
}
|
||||
|
||||
const detail = parsedBody?.detail;
|
||||
const detailObject =
|
||||
typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null;
|
||||
const detailMessage = typeof detail === "string" ? detail : undefined;
|
||||
const topLevelMessage =
|
||||
typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined;
|
||||
const detailNestedMessage =
|
||||
typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined;
|
||||
|
||||
const topLevelCode =
|
||||
typeof parsedBody?.errorCode === "string"
|
||||
? parsedBody.errorCode
|
||||
: typeof parsedBody?.error_code === "string"
|
||||
? parsedBody.error_code
|
||||
: undefined;
|
||||
const detailCode =
|
||||
typeof detailObject?.errorCode === "string"
|
||||
? detailObject.errorCode
|
||||
: typeof detailObject?.error_code === "string"
|
||||
? detailObject.error_code
|
||||
: undefined;
|
||||
|
||||
const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode;
|
||||
|
||||
const detailRetryAfterMs =
|
||||
typeof detailObject?.retry_after_ms === "number"
|
||||
? detailObject.retry_after_ms
|
||||
: typeof detailObject?.retryAfterMs === "number"
|
||||
? detailObject.retryAfterMs
|
||||
: undefined;
|
||||
const topRetryAfterMs =
|
||||
typeof parsedBody?.retry_after_ms === "number"
|
||||
? parsedBody.retry_after_ms
|
||||
: typeof parsedBody?.retryAfterMs === "number"
|
||||
? parsedBody.retryAfterMs
|
||||
: undefined;
|
||||
const headerRetryAfterMsRaw = response.headers.get("retry-after-ms");
|
||||
const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN;
|
||||
const retryAfterHeader = response.headers.get("retry-after");
|
||||
const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN;
|
||||
const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs)
|
||||
? Math.max(0, Math.round(headerRetryAfterMs))
|
||||
: Number.isFinite(retryAfterSeconds)
|
||||
? Math.max(0, Math.round(retryAfterSeconds * 1000))
|
||||
: undefined;
|
||||
const retryAfterMs =
|
||||
detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined;
|
||||
const message =
|
||||
detailNestedMessage ??
|
||||
detailMessage ??
|
||||
topLevelMessage ??
|
||||
`Backend error: ${response.status}`;
|
||||
|
||||
return Object.assign(new Error(message), { errorCode, retryAfterMs });
|
||||
}
|
||||
|
||||
export function tagPreAcceptSendFailure(error: unknown): unknown {
|
||||
if (error instanceof Error) {
|
||||
const withCode = error as Error & { errorCode?: string; code?: string };
|
||||
const existingCode = withCode.errorCode ?? withCode.code;
|
||||
const passthroughCodes = new Set([
|
||||
"PREMIUM_QUOTA_EXHAUSTED",
|
||||
"THREAD_BUSY",
|
||||
"TURN_CANCELLING",
|
||||
"AUTH_EXPIRED",
|
||||
"UNAUTHORIZED",
|
||||
"RATE_LIMITED",
|
||||
"NETWORK_ERROR",
|
||||
"STREAM_PARSE_ERROR",
|
||||
"TOOL_EXECUTION_ERROR",
|
||||
"PERSIST_MESSAGE_FAILED",
|
||||
"SERVER_ERROR",
|
||||
]);
|
||||
if (existingCode && passthroughCodes.has(existingCode)) {
|
||||
return Object.assign(error, { errorCode: existingCode });
|
||||
}
|
||||
return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" });
|
||||
}
|
||||
|
||||
return Object.assign(new Error("Failed to send message before stream acceptance"), {
|
||||
errorCode: "SEND_FAILED_PRE_ACCEPT",
|
||||
});
|
||||
}
|
||||
54
surfsense_web/lib/chat/parse-mention-segments.ts
Normal file
54
surfsense_web/lib/chat/parse-mention-segments.ts
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import type { MentionedDocumentInfo } from "@/atoms/chat/mentioned-documents.atom";
|
||||
|
||||
export type MentionSegment =
|
||||
| { type: "text"; value: string; start: number }
|
||||
| { type: "mention"; doc: MentionedDocumentInfo; start: number };
|
||||
|
||||
/**
|
||||
* Tokenizes a user message into text and `@mention` segments.
|
||||
*
|
||||
* Pure: no React, no DOM, no side effects. Safe to unit-test and reuse.
|
||||
*
|
||||
* Mentions are matched greedily by longest title first so that a longer title
|
||||
* (e.g. `@Project Roadmap`) is never shadowed by a shorter prefix
|
||||
* (e.g. `@Project`).
|
||||
*/
|
||||
export function parseMentionSegments(
|
||||
text: string,
|
||||
docs: ReadonlyArray<MentionedDocumentInfo>
|
||||
): MentionSegment[] {
|
||||
if (text.length === 0) return [];
|
||||
if (docs.length === 0) return [{ type: "text", value: text, start: 0 }];
|
||||
|
||||
const tokens = docs
|
||||
.map((doc) => ({ doc, token: `@${doc.title}` }))
|
||||
.sort((a, b) => b.token.length - a.token.length);
|
||||
|
||||
const segments: MentionSegment[] = [];
|
||||
let i = 0;
|
||||
let buffer = "";
|
||||
let bufferStart = 0;
|
||||
|
||||
while (i < text.length) {
|
||||
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
|
||||
if (tokenMatch) {
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
buffer = "";
|
||||
}
|
||||
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
|
||||
i += tokenMatch.token.length;
|
||||
bufferStart = i;
|
||||
continue;
|
||||
}
|
||||
if (!buffer) bufferStart = i;
|
||||
buffer += text[i];
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
}
|
||||
|
||||
return segments;
|
||||
}
|
||||
19
surfsense_web/lib/chat/stream-flush.ts
Normal file
19
surfsense_web/lib/chat/stream-flush.ts
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
import { FrameBatchedUpdater } from "@/lib/chat/streaming-state";
|
||||
|
||||
export function createStreamFlushHelpers(flushMessages: () => void): {
|
||||
batcher: FrameBatchedUpdater;
|
||||
scheduleFlush: () => void;
|
||||
forceFlush: () => void;
|
||||
} {
|
||||
const batcher = new FrameBatchedUpdater();
|
||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||
// Force-flush helper: ``batcher.flush()`` is a no-op when
|
||||
// ``dirty=false`` (e.g. a tool starts before any text streamed).
|
||||
// ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so
|
||||
// terminal events render promptly without the throttle delay.
|
||||
const forceFlush = () => {
|
||||
scheduleFlush();
|
||||
batcher.flush();
|
||||
};
|
||||
return { batcher, scheduleFlush, forceFlush };
|
||||
}
|
||||
196
surfsense_web/lib/chat/stream-pipeline.ts
Normal file
196
surfsense_web/lib/chat/stream-pipeline.ts
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
import {
|
||||
addStepSeparator,
|
||||
addToolCall,
|
||||
appendReasoning,
|
||||
appendText,
|
||||
appendToolInputDelta,
|
||||
type ContentPartsState,
|
||||
endReasoning,
|
||||
readSSEStream,
|
||||
type SSEEvent,
|
||||
type ThinkingStepData,
|
||||
type ToolUIGate,
|
||||
updateThinkingSteps,
|
||||
updateToolCall,
|
||||
} from "@/lib/chat/streaming-state";
|
||||
|
||||
export type SharedStreamEventContext = {
|
||||
contentPartsState: ContentPartsState;
|
||||
toolsWithUI: ToolUIGate;
|
||||
currentThinkingSteps: Map<string, ThinkingStepData>;
|
||||
scheduleFlush: () => void;
|
||||
forceFlush: () => void;
|
||||
onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void;
|
||||
onTurnStatus?: (data: Extract<SSEEvent, { type: "data-turn-status" }>["data"]) => void;
|
||||
onToolOutputAvailable?: (
|
||||
event: Extract<SSEEvent, { type: "tool-output-available" }>,
|
||||
context: {
|
||||
contentPartsState: ContentPartsState;
|
||||
toolCallIndices: Map<string, number>;
|
||||
}
|
||||
) => void;
|
||||
};
|
||||
|
||||
/**
|
||||
* After a tool produces output, mark any previously-decided interrupt tool
|
||||
* calls as completed so the ApprovalCard can transition from shimmer to done.
|
||||
*/
|
||||
export function markInterruptsCompleted(
|
||||
contentParts: Array<{ type: string; result?: unknown }>
|
||||
): void {
|
||||
for (const part of contentParts) {
|
||||
if (
|
||||
part.type === "tool-call" &&
|
||||
typeof part.result === "object" &&
|
||||
part.result !== null &&
|
||||
(part.result as Record<string, unknown>).__interrupt__ === true &&
|
||||
(part.result as Record<string, unknown>).__decided__ &&
|
||||
!(part.result as Record<string, unknown>).__completed__
|
||||
) {
|
||||
part.result = { ...(part.result as Record<string, unknown>), __completed__: true };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function hasPersistableContent(
|
||||
contentParts: ContentPartsState["contentParts"],
|
||||
toolsWithUI: ToolUIGate
|
||||
) {
|
||||
return contentParts.some(
|
||||
(part) =>
|
||||
(part.type === "text" && part.text.length > 0) ||
|
||||
(part.type === "reasoning" && part.text.length > 0) ||
|
||||
(part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName)))
|
||||
);
|
||||
}
|
||||
|
||||
function toStreamTerminalError(
|
||||
event: Extract<SSEEvent, { type: "error" }>
|
||||
): Error & { errorCode?: string } {
|
||||
return Object.assign(new Error(event.errorText || "Server error"), {
|
||||
errorCode: event.errorCode,
|
||||
});
|
||||
}
|
||||
|
||||
export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean {
|
||||
const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context;
|
||||
const { contentParts, toolCallIndices } = contentPartsState;
|
||||
|
||||
switch (parsed.type) {
|
||||
case "text-delta":
|
||||
appendText(contentPartsState, parsed.delta);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "reasoning-delta":
|
||||
appendReasoning(contentPartsState, parsed.delta);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "reasoning-end":
|
||||
endReasoning(contentPartsState);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "start-step":
|
||||
addStepSeparator(contentPartsState);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "finish-step":
|
||||
return true;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
toolsWithUI,
|
||||
parsed.toolCallId,
|
||||
parsed.toolName,
|
||||
{},
|
||||
false,
|
||||
parsed.langchainToolCallId
|
||||
);
|
||||
forceFlush();
|
||||
return true;
|
||||
|
||||
case "tool-input-delta":
|
||||
// High-frequency event: deltas can fire dozens of times per call,
|
||||
// so use throttled scheduleFlush (NOT forceFlush) to coalesce.
|
||||
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "tool-input-available": {
|
||||
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
|
||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||
args: parsed.input || {},
|
||||
argsText: finalArgsText,
|
||||
langchainToolCallId: parsed.langchainToolCallId,
|
||||
});
|
||||
} else {
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
toolsWithUI,
|
||||
parsed.toolCallId,
|
||||
parsed.toolName,
|
||||
parsed.input || {},
|
||||
false,
|
||||
parsed.langchainToolCallId
|
||||
);
|
||||
// addToolCall doesn't accept argsText today; backfill via
|
||||
// updateToolCall so the new card renders pretty-printed JSON.
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||
argsText: finalArgsText,
|
||||
});
|
||||
}
|
||||
forceFlush();
|
||||
return true;
|
||||
}
|
||||
|
||||
case "tool-output-available":
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||
result: parsed.output,
|
||||
langchainToolCallId: parsed.langchainToolCallId,
|
||||
});
|
||||
markInterruptsCompleted(contentParts);
|
||||
context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices });
|
||||
forceFlush();
|
||||
return true;
|
||||
|
||||
case "data-thinking-step": {
|
||||
const stepData = parsed.data as ThinkingStepData;
|
||||
if (stepData?.id) {
|
||||
currentThinkingSteps.set(stepData.id, stepData);
|
||||
const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
|
||||
if (didUpdate) {
|
||||
scheduleFlush();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
case "data-token-usage":
|
||||
context.onTokenUsage?.(parsed.data);
|
||||
return true;
|
||||
|
||||
case "data-turn-status":
|
||||
context.onTurnStatus?.(parsed.data);
|
||||
return true;
|
||||
|
||||
case "error":
|
||||
throw toStreamTerminalError(parsed);
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function consumeSseEvents(
|
||||
response: Response,
|
||||
onEvent: (event: SSEEvent) => void | Promise<void>
|
||||
): Promise<void> {
|
||||
for await (const parsed of readSSEStream(response)) {
|
||||
await onEvent(parsed);
|
||||
}
|
||||
}
|
||||
127
surfsense_web/lib/chat/stream-side-effects.ts
Normal file
127
surfsense_web/lib/chat/stream-side-effects.ts
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import type { ThreadMessageLike } from "@assistant-ui/react";
|
||||
import {
|
||||
addToolCall,
|
||||
type ContentPartsState,
|
||||
type ToolUIGate,
|
||||
updateToolCall,
|
||||
} from "@/lib/chat/streaming-state";
|
||||
|
||||
type InterruptActionRequest = {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type EditedInterruptAction = {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
};
|
||||
|
||||
function readInterruptActions(
|
||||
interruptData: Record<string, unknown>
|
||||
): InterruptActionRequest[] {
|
||||
return (interruptData.action_requests ?? []) as InterruptActionRequest[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an interrupt request payload to tool-call parts. Existing tool cards
|
||||
* are updated in-place; missing ones are upserted so approval UI always shows.
|
||||
*/
|
||||
export function applyInterruptRequestToContentParts(
|
||||
contentPartsState: ContentPartsState,
|
||||
toolsWithUI: ToolUIGate,
|
||||
interruptData: Record<string, unknown>
|
||||
): void {
|
||||
const { contentParts, toolCallIndices } = contentPartsState;
|
||||
const actionRequests = readInterruptActions(interruptData);
|
||||
for (const action of actionRequests) {
|
||||
const existingEntry = Array.from(toolCallIndices.entries()).find(([, idx]) => {
|
||||
const part = contentParts[idx];
|
||||
return part?.type === "tool-call" && part.toolName === action.name;
|
||||
});
|
||||
|
||||
if (existingEntry) {
|
||||
updateToolCall(contentPartsState, existingEntry[0], {
|
||||
result: { __interrupt__: true, ...interruptData },
|
||||
});
|
||||
} else {
|
||||
const toolCallId = `interrupt-${action.name}`;
|
||||
addToolCall(contentPartsState, toolsWithUI, toolCallId, action.name, action.args, true);
|
||||
updateToolCall(contentPartsState, toolCallId, {
|
||||
result: { __interrupt__: true, ...interruptData },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function mergeEditedInterruptAction(
|
||||
contentParts: ContentPartsState["contentParts"],
|
||||
editedAction: EditedInterruptAction | undefined
|
||||
): void {
|
||||
if (!editedAction) return;
|
||||
for (const part of contentParts) {
|
||||
if (part.type === "tool-call" && part.toolName === editedAction.name) {
|
||||
const mergedArgs = { ...part.args, ...editedAction.args };
|
||||
part.args = mergedArgs;
|
||||
// assistant-ui prefers argsText over JSON.stringify(args)
|
||||
part.argsText = JSON.stringify(mergedArgs, null, 2);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function markInterruptDecisionOnContentParts(
|
||||
contentParts: ContentPartsState["contentParts"],
|
||||
decisionType: "approve" | "reject" | undefined
|
||||
): void {
|
||||
if (!decisionType) return;
|
||||
for (const part of contentParts) {
|
||||
if (
|
||||
part.type === "tool-call" &&
|
||||
typeof part.result === "object" &&
|
||||
part.result !== null &&
|
||||
"__interrupt__" in (part.result as Record<string, unknown>)
|
||||
) {
|
||||
part.result = {
|
||||
...(part.result as Record<string, unknown>),
|
||||
__decided__: decisionType,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When a streamed message is persisted, the backend returns the durable
|
||||
* turn_id; merge it into assistant-ui metadata for turn-scoped actions.
|
||||
*/
|
||||
export function mergeChatTurnIdIntoMessage(
|
||||
msg: ThreadMessageLike,
|
||||
turnId: string | null | undefined
|
||||
): ThreadMessageLike {
|
||||
if (!turnId) return msg;
|
||||
const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> };
|
||||
const existingCustom = existingMeta.custom ?? {};
|
||||
if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg;
|
||||
return {
|
||||
...msg,
|
||||
metadata: {
|
||||
...existingMeta,
|
||||
custom: { ...existingCustom, chatTurnId: turnId },
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function readStreamedChatTurnId(data: unknown): string | null {
|
||||
if (typeof data !== "object" || data === null) return null;
|
||||
const value = (data as { chat_turn_id?: unknown }).chat_turn_id;
|
||||
return typeof value === "string" && value.length > 0 ? value : null;
|
||||
}
|
||||
|
||||
export function applyTurnIdToAssistantMessageList(
|
||||
messages: ThreadMessageLike[],
|
||||
assistantMsgId: string,
|
||||
turnId: string
|
||||
): ThreadMessageLike[] {
|
||||
return messages.map((m) =>
|
||||
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m
|
||||
);
|
||||
}
|
||||
|
|
@ -528,6 +528,14 @@ export type SSEEvent =
|
|||
}>;
|
||||
};
|
||||
}
|
||||
| {
|
||||
type: "data-turn-status";
|
||||
data: {
|
||||
status: "idle" | "busy" | "cancelling";
|
||||
retry_after_ms?: number;
|
||||
retry_after_at?: number;
|
||||
};
|
||||
}
|
||||
| {
|
||||
type: "data-token-usage";
|
||||
data: {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue