refactor: improve write_todos tool and UI components

- Refactored the write_todos tool to enhance argument and result schemas using Zod for better validation and type safety.
- Updated the WriteTodosToolUI to streamline the rendering logic and improve loading states, ensuring a smoother user experience.
- Enhanced the Plan and TodoItem components to better handle streaming states and display progress, providing clearer feedback during task management.
- Cleaned up code formatting and structure for improved readability and maintainability.
This commit is contained in:
Anish Sarkar 2025-12-26 17:49:56 +05:30
parent 2c86287264
commit ebc04f590e
18 changed files with 833 additions and 751 deletions

View file

@ -5,7 +5,7 @@ This module provides a tool for creating and displaying a planning/todo list
in the chat UI. It helps the agent break down complex tasks into steps. in the chat UI. It helps the agent break down complex tasks into steps.
""" """
from typing import Any, Literal from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
@ -91,4 +91,3 @@ def create_write_todos_tool():
} }
return write_todos return write_todos

View file

@ -76,12 +76,17 @@ class WebCrawlerConnector:
return result, None return result, None
except Exception as firecrawl_error: except Exception as firecrawl_error:
# Firecrawl failed, fallback to Chromium # Firecrawl failed, fallback to Chromium
logger.warning(f"[webcrawler] Firecrawl failed, falling back to Chromium+Trafilatura for: {url}") logger.warning(
f"[webcrawler] Firecrawl failed, falling back to Chromium+Trafilatura for: {url}"
)
try: try:
result = await self._crawl_with_chromium(url) result = await self._crawl_with_chromium(url)
return result, None return result, None
except Exception as chromium_error: except Exception as chromium_error:
return None, f"Both Firecrawl and Chromium failed. Firecrawl error: {firecrawl_error!s}, Chromium error: {chromium_error!s}" return (
None,
f"Both Firecrawl and Chromium failed. Firecrawl error: {firecrawl_error!s}, Chromium error: {chromium_error!s}",
)
else: else:
# No Firecrawl API key, use Chromium directly # No Firecrawl API key, use Chromium directly
logger.info(f"[webcrawler] Using Chromium+Trafilatura for: {url}") logger.info(f"[webcrawler] Using Chromium+Trafilatura for: {url}")

View file

@ -149,7 +149,11 @@ function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
if (typeof part !== "object" || part === null || !("type" in part)) return true; if (typeof part !== "object" || part === null || !("type" in part)) return true;
const partType = (part as { type: string }).type; const partType = (part as { type: string }).type;
// Filter out thinking-steps, mentioned-documents, and attachments // Filter out thinking-steps, mentioned-documents, and attachments
return partType !== "thinking-steps" && partType !== "mentioned-documents" && partType !== "attachments"; return (
partType !== "thinking-steps" &&
partType !== "mentioned-documents" &&
partType !== "attachments"
);
}); });
content = content =
filteredContent.length > 0 filteredContent.length > 0
@ -319,7 +323,13 @@ export default function NewChatPage() {
} finally { } finally {
setIsInitializing(false); setIsInitializing(false);
} }
}, [urlChatId, setMessageDocumentsMap, setMentionedDocumentIds, setMentionedDocuments, hydratePlanState]); }, [
urlChatId,
setMessageDocumentsMap,
setMentionedDocumentIds,
setMentionedDocuments,
hydratePlanState,
]);
// Initialize on mount // Initialize on mount
useEffect(() => { useEffect(() => {
@ -786,9 +796,7 @@ export default function NewChatPage() {
appendMessage(currentThreadId, { appendMessage(currentThreadId, {
role: "assistant", role: "assistant",
content: partialContent, content: partialContent,
}).catch((err) => }).catch((err) => console.error("Failed to persist partial assistant message:", err));
console.error("Failed to persist partial assistant message:", err)
);
} }
return; return;
} }

View file

@ -104,9 +104,7 @@ export interface UpdatePlanInput {
/** /**
* Helper atom to update a plan state * Helper atom to update a plan state
*/ */
export const updatePlanStateAtom = atom( export const updatePlanStateAtom = atom(null, (get, set, plan: UpdatePlanInput) => {
null,
(get, set, plan: UpdatePlanInput) => {
const states = new Map(get(planStatesAtom)); const states = new Map(get(planStatesAtom));
// Register ownership synchronously if not already done // Register ownership synchronously if not already done
@ -127,8 +125,7 @@ export const updatePlanStateAtom = atom(
ownerToolCallId, ownerToolCallId,
}); });
set(planStatesAtom, states); set(planStatesAtom, states);
} });
);
/** /**
* Helper atom to get the latest plan state by title * Helper atom to get the latest plan state by title
@ -165,9 +162,7 @@ export interface HydratePlanInput {
}; };
} }
export const hydratePlanStateAtom = atom( export const hydratePlanStateAtom = atom(null, (get, set, plan: HydratePlanInput) => {
null,
(get, set, plan: HydratePlanInput) => {
if (!plan.result?.todos || plan.result.todos.length === 0) return; if (!plan.result?.todos || plan.result.todos.length === 0) return;
const states = new Map(get(planStatesAtom)); const states = new Map(get(planStatesAtom));
@ -193,8 +188,7 @@ export const hydratePlanStateAtom = atom(
}); });
set(planStatesAtom, states); set(planStatesAtom, states);
} }
} });
);
/** /**
* Extract write_todos tool call data from message content * Extract write_todos tool call data from message content
@ -231,4 +225,3 @@ export function extractWriteTodosFromContent(content: unknown): HydratePlanInput
return results; return results;
} }

View file

@ -910,7 +910,7 @@ const AssistantMessageInner: FC = () => {
<MessageError /> <MessageError />
</div> </div>
<div className="aui-assistant-message-footer mt-1 ml-2 flex"> <div className="aui-assistant-message-footer mt-1 mb-5 ml-2 flex">
<BranchPicker /> <BranchPicker />
<AssistantActionBar /> <AssistantActionBar />
</div> </div>

View file

@ -57,18 +57,10 @@ export function TextShimmerLoader({
* Loader component - currently only supports text-shimmer variant * Loader component - currently only supports text-shimmer variant
* Can be extended with more variants if needed in the future * Can be extended with more variants if needed in the future
*/ */
export function Loader({ export function Loader({ variant = "text-shimmer", size = "md", text, className }: LoaderProps) {
variant = "text-shimmer",
size = "md",
text,
className,
}: LoaderProps) {
switch (variant) { switch (variant) {
case "text-shimmer": case "text-shimmer":
default: default:
return ( return <TextShimmerLoader text={text} size={size} className={className} />;
<TextShimmerLoader text={text} size={size} className={className} />
);
} }
} }

View file

@ -3,7 +3,7 @@
import { useQuery, useQueryClient } from "@tanstack/react-query"; import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useAtomValue, useSetAtom } from "jotai"; import { useAtomValue, useSetAtom } from "jotai";
import { Trash2 } from "lucide-react"; import { Trash2 } from "lucide-react";
import { useRouter } from "next/navigation"; import { useParams, useRouter } from "next/navigation";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { useCallback, useMemo, useState } from "react"; import { useCallback, useMemo, useState } from "react";
import { hasUnsavedEditorChangesAtom, pendingEditorNavigationAtom } from "@/atoms/editor/ui.atoms"; import { hasUnsavedEditorChangesAtom, pendingEditorNavigationAtom } from "@/atoms/editor/ui.atoms";
@ -50,7 +50,13 @@ export function AppSidebarProvider({
const t = useTranslations("dashboard"); const t = useTranslations("dashboard");
const tCommon = useTranslations("common"); const tCommon = useTranslations("common");
const router = useRouter(); const router = useRouter();
const params = useParams();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
// Get current chat ID from URL params
const currentChatId = params?.chat_id
? Number(Array.isArray(params.chat_id) ? params.chat_id[0] : params.chat_id)
: null;
const [isDeletingThread, setIsDeletingThread] = useState(false); const [isDeletingThread, setIsDeletingThread] = useState(false);
// Editor state for handling unsaved changes // Editor state for handling unsaved changes
@ -61,7 +67,6 @@ export function AppSidebarProvider({
const { const {
data: threadsData, data: threadsData,
error: threadError, error: threadError,
isLoading: isLoadingThreads,
refetch: refetchThreads, refetch: refetchThreads,
} = useQuery({ } = useQuery({
queryKey: ["threads", searchSpaceId], queryKey: ["threads", searchSpaceId],
@ -73,7 +78,6 @@ export function AppSidebarProvider({
data: searchSpace, data: searchSpace,
isLoading: isLoadingSearchSpace, isLoading: isLoadingSearchSpace,
error: searchSpaceError, error: searchSpaceError,
refetch: fetchSearchSpace,
} = useQuery({ } = useQuery({
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId), queryKey: cacheKeys.searchSpaces.detail(searchSpaceId),
queryFn: () => searchSpacesApiService.getSearchSpace({ id: Number(searchSpaceId) }), queryFn: () => searchSpacesApiService.getSearchSpace({ id: Number(searchSpaceId) }),
@ -83,12 +87,7 @@ export function AppSidebarProvider({
const { data: user } = useAtomValue(currentUserAtom); const { data: user } = useAtomValue(currentUserAtom);
// Fetch notes // Fetch notes
const { const { data: notesData, refetch: refetchNotes } = useQuery({
data: notesData,
error: notesError,
isLoading: isLoadingNotes,
refetch: refetchNotes,
} = useQuery({
queryKey: ["notes", searchSpaceId], queryKey: ["notes", searchSpaceId],
queryFn: () => queryFn: () =>
notesApiService.getNotes({ notesApiService.getNotes({
@ -108,11 +107,6 @@ export function AppSidebarProvider({
} | null>(null); } | null>(null);
const [isDeletingNote, setIsDeletingNote] = useState(false); const [isDeletingNote, setIsDeletingNote] = useState(false);
// Retry function
const retryFetch = useCallback(() => {
fetchSearchSpace();
}, [fetchSearchSpace]);
// Transform threads to the format expected by AppSidebar // Transform threads to the format expected by AppSidebar
const recentChats = useMemo(() => { const recentChats = useMemo(() => {
if (!threadsData?.threads) return []; if (!threadsData?.threads) return [];
@ -149,8 +143,10 @@ export function AppSidebarProvider({
await deleteThread(threadToDelete.id); await deleteThread(threadToDelete.id);
// Invalidate threads query to refresh the list // Invalidate threads query to refresh the list
queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] }); queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
// Navigate to new-chat after successful deletion // Only navigate to new-chat if the deleted chat is currently open
if (currentChatId === threadToDelete.id) {
router.push(`/dashboard/${searchSpaceId}/new-chat`); router.push(`/dashboard/${searchSpaceId}/new-chat`);
}
} catch (error) { } catch (error) {
console.error("Error deleting thread:", error); console.error("Error deleting thread:", error);
} finally { } finally {
@ -158,7 +154,7 @@ export function AppSidebarProvider({
setShowDeleteDialog(false); setShowDeleteDialog(false);
setThreadToDelete(null); setThreadToDelete(null);
} }
}, [threadToDelete, queryClient, searchSpaceId, router]); }, [threadToDelete, queryClient, searchSpaceId, router, currentChatId]);
// Handle delete note with confirmation // Handle delete note with confirmation
const handleDeleteNote = useCallback(async () => { const handleDeleteNote = useCallback(async () => {

View file

@ -2,6 +2,7 @@
import { makeAssistantToolUI } from "@assistant-ui/react"; import { makeAssistantToolUI } from "@assistant-ui/react";
import { AlertCircleIcon, ImageIcon } from "lucide-react"; import { AlertCircleIcon, ImageIcon } from "lucide-react";
import { z } from "zod";
import { import {
Image, Image,
ImageErrorBoundary, ImageErrorBoundary,
@ -9,27 +10,41 @@ import {
parseSerializableImage, parseSerializableImage,
} from "@/components/tool-ui/image"; } from "@/components/tool-ui/image";
/** // ============================================================================
* Type definitions for the display_image tool // Zod Schemas
*/ // ============================================================================
interface DisplayImageArgs {
src: string;
alt?: string;
title?: string;
description?: string;
}
interface DisplayImageResult { /**
id: string; * Schema for display_image tool arguments
assetId: string; */
src: string; const DisplayImageArgsSchema = z.object({
alt?: string; // Made optional - parseSerializableImage provides fallback src: z.string(),
title?: string; alt: z.string().nullish(),
description?: string; title: z.string().nullish(),
domain?: string; description: z.string().nullish(),
ratio?: string; });
error?: string;
} /**
* Schema for display_image tool result
*/
const DisplayImageResultSchema = z.object({
id: z.string(),
assetId: z.string(),
src: z.string(),
alt: z.string().nullish(),
title: z.string().nullish(),
description: z.string().nullish(),
domain: z.string().nullish(),
ratio: z.string().nullish(),
error: z.string().nullish(),
});
// ============================================================================
// Types
// ============================================================================
type DisplayImageArgs = z.infer<typeof DisplayImageArgsSchema>;
type DisplayImageResult = z.infer<typeof DisplayImageResultSchema>;
/** /**
* Error state component shown when image display fails * Error state component shown when image display fails
@ -142,4 +157,9 @@ export const DisplayImageToolUI = makeAssistantToolUI<DisplayImageArgs, DisplayI
}, },
}); });
export type { DisplayImageArgs, DisplayImageResult }; export {
DisplayImageArgsSchema,
DisplayImageResultSchema,
type DisplayImageArgs,
type DisplayImageResult,
};

View file

@ -24,6 +24,8 @@ export {
type ThinkingStep, type ThinkingStep,
} from "./deepagent-thinking"; } from "./deepagent-thinking";
export { export {
DisplayImageArgsSchema,
DisplayImageResultSchema,
type DisplayImageArgs, type DisplayImageArgs,
type DisplayImageResult, type DisplayImageResult,
DisplayImageToolUI, DisplayImageToolUI,
@ -39,9 +41,13 @@ export {
type SerializableImage, type SerializableImage,
} from "./image"; } from "./image";
export { export {
LinkPreviewArgsSchema,
LinkPreviewResultSchema,
type LinkPreviewArgs, type LinkPreviewArgs,
type LinkPreviewResult, type LinkPreviewResult,
LinkPreviewToolUI, LinkPreviewToolUI,
MultiLinkPreviewArgsSchema,
MultiLinkPreviewResultSchema,
type MultiLinkPreviewArgs, type MultiLinkPreviewArgs,
type MultiLinkPreviewResult, type MultiLinkPreviewResult,
MultiLinkPreviewToolUI, MultiLinkPreviewToolUI,
@ -56,6 +62,8 @@ export {
type SerializableMediaCard, type SerializableMediaCard,
} from "./media-card"; } from "./media-card";
export { export {
ScrapeWebpageArgsSchema,
ScrapeWebpageResultSchema,
type ScrapeWebpageArgs, type ScrapeWebpageArgs,
type ScrapeWebpageResult, type ScrapeWebpageResult,
ScrapeWebpageToolUI, ScrapeWebpageToolUI,
@ -71,6 +79,8 @@ export {
} from "./plan"; } from "./plan";
export { export {
WriteTodosToolUI, WriteTodosToolUI,
WriteTodosArgsSchema,
WriteTodosResultSchema,
type WriteTodosArgs, type WriteTodosArgs,
type WriteTodosResult, type WriteTodosResult,
} from "./write-todos"; } from "./write-todos";

View file

@ -2,6 +2,7 @@
import { makeAssistantToolUI } from "@assistant-ui/react"; import { makeAssistantToolUI } from "@assistant-ui/react";
import { AlertCircleIcon, ExternalLinkIcon, LinkIcon } from "lucide-react"; import { AlertCircleIcon, ExternalLinkIcon, LinkIcon } from "lucide-react";
import { z } from "zod";
import { import {
MediaCard, MediaCard,
MediaCardErrorBoundary, MediaCardErrorBoundary,
@ -10,25 +11,39 @@ import {
type SerializableMediaCard, type SerializableMediaCard,
} from "@/components/tool-ui/media-card"; } from "@/components/tool-ui/media-card";
/** // ============================================================================
* Type definitions for the link_preview tool // Zod Schemas
*/ // ============================================================================
interface LinkPreviewArgs {
url: string;
title?: string;
}
interface LinkPreviewResult { /**
id: string; * Schema for link_preview tool arguments
assetId: string; */
kind: "link"; const LinkPreviewArgsSchema = z.object({
href: string; url: z.string(),
title: string; title: z.string().nullish(),
description?: string; });
thumb?: string;
domain?: string; /**
error?: string; * Schema for link_preview tool result
} */
const LinkPreviewResultSchema = z.object({
id: z.string(),
assetId: z.string(),
kind: z.literal("link"),
href: z.string(),
title: z.string(),
description: z.string().nullish(),
thumb: z.string().nullish(),
domain: z.string().nullish(),
error: z.string().nullish(),
});
// ============================================================================
// Types
// ============================================================================
type LinkPreviewArgs = z.infer<typeof LinkPreviewArgsSchema>;
type LinkPreviewResult = z.infer<typeof LinkPreviewResultSchema>;
/** /**
* Error state component shown when link preview fails * Error state component shown when link preview fails
@ -150,20 +165,35 @@ export const LinkPreviewToolUI = makeAssistantToolUI<LinkPreviewArgs, LinkPrevie
}, },
}); });
/** // ============================================================================
* Multiple Link Previews Tool UI Component // Multi Link Preview Schemas
* // ============================================================================
* This component handles cases where multiple links need to be previewed.
* It renders a grid of link preview cards.
*/
interface MultiLinkPreviewArgs {
urls: string[];
}
interface MultiLinkPreviewResult { /**
previews: LinkPreviewResult[]; * Schema for multi_link_preview tool arguments
errors?: { url: string; error: string }[]; */
} const MultiLinkPreviewArgsSchema = z.object({
urls: z.array(z.string()),
});
/**
* Schema for error items in multi_link_preview result
*/
const MultiLinkPreviewErrorSchema = z.object({
url: z.string(),
error: z.string(),
});
/**
* Schema for multi_link_preview tool result
*/
const MultiLinkPreviewResultSchema = z.object({
previews: z.array(LinkPreviewResultSchema),
errors: z.array(MultiLinkPreviewErrorSchema).nullish(),
});
type MultiLinkPreviewArgs = z.infer<typeof MultiLinkPreviewArgsSchema>;
type MultiLinkPreviewResult = z.infer<typeof MultiLinkPreviewResultSchema>;
export const MultiLinkPreviewToolUI = makeAssistantToolUI< export const MultiLinkPreviewToolUI = makeAssistantToolUI<
MultiLinkPreviewArgs, MultiLinkPreviewArgs,
@ -217,4 +247,13 @@ export const MultiLinkPreviewToolUI = makeAssistantToolUI<
}, },
}); });
export type { LinkPreviewArgs, LinkPreviewResult, MultiLinkPreviewArgs, MultiLinkPreviewResult }; export {
LinkPreviewArgsSchema,
LinkPreviewResultSchema,
MultiLinkPreviewArgsSchema,
MultiLinkPreviewResultSchema,
type LinkPreviewArgs,
type LinkPreviewResult,
type MultiLinkPreviewArgs,
type MultiLinkPreviewResult,
};

View file

@ -50,4 +50,3 @@ export class PlanErrorBoundary extends Component<PlanErrorBoundaryProps, PlanErr
return this.props.children; return this.props.children;
} }
} }

View file

@ -1,12 +1,6 @@
"use client"; "use client";
import { import { CheckCircle2, Circle, CircleDashed, PartyPopper, XCircle } from "lucide-react";
CheckCircle2,
Circle,
CircleDashed,
PartyPopper,
XCircle,
} from "lucide-react";
import type { FC } from "react"; import type { FC } from "react";
import { useMemo, useState } from "react"; import { useMemo, useState } from "react";
import { import {
@ -46,11 +40,7 @@ const StatusIcon: FC<StatusIconProps> = ({ status, className, isStreaming = true
// When streaming is stopped, show as a static dashed circle // When streaming is stopped, show as a static dashed circle
return ( return (
<CircleDashed <CircleDashed
className={cn( className={cn(baseClass, "text-primary", isStreaming && "animate-spin")}
baseClass,
"text-primary",
isStreaming && "animate-spin"
)}
style={isStreaming ? { animationDuration: "3s" } : undefined} style={isStreaming ? { animationDuration: "3s" } : undefined}
/> />
); );
@ -83,12 +73,7 @@ const TodoItem: FC<TodoItemProps> = ({ todo, isStreaming = true }) => {
return <TextShimmerLoader text={todo.label} size="md" />; return <TextShimmerLoader text={todo.label} size="md" />;
} }
return ( return (
<span <span className={cn("text-sm", isStrikethrough && "line-through text-muted-foreground")}>
className={cn(
"text-sm",
isStrikethrough && "line-through text-muted-foreground"
)}
>
{todo.label} {todo.label}
</span> </span>
); );
@ -278,4 +263,3 @@ export const Plan: FC<PlanProps> = ({
</Card> </Card>
); );
}; };

View file

@ -1,5 +1,4 @@
import { z } from "zod"; import { z } from "zod";
import { ActionSchema } from "../shared/schema";
/** /**
* Todo item status * Todo item status
@ -56,7 +55,8 @@ export function parseSerializablePlan(data: unknown): SerializablePlan {
status: TodoStatusSchema.safeParse((t as any)?.status).success status: TodoStatusSchema.safeParse((t as any)?.status).success
? (t as any).status ? (t as any).status
: "pending", : "pending",
description: typeof (t as any)?.description === "string" ? (t as any).description : undefined, description:
typeof (t as any)?.description === "string" ? (t as any).description : undefined,
})) }))
: [{ id: "1", label: "No tasks", status: "pending" as const }], : [{ id: "1", label: "No tasks", status: "pending" as const }],
}; };
@ -64,4 +64,3 @@ export function parseSerializablePlan(data: unknown): SerializablePlan {
return result.data; return result.data;
} }

View file

@ -2,6 +2,7 @@
import { makeAssistantToolUI } from "@assistant-ui/react"; import { makeAssistantToolUI } from "@assistant-ui/react";
import { AlertCircleIcon, FileTextIcon } from "lucide-react"; import { AlertCircleIcon, FileTextIcon } from "lucide-react";
import { z } from "zod";
import { import {
Article, Article,
ArticleErrorBoundary, ArticleErrorBoundary,
@ -9,30 +10,44 @@ import {
parseSerializableArticle, parseSerializableArticle,
} from "@/components/tool-ui/article"; } from "@/components/tool-ui/article";
/** // ============================================================================
* Type definitions for the scrape_webpage tool // Zod Schemas
*/ // ============================================================================
interface ScrapeWebpageArgs {
url: string;
max_length?: number;
}
interface ScrapeWebpageResult { /**
id: string; * Schema for scrape_webpage tool arguments
assetId: string; */
kind: "article"; const ScrapeWebpageArgsSchema = z.object({
href: string; url: z.string(),
title: string; max_length: z.number().nullish(),
description?: string; });
content?: string;
domain?: string; /**
author?: string; * Schema for scrape_webpage tool result
date?: string; */
word_count?: number; const ScrapeWebpageResultSchema = z.object({
was_truncated?: boolean; id: z.string(),
crawler_type?: string; assetId: z.string(),
error?: string; kind: z.literal("article"),
} href: z.string(),
title: z.string(),
description: z.string().nullish(),
content: z.string().nullish(),
domain: z.string().nullish(),
author: z.string().nullish(),
date: z.string().nullish(),
word_count: z.number().nullish(),
was_truncated: z.boolean().nullish(),
crawler_type: z.string().nullish(),
error: z.string().nullish(),
});
// ============================================================================
// Types
// ============================================================================
type ScrapeWebpageArgs = z.infer<typeof ScrapeWebpageArgsSchema>;
type ScrapeWebpageResult = z.infer<typeof ScrapeWebpageResultSchema>;
/** /**
* Error state component shown when webpage scraping fails * Error state component shown when webpage scraping fails
@ -154,4 +169,9 @@ export const ScrapeWebpageToolUI = makeAssistantToolUI<ScrapeWebpageArgs, Scrape
}, },
}); });
export type { ScrapeWebpageArgs, ScrapeWebpageResult }; export {
ScrapeWebpageArgsSchema,
ScrapeWebpageResultSchema,
type ScrapeWebpageArgs,
type ScrapeWebpageResult,
};

View file

@ -16,10 +16,10 @@ export const ActionButtons: FC<ActionButtonsProps> = ({ actions, onAction, disab
// Normalize actions to array format // Normalize actions to array format
const actionArray: Action[] = Array.isArray(actions) const actionArray: Action[] = Array.isArray(actions)
? actions ? actions
: [ : ([
actions.confirm && { ...actions.confirm, id: "confirm" }, actions.confirm && { ...actions.confirm, id: "confirm" },
actions.cancel && { ...actions.cancel, id: "cancel" }, actions.cancel && { ...actions.cancel, id: "cancel" },
].filter(Boolean) as Action[]; ].filter(Boolean) as Action[]);
if (actionArray.length === 0) return null; if (actionArray.length === 0) return null;
@ -39,4 +39,3 @@ export const ActionButtons: FC<ActionButtonsProps> = ({ actions, onAction, disab
</div> </div>
); );
}; };

View file

@ -1,3 +1,2 @@
export * from "./schema"; export * from "./schema";
export * from "./action-buttons"; export * from "./action-buttons";

View file

@ -21,4 +21,3 @@ export const ActionsConfigSchema = z.object({
}); });
export type ActionsConfig = z.infer<typeof ActionsConfigSchema>; export type ActionsConfig = z.infer<typeof ActionsConfigSchema>;

View file

@ -4,41 +4,63 @@ import { makeAssistantToolUI, useAssistantState } from "@assistant-ui/react";
import { useAtomValue, useSetAtom } from "jotai"; import { useAtomValue, useSetAtom } from "jotai";
import { Loader2 } from "lucide-react"; import { Loader2 } from "lucide-react";
import { useEffect, useMemo } from "react"; import { useEffect, useMemo } from "react";
import { z } from "zod";
import { import {
getCanonicalPlanTitle, getCanonicalPlanTitle,
planStatesAtom, planStatesAtom,
registerPlanOwner, registerPlanOwner,
updatePlanStateAtom, updatePlanStateAtom,
} from "@/atoms/chat/plan-state.atom"; } from "@/atoms/chat/plan-state.atom";
import { Plan, PlanErrorBoundary, parseSerializablePlan } from "./plan"; import { Plan, PlanErrorBoundary, parseSerializablePlan, TodoStatusSchema } from "./plan";
// ============================================================================
// Zod Schemas
// ============================================================================
/** /**
* Tool arguments for write_todos * Schema for a single todo item in the args
*/ */
interface WriteTodosArgs { const WriteTodosArgsTodoSchema = z.object({
title?: string; id: z.string(),
description?: string; content: z.string(),
todos?: Array<{ status: TodoStatusSchema,
id: string; });
content: string;
status: "pending" | "in_progress" | "completed" | "cancelled";
}>;
}
/** /**
* Tool result for write_todos * Schema for write_todos tool arguments
*/ */
interface WriteTodosResult { const WriteTodosArgsSchema = z.object({
id: string; title: z.string().nullish(),
title: string; description: z.string().nullish(),
description?: string; todos: z.array(WriteTodosArgsTodoSchema).nullish(),
todos: Array<{ });
id: string;
label: string; /**
status: "pending" | "in_progress" | "completed" | "cancelled"; * Schema for a single todo item in the result
description?: string; */
}>; const WriteTodosResultTodoSchema = z.object({
} id: z.string(),
label: z.string(),
status: TodoStatusSchema,
description: z.string().nullish(),
});
/**
* Schema for write_todos tool result
*/
const WriteTodosResultSchema = z.object({
id: z.string(),
title: z.string(),
description: z.string().nullish(),
todos: z.array(WriteTodosResultTodoSchema),
});
// ============================================================================
// Types
// ============================================================================
type WriteTodosArgs = z.infer<typeof WriteTodosArgsSchema>;
type WriteTodosResult = z.infer<typeof WriteTodosResultSchema>;
/** /**
* Loading state component * Loading state component
@ -198,5 +220,4 @@ export const WriteTodosToolUI = makeAssistantToolUI<WriteTodosArgs, WriteTodosRe
}, },
}); });
export type { WriteTodosArgs, WriteTodosResult }; export { WriteTodosArgsSchema, WriteTodosResultSchema, type WriteTodosArgs, type WriteTodosResult };