feat: enhance chat functionality with chain-of-thought display and thinking steps management

This commit is contained in:
Anish Sarkar 2025-12-22 22:54:22 +05:30
parent 2f622891ae
commit 8a99752f2f
8 changed files with 857 additions and 48 deletions

View file

@ -7,10 +7,13 @@ import {
MessagePrimitive,
ThreadPrimitive,
useAssistantState,
useMessage,
} from "@assistant-ui/react";
import {
ArrowDownIcon,
ArrowUpIcon,
Brain,
CheckCircle2,
CheckIcon,
ChevronLeftIcon,
ChevronRightIcon,
@ -19,6 +22,8 @@ import {
Loader2,
PencilIcon,
RefreshCwIcon,
Search,
Sparkles,
SquareIcon,
} from "lucide-react";
import Image from "next/image";
@ -32,44 +37,134 @@ import {
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import {
ChainOfThought,
ChainOfThoughtContent,
ChainOfThoughtItem,
ChainOfThoughtStep,
ChainOfThoughtTrigger,
} from "@/components/prompt-kit/chain-of-thought";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
export const Thread: FC = () => {
/**
* Props for the Thread component
*/
interface ThreadProps {
messageThinkingSteps?: Map<string, ThinkingStep[]>;
}
// Context to pass thinking steps to AssistantMessage
import { createContext, useContext } from "react";
const ThinkingStepsContext = createContext<Map<string, ThinkingStep[]>>(new Map());
/**
* Get icon based on step status and title
*/
function getStepIcon(status: "pending" | "in_progress" | "completed", title: string) {
const titleLower = title.toLowerCase();
if (status === "in_progress") {
return <Loader2 className="size-4 animate-spin text-primary" />;
}
if (status === "completed") {
return <CheckCircle2 className="size-4 text-emerald-500" />;
}
if (titleLower.includes("search") || titleLower.includes("knowledge")) {
return <Search className="size-4 text-muted-foreground" />;
}
if (titleLower.includes("analy") || titleLower.includes("understand")) {
return <Brain className="size-4 text-muted-foreground" />;
}
return <Sparkles className="size-4 text-muted-foreground" />;
}
/**
* Chain of thought display component
*/
const ThinkingStepsDisplay: FC<{ steps: ThinkingStep[] }> = ({ steps }) => {
if (steps.length === 0) return null;
return (
<ThreadPrimitive.Root
className="aui-root aui-thread-root @container flex h-full flex-col bg-background"
style={{
["--thread-max-width" as string]: "44rem",
}}
>
<ThreadPrimitive.Viewport
turnAnchor="top"
className="aui-thread-viewport relative flex flex-1 flex-col overflow-x-auto overflow-y-scroll scroll-smooth px-4 pt-4"
<div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2">
<ChainOfThought>
{steps.map((step) => {
const icon = getStepIcon(step.status, step.title);
return (
<ChainOfThoughtStep
key={step.id}
defaultOpen={step.status === "in_progress"}
>
<ChainOfThoughtTrigger
leftIcon={icon}
swapIconOnHover={step.status !== "in_progress"}
className={cn(
step.status === "in_progress" && "text-foreground font-medium",
step.status === "completed" && "text-muted-foreground"
)}
>
{step.title}
</ChainOfThoughtTrigger>
{step.items && step.items.length > 0 && (
<ChainOfThoughtContent>
{step.items.map((item, index) => (
<ChainOfThoughtItem key={`${step.id}-item-${index}`}>
{item}
</ChainOfThoughtItem>
))}
</ChainOfThoughtContent>
)}
</ChainOfThoughtStep>
);
})}
</ChainOfThought>
</div>
);
};
export const Thread: FC<ThreadProps> = ({ messageThinkingSteps = new Map() }) => {
return (
<ThinkingStepsContext.Provider value={messageThinkingSteps}>
<ThreadPrimitive.Root
className="aui-root aui-thread-root @container flex h-full flex-col bg-background"
style={{
["--thread-max-width" as string]: "44rem",
}}
>
<AssistantIf condition={({ thread }) => thread.isEmpty}>
<ThreadWelcome />
</AssistantIf>
<ThreadPrimitive.Messages
components={{
UserMessage,
EditComposer,
AssistantMessage,
}}
/>
<ThreadPrimitive.ViewportFooter className="aui-thread-viewport-footer sticky bottom-0 mx-auto mt-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-background pb-4 md:pb-6">
<ThreadScrollToBottom />
<AssistantIf condition={({ thread }) => !thread.isEmpty}>
<div className="fade-in slide-in-from-bottom-4 animate-in duration-500 ease-out fill-mode-both">
<Composer />
</div>
<ThreadPrimitive.Viewport
turnAnchor="top"
className="aui-thread-viewport relative flex flex-1 flex-col overflow-x-auto overflow-y-scroll scroll-smooth px-4 pt-4"
>
<AssistantIf condition={({ thread }) => thread.isEmpty}>
<ThreadWelcome />
</AssistantIf>
</ThreadPrimitive.ViewportFooter>
</ThreadPrimitive.Viewport>
</ThreadPrimitive.Root>
<ThreadPrimitive.Messages
components={{
UserMessage,
EditComposer,
AssistantMessage,
}}
/>
<ThreadPrimitive.ViewportFooter className="aui-thread-viewport-footer sticky bottom-0 mx-auto mt-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-background pb-4 md:pb-6">
<ThreadScrollToBottom />
<AssistantIf condition={({ thread }) => !thread.isEmpty}>
<div className="fade-in slide-in-from-bottom-4 animate-in duration-500 ease-out fill-mode-both">
<Composer />
</div>
</AssistantIf>
</ThreadPrimitive.ViewportFooter>
</ThreadPrimitive.Viewport>
</ThreadPrimitive.Root>
</ThinkingStepsContext.Provider>
);
};
@ -197,7 +292,7 @@ const ThreadWelcome: FC = () => {
const Composer: FC = () => {
return (
<ComposerPrimitive.Root className="aui-composer-root relative flex w-full flex-col">
<ComposerPrimitive.AttachmentDropzone className="aui-composer-attachment-dropzone flex w-full flex-col rounded-2xl border-input bg-muted px-1 pt-2 outline-none transition-shadow has-[textarea:focus-visible]:border-ring has-[textarea:focus-visible]:ring-2 has-[textarea:focus-visible]:ring-ring/20 data-[dragging=true]:border-ring data-[dragging=true]:border-dashed data-[dragging=true]:bg-accent/50">
<ComposerPrimitive.AttachmentDropzone className="aui-composer-attachment-dropzone flex w-full flex-col rounded-2xl border-input bg-muted px-1 pt-2 outline-none transition-shadow data-[dragging=true]:border-ring data-[dragging=true]:border-dashed data-[dragging=true]:bg-accent/50">
<ComposerAttachments />
<ComposerPrimitive.Input
placeholder="Ask SurfSense"
@ -297,12 +392,22 @@ const MessageError: FC = () => {
);
};
const AssistantMessage: FC = () => {
const AssistantMessageInner: FC = () => {
const thinkingStepsMap = useContext(ThinkingStepsContext);
// Get the current message ID to look up thinking steps
const messageId = useMessage((m) => m.id);
const thinkingSteps = thinkingStepsMap.get(messageId) || [];
return (
<MessagePrimitive.Root
className="aui-assistant-message-root fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
data-role="assistant"
>
<>
{/* Show thinking steps BEFORE the text response */}
{thinkingSteps.length > 0 && (
<div className="mb-3">
<ThinkingStepsDisplay steps={thinkingSteps} />
</div>
)}
<div className="aui-assistant-message-content wrap-break-word px-2 text-foreground leading-relaxed">
<MessagePrimitive.Parts
components={{
@ -317,6 +422,17 @@ const AssistantMessage: FC = () => {
<BranchPicker />
<AssistantActionBar />
</div>
</>
);
};
const AssistantMessage: FC = () => {
return (
<MessagePrimitive.Root
className="aui-assistant-message-root fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
data-role="assistant"
>
<AssistantMessageInner />
</MessagePrimitive.Root>
);
};