mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
feat: added image gen support
This commit is contained in:
parent
459ffd2b78
commit
19e2857343
39 changed files with 3950 additions and 181 deletions
|
|
@ -8,6 +8,7 @@ Available tools:
|
|||
- search_knowledge_base: Search the user's personal knowledge base
|
||||
- search_surfsense_docs: Search Surfsense documentation for usage help
|
||||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- link_preview: Fetch rich previews for URLs
|
||||
- display_image: Display images in chat
|
||||
- scrape_webpage: Extract content from webpages
|
||||
|
|
@ -18,6 +19,7 @@ Available tools:
|
|||
# Registry exports
|
||||
# Tool factory exports (for direct use)
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .knowledge_base import (
|
||||
CONNECTOR_DESCRIPTIONS,
|
||||
create_search_knowledge_base_tool,
|
||||
|
|
@ -47,6 +49,7 @@ __all__ = [
|
|||
"build_tools",
|
||||
# Tool factories
|
||||
"create_display_image_tool",
|
||||
"create_generate_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_link_preview_tool",
|
||||
"create_recall_memory_tool",
|
||||
|
|
|
|||
|
|
@ -82,14 +82,20 @@ def create_display_image_tool():
|
|||
|
||||
domain = extract_domain(src)
|
||||
|
||||
# Determine aspect ratio based on common image sources
|
||||
ratio = "16:9" # Default
|
||||
if "unsplash.com" in src or "pexels.com" in src:
|
||||
# Determine aspect ratio based on image source
|
||||
# AI-generated images should use "auto" to preserve their native ratio
|
||||
is_generated = "/image-generations/" in src
|
||||
if is_generated:
|
||||
ratio = "auto"
|
||||
domain = "ai-generated"
|
||||
elif "unsplash.com" in src or "pexels.com" in src:
|
||||
ratio = "16:9"
|
||||
elif (
|
||||
"imgur.com" in src or "github.com" in src or "githubusercontent.com" in src
|
||||
):
|
||||
ratio = "auto"
|
||||
else:
|
||||
ratio = "auto"
|
||||
|
||||
return {
|
||||
"id": image_id,
|
||||
|
|
|
|||
254
surfsense_backend/app/agents/new_chat/tools/generate_image.py
Normal file
254
surfsense_backend/app/agents/new_chat/tools/generate_image.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""
|
||||
Image generation tool for the SurfSense agent.
|
||||
|
||||
This module provides a tool that generates images using litellm.aimage_generation()
|
||||
and returns the result via the existing display_image tool format so the frontend
|
||||
renders the generated image inline in the chat.
|
||||
|
||||
Config resolution:
|
||||
1. Uses the search space's image_generation_config_id preference
|
||||
2. Falls back to Auto mode (router load balancing) if available
|
||||
3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping (same as routes)
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def create_generate_image_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_image tool.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID (for config resolution)
|
||||
db_session: Async database session
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
size: str = "1024x1024",
|
||||
quality: str = "auto",
|
||||
n: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate an image from a text description using AI image models.
|
||||
|
||||
Use this tool when the user asks you to create, generate, draw, or make an image.
|
||||
The generated image will be displayed directly in the chat.
|
||||
|
||||
Args:
|
||||
prompt: A detailed text description of the image to generate.
|
||||
Be specific about subject, style, colors, composition, and mood.
|
||||
size: Image size. Options: "1024x1024" (square), "1536x1024" (landscape),
|
||||
"1024x1536" (portrait), "1792x1024" (wide). Default: "1024x1024"
|
||||
quality: Image quality. Options: "auto" (default), "high", "medium", "low".
|
||||
Default: "auto"
|
||||
n: Number of images to generate (1-4). Default: 1
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generated image(s) for display in the chat.
|
||||
"""
|
||||
try:
|
||||
# Resolve the image generation config from the search space preference
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return {"error": "Search space not found"}
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# Build generation kwargs
|
||||
# NOTE: 'style' is intentionally excluded from gen_kwargs because
|
||||
# it is only supported by DALL-E 3 and causes errors with other
|
||||
# models (e.g. gpt-image-1 rejects it as an unknown parameter).
|
||||
# Since we can't predict which model auto-mode will route to,
|
||||
# it's safest to omit it.
|
||||
gen_kwargs: dict[str, Any] = {}
|
||||
if n is not None and n > 1:
|
||||
gen_kwargs["n"] = n
|
||||
if quality:
|
||||
gen_kwargs["quality"] = quality
|
||||
if size:
|
||||
gen_kwargs["size"] = size
|
||||
|
||||
# Call litellm based on config type
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
return {
|
||||
"error": "No image generation models configured. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
}
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
)
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
return {"error": f"Image generation config {config_id} not found"}
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg["model_name"],
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
cfg_result = await db_session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
return {"error": f"Image generation config {config_id} not found"}
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value,
|
||||
db_cfg.model_name,
|
||||
db_cfg.custom_provider,
|
||||
)
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
|
||||
# Parse the response and store in DB
|
||||
response_dict = (
|
||||
response.model_dump()
|
||||
if hasattr(response, "model_dump")
|
||||
else dict(response)
|
||||
)
|
||||
|
||||
# Generate a random access token for this image
|
||||
access_token = generate_image_token()
|
||||
|
||||
# Save to image_generations table for history
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
quality=quality,
|
||||
size=size,
|
||||
image_generation_config_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
db_session.add(db_image_gen)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(db_image_gen)
|
||||
|
||||
# Extract image URLs from response
|
||||
images = response_dict.get("data", [])
|
||||
if not images:
|
||||
return {"error": "No images were generated"}
|
||||
|
||||
first_image = images[0]
|
||||
revised_prompt = first_image.get("revised_prompt", prompt)
|
||||
|
||||
# Resolve image URL:
|
||||
# - If the API returned a URL, use it directly.
|
||||
# - If the API returned b64_json (e.g. gpt-image-1), serve the
|
||||
# image through our backend endpoint to avoid bloating the
|
||||
# LLM context with megabytes of base64 data.
|
||||
if first_image.get("url"):
|
||||
image_url = first_image["url"]
|
||||
elif first_image.get("b64_json"):
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
image_url = (
|
||||
f"{backend_url}/api/v1/image-generations/"
|
||||
f"{db_image_gen.id}/image?token={access_token}"
|
||||
)
|
||||
else:
|
||||
return {"error": "No displayable image data in the response"}
|
||||
|
||||
return {
|
||||
"src": image_url,
|
||||
"alt": revised_prompt or prompt,
|
||||
"title": "Generated Image",
|
||||
"description": revised_prompt if revised_prompt != prompt else None,
|
||||
"generated": True,
|
||||
"prompt": prompt,
|
||||
"image_count": len(images),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Image generation failed in tool")
|
||||
return {
|
||||
"error": f"Image generation failed: {e!s}",
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
return generate_image
|
||||
|
|
@ -44,6 +44,7 @@ from typing import Any
|
|||
from langchain_core.tools import BaseTool
|
||||
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .link_preview import create_link_preview_tool
|
||||
from .mcp_tool import load_mcp_tools
|
||||
|
|
@ -125,6 +126,16 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
factory=lambda deps: create_display_image_tool(),
|
||||
requires=[],
|
||||
),
|
||||
# Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.)
|
||||
ToolDefinition(
|
||||
name="generate_image",
|
||||
description="Generate images from text descriptions using AI image models",
|
||||
factory=lambda deps: create_generate_image_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
),
|
||||
requires=["search_space_id", "db_session"],
|
||||
),
|
||||
# Web scraping tool - extracts content from webpages
|
||||
ToolDefinition(
|
||||
name="scrape_webpage",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue