mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 07:42:39 +02:00
feat: add GenerateImageToolUI component for rendering generated images with error handling and loading states
This commit is contained in:
parent
a99791009a
commit
15a81dbf41
2 changed files with 166 additions and 30 deletions
138
surfsense_web/components/tool-ui/generate-image.tsx
Normal file
138
surfsense_web/components/tool-ui/generate-image.tsx
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import type { ToolCallMessagePartProps } from "@assistant-ui/react";
|
||||||
|
import { AlertCircleIcon, ImageIcon } from "lucide-react";
|
||||||
|
import { z } from "zod";
|
||||||
|
import {
|
||||||
|
Image,
|
||||||
|
ImageErrorBoundary,
|
||||||
|
ImageLoading,
|
||||||
|
parseSerializableImage,
|
||||||
|
} from "@/components/tool-ui/image";
|
||||||
|
|
||||||
|
const GenerateImageArgsSchema = z.object({
|
||||||
|
prompt: z.string(),
|
||||||
|
n: z.number().nullish(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const GenerateImageResultSchema = z.object({
|
||||||
|
id: z.string(),
|
||||||
|
assetId: z.string(),
|
||||||
|
src: z.string(),
|
||||||
|
alt: z.string().nullish(),
|
||||||
|
title: z.string().nullish(),
|
||||||
|
description: z.string().nullish(),
|
||||||
|
domain: z.string().nullish(),
|
||||||
|
ratio: z.string().nullish(),
|
||||||
|
generated: z.boolean().nullish(),
|
||||||
|
prompt: z.string().nullish(),
|
||||||
|
image_count: z.number().nullish(),
|
||||||
|
error: z.string().nullish(),
|
||||||
|
});
|
||||||
|
|
||||||
|
type GenerateImageArgs = z.infer<typeof GenerateImageArgsSchema>;
|
||||||
|
type GenerateImageResult = z.infer<typeof GenerateImageResultSchema>;
|
||||||
|
|
||||||
|
function ImageErrorState({ prompt, error }: { prompt: string; error: string }) {
|
||||||
|
return (
|
||||||
|
<div className="my-4 overflow-hidden rounded-xl border border-destructive/20 bg-destructive/5 p-4 max-w-md">
|
||||||
|
<div className="flex items-center gap-4">
|
||||||
|
<div className="flex size-12 shrink-0 items-center justify-center rounded-lg bg-destructive/10">
|
||||||
|
<AlertCircleIcon className="size-6 text-destructive" />
|
||||||
|
</div>
|
||||||
|
<div className="flex-1 min-w-0">
|
||||||
|
<p className="font-medium text-destructive text-sm">Image generation failed</p>
|
||||||
|
<p className="text-muted-foreground text-xs mt-0.5 truncate">{prompt}</p>
|
||||||
|
<p className="text-muted-foreground text-xs mt-1">{error}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function ImageCancelledState({ prompt }: { prompt: string }) {
|
||||||
|
return (
|
||||||
|
<div className="my-4 rounded-xl border border-muted p-4 text-muted-foreground max-w-md">
|
||||||
|
<p className="flex items-center gap-2">
|
||||||
|
<ImageIcon className="size-4" />
|
||||||
|
<span className="line-through truncate">Generate: {prompt}</span>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function ParsedImage({ result }: { result: unknown }) {
|
||||||
|
const image = parseSerializableImage(result);
|
||||||
|
return (
|
||||||
|
<Image
|
||||||
|
id={image.id}
|
||||||
|
assetId={image.assetId}
|
||||||
|
src={image.src}
|
||||||
|
alt={image.alt}
|
||||||
|
title={image.title ?? undefined}
|
||||||
|
description={image.description ?? undefined}
|
||||||
|
href={image.href ?? undefined}
|
||||||
|
domain={image.domain ?? undefined}
|
||||||
|
ratio={image.ratio ?? undefined}
|
||||||
|
source={image.source ?? undefined}
|
||||||
|
maxWidth="512px"
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tool UI for generate_image — renders the generated image directly
|
||||||
|
* from the tool result directly.
|
||||||
|
*/
|
||||||
|
export const GenerateImageToolUI = ({ args, result, status }: ToolCallMessagePartProps<GenerateImageArgs, GenerateImageResult>) => {
|
||||||
|
const prompt = args.prompt || "Generating image...";
|
||||||
|
|
||||||
|
if (status.type === "running" || status.type === "requires-action") {
|
||||||
|
return (
|
||||||
|
<div className="my-4">
|
||||||
|
<ImageLoading title="Generating image" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status.type === "incomplete") {
|
||||||
|
if (status.reason === "cancelled") {
|
||||||
|
return <ImageCancelledState prompt={prompt} />;
|
||||||
|
}
|
||||||
|
if (status.reason === "error") {
|
||||||
|
return (
|
||||||
|
<ImageErrorState
|
||||||
|
prompt={prompt}
|
||||||
|
error={typeof status.error === "string" ? status.error : "An error occurred"}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!result) {
|
||||||
|
return (
|
||||||
|
<div className="my-4">
|
||||||
|
<ImageLoading title="Loading" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result.error) {
|
||||||
|
return <ImageErrorState prompt={prompt} error={result.error} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="my-4">
|
||||||
|
<ImageErrorBoundary>
|
||||||
|
<ParsedImage result={result} />
|
||||||
|
</ImageErrorBoundary>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export {
|
||||||
|
GenerateImageArgsSchema,
|
||||||
|
GenerateImageResultSchema,
|
||||||
|
type GenerateImageArgs,
|
||||||
|
type GenerateImageResult,
|
||||||
|
};
|
||||||
|
|
@ -6,7 +6,7 @@ import { Component, type ReactNode, useState } from "react";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
import { Card } from "@/components/ui/card";
|
import { Card } from "@/components/ui/card";
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -145,14 +145,14 @@ export class ImageErrorBoundary extends Component<
|
||||||
render() {
|
render() {
|
||||||
if (this.state.hasError) {
|
if (this.state.hasError) {
|
||||||
return (
|
return (
|
||||||
<Card className="w-full max-w-md overflow-hidden">
|
<Card className="w-full max-w-md overflow-hidden rounded-2xl border-0 shadow-none select-none">
|
||||||
<div className="aspect-square bg-muted flex items-center justify-center">
|
<div className="aspect-square bg-muted flex items-center justify-center">
|
||||||
<div className="flex flex-col items-center gap-2 text-muted-foreground">
|
<div className="flex flex-col items-center gap-2 text-muted-foreground">
|
||||||
<ImageIcon className="size-8" />
|
<ImageIcon className="size-8" />
|
||||||
<p className="text-sm">Failed to load image</p>
|
<p className="text-sm">Failed to load image</p>
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</div>
|
||||||
|
</Card>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -165,7 +165,7 @@ export class ImageErrorBoundary extends Component<
|
||||||
*/
|
*/
|
||||||
export function ImageSkeleton({ maxWidth = "512px" }: { maxWidth?: string }) {
|
export function ImageSkeleton({ maxWidth = "512px" }: { maxWidth?: string }) {
|
||||||
return (
|
return (
|
||||||
<Card className="w-full overflow-hidden animate-pulse" style={{ maxWidth }}>
|
<Card className="w-full overflow-hidden rounded-2xl border-0 shadow-none select-none animate-pulse" style={{ maxWidth }}>
|
||||||
<div className="aspect-square bg-muted flex items-center justify-center">
|
<div className="aspect-square bg-muted flex items-center justify-center">
|
||||||
<ImageIcon className="size-12 text-muted-foreground/30" />
|
<ImageIcon className="size-12 text-muted-foreground/30" />
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -176,14 +176,11 @@ export function ImageSkeleton({ maxWidth = "512px" }: { maxWidth?: string }) {
|
||||||
/**
|
/**
|
||||||
* Image Loading State
|
* Image Loading State
|
||||||
*/
|
*/
|
||||||
export function ImageLoading({ title = "Loading image..." }: { title?: string }) {
|
export function ImageLoading({ title = "Loading", maxWidth = "512px" }: { title?: string; maxWidth?: string }) {
|
||||||
return (
|
return (
|
||||||
<Card className="w-full max-w-md overflow-hidden">
|
<Card className="w-full overflow-hidden rounded-2xl border-0 shadow-none select-none" style={{ maxWidth }}>
|
||||||
<div className="aspect-square bg-muted flex items-center justify-center">
|
<div className="aspect-square bg-muted flex items-center justify-center">
|
||||||
<div className="flex flex-col items-center gap-3">
|
<TextShimmerLoader text={title} size="md" />
|
||||||
<Spinner size="lg" className="text-muted-foreground" />
|
|
||||||
<p className="text-muted-foreground text-sm">{title}</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
);
|
);
|
||||||
|
|
@ -214,8 +211,8 @@ export function Image({
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
const [imageError, setImageError] = useState(false);
|
const [imageError, setImageError] = useState(false);
|
||||||
const [imageLoaded, setImageLoaded] = useState(false);
|
const [imageLoaded, setImageLoaded] = useState(false);
|
||||||
const displayDomain = domain || source?.label;
|
|
||||||
const isGenerated = domain === "ai-generated";
|
const isGenerated = domain === "ai-generated";
|
||||||
|
const displayDomain = isGenerated ? "AI Generated" : (domain || source?.label);
|
||||||
const isAutoRatio = !ratio || ratio === "auto";
|
const isAutoRatio = !ratio || ratio === "auto";
|
||||||
|
|
||||||
const handleClick = () => {
|
const handleClick = () => {
|
||||||
|
|
@ -227,7 +224,7 @@ export function Image({
|
||||||
|
|
||||||
if (imageError) {
|
if (imageError) {
|
||||||
return (
|
return (
|
||||||
<Card id={id} className={cn("w-full overflow-hidden", className)} style={{ maxWidth }}>
|
<Card id={id} className={cn("w-full overflow-hidden rounded-2xl border-0 shadow-none select-none", className)} style={{ maxWidth }}>
|
||||||
<div className="aspect-square bg-muted flex items-center justify-center">
|
<div className="aspect-square bg-muted flex items-center justify-center">
|
||||||
<div className="flex flex-col items-center gap-2 text-muted-foreground">
|
<div className="flex flex-col items-center gap-2 text-muted-foreground">
|
||||||
<ImageIcon className="size-8" />
|
<ImageIcon className="size-8" />
|
||||||
|
|
@ -242,8 +239,7 @@ export function Image({
|
||||||
<Card
|
<Card
|
||||||
id={id}
|
id={id}
|
||||||
className={cn(
|
className={cn(
|
||||||
"group w-full overflow-hidden cursor-pointer transition-shadow duration-200 hover:shadow-lg",
|
"group w-full overflow-hidden rounded-2xl border-0 shadow-none select-none cursor-pointer transition-shadow duration-200 hover:shadow-lg",
|
||||||
isGenerated && "ring-1 ring-primary/10",
|
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
style={{ maxWidth }}
|
style={{ maxWidth }}
|
||||||
|
|
@ -263,20 +259,24 @@ export function Image({
|
||||||
{isAutoRatio ? (
|
{isAutoRatio ? (
|
||||||
/* Auto ratio: image renders at natural dimensions, no cropping */
|
/* Auto ratio: image renders at natural dimensions, no cropping */
|
||||||
<>
|
<>
|
||||||
{!imageLoaded && (
|
{!imageLoaded && (
|
||||||
<div className="aspect-square flex items-center justify-center">
|
<div className="aspect-square flex items-center justify-center">
|
||||||
<Spinner size="lg" className="text-muted-foreground" />
|
<TextShimmerLoader text="Loading" size="md" />
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
<NextImage
|
||||||
<img
|
|
||||||
src={src}
|
src={src}
|
||||||
alt={alt}
|
alt={alt}
|
||||||
|
width={0}
|
||||||
|
height={0}
|
||||||
|
sizes="100vw"
|
||||||
|
loading="eager"
|
||||||
className={cn(
|
className={cn(
|
||||||
"w-full h-auto transition-transform duration-300",
|
"w-full h-auto transition-transform duration-300",
|
||||||
isHovered && "scale-[1.02]",
|
isHovered && "scale-[1.02]",
|
||||||
!imageLoaded && "hidden"
|
!imageLoaded && "hidden"
|
||||||
)}
|
)}
|
||||||
|
unoptimized
|
||||||
onLoad={() => setImageLoaded(true)}
|
onLoad={() => setImageLoaded(true)}
|
||||||
onError={() => setImageError(true)}
|
onError={() => setImageError(true)}
|
||||||
/>
|
/>
|
||||||
|
|
@ -316,11 +316,9 @@ export function Image({
|
||||||
{description && (
|
{description && (
|
||||||
<p className="text-white/80 text-xs line-clamp-2 mb-1.5">{description}</p>
|
<p className="text-white/80 text-xs line-clamp-2 mb-1.5">{description}</p>
|
||||||
)}
|
)}
|
||||||
{displayDomain && (
|
{displayDomain && !isGenerated && (
|
||||||
<div className="flex items-center gap-1.5">
|
<div className="flex items-center gap-1.5">
|
||||||
{isGenerated ? (
|
{source?.iconUrl ? (
|
||||||
<SparklesIcon className="size-3.5 text-white/70" />
|
|
||||||
) : source?.iconUrl ? (
|
|
||||||
<NextImage
|
<NextImage
|
||||||
src={source.iconUrl}
|
src={source.iconUrl}
|
||||||
alt={source.label}
|
alt={source.label}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue