refactor: improve write_todos tool and UI components

- Refactored the write_todos tool to enhance argument and result schemas using Zod for better validation and type safety.
- Updated the WriteTodosToolUI to streamline the rendering logic and improve loading states, ensuring a smoother user experience.
- Enhanced the Plan and TodoItem components to better handle streaming states and display progress, providing clearer feedback during task management.
- Cleaned up code formatting and structure for improved readability and maintainability.
This commit is contained in:
Anish Sarkar 2025-12-26 17:49:56 +05:30
parent 2c86287264
commit ebc04f590e
18 changed files with 833 additions and 751 deletions

View file

@ -5,7 +5,7 @@ This module provides a tool for creating and displaying a planning/todo list
in the chat UI. It helps the agent break down complex tasks into steps. in the chat UI. It helps the agent break down complex tasks into steps.
""" """
from typing import Any, Literal from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
@ -91,4 +91,3 @@ def create_write_todos_tool():
} }
return write_todos return write_todos

View file

@ -76,12 +76,17 @@ class WebCrawlerConnector:
return result, None return result, None
except Exception as firecrawl_error: except Exception as firecrawl_error:
# Firecrawl failed, fallback to Chromium # Firecrawl failed, fallback to Chromium
logger.warning(f"[webcrawler] Firecrawl failed, falling back to Chromium+Trafilatura for: {url}") logger.warning(
f"[webcrawler] Firecrawl failed, falling back to Chromium+Trafilatura for: {url}"
)
try: try:
result = await self._crawl_with_chromium(url) result = await self._crawl_with_chromium(url)
return result, None return result, None
except Exception as chromium_error: except Exception as chromium_error:
return None, f"Both Firecrawl and Chromium failed. Firecrawl error: {firecrawl_error!s}, Chromium error: {chromium_error!s}" return (
None,
f"Both Firecrawl and Chromium failed. Firecrawl error: {firecrawl_error!s}, Chromium error: {chromium_error!s}",
)
else: else:
# No Firecrawl API key, use Chromium directly # No Firecrawl API key, use Chromium directly
logger.info(f"[webcrawler] Using Chromium+Trafilatura for: {url}") logger.info(f"[webcrawler] Using Chromium+Trafilatura for: {url}")

View file

@ -149,7 +149,11 @@ function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
if (typeof part !== "object" || part === null || !("type" in part)) return true; if (typeof part !== "object" || part === null || !("type" in part)) return true;
const partType = (part as { type: string }).type; const partType = (part as { type: string }).type;
// Filter out thinking-steps, mentioned-documents, and attachments // Filter out thinking-steps, mentioned-documents, and attachments
return partType !== "thinking-steps" && partType !== "mentioned-documents" && partType !== "attachments"; return (
partType !== "thinking-steps" &&
partType !== "mentioned-documents" &&
partType !== "attachments"
);
}); });
content = content =
filteredContent.length > 0 filteredContent.length > 0
@ -319,7 +323,13 @@ export default function NewChatPage() {
} finally { } finally {
setIsInitializing(false); setIsInitializing(false);
} }
}, [urlChatId, setMessageDocumentsMap, setMentionedDocumentIds, setMentionedDocuments, hydratePlanState]); }, [
urlChatId,
setMessageDocumentsMap,
setMentionedDocumentIds,
setMentionedDocuments,
hydratePlanState,
]);
// Initialize on mount // Initialize on mount
useEffect(() => { useEffect(() => {
@ -786,9 +796,7 @@ export default function NewChatPage() {
appendMessage(currentThreadId, { appendMessage(currentThreadId, {
role: "assistant", role: "assistant",
content: partialContent, content: partialContent,
}).catch((err) => }).catch((err) => console.error("Failed to persist partial assistant message:", err));
console.error("Failed to persist partial assistant message:", err)
);
} }
return; return;
} }

View file

@ -10,20 +10,20 @@
import { atom } from "jotai"; import { atom } from "jotai";
export interface PlanTodo { export interface PlanTodo {
id: string; id: string;
label: string; label: string;
status: "pending" | "in_progress" | "completed" | "cancelled"; status: "pending" | "in_progress" | "completed" | "cancelled";
description?: string; description?: string;
} }
export interface PlanState { export interface PlanState {
id: string; id: string;
title: string; title: string;
description?: string; description?: string;
todos: PlanTodo[]; todos: PlanTodo[];
lastUpdated: number; lastUpdated: number;
/** The toolCallId of the first component that rendered this plan */ /** The toolCallId of the first component that rendered this plan */
ownerToolCallId: string; ownerToolCallId: string;
} }
/** /**
@ -38,14 +38,14 @@ let firstPlanOwner: { toolCallId: string; title: string } | null = null;
* All subsequent calls update the state but don't render their own card. * All subsequent calls update the state but don't render their own card.
*/ */
export function registerPlanOwner(title: string, toolCallId: string): boolean { export function registerPlanOwner(title: string, toolCallId: string): boolean {
if (!firstPlanOwner) { if (!firstPlanOwner) {
// First plan in this conversation - claim ownership // First plan in this conversation - claim ownership
firstPlanOwner = { toolCallId, title }; firstPlanOwner = { toolCallId, title };
return true; return true;
} }
// Check if we're the owner // Check if we're the owner
return firstPlanOwner.toolCallId === toolCallId; return firstPlanOwner.toolCallId === toolCallId;
} }
/** /**
@ -53,35 +53,35 @@ export function registerPlanOwner(title: string, toolCallId: string): boolean {
* Returns the first plan's title if one exists, otherwise the provided title * Returns the first plan's title if one exists, otherwise the provided title
*/ */
export function getCanonicalPlanTitle(title: string): string { export function getCanonicalPlanTitle(title: string): string {
return firstPlanOwner?.title || title; return firstPlanOwner?.title || title;
} }
/** /**
* Check if a plan already exists in this conversation * Check if a plan already exists in this conversation
*/ */
export function hasPlan(): boolean { export function hasPlan(): boolean {
return firstPlanOwner !== null; return firstPlanOwner !== null;
} }
/** /**
* Get the first plan's info * Get the first plan's info
*/ */
export function getFirstPlanInfo(): { toolCallId: string; title: string } | null { export function getFirstPlanInfo(): { toolCallId: string; title: string } | null {
return firstPlanOwner; return firstPlanOwner;
} }
/** /**
* Check if a toolCallId is the owner of the plan SYNCHRONOUSLY * Check if a toolCallId is the owner of the plan SYNCHRONOUSLY
*/ */
export function isPlanOwner(toolCallId: string): boolean { export function isPlanOwner(toolCallId: string): boolean {
return !firstPlanOwner || firstPlanOwner.toolCallId === toolCallId; return !firstPlanOwner || firstPlanOwner.toolCallId === toolCallId;
} }
/** /**
* Clear ownership registry (call when starting a new chat) * Clear ownership registry (call when starting a new chat)
*/ */
export function clearPlanOwnerRegistry(): void { export function clearPlanOwnerRegistry(): void {
firstPlanOwner = null; firstPlanOwner = null;
} }
/** /**
@ -94,56 +94,53 @@ export const planStatesAtom = atom<Map<string, PlanState>>(new Map());
* Input type for updating plan state * Input type for updating plan state
*/ */
export interface UpdatePlanInput { export interface UpdatePlanInput {
id: string; id: string;
title: string; title: string;
description?: string; description?: string;
todos: PlanTodo[]; todos: PlanTodo[];
toolCallId: string; toolCallId: string;
} }
/** /**
* Helper atom to update a plan state * Helper atom to update a plan state
*/ */
export const updatePlanStateAtom = atom( export const updatePlanStateAtom = atom(null, (get, set, plan: UpdatePlanInput) => {
null, const states = new Map(get(planStatesAtom));
(get, set, plan: UpdatePlanInput) => {
const states = new Map(get(planStatesAtom));
// Register ownership synchronously if not already done // Register ownership synchronously if not already done
registerPlanOwner(plan.title, plan.toolCallId); registerPlanOwner(plan.title, plan.toolCallId);
// Get the actual owner from the first plan // Get the actual owner from the first plan
const ownerToolCallId = firstPlanOwner?.toolCallId || plan.toolCallId; const ownerToolCallId = firstPlanOwner?.toolCallId || plan.toolCallId;
// Always use the canonical (first) title for the plan key // Always use the canonical (first) title for the plan key
const canonicalTitle = getCanonicalPlanTitle(plan.title); const canonicalTitle = getCanonicalPlanTitle(plan.title);
states.set(canonicalTitle, { states.set(canonicalTitle, {
id: plan.id, id: plan.id,
title: canonicalTitle, title: canonicalTitle,
description: plan.description, description: plan.description,
todos: plan.todos, todos: plan.todos,
lastUpdated: Date.now(), lastUpdated: Date.now(),
ownerToolCallId, ownerToolCallId,
}); });
set(planStatesAtom, states); set(planStatesAtom, states);
} });
);
/** /**
* Helper atom to get the latest plan state by title * Helper atom to get the latest plan state by title
*/ */
export const getPlanStateAtom = atom((get) => { export const getPlanStateAtom = atom((get) => {
const states = get(planStatesAtom); const states = get(planStatesAtom);
return (title: string) => states.get(title); return (title: string) => states.get(title);
}); });
/** /**
* Helper atom to clear all plan states (useful when starting a new chat) * Helper atom to clear all plan states (useful when starting a new chat)
*/ */
export const clearPlanStatesAtom = atom(null, (get, set) => { export const clearPlanStatesAtom = atom(null, (get, set) => {
clearPlanOwnerRegistry(); clearPlanOwnerRegistry();
set(planStatesAtom, new Map()); set(planStatesAtom, new Map());
}); });
/** /**
@ -151,84 +148,80 @@ export const clearPlanStatesAtom = atom(null, (get, set) => {
* Call this when loading messages from the database to restore plan state * Call this when loading messages from the database to restore plan state
*/ */
export interface HydratePlanInput { export interface HydratePlanInput {
toolCallId: string; toolCallId: string;
result: { result: {
id?: string; id?: string;
title?: string; title?: string;
description?: string; description?: string;
todos?: Array<{ todos?: Array<{
id: string; id: string;
label: string; label: string;
status: "pending" | "in_progress" | "completed" | "cancelled"; status: "pending" | "in_progress" | "completed" | "cancelled";
description?: string; description?: string;
}>; }>;
}; };
} }
export const hydratePlanStateAtom = atom( export const hydratePlanStateAtom = atom(null, (get, set, plan: HydratePlanInput) => {
null, if (!plan.result?.todos || plan.result.todos.length === 0) return;
(get, set, plan: HydratePlanInput) => {
if (!plan.result?.todos || plan.result.todos.length === 0) return;
const states = new Map(get(planStatesAtom)); const states = new Map(get(planStatesAtom));
const title = plan.result.title || "Planning Approach"; const title = plan.result.title || "Planning Approach";
// Register this as the owner if no plan exists yet // Register this as the owner if no plan exists yet
registerPlanOwner(title, plan.toolCallId); registerPlanOwner(title, plan.toolCallId);
// Get the canonical title // Get the canonical title
const canonicalTitle = getCanonicalPlanTitle(title); const canonicalTitle = getCanonicalPlanTitle(title);
const ownerToolCallId = firstPlanOwner?.toolCallId || plan.toolCallId; const ownerToolCallId = firstPlanOwner?.toolCallId || plan.toolCallId;
// Only set if this is newer or doesn't exist // Only set if this is newer or doesn't exist
const existing = states.get(canonicalTitle); const existing = states.get(canonicalTitle);
if (!existing) { if (!existing) {
states.set(canonicalTitle, { states.set(canonicalTitle, {
id: plan.result.id || `plan-${Date.now()}`, id: plan.result.id || `plan-${Date.now()}`,
title: canonicalTitle, title: canonicalTitle,
description: plan.result.description, description: plan.result.description,
todos: plan.result.todos, todos: plan.result.todos,
lastUpdated: Date.now(), lastUpdated: Date.now(),
ownerToolCallId, ownerToolCallId,
}); });
set(planStatesAtom, states); set(planStatesAtom, states);
} }
} });
);
/** /**
* Extract write_todos tool call data from message content * Extract write_todos tool call data from message content
* Returns an array of { toolCallId, result } for each write_todos call found * Returns an array of { toolCallId, result } for each write_todos call found
*/ */
export function extractWriteTodosFromContent(content: unknown): HydratePlanInput[] { export function extractWriteTodosFromContent(content: unknown): HydratePlanInput[] {
if (!Array.isArray(content)) return []; if (!Array.isArray(content)) return [];
const results: HydratePlanInput[] = []; const results: HydratePlanInput[] = [];
for (const part of content) { for (const part of content) {
if ( if (
typeof part === "object" && typeof part === "object" &&
part !== null && part !== null &&
"type" in part && "type" in part &&
(part as { type: string }).type === "tool-call" && (part as { type: string }).type === "tool-call" &&
"toolName" in part && "toolName" in part &&
(part as { toolName: string }).toolName === "write_todos" && (part as { toolName: string }).toolName === "write_todos" &&
"toolCallId" in part && "toolCallId" in part &&
"result" in part "result" in part
) { ) {
const toolCall = part as { const toolCall = part as {
toolCallId: string; toolCallId: string;
result: HydratePlanInput["result"]; result: HydratePlanInput["result"];
}; };
if (toolCall.result) { if (toolCall.result) {
results.push({ results.push({
toolCallId: toolCall.toolCallId, toolCallId: toolCall.toolCallId,
result: toolCall.result, result: toolCall.result,
}); });
} }
} }
} }
return results; return results;
} }

View file

@ -910,7 +910,7 @@ const AssistantMessageInner: FC = () => {
<MessageError /> <MessageError />
</div> </div>
<div className="aui-assistant-message-footer mt-1 ml-2 flex"> <div className="aui-assistant-message-footer mt-1 mb-5 ml-2 flex">
<BranchPicker /> <BranchPicker />
<AssistantActionBar /> <AssistantActionBar />
</div> </div>

View file

@ -3,16 +3,16 @@
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
export interface LoaderProps { export interface LoaderProps {
variant?: "text-shimmer"; variant?: "text-shimmer";
size?: "sm" | "md" | "lg"; size?: "sm" | "md" | "lg";
text?: string; text?: string;
className?: string; className?: string;
} }
const textSizes = { const textSizes = {
sm: "text-xs", sm: "text-xs",
md: "text-sm", md: "text-sm",
lg: "text-base", lg: "text-base",
} as const; } as const;
/** /**
@ -20,55 +20,47 @@ const textSizes = {
* Used for in-progress states in write_todos and chain-of-thought * Used for in-progress states in write_todos and chain-of-thought
*/ */
export function TextShimmerLoader({ export function TextShimmerLoader({
text = "Thinking", text = "Thinking",
className, className,
size = "md", size = "md",
}: { }: {
text?: string; text?: string;
className?: string; className?: string;
size?: "sm" | "md" | "lg"; size?: "sm" | "md" | "lg";
}) { }) {
return ( return (
<> <>
<style> <style>
{` {`
@keyframes shimmer { @keyframes shimmer {
0% { background-position: 200% 50%; } 0% { background-position: 200% 50%; }
100% { background-position: -200% 50%; } 100% { background-position: -200% 50%; }
} }
`} `}
</style> </style>
<span <span
className={cn( className={cn(
"bg-[linear-gradient(to_right,var(--muted-foreground)_40%,var(--foreground)_60%,var(--muted-foreground)_80%)]", "bg-[linear-gradient(to_right,var(--muted-foreground)_40%,var(--foreground)_60%,var(--muted-foreground)_80%)]",
"bg-[length:200%_auto] bg-clip-text font-medium text-transparent", "bg-[length:200%_auto] bg-clip-text font-medium text-transparent",
"animate-[shimmer_4s_infinite_linear]", "animate-[shimmer_4s_infinite_linear]",
textSizes[size], textSizes[size],
className className
)} )}
> >
{text} {text}
</span> </span>
</> </>
); );
} }
/** /**
* Loader component - currently only supports text-shimmer variant * Loader component - currently only supports text-shimmer variant
* Can be extended with more variants if needed in the future * Can be extended with more variants if needed in the future
*/ */
export function Loader({ export function Loader({ variant = "text-shimmer", size = "md", text, className }: LoaderProps) {
variant = "text-shimmer", switch (variant) {
size = "md", case "text-shimmer":
text, default:
className, return <TextShimmerLoader text={text} size={size} className={className} />;
}: LoaderProps) { }
switch (variant) {
case "text-shimmer":
default:
return (
<TextShimmerLoader text={text} size={size} className={className} />
);
}
} }

View file

@ -3,7 +3,7 @@
import { useQuery, useQueryClient } from "@tanstack/react-query"; import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useAtomValue, useSetAtom } from "jotai"; import { useAtomValue, useSetAtom } from "jotai";
import { Trash2 } from "lucide-react"; import { Trash2 } from "lucide-react";
import { useRouter } from "next/navigation"; import { useParams, useRouter } from "next/navigation";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { useCallback, useMemo, useState } from "react"; import { useCallback, useMemo, useState } from "react";
import { hasUnsavedEditorChangesAtom, pendingEditorNavigationAtom } from "@/atoms/editor/ui.atoms"; import { hasUnsavedEditorChangesAtom, pendingEditorNavigationAtom } from "@/atoms/editor/ui.atoms";
@ -50,7 +50,13 @@ export function AppSidebarProvider({
const t = useTranslations("dashboard"); const t = useTranslations("dashboard");
const tCommon = useTranslations("common"); const tCommon = useTranslations("common");
const router = useRouter(); const router = useRouter();
const params = useParams();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
// Get current chat ID from URL params
const currentChatId = params?.chat_id
? Number(Array.isArray(params.chat_id) ? params.chat_id[0] : params.chat_id)
: null;
const [isDeletingThread, setIsDeletingThread] = useState(false); const [isDeletingThread, setIsDeletingThread] = useState(false);
// Editor state for handling unsaved changes // Editor state for handling unsaved changes
@ -61,7 +67,6 @@ export function AppSidebarProvider({
const { const {
data: threadsData, data: threadsData,
error: threadError, error: threadError,
isLoading: isLoadingThreads,
refetch: refetchThreads, refetch: refetchThreads,
} = useQuery({ } = useQuery({
queryKey: ["threads", searchSpaceId], queryKey: ["threads", searchSpaceId],
@ -73,7 +78,6 @@ export function AppSidebarProvider({
data: searchSpace, data: searchSpace,
isLoading: isLoadingSearchSpace, isLoading: isLoadingSearchSpace,
error: searchSpaceError, error: searchSpaceError,
refetch: fetchSearchSpace,
} = useQuery({ } = useQuery({
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId), queryKey: cacheKeys.searchSpaces.detail(searchSpaceId),
queryFn: () => searchSpacesApiService.getSearchSpace({ id: Number(searchSpaceId) }), queryFn: () => searchSpacesApiService.getSearchSpace({ id: Number(searchSpaceId) }),
@ -83,12 +87,7 @@ export function AppSidebarProvider({
const { data: user } = useAtomValue(currentUserAtom); const { data: user } = useAtomValue(currentUserAtom);
// Fetch notes // Fetch notes
const { const { data: notesData, refetch: refetchNotes } = useQuery({
data: notesData,
error: notesError,
isLoading: isLoadingNotes,
refetch: refetchNotes,
} = useQuery({
queryKey: ["notes", searchSpaceId], queryKey: ["notes", searchSpaceId],
queryFn: () => queryFn: () =>
notesApiService.getNotes({ notesApiService.getNotes({
@ -108,11 +107,6 @@ export function AppSidebarProvider({
} | null>(null); } | null>(null);
const [isDeletingNote, setIsDeletingNote] = useState(false); const [isDeletingNote, setIsDeletingNote] = useState(false);
// Retry function
const retryFetch = useCallback(() => {
fetchSearchSpace();
}, [fetchSearchSpace]);
// Transform threads to the format expected by AppSidebar // Transform threads to the format expected by AppSidebar
const recentChats = useMemo(() => { const recentChats = useMemo(() => {
if (!threadsData?.threads) return []; if (!threadsData?.threads) return [];
@ -149,8 +143,10 @@ export function AppSidebarProvider({
await deleteThread(threadToDelete.id); await deleteThread(threadToDelete.id);
// Invalidate threads query to refresh the list // Invalidate threads query to refresh the list
queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] }); queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
// Navigate to new-chat after successful deletion // Only navigate to new-chat if the deleted chat is currently open
router.push(`/dashboard/${searchSpaceId}/new-chat`); if (currentChatId === threadToDelete.id) {
router.push(`/dashboard/${searchSpaceId}/new-chat`);
}
} catch (error) { } catch (error) {
console.error("Error deleting thread:", error); console.error("Error deleting thread:", error);
} finally { } finally {
@ -158,7 +154,7 @@ export function AppSidebarProvider({
setShowDeleteDialog(false); setShowDeleteDialog(false);
setThreadToDelete(null); setThreadToDelete(null);
} }
}, [threadToDelete, queryClient, searchSpaceId, router]); }, [threadToDelete, queryClient, searchSpaceId, router, currentChatId]);
// Handle delete note with confirmation // Handle delete note with confirmation
const handleDeleteNote = useCallback(async () => { const handleDeleteNote = useCallback(async () => {

View file

@ -2,6 +2,7 @@
import { makeAssistantToolUI } from "@assistant-ui/react"; import { makeAssistantToolUI } from "@assistant-ui/react";
import { AlertCircleIcon, ImageIcon } from "lucide-react"; import { AlertCircleIcon, ImageIcon } from "lucide-react";
import { z } from "zod";
import { import {
Image, Image,
ImageErrorBoundary, ImageErrorBoundary,
@ -9,27 +10,41 @@ import {
parseSerializableImage, parseSerializableImage,
} from "@/components/tool-ui/image"; } from "@/components/tool-ui/image";
/** // ============================================================================
* Type definitions for the display_image tool // Zod Schemas
*/ // ============================================================================
interface DisplayImageArgs {
src: string;
alt?: string;
title?: string;
description?: string;
}
interface DisplayImageResult { /**
id: string; * Schema for display_image tool arguments
assetId: string; */
src: string; const DisplayImageArgsSchema = z.object({
alt?: string; // Made optional - parseSerializableImage provides fallback src: z.string(),
title?: string; alt: z.string().nullish(),
description?: string; title: z.string().nullish(),
domain?: string; description: z.string().nullish(),
ratio?: string; });
error?: string;
} /**
* Schema for display_image tool result
*/
const DisplayImageResultSchema = z.object({
id: z.string(),
assetId: z.string(),
src: z.string(),
alt: z.string().nullish(),
title: z.string().nullish(),
description: z.string().nullish(),
domain: z.string().nullish(),
ratio: z.string().nullish(),
error: z.string().nullish(),
});
// ============================================================================
// Types
// ============================================================================
type DisplayImageArgs = z.infer<typeof DisplayImageArgsSchema>;
type DisplayImageResult = z.infer<typeof DisplayImageResultSchema>;
/** /**
* Error state component shown when image display fails * Error state component shown when image display fails
@ -142,4 +157,9 @@ export const DisplayImageToolUI = makeAssistantToolUI<DisplayImageArgs, DisplayI
}, },
}); });
export type { DisplayImageArgs, DisplayImageResult }; export {
DisplayImageArgsSchema,
DisplayImageResultSchema,
type DisplayImageArgs,
type DisplayImageResult,
};

View file

@ -24,6 +24,8 @@ export {
type ThinkingStep, type ThinkingStep,
} from "./deepagent-thinking"; } from "./deepagent-thinking";
export { export {
DisplayImageArgsSchema,
DisplayImageResultSchema,
type DisplayImageArgs, type DisplayImageArgs,
type DisplayImageResult, type DisplayImageResult,
DisplayImageToolUI, DisplayImageToolUI,
@ -39,9 +41,13 @@ export {
type SerializableImage, type SerializableImage,
} from "./image"; } from "./image";
export { export {
LinkPreviewArgsSchema,
LinkPreviewResultSchema,
type LinkPreviewArgs, type LinkPreviewArgs,
type LinkPreviewResult, type LinkPreviewResult,
LinkPreviewToolUI, LinkPreviewToolUI,
MultiLinkPreviewArgsSchema,
MultiLinkPreviewResultSchema,
type MultiLinkPreviewArgs, type MultiLinkPreviewArgs,
type MultiLinkPreviewResult, type MultiLinkPreviewResult,
MultiLinkPreviewToolUI, MultiLinkPreviewToolUI,
@ -56,6 +62,8 @@ export {
type SerializableMediaCard, type SerializableMediaCard,
} from "./media-card"; } from "./media-card";
export { export {
ScrapeWebpageArgsSchema,
ScrapeWebpageResultSchema,
type ScrapeWebpageArgs, type ScrapeWebpageArgs,
type ScrapeWebpageResult, type ScrapeWebpageResult,
ScrapeWebpageToolUI, ScrapeWebpageToolUI,
@ -71,6 +79,8 @@ export {
} from "./plan"; } from "./plan";
export { export {
WriteTodosToolUI, WriteTodosToolUI,
WriteTodosArgsSchema,
WriteTodosResultSchema,
type WriteTodosArgs, type WriteTodosArgs,
type WriteTodosResult, type WriteTodosResult,
} from "./write-todos"; } from "./write-todos";

View file

@ -2,6 +2,7 @@
import { makeAssistantToolUI } from "@assistant-ui/react"; import { makeAssistantToolUI } from "@assistant-ui/react";
import { AlertCircleIcon, ExternalLinkIcon, LinkIcon } from "lucide-react"; import { AlertCircleIcon, ExternalLinkIcon, LinkIcon } from "lucide-react";
import { z } from "zod";
import { import {
MediaCard, MediaCard,
MediaCardErrorBoundary, MediaCardErrorBoundary,
@ -10,25 +11,39 @@ import {
type SerializableMediaCard, type SerializableMediaCard,
} from "@/components/tool-ui/media-card"; } from "@/components/tool-ui/media-card";
/** // ============================================================================
* Type definitions for the link_preview tool // Zod Schemas
*/ // ============================================================================
interface LinkPreviewArgs {
url: string;
title?: string;
}
interface LinkPreviewResult { /**
id: string; * Schema for link_preview tool arguments
assetId: string; */
kind: "link"; const LinkPreviewArgsSchema = z.object({
href: string; url: z.string(),
title: string; title: z.string().nullish(),
description?: string; });
thumb?: string;
domain?: string; /**
error?: string; * Schema for link_preview tool result
} */
const LinkPreviewResultSchema = z.object({
id: z.string(),
assetId: z.string(),
kind: z.literal("link"),
href: z.string(),
title: z.string(),
description: z.string().nullish(),
thumb: z.string().nullish(),
domain: z.string().nullish(),
error: z.string().nullish(),
});
// ============================================================================
// Types
// ============================================================================
type LinkPreviewArgs = z.infer<typeof LinkPreviewArgsSchema>;
type LinkPreviewResult = z.infer<typeof LinkPreviewResultSchema>;
/** /**
* Error state component shown when link preview fails * Error state component shown when link preview fails
@ -150,20 +165,35 @@ export const LinkPreviewToolUI = makeAssistantToolUI<LinkPreviewArgs, LinkPrevie
}, },
}); });
/** // ============================================================================
* Multiple Link Previews Tool UI Component // Multi Link Preview Schemas
* // ============================================================================
* This component handles cases where multiple links need to be previewed.
* It renders a grid of link preview cards.
*/
interface MultiLinkPreviewArgs {
urls: string[];
}
interface MultiLinkPreviewResult { /**
previews: LinkPreviewResult[]; * Schema for multi_link_preview tool arguments
errors?: { url: string; error: string }[]; */
} const MultiLinkPreviewArgsSchema = z.object({
urls: z.array(z.string()),
});
/**
* Schema for error items in multi_link_preview result
*/
const MultiLinkPreviewErrorSchema = z.object({
url: z.string(),
error: z.string(),
});
/**
* Schema for multi_link_preview tool result
*/
const MultiLinkPreviewResultSchema = z.object({
previews: z.array(LinkPreviewResultSchema),
errors: z.array(MultiLinkPreviewErrorSchema).nullish(),
});
type MultiLinkPreviewArgs = z.infer<typeof MultiLinkPreviewArgsSchema>;
type MultiLinkPreviewResult = z.infer<typeof MultiLinkPreviewResultSchema>;
export const MultiLinkPreviewToolUI = makeAssistantToolUI< export const MultiLinkPreviewToolUI = makeAssistantToolUI<
MultiLinkPreviewArgs, MultiLinkPreviewArgs,
@ -217,4 +247,13 @@ export const MultiLinkPreviewToolUI = makeAssistantToolUI<
}, },
}); });
export type { LinkPreviewArgs, LinkPreviewResult, MultiLinkPreviewArgs, MultiLinkPreviewResult }; export {
LinkPreviewArgsSchema,
LinkPreviewResultSchema,
MultiLinkPreviewArgsSchema,
MultiLinkPreviewResultSchema,
type LinkPreviewArgs,
type LinkPreviewResult,
type MultiLinkPreviewArgs,
type MultiLinkPreviewResult,
};

View file

@ -11,43 +11,42 @@ export * from "./schema";
// ============================================================================ // ============================================================================
interface PlanErrorBoundaryProps { interface PlanErrorBoundaryProps {
children: ReactNode; children: ReactNode;
fallback?: ReactNode; fallback?: ReactNode;
} }
interface PlanErrorBoundaryState { interface PlanErrorBoundaryState {
hasError: boolean; hasError: boolean;
error?: Error; error?: Error;
} }
export class PlanErrorBoundary extends Component<PlanErrorBoundaryProps, PlanErrorBoundaryState> { export class PlanErrorBoundary extends Component<PlanErrorBoundaryProps, PlanErrorBoundaryState> {
constructor(props: PlanErrorBoundaryProps) { constructor(props: PlanErrorBoundaryProps) {
super(props); super(props);
this.state = { hasError: false }; this.state = { hasError: false };
} }
static getDerivedStateFromError(error: Error): PlanErrorBoundaryState { static getDerivedStateFromError(error: Error): PlanErrorBoundaryState {
return { hasError: true, error }; return { hasError: true, error };
} }
render() { render() {
if (this.state.hasError) { if (this.state.hasError) {
if (this.props.fallback) { if (this.props.fallback) {
return this.props.fallback; return this.props.fallback;
} }
return ( return (
<Card className="w-full max-w-xl border-destructive/50"> <Card className="w-full max-w-xl border-destructive/50">
<CardContent className="pt-6"> <CardContent className="pt-6">
<div className="flex items-center gap-2 text-destructive"> <div className="flex items-center gap-2 text-destructive">
<span className="text-sm">Failed to render plan</span> <span className="text-sm">Failed to render plan</span>
</div> </div>
</CardContent> </CardContent>
</Card> </Card>
); );
} }
return this.props.children; return this.props.children;
} }
} }

View file

@ -1,19 +1,13 @@
"use client"; "use client";
import { import { CheckCircle2, Circle, CircleDashed, PartyPopper, XCircle } from "lucide-react";
CheckCircle2,
Circle,
CircleDashed,
PartyPopper,
XCircle,
} from "lucide-react";
import type { FC } from "react"; import type { FC } from "react";
import { useMemo, useState } from "react"; import { useMemo, useState } from "react";
import { import {
Accordion, Accordion,
AccordionContent, AccordionContent,
AccordionItem, AccordionItem,
AccordionTrigger, AccordionTrigger,
} from "@/components/ui/accordion"; } from "@/components/ui/accordion";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
@ -29,37 +23,33 @@ import type { PlanTodo, TodoStatus } from "./schema";
// ============================================================================ // ============================================================================
interface StatusIconProps { interface StatusIconProps {
status: TodoStatus; status: TodoStatus;
className?: string; className?: string;
/** When false, in_progress items show as static (no spinner) */ /** When false, in_progress items show as static (no spinner) */
isStreaming?: boolean; isStreaming?: boolean;
} }
const StatusIcon: FC<StatusIconProps> = ({ status, className, isStreaming = true }) => { const StatusIcon: FC<StatusIconProps> = ({ status, className, isStreaming = true }) => {
const baseClass = cn("size-4 shrink-0", className); const baseClass = cn("size-4 shrink-0", className);
switch (status) { switch (status) {
case "completed": case "completed":
return <CheckCircle2 className={cn(baseClass, "text-emerald-500")} />; return <CheckCircle2 className={cn(baseClass, "text-emerald-500")} />;
case "in_progress": case "in_progress":
// Only animate the spinner if we're actively streaming // Only animate the spinner if we're actively streaming
// When streaming is stopped, show as a static dashed circle // When streaming is stopped, show as a static dashed circle
return ( return (
<CircleDashed <CircleDashed
className={cn( className={cn(baseClass, "text-primary", isStreaming && "animate-spin")}
baseClass, style={isStreaming ? { animationDuration: "3s" } : undefined}
"text-primary", />
isStreaming && "animate-spin" );
)} case "cancelled":
style={isStreaming ? { animationDuration: "3s" } : undefined} return <XCircle className={cn(baseClass, "text-destructive")} />;
/> case "pending":
); default:
case "cancelled": return <Circle className={cn(baseClass, "text-muted-foreground")} />;
return <XCircle className={cn(baseClass, "text-destructive")} />; }
case "pending":
default:
return <Circle className={cn(baseClass, "text-muted-foreground")} />;
}
}; };
// ============================================================================ // ============================================================================
@ -67,55 +57,50 @@ const StatusIcon: FC<StatusIconProps> = ({ status, className, isStreaming = true
// ============================================================================ // ============================================================================
interface TodoItemProps { interface TodoItemProps {
todo: PlanTodo; todo: PlanTodo;
/** When false, in_progress items show as static (no spinner/pulse) */ /** When false, in_progress items show as static (no spinner/pulse) */
isStreaming?: boolean; isStreaming?: boolean;
} }
const TodoItem: FC<TodoItemProps> = ({ todo, isStreaming = true }) => { const TodoItem: FC<TodoItemProps> = ({ todo, isStreaming = true }) => {
const isStrikethrough = todo.status === "completed" || todo.status === "cancelled"; const isStrikethrough = todo.status === "completed" || todo.status === "cancelled";
// Only show shimmer animation if streaming and in progress // Only show shimmer animation if streaming and in progress
const isShimmer = todo.status === "in_progress" && isStreaming; const isShimmer = todo.status === "in_progress" && isStreaming;
// Render the label with optional shimmer effect // Render the label with optional shimmer effect
const renderLabel = () => { const renderLabel = () => {
if (isShimmer) { if (isShimmer) {
return <TextShimmerLoader text={todo.label} size="md" />; return <TextShimmerLoader text={todo.label} size="md" />;
} }
return ( return (
<span <span className={cn("text-sm", isStrikethrough && "line-through text-muted-foreground")}>
className={cn( {todo.label}
"text-sm", </span>
isStrikethrough && "line-through text-muted-foreground" );
)} };
>
{todo.label}
</span>
);
};
if (todo.description) { if (todo.description) {
return ( return (
<AccordionItem value={todo.id} className="border-0"> <AccordionItem value={todo.id} className="border-0">
<AccordionTrigger className="py-2 hover:no-underline"> <AccordionTrigger className="py-2 hover:no-underline">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<StatusIcon status={todo.status} isStreaming={isStreaming} /> <StatusIcon status={todo.status} isStreaming={isStreaming} />
{renderLabel()} {renderLabel()}
</div> </div>
</AccordionTrigger> </AccordionTrigger>
<AccordionContent className="pb-2 pl-6"> <AccordionContent className="pb-2 pl-6">
<p className="text-sm text-muted-foreground">{todo.description}</p> <p className="text-sm text-muted-foreground">{todo.description}</p>
</AccordionContent> </AccordionContent>
</AccordionItem> </AccordionItem>
); );
} }
return ( return (
<div className="flex items-center gap-2 py-2"> <div className="flex items-center gap-2 py-2">
<StatusIcon status={todo.status} isStreaming={isStreaming} /> <StatusIcon status={todo.status} isStreaming={isStreaming} />
{renderLabel()} {renderLabel()}
</div> </div>
); );
}; };
// ============================================================================ // ============================================================================
@ -123,159 +108,158 @@ const TodoItem: FC<TodoItemProps> = ({ todo, isStreaming = true }) => {
// ============================================================================ // ============================================================================
export interface PlanProps { export interface PlanProps {
id: string; id: string;
title: string; title: string;
description?: string; description?: string;
todos: PlanTodo[]; todos: PlanTodo[];
maxVisibleTodos?: number; maxVisibleTodos?: number;
showProgress?: boolean; showProgress?: boolean;
/** When false, in_progress items show as static (no spinner/pulse animations) */ /** When false, in_progress items show as static (no spinner/pulse animations) */
isStreaming?: boolean; isStreaming?: boolean;
responseActions?: Action[] | ActionsConfig; responseActions?: Action[] | ActionsConfig;
className?: string; className?: string;
onResponseAction?: (actionId: string) => void; onResponseAction?: (actionId: string) => void;
onBeforeResponseAction?: (actionId: string) => boolean; onBeforeResponseAction?: (actionId: string) => boolean;
} }
export const Plan: FC<PlanProps> = ({ export const Plan: FC<PlanProps> = ({
id, id,
title, title,
description, description,
todos, todos,
maxVisibleTodos = 4, maxVisibleTodos = 4,
showProgress = true, showProgress = true,
isStreaming = true, isStreaming = true,
responseActions, responseActions,
className, className,
onResponseAction, onResponseAction,
onBeforeResponseAction, onBeforeResponseAction,
}) => { }) => {
const [isExpanded, setIsExpanded] = useState(false); const [isExpanded, setIsExpanded] = useState(false);
// Calculate progress // Calculate progress
const progress = useMemo(() => { const progress = useMemo(() => {
const completed = todos.filter((t) => t.status === "completed").length; const completed = todos.filter((t) => t.status === "completed").length;
const total = todos.filter((t) => t.status !== "cancelled").length; const total = todos.filter((t) => t.status !== "cancelled").length;
return { completed, total, percentage: total > 0 ? (completed / total) * 100 : 0 }; return { completed, total, percentage: total > 0 ? (completed / total) * 100 : 0 };
}, [todos]); }, [todos]);
const isAllComplete = progress.completed === progress.total && progress.total > 0; const isAllComplete = progress.completed === progress.total && progress.total > 0;
// Split todos for collapsible display // Split todos for collapsible display
const visibleTodos = todos.slice(0, maxVisibleTodos); const visibleTodos = todos.slice(0, maxVisibleTodos);
const hiddenTodos = todos.slice(maxVisibleTodos); const hiddenTodos = todos.slice(maxVisibleTodos);
const hasHiddenTodos = hiddenTodos.length > 0; const hasHiddenTodos = hiddenTodos.length > 0;
// Check if any todo has a description (for accordion mode) // Check if any todo has a description (for accordion mode)
const hasDescriptions = todos.some((t) => t.description); const hasDescriptions = todos.some((t) => t.description);
// Handle action click // Handle action click
const handleAction = (actionId: string) => { const handleAction = (actionId: string) => {
if (onBeforeResponseAction && !onBeforeResponseAction(actionId)) { if (onBeforeResponseAction && !onBeforeResponseAction(actionId)) {
return; return;
} }
onResponseAction?.(actionId); onResponseAction?.(actionId);
}; };
// Normalize actions to array // Normalize actions to array
const actionArray: Action[] = useMemo(() => { const actionArray: Action[] = useMemo(() => {
if (!responseActions) return []; if (!responseActions) return [];
if (Array.isArray(responseActions)) return responseActions; if (Array.isArray(responseActions)) return responseActions;
return [ return [
responseActions.confirm && { ...responseActions.confirm, id: "confirm" }, responseActions.confirm && { ...responseActions.confirm, id: "confirm" },
responseActions.cancel && { ...responseActions.cancel, id: "cancel" }, responseActions.cancel && { ...responseActions.cancel, id: "cancel" },
].filter(Boolean) as Action[]; ].filter(Boolean) as Action[];
}, [responseActions]); }, [responseActions]);
const TodoList: FC<{ items: PlanTodo[] }> = ({ items }) => { const TodoList: FC<{ items: PlanTodo[] }> = ({ items }) => {
if (hasDescriptions) { if (hasDescriptions) {
return ( return (
<Accordion type="single" collapsible className="w-full"> <Accordion type="single" collapsible className="w-full">
{items.map((todo) => ( {items.map((todo) => (
<TodoItem key={todo.id} todo={todo} isStreaming={isStreaming} /> <TodoItem key={todo.id} todo={todo} isStreaming={isStreaming} />
))} ))}
</Accordion> </Accordion>
); );
} }
return ( return (
<div className="space-y-0"> <div className="space-y-0">
{items.map((todo) => ( {items.map((todo) => (
<TodoItem key={todo.id} todo={todo} isStreaming={isStreaming} /> <TodoItem key={todo.id} todo={todo} isStreaming={isStreaming} />
))} ))}
</div> </div>
); );
}; };
return ( return (
<Card id={id} className={cn("w-full max-w-xl", className)}> <Card id={id} className={cn("w-full max-w-xl", className)}>
<CardHeader className="pb-3"> <CardHeader className="pb-3">
<div className="flex items-start justify-between gap-2"> <div className="flex items-start justify-between gap-2">
<div className="flex-1 min-w-0"> <div className="flex-1 min-w-0">
<CardTitle className="text-base font-semibold">{title}</CardTitle> <CardTitle className="text-base font-semibold">{title}</CardTitle>
{description && ( {description && (
<CardDescription className="mt-1 text-sm">{description}</CardDescription> <CardDescription className="mt-1 text-sm">{description}</CardDescription>
)} )}
</div> </div>
{isAllComplete && ( {isAllComplete && (
<div className="flex items-center gap-1 text-emerald-500"> <div className="flex items-center gap-1 text-emerald-500">
<PartyPopper className="size-5" /> <PartyPopper className="size-5" />
</div> </div>
)} )}
</div> </div>
{showProgress && ( {showProgress && (
<div className="mt-3 space-y-1.5"> <div className="mt-3 space-y-1.5">
<div className="flex items-center justify-between text-xs text-muted-foreground"> <div className="flex items-center justify-between text-xs text-muted-foreground">
<span> <span>
{progress.completed} of {progress.total} complete {progress.completed} of {progress.total} complete
</span> </span>
<span>{Math.round(progress.percentage)}%</span> <span>{Math.round(progress.percentage)}%</span>
</div> </div>
<Progress value={progress.percentage} className="h-1.5" /> <Progress value={progress.percentage} className="h-1.5" />
</div> </div>
)} )}
</CardHeader> </CardHeader>
<CardContent className="pt-0"> <CardContent className="pt-0">
<TodoList items={visibleTodos} /> <TodoList items={visibleTodos} />
{hasHiddenTodos && ( {hasHiddenTodos && (
<Collapsible open={isExpanded} onOpenChange={setIsExpanded}> <Collapsible open={isExpanded} onOpenChange={setIsExpanded}>
<CollapsibleTrigger asChild> <CollapsibleTrigger asChild>
<Button <Button
variant="ghost" variant="ghost"
size="sm" size="sm"
className="w-full mt-2 text-xs text-muted-foreground hover:text-foreground" className="w-full mt-2 text-xs text-muted-foreground hover:text-foreground"
> >
{isExpanded {isExpanded
? "Show less" ? "Show less"
: `Show ${hiddenTodos.length} more ${hiddenTodos.length === 1 ? "task" : "tasks"}`} : `Show ${hiddenTodos.length} more ${hiddenTodos.length === 1 ? "task" : "tasks"}`}
</Button> </Button>
</CollapsibleTrigger> </CollapsibleTrigger>
<CollapsibleContent> <CollapsibleContent>
<TodoList items={hiddenTodos} /> <TodoList items={hiddenTodos} />
</CollapsibleContent> </CollapsibleContent>
</Collapsible> </Collapsible>
)} )}
{actionArray.length > 0 && ( {actionArray.length > 0 && (
<div className="flex flex-wrap gap-2 pt-4 mt-2 border-t"> <div className="flex flex-wrap gap-2 pt-4 mt-2 border-t">
{actionArray.map((action) => ( {actionArray.map((action) => (
<Button <Button
key={action.id} key={action.id}
variant={action.variant || "default"} variant={action.variant || "default"}
size="sm" size="sm"
disabled={action.disabled} disabled={action.disabled}
onClick={() => handleAction(action.id)} onClick={() => handleAction(action.id)}
> >
{action.label} {action.label}
</Button> </Button>
))} ))}
</div> </div>
)} )}
</CardContent> </CardContent>
</Card> </Card>
); );
}; };

View file

@ -1,5 +1,4 @@
import { z } from "zod"; import { z } from "zod";
import { ActionSchema } from "../shared/schema";
/** /**
* Todo item status * Todo item status
@ -11,10 +10,10 @@ export type TodoStatus = z.infer<typeof TodoStatusSchema>;
* Single todo item in a plan * Single todo item in a plan
*/ */
export const PlanTodoSchema = z.object({ export const PlanTodoSchema = z.object({
id: z.string(), id: z.string(),
label: z.string(), label: z.string(),
status: TodoStatusSchema, status: TodoStatusSchema,
description: z.string().optional(), description: z.string().optional(),
}); });
export type PlanTodo = z.infer<typeof PlanTodoSchema>; export type PlanTodo = z.infer<typeof PlanTodoSchema>;
@ -23,12 +22,12 @@ export type PlanTodo = z.infer<typeof PlanTodoSchema>;
* Serializable plan schema for tool results * Serializable plan schema for tool results
*/ */
export const SerializablePlanSchema = z.object({ export const SerializablePlanSchema = z.object({
id: z.string(), id: z.string(),
title: z.string(), title: z.string(),
description: z.string().optional(), description: z.string().optional(),
todos: z.array(PlanTodoSchema).min(1), todos: z.array(PlanTodoSchema).min(1),
maxVisibleTodos: z.number().optional(), maxVisibleTodos: z.number().optional(),
showProgress: z.boolean().optional(), showProgress: z.boolean().optional(),
}); });
export type SerializablePlan = z.infer<typeof SerializablePlanSchema>; export type SerializablePlan = z.infer<typeof SerializablePlanSchema>;
@ -37,31 +36,31 @@ export type SerializablePlan = z.infer<typeof SerializablePlanSchema>;
* Parse and validate a serializable plan from tool result * Parse and validate a serializable plan from tool result
*/ */
export function parseSerializablePlan(data: unknown): SerializablePlan { export function parseSerializablePlan(data: unknown): SerializablePlan {
const result = SerializablePlanSchema.safeParse(data); const result = SerializablePlanSchema.safeParse(data);
if (!result.success) { if (!result.success) {
console.warn("Invalid plan data:", result.error.issues); console.warn("Invalid plan data:", result.error.issues);
// Try to extract basic info for fallback // Try to extract basic info for fallback
const obj = (data && typeof data === "object" ? data : {}) as Record<string, unknown>; const obj = (data && typeof data === "object" ? data : {}) as Record<string, unknown>;
return { return {
id: typeof obj.id === "string" ? obj.id : "unknown", id: typeof obj.id === "string" ? obj.id : "unknown",
title: typeof obj.title === "string" ? obj.title : "Plan", title: typeof obj.title === "string" ? obj.title : "Plan",
description: typeof obj.description === "string" ? obj.description : undefined, description: typeof obj.description === "string" ? obj.description : undefined,
todos: Array.isArray(obj.todos) todos: Array.isArray(obj.todos)
? obj.todos.map((t, i) => ({ ? obj.todos.map((t, i) => ({
id: typeof (t as any)?.id === "string" ? (t as any).id : `todo-${i}`, id: typeof (t as any)?.id === "string" ? (t as any).id : `todo-${i}`,
label: typeof (t as any)?.label === "string" ? (t as any).label : "Task", label: typeof (t as any)?.label === "string" ? (t as any).label : "Task",
status: TodoStatusSchema.safeParse((t as any)?.status).success status: TodoStatusSchema.safeParse((t as any)?.status).success
? (t as any).status ? (t as any).status
: "pending", : "pending",
description: typeof (t as any)?.description === "string" ? (t as any).description : undefined, description:
})) typeof (t as any)?.description === "string" ? (t as any).description : undefined,
: [{ id: "1", label: "No tasks", status: "pending" as const }], }))
}; : [{ id: "1", label: "No tasks", status: "pending" as const }],
} };
}
return result.data; return result.data;
} }

View file

@ -2,6 +2,7 @@
import { makeAssistantToolUI } from "@assistant-ui/react"; import { makeAssistantToolUI } from "@assistant-ui/react";
import { AlertCircleIcon, FileTextIcon } from "lucide-react"; import { AlertCircleIcon, FileTextIcon } from "lucide-react";
import { z } from "zod";
import { import {
Article, Article,
ArticleErrorBoundary, ArticleErrorBoundary,
@ -9,30 +10,44 @@ import {
parseSerializableArticle, parseSerializableArticle,
} from "@/components/tool-ui/article"; } from "@/components/tool-ui/article";
/** // ============================================================================
* Type definitions for the scrape_webpage tool // Zod Schemas
*/ // ============================================================================
interface ScrapeWebpageArgs {
url: string;
max_length?: number;
}
interface ScrapeWebpageResult { /**
id: string; * Schema for scrape_webpage tool arguments
assetId: string; */
kind: "article"; const ScrapeWebpageArgsSchema = z.object({
href: string; url: z.string(),
title: string; max_length: z.number().nullish(),
description?: string; });
content?: string;
domain?: string; /**
author?: string; * Schema for scrape_webpage tool result
date?: string; */
word_count?: number; const ScrapeWebpageResultSchema = z.object({
was_truncated?: boolean; id: z.string(),
crawler_type?: string; assetId: z.string(),
error?: string; kind: z.literal("article"),
} href: z.string(),
title: z.string(),
description: z.string().nullish(),
content: z.string().nullish(),
domain: z.string().nullish(),
author: z.string().nullish(),
date: z.string().nullish(),
word_count: z.number().nullish(),
was_truncated: z.boolean().nullish(),
crawler_type: z.string().nullish(),
error: z.string().nullish(),
});
// ============================================================================
// Types
// ============================================================================
type ScrapeWebpageArgs = z.infer<typeof ScrapeWebpageArgsSchema>;
type ScrapeWebpageResult = z.infer<typeof ScrapeWebpageResultSchema>;
/** /**
* Error state component shown when webpage scraping fails * Error state component shown when webpage scraping fails
@ -154,4 +169,9 @@ export const ScrapeWebpageToolUI = makeAssistantToolUI<ScrapeWebpageArgs, Scrape
}, },
}); });
export type { ScrapeWebpageArgs, ScrapeWebpageResult }; export {
ScrapeWebpageArgsSchema,
ScrapeWebpageResultSchema,
type ScrapeWebpageArgs,
type ScrapeWebpageResult,
};

View file

@ -5,38 +5,37 @@ import { Button } from "@/components/ui/button";
import type { Action, ActionsConfig } from "./schema"; import type { Action, ActionsConfig } from "./schema";
interface ActionButtonsProps { interface ActionButtonsProps {
actions?: Action[] | ActionsConfig; actions?: Action[] | ActionsConfig;
onAction?: (actionId: string) => void; onAction?: (actionId: string) => void;
disabled?: boolean; disabled?: boolean;
} }
export const ActionButtons: FC<ActionButtonsProps> = ({ actions, onAction, disabled }) => { export const ActionButtons: FC<ActionButtonsProps> = ({ actions, onAction, disabled }) => {
if (!actions) return null; if (!actions) return null;
// Normalize actions to array format // Normalize actions to array format
const actionArray: Action[] = Array.isArray(actions) const actionArray: Action[] = Array.isArray(actions)
? actions ? actions
: [ : ([
actions.confirm && { ...actions.confirm, id: "confirm" }, actions.confirm && { ...actions.confirm, id: "confirm" },
actions.cancel && { ...actions.cancel, id: "cancel" }, actions.cancel && { ...actions.cancel, id: "cancel" },
].filter(Boolean) as Action[]; ].filter(Boolean) as Action[]);
if (actionArray.length === 0) return null; if (actionArray.length === 0) return null;
return ( return (
<div className="flex flex-wrap gap-2 pt-3"> <div className="flex flex-wrap gap-2 pt-3">
{actionArray.map((action) => ( {actionArray.map((action) => (
<Button <Button
key={action.id} key={action.id}
variant={action.variant || "default"} variant={action.variant || "default"}
size="sm" size="sm"
disabled={disabled || action.disabled} disabled={disabled || action.disabled}
onClick={() => onAction?.(action.id)} onClick={() => onAction?.(action.id)}
> >
{action.label} {action.label}
</Button> </Button>
))} ))}
</div> </div>
); );
}; };

View file

@ -1,3 +1,2 @@
export * from "./schema"; export * from "./schema";
export * from "./action-buttons"; export * from "./action-buttons";

View file

@ -4,10 +4,10 @@ import { z } from "zod";
* Shared action schema for tool UI components * Shared action schema for tool UI components
*/ */
export const ActionSchema = z.object({ export const ActionSchema = z.object({
id: z.string(), id: z.string(),
label: z.string(), label: z.string(),
variant: z.enum(["default", "secondary", "destructive", "outline", "ghost", "link"]).optional(), variant: z.enum(["default", "secondary", "destructive", "outline", "ghost", "link"]).optional(),
disabled: z.boolean().optional(), disabled: z.boolean().optional(),
}); });
export type Action = z.infer<typeof ActionSchema>; export type Action = z.infer<typeof ActionSchema>;
@ -16,9 +16,8 @@ export type Action = z.infer<typeof ActionSchema>;
* Actions configuration schema * Actions configuration schema
*/ */
export const ActionsConfigSchema = z.object({ export const ActionsConfigSchema = z.object({
confirm: ActionSchema.optional(), confirm: ActionSchema.optional(),
cancel: ActionSchema.optional(), cancel: ActionSchema.optional(),
}); });
export type ActionsConfig = z.infer<typeof ActionsConfigSchema>; export type ActionsConfig = z.infer<typeof ActionsConfigSchema>;

View file

@ -4,54 +4,76 @@ import { makeAssistantToolUI, useAssistantState } from "@assistant-ui/react";
import { useAtomValue, useSetAtom } from "jotai"; import { useAtomValue, useSetAtom } from "jotai";
import { Loader2 } from "lucide-react"; import { Loader2 } from "lucide-react";
import { useEffect, useMemo } from "react"; import { useEffect, useMemo } from "react";
import { z } from "zod";
import { import {
getCanonicalPlanTitle, getCanonicalPlanTitle,
planStatesAtom, planStatesAtom,
registerPlanOwner, registerPlanOwner,
updatePlanStateAtom, updatePlanStateAtom,
} from "@/atoms/chat/plan-state.atom"; } from "@/atoms/chat/plan-state.atom";
import { Plan, PlanErrorBoundary, parseSerializablePlan } from "./plan"; import { Plan, PlanErrorBoundary, parseSerializablePlan, TodoStatusSchema } from "./plan";
// ============================================================================
// Zod Schemas
// ============================================================================
/** /**
* Tool arguments for write_todos * Schema for a single todo item in the args
*/ */
interface WriteTodosArgs { const WriteTodosArgsTodoSchema = z.object({
title?: string; id: z.string(),
description?: string; content: z.string(),
todos?: Array<{ status: TodoStatusSchema,
id: string; });
content: string;
status: "pending" | "in_progress" | "completed" | "cancelled";
}>;
}
/** /**
* Tool result for write_todos * Schema for write_todos tool arguments
*/ */
interface WriteTodosResult { const WriteTodosArgsSchema = z.object({
id: string; title: z.string().nullish(),
title: string; description: z.string().nullish(),
description?: string; todos: z.array(WriteTodosArgsTodoSchema).nullish(),
todos: Array<{ });
id: string;
label: string; /**
status: "pending" | "in_progress" | "completed" | "cancelled"; * Schema for a single todo item in the result
description?: string; */
}>; const WriteTodosResultTodoSchema = z.object({
} id: z.string(),
label: z.string(),
status: TodoStatusSchema,
description: z.string().nullish(),
});
/**
* Schema for write_todos tool result
*/
const WriteTodosResultSchema = z.object({
id: z.string(),
title: z.string(),
description: z.string().nullish(),
todos: z.array(WriteTodosResultTodoSchema),
});
// ============================================================================
// Types
// ============================================================================
type WriteTodosArgs = z.infer<typeof WriteTodosArgsSchema>;
type WriteTodosResult = z.infer<typeof WriteTodosResultSchema>;
/** /**
* Loading state component * Loading state component
*/ */
function WriteTodosLoading() { function WriteTodosLoading() {
return ( return (
<div className="my-4 w-full max-w-xl rounded-2xl border bg-card/60 px-5 py-4 shadow-sm"> <div className="my-4 w-full max-w-xl rounded-2xl border bg-card/60 px-5 py-4 shadow-sm">
<div className="flex items-center gap-3"> <div className="flex items-center gap-3">
<Loader2 className="size-5 animate-spin text-primary" /> <Loader2 className="size-5 animate-spin text-primary" />
<span className="text-sm text-muted-foreground">Creating plan...</span> <span className="text-sm text-muted-foreground">Creating plan...</span>
</div> </div>
</div> </div>
); );
} }
/** /**
@ -59,20 +81,20 @@ function WriteTodosLoading() {
* This handles the case where the LLM is streaming the tool call * This handles the case where the LLM is streaming the tool call
*/ */
function transformArgsToResult(args: WriteTodosArgs): WriteTodosResult | null { function transformArgsToResult(args: WriteTodosArgs): WriteTodosResult | null {
if (!args.todos || !Array.isArray(args.todos) || args.todos.length === 0) { if (!args.todos || !Array.isArray(args.todos) || args.todos.length === 0) {
return null; return null;
} }
return { return {
id: `plan-${Date.now()}`, id: `plan-${Date.now()}`,
title: args.title || "Planning Approach", title: args.title || "Planning Approach",
description: args.description, description: args.description,
todos: args.todos.map((todo, index) => ({ todos: args.todos.map((todo, index) => ({
id: todo.id || `todo-${index}`, id: todo.id || `todo-${index}`,
label: todo.content || "Task", label: todo.content || "Task",
status: todo.status || "pending", status: todo.status || "pending",
})), })),
}; };
} }
/** /**
@ -87,116 +109,115 @@ function transformArgsToResult(args: WriteTodosArgs): WriteTodosResult | null {
* layout shift when plans are updated. * layout shift when plans are updated.
*/ */
export const WriteTodosToolUI = makeAssistantToolUI<WriteTodosArgs, WriteTodosResult>({ export const WriteTodosToolUI = makeAssistantToolUI<WriteTodosArgs, WriteTodosResult>({
toolName: "write_todos", toolName: "write_todos",
render: function WriteTodosUI({ args, result, status, toolCallId }) { render: function WriteTodosUI({ args, result, status, toolCallId }) {
const updatePlanState = useSetAtom(updatePlanStateAtom); const updatePlanState = useSetAtom(updatePlanStateAtom);
const planStates = useAtomValue(planStatesAtom); const planStates = useAtomValue(planStatesAtom);
// Check if the THREAD is running (not just this tool) // Check if the THREAD is running (not just this tool)
// This hook subscribes to state changes, so it re-renders when thread stops // This hook subscribes to state changes, so it re-renders when thread stops
const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning); const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning);
// Get the plan data (from result or args) // Get the plan data (from result or args)
const planData = result || transformArgsToResult(args); const planData = result || transformArgsToResult(args);
const rawTitle = planData?.title || args.title || "Planning Approach"; const rawTitle = planData?.title || args.title || "Planning Approach";
// SYNCHRONOUS ownership check - happens immediately, no race conditions // SYNCHRONOUS ownership check - happens immediately, no race conditions
// ONE PLAN PER CONVERSATION: Only first write_todos call becomes owner // ONE PLAN PER CONVERSATION: Only first write_todos call becomes owner
const isOwner = useMemo(() => { const isOwner = useMemo(() => {
return registerPlanOwner(rawTitle, toolCallId); return registerPlanOwner(rawTitle, toolCallId);
}, [rawTitle, toolCallId]); }, [rawTitle, toolCallId]);
// Get canonical title - always use the FIRST plan's title // Get canonical title - always use the FIRST plan's title
// This ensures all updates go to the same plan state // This ensures all updates go to the same plan state
const planTitle = useMemo(() => getCanonicalPlanTitle(rawTitle), [rawTitle]); const planTitle = useMemo(() => getCanonicalPlanTitle(rawTitle), [rawTitle]);
// Register/update the plan state - ALWAYS use canonical title // Register/update the plan state - ALWAYS use canonical title
useEffect(() => { useEffect(() => {
if (planData) { if (planData) {
updatePlanState({ updatePlanState({
id: planData.id, id: planData.id,
title: planTitle, // Use canonical title, not raw title title: planTitle, // Use canonical title, not raw title
description: planData.description, description: planData.description,
todos: planData.todos, todos: planData.todos,
toolCallId, toolCallId,
}); });
} }
}, [planData, planTitle, updatePlanState, toolCallId]); }, [planData, planTitle, updatePlanState, toolCallId]);
// Update when result changes (for streaming updates) // Update when result changes (for streaming updates)
useEffect(() => { useEffect(() => {
if (result) { if (result) {
updatePlanState({ updatePlanState({
id: result.id, id: result.id,
title: planTitle, // Use canonical title, not raw title title: planTitle, // Use canonical title, not raw title
description: result.description, description: result.description,
todos: result.todos, todos: result.todos,
toolCallId, toolCallId,
}); });
} }
}, [result, planTitle, updatePlanState, toolCallId]); }, [result, planTitle, updatePlanState, toolCallId]);
// Get the current plan state (may be updated by other components) // Get the current plan state (may be updated by other components)
const currentPlanState = planStates.get(planTitle); const currentPlanState = planStates.get(planTitle);
// If we're NOT the owner, render nothing (the owner will render) // If we're NOT the owner, render nothing (the owner will render)
if (!isOwner) { if (!isOwner) {
return null; return null;
} }
// Loading state - tool is still running (no data yet) // Loading state - tool is still running (no data yet)
if (status.type === "running" || status.type === "requires-action") { if (status.type === "running" || status.type === "requires-action") {
// Try to show partial results from args while streaming // Try to show partial results from args while streaming
const partialResult = transformArgsToResult(args); const partialResult = transformArgsToResult(args);
if (partialResult) { if (partialResult) {
const plan = parseSerializablePlan(partialResult); const plan = parseSerializablePlan(partialResult);
return ( return (
<div className="my-4"> <div className="my-4">
<PlanErrorBoundary> <PlanErrorBoundary>
<Plan {...plan} showProgress={true} isStreaming={isThreadRunning} /> <Plan {...plan} showProgress={true} isStreaming={isThreadRunning} />
</PlanErrorBoundary> </PlanErrorBoundary>
</div> </div>
); );
} }
return <WriteTodosLoading />; return <WriteTodosLoading />;
} }
// Incomplete/cancelled state // Incomplete/cancelled state
if (status.type === "incomplete") { if (status.type === "incomplete") {
// For cancelled or errors, try to show what we have from args or shared state // For cancelled or errors, try to show what we have from args or shared state
// Use isThreadRunning to determine if we should still animate // Use isThreadRunning to determine if we should still animate
const fallbackResult = currentPlanState || transformArgsToResult(args); const fallbackResult = currentPlanState || transformArgsToResult(args);
if (fallbackResult) { if (fallbackResult) {
const plan = parseSerializablePlan(fallbackResult); const plan = parseSerializablePlan(fallbackResult);
return ( return (
<div className="my-4"> <div className="my-4">
<PlanErrorBoundary> <PlanErrorBoundary>
<Plan {...plan} showProgress={true} isStreaming={isThreadRunning} /> <Plan {...plan} showProgress={true} isStreaming={isThreadRunning} />
</PlanErrorBoundary> </PlanErrorBoundary>
</div> </div>
); );
} }
return null; return null;
} }
// Success - render the plan using the LATEST shared state // Success - render the plan using the LATEST shared state
// Use isThreadRunning to determine if we should animate in_progress items // Use isThreadRunning to determine if we should animate in_progress items
// (LLM may still be working on tasks even though this tool call completed) // (LLM may still be working on tasks even though this tool call completed)
const planToRender = currentPlanState || result; const planToRender = currentPlanState || result;
if (!planToRender) { if (!planToRender) {
return <WriteTodosLoading />; return <WriteTodosLoading />;
} }
const plan = parseSerializablePlan(planToRender); const plan = parseSerializablePlan(planToRender);
return ( return (
<div className="my-4"> <div className="my-4">
<PlanErrorBoundary> <PlanErrorBoundary>
<Plan {...plan} showProgress={true} isStreaming={isThreadRunning} /> <Plan {...plan} showProgress={true} isStreaming={isThreadRunning} />
</PlanErrorBoundary> </PlanErrorBoundary>
</div> </div>
); );
}, },
}); });
export type { WriteTodosArgs, WriteTodosResult }; export { WriteTodosArgsSchema, WriteTodosResultSchema, type WriteTodosArgs, type WriteTodosResult };