* added image tool

* show image in playground

* store images in Redis

* new images use unique urls

* moved from redis to s3 for image urls

* removed unnecessary changes

* removed the bubble around assistant messages

* added a download button on hover on image

* increased image size and removed border

* revert the bubbes for the assistant messages
This commit is contained in:
arkml 2025-09-11 20:50:20 +05:30 committed by GitHub
parent af0fcce127
commit 158777b045
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 489 additions and 22 deletions

View file

@ -6,6 +6,9 @@ import { z } from "zod";
import { composio } from "@/src/application/lib/composio/composio";
import { SignJWT } from "jose";
import crypto from "crypto";
import { GoogleGenerativeAI } from "@google/generative-ai";
import { tempBinaryCache } from "@/src/application/services/temp-binary-cache";
import { S3Client, PutObjectCommand } from "@aws-sdk/client-s3";
// Internal dependencies
import { embeddingModel } from "@/app/lib/embedding";
@ -31,6 +34,87 @@ const openai = createOpenAI({
baseURL: PROVIDER_BASE_URL,
});
// Image generation (Gemini) defaults
const DEFAULT_IMAGE_MODEL = "gemini-2.5-flash-image-preview";
// Helper to generate an image using Gemini
export async function invokeGenerateImageTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
prompt: string,
options?: {
modelName?: string;
}
): Promise<{
texts: string[];
images: { mimeType: string; bytes: number; dataBase64: string }[];
model: string;
}> {
const log = logger.child("invokeGenerateImageTool");
const apiKey = process.env.GOOGLE_API_KEY || process.env.GEMINI_API_KEY || "";
if (!apiKey) {
throw new Error("Missing API key. Set GOOGLE_API_KEY or GEMINI_API_KEY.");
}
const modelName = options?.modelName || DEFAULT_IMAGE_MODEL;
const client = new GoogleGenerativeAI(apiKey);
const model = client.getGenerativeModel({ model: modelName });
log.log(`Generating image with model: ${modelName}`);
const result = await model.generateContent(prompt);
const response = result.response as any;
// Track usage if available
try {
const inputTokens = response?.usageMetadata?.promptTokenCount || 0;
const outputTokens = response?.usageMetadata?.candidatesTokenCount || 0;
usageTracker.track({
type: "LLM_USAGE",
modelName: modelName,
inputTokens,
outputTokens,
context: "agents_runtime.gemini_image_generation",
});
} catch (_) {
// ignore usage tracking errors
}
const candidates = (response?.candidates ?? []) as any[];
if (!candidates.length) {
throw new Error("No candidates returned in response.");
}
const parts = (candidates[0]?.content?.parts ?? []) as any[];
if (!parts.length) {
throw new Error("No parts in candidate content.");
}
const texts: string[] = [];
const images: { mimeType: string; bytes: number; dataBase64: string }[] = [];
for (const part of parts) {
if (typeof part.text === "string" && part.text.length) {
texts.push(part.text);
continue;
}
const dataB64 = part?.inlineData?.data as string | undefined;
if (dataB64) {
const mime = part?.inlineData?.mimeType || "image/png";
const buf = Buffer.from(dataB64, "base64");
images.push({ mimeType: mime, bytes: buf.length, dataBase64: dataB64 });
}
}
if (!images.length) {
log.log("No image part found in response.");
}
return { texts, images, model: modelName };
}
// Helper to handle mock tool responses
export async function invokeMockTool(
logger: PrefixLogger,
@ -528,6 +612,108 @@ export function createComposioTool(
});
}
// Helper to create a Gemini image generation tool
export function createGenerateImageTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowTool>,
projectId: string,
): Tool {
const { name, description, parameters } = config;
return tool({
name,
description,
strict: false,
parameters: {
type: 'object',
properties: parameters.properties,
required: parameters.required || [],
additionalProperties: true,
},
async execute(input: any) {
try {
const prompt: string = input?.prompt || '';
if (!prompt) {
return JSON.stringify({ error: "Missing required field: prompt" });
}
const modelName: string | undefined = input?.modelName;
const result = await invokeGenerateImageTool(
logger,
usageTracker,
prompt,
{ modelName }
);
// If S3 bucket configured, store in S3 under generated_images/<c>/<d>/<filename>
const s3Bucket = process.env.UPLOADS_S3_BUCKET || '';
if (s3Bucket) {
const s3Region = process.env.UPLOADS_AWS_REGION || 'us-east-1';
const s3 = new S3Client({
region: s3Region,
credentials: process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY ? {
accessKeyId: process.env.AWS_ACCESS_KEY_ID,
secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY,
} as any : undefined,
});
const images = await Promise.all(result.images.map(async (img) => {
const buf = Buffer.from(img.dataBase64, 'base64');
const ext = img.mimeType === 'image/jpeg' ? '.jpg' : img.mimeType === 'image/webp' ? '.webp' : '.png';
const base = `${projectId}-${Math.floor(Math.random() * 1e12).toString(36)}`;
const last2 = base.slice(-2).padStart(2, '0');
const dirA = last2.charAt(0);
const dirB = last2.charAt(1);
const filename = `${base}${ext}`;
const key = `generated_images/${dirA}/${dirB}/${filename}`;
await s3.send(new PutObjectCommand({
Bucket: s3Bucket,
Key: key,
Body: buf,
ContentType: img.mimeType,
}));
const url = `/api/generated-images/${dirA}/${dirB}/${filename}`;
return { mimeType: img.mimeType, bytes: buf.length, url };
}));
const payload = {
model: result.model,
texts: result.texts,
images,
storage: 's3',
} as any;
return JSON.stringify(payload);
}
// Otherwise, use in-memory temp cache URLs
const ttlSec = 10 * 60; // 10 minutes
const ttlMs = ttlSec * 1000;
const images = result.images.map(img => {
try {
const buf = Buffer.from(img.dataBase64, 'base64');
const id = tempBinaryCache.put(buf, img.mimeType, ttlMs);
const url = `/api/tmp-images/${id}`;
return { mimeType: img.mimeType, bytes: buf.length, url };
} catch {
return { mimeType: img.mimeType, bytes: img.bytes, url: null };
}
});
const payload = {
model: result.model,
texts: result.texts,
images,
storage: 'temp',
expiresInSec: ttlSec,
} as any;
return JSON.stringify(payload);
} catch (error) {
logger.log(`Error executing generate image tool ${name}:`, error);
return JSON.stringify({
error: "Tool execution failed!",
});
}
}
});
}
export function createTools(
logger: PrefixLogger,
usageTracker: UsageTracker,
@ -541,7 +727,7 @@ export function createTools(
toolLogger.log(`=== CREATING ${Object.keys(toolConfig).length} TOOLS ===`);
for (const [toolName, config] of Object.entries(toolConfig)) {
toolLogger.log(`creating tool: ${toolName} (type: ${config.mockTool ? 'mock' : config.isMcp ? 'mcp' : config.isComposio ? 'composio' : 'webhook'})`);
toolLogger.log(`creating tool: ${toolName} (type: ${config.mockTool ? 'mock' : config.isMcp ? 'mcp' : config.isComposio ? 'composio' : config.isGeminiImage ? 'gemini-image' : 'webhook'})`);
if (config.mockTool) {
tools[toolName] = createMockTool(logger, usageTracker, config);
@ -552,6 +738,9 @@ export function createTools(
} else if (config.isComposio) {
tools[toolName] = createComposioTool(logger, usageTracker, config, projectId);
toolLogger.log(`✓ created composio tool: ${toolName}`);
} else if (config.isGeminiImage) {
tools[toolName] = createGenerateImageTool(logger, usageTracker, config, projectId);
toolLogger.log(`✓ created gemini image tool: ${toolName}`);
} else if (config.isWebhook) {
tools[toolName] = createWebhookTool(logger, usageTracker, config, projectId);
toolLogger.log(`✓ created webhook tool: ${toolName} (fallback)`);
@ -563,4 +752,4 @@ export function createTools(
toolLogger.log(`=== TOOL CREATION COMPLETE ===`);
return tools;
}
}

View file

@ -0,0 +1,47 @@
import crypto from 'crypto';
type Entry = {
buf: Buffer;
mimeType: string;
expiresAt: number; // epoch ms
};
class TempBinaryCache {
private store = new Map<string, Entry>();
private cleanupInterval: NodeJS.Timeout | null = null;
constructor() {
this.startCleanup();
}
private startCleanup() {
if (this.cleanupInterval) return;
this.cleanupInterval = setInterval(() => {
const now = Date.now();
for (const [id, entry] of this.store.entries()) {
if (entry.expiresAt <= now) this.store.delete(id);
}
}, 60_000); // every minute
if (this.cleanupInterval.unref) this.cleanupInterval.unref();
}
put(buf: Buffer, mimeType: string, ttlMs: number = 10 * 60 * 1000): string {
const id = crypto.randomUUID();
const expiresAt = Date.now() + ttlMs;
this.store.set(id, { buf, mimeType, expiresAt });
return id;
}
get(id: string): { buf: Buffer; mimeType: string } | undefined {
const entry = this.store.get(id);
if (!entry) return undefined;
if (entry.expiresAt <= Date.now()) {
this.store.delete(id);
return undefined;
}
return { buf: entry.buf, mimeType: entry.mimeType };
}
}
export const tempBinaryCache = new TempBinaryCache();