From 527afff60161d51eb6048ac6f9f61c0e67656fb9 Mon Sep 17 00:00:00 2001 From: arkml Date: Fri, 19 Sep 2025 15:26:13 +0530 Subject: [PATCH] make generate image accept an input image --- .../lib/agents-runtime/agent-tools.ts | 146 +++++++++++++++++- 1 file changed, 142 insertions(+), 4 deletions(-) diff --git a/apps/rowboat/src/application/lib/agents-runtime/agent-tools.ts b/apps/rowboat/src/application/lib/agents-runtime/agent-tools.ts index 545ce677..c4b9590a 100644 --- a/apps/rowboat/src/application/lib/agents-runtime/agent-tools.ts +++ b/apps/rowboat/src/application/lib/agents-runtime/agent-tools.ts @@ -8,7 +8,7 @@ import { SignJWT } from "jose"; import crypto from "crypto"; import { GoogleGenerativeAI } from "@google/generative-ai"; import { tempBinaryCache } from "@/src/application/services/temp-binary-cache"; -import { S3Client, PutObjectCommand } from "@aws-sdk/client-s3"; +import { S3Client, PutObjectCommand, GetObjectCommand, HeadObjectCommand } from "@aws-sdk/client-s3"; // Internal dependencies import { embeddingModel } from "@/app/lib/embedding"; @@ -44,6 +44,7 @@ export async function invokeGenerateImageTool( prompt: string, options?: { modelName?: string; + inputImageUrl?: string; } ): Promise<{ texts: string[]; @@ -62,7 +63,140 @@ export async function invokeGenerateImageTool( const model = client.getGenerativeModel({ model: modelName }); log.log(`Generating image with model: ${modelName}`); - const result = await model.generateContent(prompt); + + let result: any; + const inputImageUrl = options?.inputImageUrl; + if (inputImageUrl) { + try { + // Resolve the image into inlineData for Gemini + let imageBuf: Buffer | null = null; + let imageMime: string = 'image/png'; + + if (inputImageUrl.startsWith('/api/tmp-images/')) { + const id = inputImageUrl.split('/api/tmp-images/')[1]; + const entry = tempBinaryCache.get(id); + if (entry) { + imageBuf = entry.buf; + imageMime = entry.mimeType || imageMime; + } + } else if (inputImageUrl.startsWith('/api/uploaded-images/')) { + const bucket = process.env.RAG_UPLOADS_S3_BUCKET || ''; + if (bucket) { + const region = process.env.RAG_UPLOADS_S3_REGION || 'us-east-1'; + const s3 = new S3Client({ + region, + credentials: process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY ? { + accessKeyId: process.env.AWS_ACCESS_KEY_ID, + secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY, + } as any : undefined, + }); + const id = inputImageUrl.split('/api/uploaded-images/')[1]; + const last2 = id.slice(-2).padStart(2, '0'); + const dirA = last2.charAt(0); + const dirB = last2.charAt(1); + const baseKey = `uploaded_images/${dirA}/${dirB}/${id}`; + const exts = ['.png', '.jpg', '.webp', '.bin']; + let foundExt: string | null = null; + for (const ext of exts) { + try { + await s3.send(new HeadObjectCommand({ Bucket: bucket, Key: `${baseKey}${ext}` })); + foundExt = ext; break; + } catch {} + } + if (foundExt) { + const key = `${baseKey}${foundExt}`; + const resp = await s3.send(new GetObjectCommand({ Bucket: bucket, Key: key })); + const chunks: Buffer[] = []; + const body = resp.Body as any; + const nodeStream = typeof body?.pipe === 'function' ? body : undefined; + if (nodeStream) { + imageMime = resp.ContentType || imageMime; + await new Promise((resolve, reject) => { + nodeStream.on('data', (c: Buffer) => chunks.push(Buffer.isBuffer(c) ? c : Buffer.from(c))); + nodeStream.on('end', () => resolve()); + nodeStream.on('error', reject); + }); + imageBuf = Buffer.concat(chunks); + } + } + } + } else if (inputImageUrl.startsWith('/api/generated-images/')) { + const bucket = process.env.RAG_UPLOADS_S3_BUCKET || ''; + if (bucket) { + const region = process.env.RAG_UPLOADS_S3_REGION || 'us-east-1'; + const s3 = new S3Client({ + region, + credentials: process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY ? { + accessKeyId: process.env.AWS_ACCESS_KEY_ID, + secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY, + } as any : undefined, + }); + const id = inputImageUrl.split('/api/generated-images/')[1]; + const last2 = id.slice(-2).padStart(2, '0'); + const dirA = last2.charAt(0); + const dirB = last2.charAt(1); + const baseKey = `generated_images/${dirA}/${dirB}/${id}`; + const exts = ['.png', '.jpg', '.webp']; + let foundExt: string | null = null; + for (const ext of exts) { + try { + await s3.send(new HeadObjectCommand({ Bucket: bucket, Key: `${baseKey}${ext}` })); + foundExt = ext; break; + } catch {} + } + if (foundExt) { + const key = `${baseKey}${foundExt}`; + const resp = await s3.send(new GetObjectCommand({ Bucket: bucket, Key: key })); + const chunks: Buffer[] = []; + const body = resp.Body as any; + const nodeStream = typeof body?.pipe === 'function' ? body : undefined; + if (nodeStream) { + imageMime = resp.ContentType || imageMime; + await new Promise((resolve, reject) => { + nodeStream.on('data', (c: Buffer) => chunks.push(Buffer.isBuffer(c) ? c : Buffer.from(c))); + nodeStream.on('end', () => resolve()); + nodeStream.on('error', reject); + }); + imageBuf = Buffer.concat(chunks); + } + } + } + } else if (inputImageUrl.startsWith('data:')) { + // data URL + const m = inputImageUrl.match(/^data:([^;]+);base64,(.*)$/); + if (m) { + imageMime = m[1]; + imageBuf = Buffer.from(m[2], 'base64'); + } + } else if (/^https?:\/\//.test(inputImageUrl)) { + // Best-effort network fetch (may fail if egress restricted) + try { + const resp = await fetch(inputImageUrl); + const ab = await resp.arrayBuffer(); + imageBuf = Buffer.from(ab); + imageMime = resp.headers.get('content-type') || imageMime; + } catch { + // ignore + } + } + + if (imageBuf) { + const parts: any[] = [ + { inlineData: { data: imageBuf.toString('base64'), mimeType: imageMime } }, + prompt, + ]; + result = await model.generateContent(parts as any); + } else { + // Fallback to text-only + result = await model.generateContent(prompt); + } + } catch (e) { + log.log('Falling back to text-only generation due to input image error'); + result = await model.generateContent(prompt); + } + } else { + result = await model.generateContent(prompt); + } const response = result.response as any; // Track usage if available @@ -627,7 +761,10 @@ export function createGenerateImageTool( strict: false, parameters: { type: 'object', - properties: parameters.properties, + properties: { + ...parameters.properties, + input_image_url: { type: 'string', description: 'Optional URL of an input image to condition generation.' }, + }, required: parameters.required || [], additionalProperties: true, }, @@ -638,11 +775,12 @@ export function createGenerateImageTool( return JSON.stringify({ error: "Missing required field: prompt" }); } const modelName: string | undefined = input?.modelName; + const inputImageUrl: string | undefined = input?.input_image_url; const result = await invokeGenerateImageTool( logger, usageTracker, prompt, - { modelName } + { modelName, inputImageUrl } ); // If S3 bucket configured, store in S3 under generated_images/// const s3Bucket = process.env.RAG_UPLOADS_S3_BUCKET || '';