diff --git a/surfsense_backend/alembic/versions/93_add_image_generations_table.py b/surfsense_backend/alembic/versions/93_add_image_generations_table.py new file mode 100644 index 000000000..151208229 --- /dev/null +++ b/surfsense_backend/alembic/versions/93_add_image_generations_table.py @@ -0,0 +1,292 @@ +"""Add image generation tables and search space preference + +Revision ID: 93 +Revises: 92 + +Changes: +1. Create image_generation_configs table (user-created image model configs) +2. Create image_generations table (stores generation requests/results) +3. Add image_generation_config_id column to searchspaces table +4. Add image generation permissions to existing system roles +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "93" +down_revision: str | None = "92" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + connection = op.get_bind() + + # 1. Create imagegenprovider enum type if it doesn't exist + connection.execute( + sa.text( + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'imagegenprovider') THEN + CREATE TYPE imagegenprovider AS ENUM ( + 'OPENAI', 'AZURE_OPENAI', 'GOOGLE', 'VERTEX_AI', 'BEDROCK', + 'RECRAFT', 'OPENROUTER', 'XINFERENCE', 'NSCALE' + ); + END IF; + END + $$; + """ + ) + ) + + # 2. Create image_generation_configs table (uses imagegenprovider enum) + result = connection.execute( + sa.text( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generation_configs')" + ) + ) + if not result.scalar(): + op.create_table( + "image_generation_configs", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("name", sa.String(100), nullable=False), + sa.Column("description", sa.String(500), nullable=True), + sa.Column( + "provider", + sa.Enum( + "OPENAI", "AZURE_OPENAI", "GOOGLE", "VERTEX_AI", "BEDROCK", + "RECRAFT", "OPENROUTER", "XINFERENCE", "NSCALE", + name="imagegenprovider", + create_type=False, + ), + nullable=False, + ), + sa.Column("custom_provider", sa.String(100), nullable=True), + sa.Column("model_name", sa.String(100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(500), nullable=True), + sa.Column("api_version", sa.String(50), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generation_configs_name " + "ON image_generation_configs (name)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generation_configs_search_space_id " + "ON image_generation_configs (search_space_id)" + ) + + # 3. Create image_generations table + result = connection.execute( + sa.text( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generations')" + ) + ) + if not result.scalar(): + op.create_table( + "image_generations", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("prompt", sa.Text(), nullable=False), + sa.Column("model", sa.String(200), nullable=True), + sa.Column("n", sa.Integer(), nullable=True), + sa.Column("quality", sa.String(50), nullable=True), + sa.Column("size", sa.String(50), nullable=True), + sa.Column("style", sa.String(50), nullable=True), + sa.Column("response_format", sa.String(50), nullable=True), + sa.Column("image_generation_config_id", sa.Integer(), nullable=True), + sa.Column("response_data", JSONB(), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("created_by_id", UUID(as_uuid=True), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["created_by_id"], ["user.id"], ondelete="SET NULL" + ), + ) + + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generations_search_space_id " + "ON image_generations (search_space_id)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generations_created_by_id " + "ON image_generations (created_by_id)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generations_created_at " + "ON image_generations (created_at)" + ) + + # 4. Add image_generation_config_id column to searchspaces + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' + AND column_name = 'image_generation_config_id' + ) + """ + ) + ) + if not result.scalar(): + op.add_column( + "searchspaces", + sa.Column( + "image_generation_config_id", + sa.Integer(), + nullable=True, + server_default="0", + ), + ) + + # Drop old column name if it exists (from earlier version of this migration) + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' + AND column_name = 'image_generation_llm_id' + ) + """ + ) + ) + if result.scalar(): + op.drop_column("searchspaces", "image_generation_llm_id") + + # Drop old column name on image_generations if it exists + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'image_generations' + AND column_name = 'llm_config_id' + ) + """ + ) + ) + if result.scalar(): + op.drop_column("image_generations", "llm_config_id") + + # Drop old api_version column on image_generations if it exists + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'image_generations' + AND column_name = 'api_version' + ) + """ + ) + ) + if result.scalar(): + op.drop_column("image_generations", "api_version") + + # 5. Add image generation permissions to existing system roles + connection.execute( + sa.text( + """ + UPDATE search_space_roles + SET permissions = array_cat( + permissions, + ARRAY['image_generations:create', 'image_generations:read'] + ) + WHERE is_system_role = true + AND name = 'Editor' + AND NOT ('image_generations:create' = ANY(permissions)) + """ + ) + ) + connection.execute( + sa.text( + """ + UPDATE search_space_roles + SET permissions = array_cat( + permissions, + ARRAY['image_generations:read'] + ) + WHERE is_system_role = true + AND name = 'Viewer' + AND NOT ('image_generations:read' = ANY(permissions)) + """ + ) + ) + + +def downgrade() -> None: + connection = op.get_bind() + + # Remove permissions + connection.execute( + sa.text( + """ + UPDATE search_space_roles + SET permissions = array_remove( + array_remove( + array_remove(permissions, 'image_generations:create'), + 'image_generations:read' + ), + 'image_generations:delete' + ) + WHERE is_system_role = true + """ + ) + ) + + # Remove image_generation_config_id from searchspaces + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' + AND column_name = 'image_generation_config_id' + ) + """ + ) + ) + if result.scalar(): + op.drop_column("searchspaces", "image_generation_config_id") + + # Drop indexes and tables + op.execute("DROP INDEX IF EXISTS ix_image_generations_created_at") + op.execute("DROP INDEX IF EXISTS ix_image_generations_created_by_id") + op.execute("DROP INDEX IF EXISTS ix_image_generations_search_space_id") + op.execute("DROP TABLE IF EXISTS image_generations") + + op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_search_space_id") + op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_name") + op.execute("DROP TABLE IF EXISTS image_generation_configs") + + # Drop the imagegenprovider enum type + op.execute("DROP TYPE IF EXISTS imagegenprovider") diff --git a/surfsense_backend/alembic/versions/94_add_access_token_to_image_generations.py b/surfsense_backend/alembic/versions/94_add_access_token_to_image_generations.py new file mode 100644 index 000000000..09bea2c19 --- /dev/null +++ b/surfsense_backend/alembic/versions/94_add_access_token_to_image_generations.py @@ -0,0 +1,39 @@ +"""Add access_token column to image_generations + +Revision ID: 94 +Revises: 93 + +Adds an indexed access_token column to the image_generations table. +This token is stored per-record so that image serving URLs survive +SECRET_KEY rotation. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "94" +down_revision: str | None = "93" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # Add access_token column (nullable so existing rows are unaffected) + op.add_column( + "image_generations", + sa.Column("access_token", sa.String(64), nullable=True), + ) + op.create_index( + "ix_image_generations_access_token", + "image_generations", + ["access_token"], + ) + + +def downgrade() -> None: + op.drop_index("ix_image_generations_access_token", table_name="image_generations") + op.drop_column("image_generations", "access_token") diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 9c383c308..9da6ea3c2 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -133,6 +133,7 @@ async def create_surfsense_deep_agent( The agent comes with built-in tools that can be configured: - search_knowledge_base: Search the user's personal knowledge base - generate_podcast: Generate audio podcasts from content + - generate_image: Generate images from text descriptions using AI models - link_preview: Fetch rich previews for URLs - display_image: Display images in chat - scrape_webpage: Extract content from webpages diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index fdb80acb9..38bae230d 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -83,6 +83,7 @@ You have access to the following tools: * Showing an image from a URL the user explicitly mentioned in their message * Displaying images found in scraped webpage content (from scrape_webpage tool) * Showing a publicly accessible diagram or chart from a known URL + * Displaying an AI-generated image after calling the generate_image tool (ALWAYS required) CRITICAL - NEVER USE THIS TOOL FOR USER-UPLOADED ATTACHMENTS: When a user uploads/attaches an image file to their message: @@ -100,7 +101,23 @@ You have access to the following tools: - Returns: An image card with the image, title, and description - The image will automatically be displayed in the chat. -5. scrape_webpage: Scrape and extract the main content from a webpage. +5. generate_image: Generate images from text descriptions using AI image models. + - Use this when the user asks you to create, generate, draw, design, or make an image. + - Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork" + - Args: + - prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood. + - size: Image size. Options: "1024x1024" (square, default), "1536x1024" (landscape), "1024x1536" (portrait), "1792x1024" (wide) + - quality: Image quality. Options: "auto" (default), "high", "medium", "low" + - n: Number of images to generate (1-4, default: 1) + - Returns: A dictionary with the generated image URL in the "src" field, along with metadata. + - CRITICAL: After calling generate_image, you MUST call `display_image` with the returned "src" URL + to actually show the image in the chat. The generate_image tool only generates the image and returns + the URL — it does NOT display anything. You must always follow up with display_image. + - IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim - + expand and improve the prompt with specific details about style, lighting, composition, and mood. + - If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details. + +6. scrape_webpage: Scrape and extract the main content from a webpage. - Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage. - IMPORTANT: This is different from link_preview: * link_preview: Only fetches metadata (title, description, thumbnail) for display @@ -123,7 +140,7 @@ You have access to the following tools: * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. * Don't show every image - just the most relevant 1-3 images that enhance understanding. -6. save_memory: Save facts, preferences, or context about the user for personalized responses. +7. save_memory: Save facts, preferences, or context about the user for personalized responses. - Use this when the user explicitly or implicitly shares information worth remembering. - Trigger scenarios: * User says "remember this", "keep this in mind", "note that", or similar @@ -146,7 +163,7 @@ You have access to the following tools: - IMPORTANT: Only save information that would be genuinely useful for future conversations. Don't save trivial or temporary information. -7. recall_memory: Retrieve relevant memories about the user for personalized responses. +8. recall_memory: Retrieve relevant memories about the user for personalized responses. - Use this to access stored information about the user. - Trigger scenarios: * You need user context to give a better, more personalized answer @@ -281,6 +298,22 @@ You have access to the following tools: - Then, if the content contains useful diagrams/images like `![Neural Network Diagram](https://example.com/nn-diagram.png)`: - Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")` - Then provide your explanation, referencing the displayed image + +- User: "Generate an image of a cat" + - Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere", size="1024x1024", quality="auto")` + - Step 2: Use the returned "src" URL to display it: `display_image(src="", alt="A fluffy orange tabby cat on a windowsill", title="Generated Image")` + +- User: "Create a landscape painting of mountains" + - Step 1: `generate_image(prompt="Majestic snow-capped mountain range at sunset, dramatic orange and purple sky, alpine meadow with wildflowers in the foreground, oil painting style with visible brushstrokes, inspired by the Hudson River School art movement", size="1536x1024", quality="high")` + - Step 2: `display_image(src="", alt="Mountain landscape painting", title="Generated Image")` + +- User: "Draw me a logo for a coffee shop called Bean Dream" + - Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding", size="1024x1024", quality="high")` + - Step 2: `display_image(src="", alt="Bean Dream coffee shop logo", title="Generated Image")` + +- User: "Make a wide banner image for my blog about AI" + - Step 1: `generate_image(prompt="Wide banner illustration for an AI technology blog, featuring abstract neural network patterns, glowing blue and purple connections, modern futuristic aesthetic, digital art style, clean and professional", size="1792x1024", quality="high")` + - Step 2: `display_image(src="", alt="AI blog banner", title="Generated Image")` """ diff --git a/surfsense_backend/app/agents/new_chat/tools/__init__.py b/surfsense_backend/app/agents/new_chat/tools/__init__.py index 9e1a4f19c..0a11951f0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/__init__.py +++ b/surfsense_backend/app/agents/new_chat/tools/__init__.py @@ -8,6 +8,7 @@ Available tools: - search_knowledge_base: Search the user's personal knowledge base - search_surfsense_docs: Search Surfsense documentation for usage help - generate_podcast: Generate audio podcasts from content +- generate_image: Generate images from text descriptions using AI models - link_preview: Fetch rich previews for URLs - display_image: Display images in chat - scrape_webpage: Extract content from webpages @@ -18,6 +19,7 @@ Available tools: # Registry exports # Tool factory exports (for direct use) from .display_image import create_display_image_tool +from .generate_image import create_generate_image_tool from .knowledge_base import ( CONNECTOR_DESCRIPTIONS, create_search_knowledge_base_tool, @@ -47,6 +49,7 @@ __all__ = [ "build_tools", # Tool factories "create_display_image_tool", + "create_generate_image_tool", "create_generate_podcast_tool", "create_link_preview_tool", "create_recall_memory_tool", diff --git a/surfsense_backend/app/agents/new_chat/tools/display_image.py b/surfsense_backend/app/agents/new_chat/tools/display_image.py index 5eb846063..4424cc0d3 100644 --- a/surfsense_backend/app/agents/new_chat/tools/display_image.py +++ b/surfsense_backend/app/agents/new_chat/tools/display_image.py @@ -82,14 +82,20 @@ def create_display_image_tool(): domain = extract_domain(src) - # Determine aspect ratio based on common image sources - ratio = "16:9" # Default - if "unsplash.com" in src or "pexels.com" in src: + # Determine aspect ratio based on image source + # AI-generated images should use "auto" to preserve their native ratio + is_generated = "/image-generations/" in src + if is_generated: + ratio = "auto" + domain = "ai-generated" + elif "unsplash.com" in src or "pexels.com" in src: ratio = "16:9" elif ( "imgur.com" in src or "github.com" in src or "githubusercontent.com" in src ): ratio = "auto" + else: + ratio = "auto" return { "id": image_id, diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py new file mode 100644 index 000000000..091fb122f --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py @@ -0,0 +1,254 @@ +""" +Image generation tool for the SurfSense agent. + +This module provides a tool that generates images using litellm.aimage_generation() +and returns the result via the existing display_image tool format so the frontend +renders the generated image inline in the chat. + +Config resolution: +1. Uses the search space's image_generation_config_id preference +2. Falls back to Auto mode (router load balancing) if available +3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs) +""" + +import logging +from typing import Any + +from langchain_core.tools import tool +from litellm import aimage_generation +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace +from app.utils.signed_image_urls import generate_image_token +from app.services.image_gen_router_service import ( + IMAGE_GEN_AUTO_MODE_ID, + ImageGenRouterService, + is_image_gen_auto_mode, +) + +logger = logging.getLogger(__name__) + +# Provider mapping (same as routes) +_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} + + +def _build_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: + if custom_provider: + return f"{custom_provider}/{model_name}" + prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) + return f"{prefix}/{model_name}" + + +def _get_global_image_gen_config(config_id: int) -> dict | None: + """Get a global image gen config by negative ID.""" + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None + + +def create_generate_image_tool( + search_space_id: int, + db_session: AsyncSession, +): + """ + Factory function to create the generate_image tool. + + Args: + search_space_id: The search space ID (for config resolution) + db_session: Async database session + """ + + @tool + async def generate_image( + prompt: str, + size: str = "1024x1024", + quality: str = "auto", + n: int = 1, + ) -> dict[str, Any]: + """ + Generate an image from a text description using AI image models. + + Use this tool when the user asks you to create, generate, draw, or make an image. + The generated image will be displayed directly in the chat. + + Args: + prompt: A detailed text description of the image to generate. + Be specific about subject, style, colors, composition, and mood. + size: Image size. Options: "1024x1024" (square), "1536x1024" (landscape), + "1024x1536" (portrait), "1792x1024" (wide). Default: "1024x1024" + quality: Image quality. Options: "auto" (default), "high", "medium", "low". + Default: "auto" + n: Number of images to generate (1-4). Default: 1 + + Returns: + A dictionary containing the generated image(s) for display in the chat. + """ + try: + # Resolve the image generation config from the search space preference + result = await db_session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + return {"error": "Search space not found"} + + config_id = ( + search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + ) + + # Build generation kwargs + # NOTE: 'style' is intentionally excluded from gen_kwargs because + # it is only supported by DALL-E 3 and causes errors with other + # models (e.g. gpt-image-1 rejects it as an unknown parameter). + # Since we can't predict which model auto-mode will route to, + # it's safest to omit it. + gen_kwargs: dict[str, Any] = {} + if n is not None and n > 1: + gen_kwargs["n"] = n + if quality: + gen_kwargs["quality"] = quality + if size: + gen_kwargs["size"] = size + + # Call litellm based on config type + if is_image_gen_auto_mode(config_id): + if not ImageGenRouterService.is_initialized(): + return { + "error": "No image generation models configured. " + "Please add an image model in Settings > Image Models." + } + response = await ImageGenRouterService.aimage_generation( + prompt=prompt, model="auto", **gen_kwargs + ) + elif config_id < 0: + cfg = _get_global_image_gen_config(config_id) + if not cfg: + return {"error": f"Image generation config {config_id} not found"} + + model_string = _build_model_string( + cfg.get("provider", ""), + cfg["model_name"], + cfg.get("custom_provider"), + ) + gen_kwargs["api_key"] = cfg.get("api_key") + if cfg.get("api_base"): + gen_kwargs["api_base"] = cfg["api_base"] + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) + + response = await aimage_generation( + prompt=prompt, model=model_string, **gen_kwargs + ) + else: + # Positive ID = user-created ImageGenerationConfig + cfg_result = await db_session.execute( + select(ImageGenerationConfig).filter( + ImageGenerationConfig.id == config_id + ) + ) + db_cfg = cfg_result.scalars().first() + if not db_cfg: + return {"error": f"Image generation config {config_id} not found"} + + model_string = _build_model_string( + db_cfg.provider.value, + db_cfg.model_name, + db_cfg.custom_provider, + ) + gen_kwargs["api_key"] = db_cfg.api_key + if db_cfg.api_base: + gen_kwargs["api_base"] = db_cfg.api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) + + response = await aimage_generation( + prompt=prompt, model=model_string, **gen_kwargs + ) + + # Parse the response and store in DB + response_dict = ( + response.model_dump() + if hasattr(response, "model_dump") + else dict(response) + ) + + # Generate a random access token for this image + access_token = generate_image_token() + + # Save to image_generations table for history + db_image_gen = ImageGeneration( + prompt=prompt, + model=getattr(response, "_hidden_params", {}).get("model"), + n=n, + quality=quality, + size=size, + image_generation_config_id=config_id, + response_data=response_dict, + search_space_id=search_space_id, + access_token=access_token, + ) + db_session.add(db_image_gen) + await db_session.commit() + await db_session.refresh(db_image_gen) + + # Extract image URLs from response + images = response_dict.get("data", []) + if not images: + return {"error": "No images were generated"} + + first_image = images[0] + revised_prompt = first_image.get("revised_prompt", prompt) + + # Resolve image URL: + # - If the API returned a URL, use it directly. + # - If the API returned b64_json (e.g. gpt-image-1), serve the + # image through our backend endpoint to avoid bloating the + # LLM context with megabytes of base64 data. + if first_image.get("url"): + image_url = first_image["url"] + elif first_image.get("b64_json"): + backend_url = config.BACKEND_URL or "http://localhost:8000" + image_url = ( + f"{backend_url}/api/v1/image-generations/" + f"{db_image_gen.id}/image?token={access_token}" + ) + else: + return {"error": "No displayable image data in the response"} + + return { + "src": image_url, + "alt": revised_prompt or prompt, + "title": "Generated Image", + "description": revised_prompt if revised_prompt != prompt else None, + "generated": True, + "prompt": prompt, + "image_count": len(images), + } + + except Exception as e: + logger.exception("Image generation failed in tool") + return { + "error": f"Image generation failed: {e!s}", + "prompt": prompt, + } + + return generate_image diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index c65445419..2cf43c973 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -44,6 +44,7 @@ from typing import Any from langchain_core.tools import BaseTool from .display_image import create_display_image_tool +from .generate_image import create_generate_image_tool from .knowledge_base import create_search_knowledge_base_tool from .link_preview import create_link_preview_tool from .mcp_tool import load_mcp_tools @@ -125,6 +126,16 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ factory=lambda deps: create_display_image_tool(), requires=[], ), + # Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.) + ToolDefinition( + name="generate_image", + description="Generate images from text descriptions using AI image models", + factory=lambda deps: create_generate_image_tool( + search_space_id=deps["search_space_id"], + db_session=deps["db_session"], + ), + requires=["search_space_id", "db_session"], + ), # Web scraping tool - extracts content from webpages ToolDefinition( name="scrape_webpage", diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 63da4e8ad..0f619097e 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -9,7 +9,7 @@ from app.agents.new_chat.checkpointer import ( close_checkpointer, setup_checkpointer_tables, ) -from app.config import config, initialize_llm_router +from app.config import config, initialize_image_gen_router, initialize_llm_router from app.db import User, create_db_and_tables, get_async_session from app.routes import router as crud_router from app.routes.auth_routes import router as auth_router @@ -26,6 +26,8 @@ async def lifespan(app: FastAPI): await setup_checkpointer_tables() # Initialize LLM Router for Auto mode load balancing initialize_llm_router() + # Initialize Image Generation Router for Auto mode load balancing + initialize_image_gen_router() # Seed Surfsense documentation await seed_surfsense_docs() yield diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 74b21fbf0..af406eab7 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -13,14 +13,15 @@ load_dotenv() @worker_process_init.connect def init_worker(**kwargs): - """Initialize the LLM Router when a Celery worker process starts. + """Initialize the LLM Router and Image Gen Router when a Celery worker process starts. This ensures the Auto mode (LiteLLM Router) is available for background tasks - like document summarization. + like document summarization and image generation. """ - from app.config import initialize_llm_router + from app.config import initialize_image_gen_router, initialize_llm_router initialize_llm_router() + initialize_image_gen_router() # Get Celery configuration from environment diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 121e5d3b2..bb299e583 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -81,6 +81,56 @@ def load_router_settings(): return default_settings +def load_global_image_gen_configs(): + """ + Load global image generation configurations from YAML file. + + Returns: + list: List of global image generation config dictionaries, or empty list + """ + global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml" + + if not global_config_file.exists(): + return [] + + try: + with open(global_config_file, encoding="utf-8") as f: + data = yaml.safe_load(f) + return data.get("global_image_generation_configs", []) + except Exception as e: + print(f"Warning: Failed to load global image generation configs: {e}") + return [] + + +def load_image_gen_router_settings(): + """ + Load router settings for image generation Auto mode from YAML file. + + Returns: + dict: Router settings dictionary + """ + default_settings = { + "routing_strategy": "usage-based-routing", + "num_retries": 3, + "allowed_fails": 3, + "cooldown_time": 60, + } + + global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml" + + if not global_config_file.exists(): + return default_settings + + try: + with open(global_config_file, encoding="utf-8") as f: + data = yaml.safe_load(f) + settings = data.get("image_generation_router_settings", {}) + return {**default_settings, **settings} + except Exception as e: + print(f"Warning: Failed to load image generation router settings: {e}") + return default_settings + + def initialize_llm_router(): """ Initialize the LLM Router service for Auto mode. @@ -105,6 +155,33 @@ def initialize_llm_router(): print(f"Warning: Failed to initialize LLM Router: {e}") +def initialize_image_gen_router(): + """ + Initialize the Image Generation Router service for Auto mode. + This should be called during application startup. + """ + image_gen_configs = load_global_image_gen_configs() + router_settings = load_image_gen_router_settings() + + if not image_gen_configs: + print( + "Info: No global image generation configs found, " + "Image Generation Auto mode will not be available" + ) + return + + try: + from app.services.image_gen_router_service import ImageGenRouterService + + ImageGenRouterService.initialize(image_gen_configs, router_settings) + print( + f"Info: Image Generation Router initialized with {len(image_gen_configs)} models " + f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})" + ) + except Exception as e: + print(f"Warning: Failed to initialize Image Generation Router: {e}") + + class Config: # Check if ffmpeg is installed if not is_ffmpeg_installed(): @@ -216,6 +293,12 @@ class Config: # Router settings for Auto mode (LiteLLM Router load balancing) ROUTER_SETTINGS = load_router_settings() + # Global Image Generation Configurations (optional) + GLOBAL_IMAGE_GEN_CONFIGS = load_global_image_gen_configs() + + # Router settings for Image Generation Auto mode + IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings() + # Chonkie Configuration | Edit this to your needs EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") # Azure OpenAI credentials from environment variables diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 9b213aafe..e727b8d56 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -183,6 +183,73 @@ global_llm_configs: use_default_system_instructions: true citations_enabled: true +# ============================================================================= +# Image Generation Configuration +# ============================================================================= +# These configurations power the image generation feature using litellm.aimage_generation(). +# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock, +# Recraft, OpenRouter, Xinference, Nscale +# +# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs. + +# Router Settings for Image Generation Auto Mode +image_generation_router_settings: + routing_strategy: "usage-based-routing" + num_retries: 3 + allowed_fails: 3 + cooldown_time: 60 + +global_image_generation_configs: + # Example: OpenAI DALL-E 3 + - id: -1 + name: "Global DALL-E 3" + description: "OpenAI's DALL-E 3 for high-quality image generation" + provider: "OPENAI" + model_name: "dall-e-3" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 50 + tpm: 100000 + litellm_params: {} + + # Example: OpenAI GPT Image 1 + - id: -2 + name: "Global GPT Image 1" + description: "OpenAI's GPT Image 1 model" + provider: "OPENAI" + model_name: "gpt-image-1" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 50 + tpm: 100000 + litellm_params: {} + + # Example: Azure OpenAI DALL-E 3 + - id: -3 + name: "Global Azure DALL-E 3" + description: "Azure-hosted DALL-E 3 deployment" + provider: "AZURE_OPENAI" + model_name: "azure/dall-e-3-deployment" + api_key: "your-azure-api-key-here" + api_base: "https://your-resource.openai.azure.com" + api_version: "2024-02-15-preview" + rpm: 50 + tpm: 100000 + litellm_params: + base_model: "dall-e-3" + + # Example: OpenRouter Gemini Image Generation + # - id: -4 + # name: "Global Gemini Image Gen" + # description: "Google Gemini image generation via OpenRouter" + # provider: "OPENROUTER" + # model_name: "google/gemini-2.5-flash-image" + # api_key: "your-openrouter-api-key-here" + # api_base: "" + # rpm: 30 + # tpm: 50000 + # litellm_params: {} + # Notes: # - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing # - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB) @@ -202,3 +269,10 @@ global_llm_configs: # - model_name format: "azure/" # - api_version: Use a recent Azure API version (e.g., "2024-02-15-preview") # - See: https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models +# +# IMAGE GENERATION NOTES: +# - Image generation configs use the same ID scheme as LLM configs (negative for global) +# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure), +# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter) +# - The router uses litellm.aimage_generation() for async image generation +# - api_version is required for Azure image generation deployments diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 2298e7438..320ff6d8d 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -100,6 +100,7 @@ class PodcastStatus(str, Enum): FAILED = "failed" + class LiteLLMProvider(str, Enum): """ Enum for LLM providers supported by LiteLLM. @@ -137,6 +138,24 @@ class LiteLLMProvider(str, Enum): CUSTOM = "CUSTOM" +class ImageGenProvider(str, Enum): + """ + Enum for image generation providers supported by LiteLLM. + This is a subset of LLM providers — only those that support image generation. + See: https://docs.litellm.ai/docs/image_generation#supported-providers + """ + + OPENAI = "OPENAI" + AZURE_OPENAI = "AZURE_OPENAI" + GOOGLE = "GOOGLE" # Google AI Studio + VERTEX_AI = "VERTEX_AI" + BEDROCK = "BEDROCK" # AWS Bedrock + RECRAFT = "RECRAFT" + OPENROUTER = "OPENROUTER" + XINFERENCE = "XINFERENCE" + NSCALE = "NSCALE" + + class LogLevel(str, Enum): DEBUG = "DEBUG" INFO = "INFO" @@ -237,6 +256,11 @@ class Permission(str, Enum): PODCASTS_UPDATE = "podcasts:update" PODCASTS_DELETE = "podcasts:delete" + # Image Generations + IMAGE_GENERATIONS_CREATE = "image_generations:create" + IMAGE_GENERATIONS_READ = "image_generations:read" + IMAGE_GENERATIONS_DELETE = "image_generations:delete" + # Connectors CONNECTORS_CREATE = "connectors:create" CONNECTORS_READ = "connectors:read" @@ -298,6 +322,9 @@ DEFAULT_ROLE_PERMISSIONS = { Permission.PODCASTS_CREATE.value, Permission.PODCASTS_READ.value, Permission.PODCASTS_UPDATE.value, + # Image Generations (create and read, no delete) + Permission.IMAGE_GENERATIONS_CREATE.value, + Permission.IMAGE_GENERATIONS_READ.value, # Connectors (no delete) Permission.CONNECTORS_CREATE.value, Permission.CONNECTORS_READ.value, @@ -327,6 +354,8 @@ DEFAULT_ROLE_PERMISSIONS = { Permission.LLM_CONFIGS_READ.value, # Podcasts (read only) Permission.PODCASTS_READ.value, + # Image Generations (read only) + Permission.IMAGE_GENERATIONS_READ.value, # Connectors (read only) Permission.CONNECTORS_READ.value, # Logs (read only) @@ -881,6 +910,103 @@ class Podcast(BaseModel, TimestampMixin): thread = relationship("NewChatThread") +class ImageGenerationConfig(BaseModel, TimestampMixin): + """ + Dedicated configuration table for image generation models. + + Separate from NewLLMConfig because image generation models don't need + system_instructions, citations_enabled, or use_default_system_instructions. + They only need provider credentials and model parameters. + """ + + __tablename__ = "image_generation_configs" + + name = Column(String(100), nullable=False, index=True) + description = Column(String(500), nullable=True) + + # Provider & model (uses ImageGenProvider, NOT LiteLLMProvider) + provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False) + custom_provider = Column(String(100), nullable=True) + model_name = Column(String(100), nullable=False) + + # Credentials + api_key = Column(String, nullable=False) + api_base = Column(String(500), nullable=True) + api_version = Column(String(50), nullable=True) # Azure-specific + + # Additional litellm parameters + litellm_params = Column(JSON, nullable=True, default={}) + + # Relationships + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship("SearchSpace", back_populates="image_generation_configs") + + +class ImageGeneration(BaseModel, TimestampMixin): + """ + Stores image generation requests and results using litellm.aimage_generation(). + + Since aimage_generation is a single async call (not a background job), + there is no status enum. A row with response_data means success; + a row with error_message means failure. + + Response data is stored as JSONB matching the litellm output format: + { + "created": int, + "data": [{"b64_json": str|None, "revised_prompt": str|None, "url": str|None}], + "usage": {"prompt_tokens": int, "completion_tokens": int, "total_tokens": int} + } + """ + + __tablename__ = "image_generations" + + # Request parameters (matching litellm.aimage_generation() params) + prompt = Column(Text, nullable=False) + model = Column(String(200), nullable=True) # e.g., "dall-e-3", "gpt-image-1" + n = Column(Integer, nullable=True, default=1) + quality = Column( + String(50), nullable=True + ) # "auto", "high", "medium", "low", "hd", "standard" + size = Column( + String(50), nullable=True + ) # "1024x1024", "1536x1024", "1024x1536", etc. + style = Column(String(50), nullable=True) # Model-specific style parameter + response_format = Column( + String(50), nullable=True + ) # "url" or "b64_json" + + # Image generation config reference + # 0 = Auto mode (router), negative IDs = global configs from YAML, + # positive IDs = ImageGenerationConfig records in DB + image_generation_config_id = Column(Integer, nullable=True) + + # Response data (full litellm response as JSONB) — present on success + response_data = Column(JSONB, nullable=True) + # Error message — present on failure + error_message = Column(Text, nullable=True) + + # Signed access token for serving images via tags. + # Stored in DB so it survives SECRET_KEY rotation. + access_token = Column(String(64), nullable=True, index=True) + + # Foreign keys + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + created_by_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + # Relationships + search_space = relationship("SearchSpace", back_populates="image_generations") + created_by = relationship("User", back_populates="image_generations") + + class SearchSpace(BaseModel, TimestampMixin): __tablename__ = "searchspaces" @@ -905,6 +1031,9 @@ class SearchSpace(BaseModel, TimestampMixin): document_summary_llm_id = Column( Integer, nullable=True, default=0 ) # For document summarization, defaults to Auto mode + image_generation_config_id = Column( + Integer, nullable=True, default=0 + ) # For image generation, defaults to Auto mode user_id = Column( UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False @@ -929,6 +1058,12 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="Podcast.id.desc()", cascade="all, delete-orphan", ) + image_generations = relationship( + "ImageGeneration", + back_populates="search_space", + order_by="ImageGeneration.id.desc()", + cascade="all, delete-orphan", + ) logs = relationship( "Log", back_populates="search_space", @@ -953,6 +1088,12 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="NewLLMConfig.id", cascade="all, delete-orphan", ) + image_generation_configs = relationship( + "ImageGenerationConfig", + back_populates="search_space", + order_by="ImageGenerationConfig.id", + cascade="all, delete-orphan", + ) # RBAC relationships roles = relationship( @@ -1333,6 +1474,13 @@ if config.AUTH_TYPE == "GOOGLE": passive_deletes=True, ) + # Image generations created by this user + image_generations = relationship( + "ImageGeneration", + back_populates="created_by", + passive_deletes=True, + ) + # User memories for personalized AI responses memories = relationship( "UserMemory", @@ -1405,6 +1553,13 @@ else: passive_deletes=True, ) + # Image generations created by this user + image_generations = relationship( + "ImageGeneration", + back_populates="created_by", + passive_deletes=True, + ) + # User memories for personalized AI responses memories = relationship( "UserMemory", diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 746c18c6d..683f3548b 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -11,6 +11,7 @@ from .confluence_add_connector_route import router as confluence_add_connector_r from .discord_add_connector_route import router as discord_add_connector_router from .documents_routes import router as documents_router from .editor_routes import router as editor_router +from .image_generation_routes import router as image_generation_router from .google_calendar_add_connector_route import ( router as google_calendar_add_connector_router, ) @@ -49,6 +50,7 @@ router.include_router(notes_router) router.include_router(new_chat_router) # Chat with assistant-ui persistence router.include_router(chat_comments_router) router.include_router(podcasts_router) # Podcast task status and audio +router.include_router(image_generation_router) # Image generation via litellm router.include_router(search_source_connectors_router) router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py new file mode 100644 index 000000000..9b79771eb --- /dev/null +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -0,0 +1,646 @@ +""" +Image Generation routes: +- CRUD for ImageGenerationConfig (user-created image model configs) +- Global image gen configs endpoint (from YAML) +- Image generation execution (calls litellm.aimage_generation()) +- CRUD for ImageGeneration records (results) +- Image serving endpoint (serves b64_json images from DB, protected by signed tokens) +""" + +import base64 +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import Response +from litellm import aimage_generation +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import ( + ImageGeneration, + ImageGenerationConfig, + Permission, + SearchSpace, + SearchSpaceMembership, + User, + get_async_session, +) +from app.schemas import ( + GlobalImageGenConfigRead, + ImageGenerationConfigCreate, + ImageGenerationConfigRead, + ImageGenerationConfigUpdate, + ImageGenerationCreate, + ImageGenerationListRead, + ImageGenerationRead, +) +from app.services.image_gen_router_service import ( + IMAGE_GEN_AUTO_MODE_ID, + ImageGenRouterService, + is_image_gen_auto_mode, +) +from app.users import current_active_user +from app.utils.rbac import check_permission +from app.utils.signed_image_urls import verify_image_token + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Provider mapping for building litellm model strings. +# Only includes providers that support image generation. +# See: https://docs.litellm.ai/docs/image_generation#supported-providers +_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", # Google AI Studio + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", # AWS Bedrock + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} + + +def _get_global_image_gen_config(config_id: int) -> dict | None: + """Get a global image generation configuration by ID (negative IDs).""" + if config_id == IMAGE_GEN_AUTO_MODE_ID: + return { + "id": IMAGE_GEN_AUTO_MODE_ID, + "name": "Auto (Load Balanced)", + "provider": "AUTO", + "model_name": "auto", + "is_auto_mode": True, + } + if config_id > 0: + return None + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None + + +def _build_model_string(provider: str, model_name: str, custom_provider: str | None) -> str: + """Build a litellm model string from provider + model_name.""" + if custom_provider: + return f"{custom_provider}/{model_name}" + prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) + return f"{prefix}/{model_name}" + + +async def _execute_image_generation( + session: AsyncSession, + image_gen: ImageGeneration, + search_space: SearchSpace, +) -> None: + """ + Call litellm.aimage_generation() with the appropriate config. + + Resolution order: + 1. Explicit image_generation_config_id on the request + 2. Search space's image_generation_config_id preference + 3. Falls back to Auto mode if available + """ + config_id = image_gen.image_generation_config_id + if config_id is None: + config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + image_gen.image_generation_config_id = config_id + + # Build kwargs + gen_kwargs = {} + if image_gen.n is not None: + gen_kwargs["n"] = image_gen.n + if image_gen.quality is not None: + gen_kwargs["quality"] = image_gen.quality + if image_gen.size is not None: + gen_kwargs["size"] = image_gen.size + if image_gen.style is not None: + gen_kwargs["style"] = image_gen.style + if image_gen.response_format is not None: + gen_kwargs["response_format"] = image_gen.response_format + + if is_image_gen_auto_mode(config_id): + if not ImageGenRouterService.is_initialized(): + raise ValueError( + "Auto mode requested but Image Generation Router not initialized. " + "Ensure global_llm_config.yaml has global_image_generation_configs." + ) + response = await ImageGenRouterService.aimage_generation( + prompt=image_gen.prompt, model="auto", **gen_kwargs + ) + elif config_id < 0: + # Global config from YAML + cfg = _get_global_image_gen_config(config_id) + if not cfg: + raise ValueError(f"Global image generation config {config_id} not found") + + model_string = _build_model_string( + cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider") + ) + gen_kwargs["api_key"] = cfg.get("api_key") + if cfg.get("api_base"): + gen_kwargs["api_base"] = cfg["api_base"] + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) + + # User model override + if image_gen.model: + model_string = image_gen.model + + response = await aimage_generation( + prompt=image_gen.prompt, model=model_string, **gen_kwargs + ) + else: + # Positive ID = DB ImageGenerationConfig + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_cfg = result.scalars().first() + if not db_cfg: + raise ValueError(f"Image generation config {config_id} not found") + + model_string = _build_model_string( + db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider + ) + gen_kwargs["api_key"] = db_cfg.api_key + if db_cfg.api_base: + gen_kwargs["api_base"] = db_cfg.api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) + + # User model override + if image_gen.model: + model_string = image_gen.model + + response = await aimage_generation( + prompt=image_gen.prompt, model=model_string, **gen_kwargs + ) + + # Store response + image_gen.response_data = ( + response.model_dump() if hasattr(response, "model_dump") else dict(response) + ) + if not image_gen.model and hasattr(response, "_hidden_params"): + hidden = response._hidden_params + if isinstance(hidden, dict) and hidden.get("model"): + image_gen.model = hidden["model"] + + +# ============================================================================= +# Global Image Generation Configs (from YAML) +# ============================================================================= + + +@router.get( + "/global-image-generation-configs", + response_model=list[GlobalImageGenConfigRead], +) +async def get_global_image_gen_configs( + user: User = Depends(current_active_user), +): + """Get all global image generation configs. API keys are hidden.""" + try: + global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS + safe_configs = [] + + if global_configs and len(global_configs) > 0: + safe_configs.append({ + "id": 0, + "name": "Auto (Load Balanced)", + "description": "Automatically routes across available image generation providers.", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "api_version": None, + "litellm_params": {}, + "is_global": True, + "is_auto_mode": True, + }) + + for cfg in global_configs: + safe_configs.append({ + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + }) + + return safe_configs + except Exception as e: + logger.exception("Failed to fetch global image generation configs") + raise HTTPException(status_code=500, detail=f"Failed to fetch configs: {e!s}") from e + + +# ============================================================================= +# ImageGenerationConfig CRUD +# ============================================================================= + + +@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead) +async def create_image_gen_config( + config_data: ImageGenerationConfigCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Create a new image generation config for a search space.""" + try: + await check_permission( + session, user, config_data.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to create image generation configs in this search space", + ) + + db_config = ImageGenerationConfig(**config_data.model_dump()) + session.add(db_config) + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to create ImageGenerationConfig") + raise HTTPException(status_code=500, detail=f"Failed to create config: {e!s}") from e + + +@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead]) +async def list_image_gen_configs( + search_space_id: int, + skip: int = 0, + limit: int = 100, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """List image generation configs for a search space.""" + try: + await check_permission( + session, user, search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to view image generation configs in this search space", + ) + + result = await session.execute( + select(ImageGenerationConfig) + .filter(ImageGenerationConfig.search_space_id == search_space_id) + .order_by(ImageGenerationConfig.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to list ImageGenerationConfigs") + raise HTTPException(status_code=500, detail=f"Failed to fetch configs: {e!s}") from e + + +@router.get("/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead) +async def get_image_gen_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Get a specific image generation config by ID.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, user, db_config.search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to view image generation configs in this search space", + ) + return db_config + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get ImageGenerationConfig") + raise HTTPException(status_code=500, detail=f"Failed to fetch config: {e!s}") from e + + +@router.put("/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead) +async def update_image_gen_config( + config_id: int, + update_data: ImageGenerationConfigUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Update an existing image generation config.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, user, db_config.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to update image generation configs in this search space", + ) + + for key, value in update_data.model_dump(exclude_unset=True).items(): + setattr(db_config, key, value) + + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update ImageGenerationConfig") + raise HTTPException(status_code=500, detail=f"Failed to update config: {e!s}") from e + + +@router.delete("/image-generation-configs/{config_id}", response_model=dict) +async def delete_image_gen_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Delete an image generation config.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, user, db_config.search_space_id, + Permission.IMAGE_GENERATIONS_DELETE.value, + "You don't have permission to delete image generation configs in this search space", + ) + + await session.delete(db_config) + await session.commit() + return {"message": "Image generation config deleted successfully", "id": config_id} + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to delete ImageGenerationConfig") + raise HTTPException(status_code=500, detail=f"Failed to delete config: {e!s}") from e + + +# ============================================================================= +# Image Generation Execution + Results CRUD +# ============================================================================= + + +@router.post("/image-generations", response_model=ImageGenerationRead) +async def create_image_generation( + data: ImageGenerationCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Create and execute an image generation request.""" + try: + await check_permission( + session, user, data.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to create image generations in this search space", + ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == data.search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + db_image_gen = ImageGeneration( + prompt=data.prompt, + model=data.model, + n=data.n, + quality=data.quality, + size=data.size, + style=data.style, + response_format=data.response_format, + image_generation_config_id=data.image_generation_config_id, + search_space_id=data.search_space_id, + created_by_id=user.id, + ) + session.add(db_image_gen) + await session.flush() + + try: + await _execute_image_generation(session, db_image_gen, search_space) + except Exception as e: + logger.exception("Image generation call failed") + db_image_gen.error_message = str(e) + + await session.commit() + await session.refresh(db_image_gen) + return db_image_gen + + except HTTPException: + raise + except SQLAlchemyError: + await session.rollback() + raise HTTPException(status_code=500, detail="Database error during image generation") from None + except Exception as e: + await session.rollback() + logger.exception("Failed to create image generation") + raise HTTPException(status_code=500, detail=f"Image generation failed: {e!s}") from e + + +@router.get("/image-generations", response_model=list[ImageGenerationListRead]) +async def list_image_generations( + search_space_id: int | None = None, + skip: int = 0, + limit: int = 50, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """List image generations.""" + if skip < 0 or limit < 1: + raise HTTPException(status_code=400, detail="Invalid pagination parameters") + if limit > 100: + limit = 100 + + try: + if search_space_id is not None: + await check_permission( + session, user, search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to read image generations in this search space", + ) + result = await session.execute( + select(ImageGeneration) + .filter(ImageGeneration.search_space_id == search_space_id) + .order_by(ImageGeneration.created_at.desc()) + .offset(skip).limit(limit) + ) + else: + result = await session.execute( + select(ImageGeneration) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + .order_by(ImageGeneration.created_at.desc()) + .offset(skip).limit(limit) + ) + + return [ImageGenerationListRead.from_orm_with_count(img) for img in result.scalars().all()] + + except HTTPException: + raise + except SQLAlchemyError: + raise HTTPException(status_code=500, detail="Database error fetching image generations") from None + + +@router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead) +async def get_image_generation( + image_gen_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Get a specific image generation by ID.""" + try: + result = await session.execute( + select(ImageGeneration).filter(ImageGeneration.id == image_gen_id) + ) + image_gen = result.scalars().first() + if not image_gen: + raise HTTPException(status_code=404, detail="Image generation not found") + + await check_permission( + session, user, image_gen.search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to read image generations in this search space", + ) + return image_gen + + except HTTPException: + raise + except SQLAlchemyError: + raise HTTPException(status_code=500, detail="Database error fetching image generation") from None + + +@router.delete("/image-generations/{image_gen_id}", response_model=dict) +async def delete_image_generation( + image_gen_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Delete an image generation record.""" + try: + result = await session.execute( + select(ImageGeneration).filter(ImageGeneration.id == image_gen_id) + ) + db_image_gen = result.scalars().first() + if not db_image_gen: + raise HTTPException(status_code=404, detail="Image generation not found") + + await check_permission( + session, user, db_image_gen.search_space_id, + Permission.IMAGE_GENERATIONS_DELETE.value, + "You don't have permission to delete image generations in this search space", + ) + + await session.delete(db_image_gen) + await session.commit() + return {"message": "Image generation deleted successfully"} + + except HTTPException: + raise + except SQLAlchemyError: + await session.rollback() + raise HTTPException(status_code=500, detail="Database error deleting image generation") from None + + +# ============================================================================= +# Image Serving (serves generated images from DB, protected by signed tokens) +# ============================================================================= + +@router.get("/image-generations/{image_gen_id}/image") +async def serve_generated_image( + image_gen_id: int, + token: str = Query(..., description="Signed access token"), + index: int = 0, + session: AsyncSession = Depends(get_async_session), +): + """ + Serve a generated image by ID, protected by a signed token. + + The token is generated when the image URL is created by the generate_image + tool and encodes the image_gen_id, search_space_id, and an expiry timestamp. + This ensures only users with access to the search space can view images, + without requiring auth headers (which tags cannot pass). + + Args: + image_gen_id: The image generation record ID + token: HMAC-signed access token (included as query parameter) + index: Which image to serve if multiple were generated (default: 0) + """ + try: + result = await session.execute( + select(ImageGeneration).filter(ImageGeneration.id == image_gen_id) + ) + image_gen = result.scalars().first() + if not image_gen: + raise HTTPException(status_code=404, detail="Image generation not found") + + # Verify the access token against the one stored on the record + if not verify_image_token(image_gen.access_token, token): + raise HTTPException(status_code=403, detail="Invalid image access token") + + if not image_gen.response_data: + raise HTTPException(status_code=404, detail="No image data available") + + images = image_gen.response_data.get("data", []) + if not images or index >= len(images): + raise HTTPException(status_code=404, detail="Image not found at the specified index") + + image_entry = images[index] + + # If there's a URL, redirect to it + if image_entry.get("url"): + from fastapi.responses import RedirectResponse + return RedirectResponse(url=image_entry["url"]) + + # If there's b64_json data, decode and serve it + if image_entry.get("b64_json"): + image_bytes = base64.b64decode(image_entry["b64_json"]) + return Response( + content=image_bytes, + media_type="image/png", + headers={ + "Cache-Control": "public, max-age=86400", + "Content-Disposition": f'inline; filename="generated-{image_gen_id}-{index}.png"', + }, + ) + + raise HTTPException(status_code=404, detail="No displayable image data") + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to serve generated image") + raise HTTPException(status_code=500, detail=f"Failed to serve image: {e!s}") from e diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index a8916f2ea..fd84c0f45 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -7,6 +7,7 @@ from sqlalchemy.future import select from app.config import config from app.db import ( + ImageGenerationConfig, NewLLMConfig, Permission, SearchSpace, @@ -387,6 +388,69 @@ async def _get_llm_config_by_id( return None +async def _get_image_gen_config_by_id( + session: AsyncSession, config_id: int | None +) -> dict | None: + """ + Get an image generation config by ID as a dictionary. + Returns Auto mode for ID 0, global config for negative IDs, + DB ImageGenerationConfig for positive IDs, or None. + """ + if config_id is None: + return None + + if config_id == 0: + return { + "id": 0, + "name": "Auto (Load Balanced)", + "description": "Automatically routes requests across available image generation providers", + "provider": "AUTO", + "model_name": "auto", + "is_global": True, + "is_auto_mode": True, + } + + if config_id < 0: + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + } + return None + + # Positive ID: query ImageGenerationConfig table + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if db_config: + return { + "id": db_config.id, + "name": db_config.name, + "description": db_config.description, + "provider": db_config.provider.value if db_config.provider else None, + "custom_provider": db_config.custom_provider, + "model_name": db_config.model_name, + "api_base": db_config.api_base, + "api_version": db_config.api_version, + "litellm_params": db_config.litellm_params or {}, + "created_at": db_config.created_at.isoformat() + if db_config.created_at + else None, + "search_space_id": db_config.search_space_id, + } + return None + + @router.get( "/search-spaces/{search_space_id}/llm-preferences", response_model=LLMPreferencesRead, @@ -423,12 +487,17 @@ async def get_llm_preferences( document_summary_llm = await _get_llm_config_by_id( session, search_space.document_summary_llm_id ) + image_generation_config = await _get_image_gen_config_by_id( + session, search_space.image_generation_config_id + ) return LLMPreferencesRead( agent_llm_id=search_space.agent_llm_id, document_summary_llm_id=search_space.document_summary_llm_id, + image_generation_config_id=search_space.image_generation_config_id, agent_llm=agent_llm, document_summary_llm=document_summary_llm, + image_generation_config=image_generation_config, ) except HTTPException: @@ -485,12 +554,17 @@ async def update_llm_preferences( document_summary_llm = await _get_llm_config_by_id( session, search_space.document_summary_llm_id ) + image_generation_config = await _get_image_gen_config_by_id( + session, search_space.image_generation_config_id + ) return LLMPreferencesRead( agent_llm_id=search_space.agent_llm_id, document_summary_llm_id=search_space.document_summary_llm_id, + image_generation_config_id=search_space.image_generation_config_id, agent_llm=agent_llm, document_summary_llm=document_summary_llm, + image_generation_config=image_generation_config, ) except HTTPException: diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 5ff166733..332af55fd 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -20,6 +20,16 @@ from .documents import ( PaginatedResponse, ) from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest +from .image_generation import ( + GlobalImageGenConfigRead, + ImageGenerationConfigCreate, + ImageGenerationConfigPublic, + ImageGenerationConfigRead, + ImageGenerationConfigUpdate, + ImageGenerationCreate, + ImageGenerationListRead, + ImageGenerationRead, +) from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate from .new_chat import ( ChatMessage, @@ -106,6 +116,16 @@ __all__ = [ "GlobalNewLLMConfigRead", "GoogleDriveIndexRequest", "GoogleDriveIndexingOptions", + "GlobalImageGenConfigRead", + # Image Generation Config schemas + "ImageGenerationConfigCreate", + "ImageGenerationConfigPublic", + "ImageGenerationConfigRead", + "ImageGenerationConfigUpdate", + # Image Generation schemas + "ImageGenerationCreate", + "ImageGenerationListRead", + "ImageGenerationRead", # Base schemas "IDModel", # RBAC schemas diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py new file mode 100644 index 000000000..367a35a77 --- /dev/null +++ b/surfsense_backend/app/schemas/image_generation.py @@ -0,0 +1,231 @@ +""" +Pydantic schemas for Image Generation configs and generation requests. + +ImageGenerationConfig: CRUD schemas for user-created image gen model configs. +ImageGeneration: Schemas for the actual image generation requests/results. +GlobalImageGenConfigRead: Schema for admin-configured YAML configs. +""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import ImageGenProvider + + +# ============================================================================= +# ImageGenerationConfig CRUD Schemas +# ============================================================================= + + +class ImageGenerationConfigBase(BaseModel): + """Base schema with fields for ImageGenerationConfig.""" + + name: str = Field( + ..., max_length=100, description="User-friendly name for the config" + ) + description: str | None = Field( + None, max_length=500, description="Optional description" + ) + provider: ImageGenProvider = Field( + ..., + description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)", + ) + custom_provider: str | None = Field( + None, max_length=100, description="Custom provider name" + ) + model_name: str = Field( + ..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)" + ) + api_key: str = Field(..., description="API key for the provider") + api_base: str | None = Field( + None, max_length=500, description="Optional API base URL" + ) + api_version: str | None = Field( + None, + max_length=50, + description="Azure-specific API version (e.g., '2024-02-15-preview')", + ) + litellm_params: dict[str, Any] | None = Field( + default=None, description="Additional LiteLLM parameters" + ) + + +class ImageGenerationConfigCreate(ImageGenerationConfigBase): + """Schema for creating a new ImageGenerationConfig.""" + + search_space_id: int = Field( + ..., description="Search space ID to associate the config with" + ) + + +class ImageGenerationConfigUpdate(BaseModel): + """Schema for updating an existing ImageGenerationConfig. All fields optional.""" + + name: str | None = Field(None, max_length=100) + description: str | None = Field(None, max_length=500) + provider: ImageGenProvider | None = None + custom_provider: str | None = Field(None, max_length=100) + model_name: str | None = Field(None, max_length=100) + api_key: str | None = None + api_base: str | None = Field(None, max_length=500) + api_version: str | None = Field(None, max_length=50) + litellm_params: dict[str, Any] | None = None + + +class ImageGenerationConfigRead(ImageGenerationConfigBase): + """Schema for reading an ImageGenerationConfig (includes id and timestamps).""" + + id: int + created_at: datetime + search_space_id: int + + model_config = ConfigDict(from_attributes=True) + + +class ImageGenerationConfigPublic(BaseModel): + """Public schema that hides the API key (for list views).""" + + id: int + name: str + description: str | None = None + provider: ImageGenProvider + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + created_at: datetime + search_space_id: int + + model_config = ConfigDict(from_attributes=True) + + +# ============================================================================= +# ImageGeneration (request/result) Schemas +# ============================================================================= + + +class ImageGenerationCreate(BaseModel): + """Schema for creating an image generation request.""" + + prompt: str = Field( + ..., + min_length=1, + max_length=4000, + description="A text description of the desired image(s)", + ) + model: str | None = Field( + None, + max_length=200, + description="The model to use (e.g., 'dall-e-3', 'gpt-image-1'). Overrides the config model.", + ) + n: int | None = Field( + None, + ge=1, + le=10, + description="Number of images to generate (1-10).", + ) + quality: str | None = Field(None, max_length=50) + size: str | None = Field(None, max_length=50) + style: str | None = Field(None, max_length=50) + response_format: str | None = Field(None, max_length=50) + search_space_id: int = Field( + ..., description="Search space ID to associate the generation with" + ) + image_generation_config_id: int | None = Field( + None, + description=( + "Image generation config ID. " + "0 = Auto mode (router), negative = global YAML config, positive = DB config. " + "If not provided, uses the search space's image_generation_config_id preference." + ), + ) + + +class ImageGenerationRead(BaseModel): + """Schema for reading an image generation record.""" + + id: int + prompt: str + model: str | None = None + n: int | None = None + quality: str | None = None + size: str | None = None + style: str | None = None + response_format: str | None = None + image_generation_config_id: int | None = None + response_data: dict[str, Any] | None = None + error_message: str | None = None + search_space_id: int + created_at: datetime + + model_config = ConfigDict(from_attributes=True) + + +class ImageGenerationListRead(BaseModel): + """Lightweight schema for listing image generations (without full response_data).""" + + id: int + prompt: str + model: str | None = None + n: int | None = None + quality: str | None = None + size: str | None = None + search_space_id: int + created_at: datetime + is_success: bool + image_count: int | None = None + + model_config = ConfigDict(from_attributes=True) + + @classmethod + def from_orm_with_count(cls, obj: Any) -> "ImageGenerationListRead": + """Create ImageGenerationListRead with computed fields.""" + image_count = None + if obj.response_data and isinstance(obj.response_data, dict): + data = obj.response_data.get("data") + if isinstance(data, list): + image_count = len(data) + + return cls( + id=obj.id, + prompt=obj.prompt, + model=obj.model, + n=obj.n, + quality=obj.quality, + size=obj.size, + search_space_id=obj.search_space_id, + created_at=obj.created_at, + is_success=obj.response_data is not None, + image_count=image_count, + ) + + +# ============================================================================= +# Global Image Gen Config (from YAML) +# ============================================================================= + + +class GlobalImageGenConfigRead(BaseModel): + """ + Schema for reading global image generation configs from YAML. + Global configs have negative IDs. API key is hidden. + ID 0 is reserved for Auto mode (LiteLLM Router load balancing). + """ + + id: int = Field( + ..., + description="Config ID: 0 for Auto mode, negative for global configs", + ) + name: str + description: str | None = None + provider: str + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + is_global: bool = True + is_auto_mode: bool = False diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py index 286c07843..a6294fba2 100644 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -176,12 +176,18 @@ class LLMPreferencesRead(BaseModel): document_summary_llm_id: int | None = Field( None, description="ID of the LLM config to use for document summarization" ) + image_generation_config_id: int | None = Field( + None, description="ID of the image generation config to use" + ) agent_llm: dict[str, Any] | None = Field( None, description="Full config for agent LLM" ) document_summary_llm: dict[str, Any] | None = Field( None, description="Full config for document summary LLM" ) + image_generation_config: dict[str, Any] | None = Field( + None, description="Full config for image generation" + ) model_config = ConfigDict(from_attributes=True) @@ -195,3 +201,6 @@ class LLMPreferencesUpdate(BaseModel): document_summary_llm_id: int | None = Field( None, description="ID of the LLM config to use for document summarization" ) + image_generation_config_id: int | None = Field( + None, description="ID of the image generation config to use" + ) diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py new file mode 100644 index 000000000..3b8a15d2a --- /dev/null +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -0,0 +1,294 @@ +""" +Image Generation Router Service for Load Balancing + +This module provides a singleton LiteLLM Router for automatic load balancing +across multiple image generation deployments. It uses litellm.Router which +natively supports aimage_generation() for async image generation. + +The router handles: +- Rate limit management with automatic cooldowns +- Automatic failover and retries +- Usage-based routing to distribute load evenly + +Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, +AWS Bedrock, Recraft, OpenRouter, Xinference, Nscale. +""" + +import logging +from typing import Any + +from litellm import Router +from litellm.utils import ImageResponse + +logger = logging.getLogger(__name__) + +# Special ID for Auto mode - uses router for load balancing +IMAGE_GEN_AUTO_MODE_ID = 0 + +# Provider mapping for LiteLLM model string construction. +# Only includes providers that support image generation. +# See: https://docs.litellm.ai/docs/image_generation#supported-providers +IMAGE_GEN_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", # Google AI Studio + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", # AWS Bedrock + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} + + +class ImageGenRouterService: + """ + Singleton service for managing LiteLLM Router for image generation. + + The router provides automatic load balancing, failover, and rate limit + handling across multiple image generation deployments. + Uses Router.aimage_generation() for async image generation calls. + """ + + _instance = None + _router: Router | None = None + _model_list: list[dict] = [] + _router_settings: dict = {} + _initialized: bool = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def get_instance(cls) -> "ImageGenRouterService": + """Get the singleton instance of the router service.""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def initialize( + cls, + global_configs: list[dict], + router_settings: dict | None = None, + ) -> None: + """ + Initialize the router with global image generation configurations. + + Args: + global_configs: List of global image gen config dictionaries from YAML + router_settings: Optional router settings (routing_strategy, num_retries, etc.) + """ + instance = cls.get_instance() + + if instance._initialized: + logger.debug("Image Generation Router already initialized, skipping") + return + + # Build model list from global configs + model_list = [] + for config in global_configs: + deployment = cls._config_to_deployment(config) + if deployment: + model_list.append(deployment) + + if not model_list: + logger.warning( + "No valid image generation configs found for router initialization" + ) + return + + instance._model_list = model_list + instance._router_settings = router_settings or {} + + # Default router settings optimized for rate limit handling + default_settings = { + "routing_strategy": "usage-based-routing", + "num_retries": 3, + "allowed_fails": 3, + "cooldown_time": 60, + "retry_after": 5, + } + + # Merge with provided settings + final_settings = {**default_settings, **instance._router_settings} + + try: + instance._router = Router( + model_list=model_list, + routing_strategy=final_settings.get( + "routing_strategy", "usage-based-routing" + ), + num_retries=final_settings.get("num_retries", 3), + allowed_fails=final_settings.get("allowed_fails", 3), + cooldown_time=final_settings.get("cooldown_time", 60), + set_verbose=False, + ) + instance._initialized = True + logger.info( + f"Image Generation Router initialized with {len(model_list)} deployments, " + f"strategy: {final_settings.get('routing_strategy')}" + ) + except Exception as e: + logger.error(f"Failed to initialize Image Generation Router: {e}") + instance._router = None + + @classmethod + def _config_to_deployment(cls, config: dict) -> dict | None: + """ + Convert a global image gen config to a router deployment entry. + + Args: + config: Global image gen config dictionary + + Returns: + Router deployment dictionary or None if invalid + """ + try: + # Skip if essential fields are missing + if not config.get("model_name") or not config.get("api_key"): + return None + + # Build model string + if config.get("custom_provider"): + model_string = f"{config['custom_provider']}/{config['model_name']}" + else: + provider = config.get("provider", "").upper() + provider_prefix = IMAGE_GEN_PROVIDER_MAP.get( + provider, provider.lower() + ) + model_string = f"{provider_prefix}/{config['model_name']}" + + # Build litellm params + litellm_params: dict[str, Any] = { + "model": model_string, + "api_key": config.get("api_key"), + } + + # Add optional api_base + if config.get("api_base"): + litellm_params["api_base"] = config["api_base"] + + # Add api_version (required for Azure) + if config.get("api_version"): + litellm_params["api_version"] = config["api_version"] + + # Add any additional litellm parameters + if config.get("litellm_params"): + litellm_params.update(config["litellm_params"]) + + # All configs use same alias "auto" for unified routing + deployment: dict[str, Any] = { + "model_name": "auto", + "litellm_params": litellm_params, + } + + # Add rate limits from config if available + if config.get("rpm"): + deployment["rpm"] = config["rpm"] + if config.get("tpm"): + deployment["tpm"] = config["tpm"] + + return deployment + + except Exception as e: + logger.warning( + f"Failed to convert image gen config to deployment: {e}" + ) + return None + + @classmethod + def get_router(cls) -> Router | None: + """Get the initialized router instance.""" + instance = cls.get_instance() + return instance._router + + @classmethod + def is_initialized(cls) -> bool: + """Check if the router has been initialized.""" + instance = cls.get_instance() + return instance._initialized and instance._router is not None + + @classmethod + def get_model_count(cls) -> int: + """Get the number of models in the router.""" + instance = cls.get_instance() + return len(instance._model_list) + + @classmethod + async def aimage_generation( + cls, + prompt: str, + model: str = "auto", + n: int | None = None, + quality: str | None = None, + size: str | None = None, + style: str | None = None, + response_format: str | None = None, + timeout: int = 600, + **kwargs, + ) -> ImageResponse: + """ + Generate images using the router for load balancing. + + Uses Router.aimage_generation() which distributes requests + across configured image generation deployments. + + Args: + prompt: Text description of the desired image(s) + model: Model alias (default "auto" for router routing) + n: Number of images to generate + quality: Image quality setting + size: Image size + style: Style parameter + response_format: "url" or "b64_json" + timeout: Request timeout in seconds + **kwargs: Additional litellm params + + Returns: + ImageResponse from litellm + + Raises: + ValueError: If router is not initialized + """ + instance = cls.get_instance() + if not instance._router: + raise ValueError( + "Image Generation Router not initialized. " + "Ensure global_llm_config.yaml has global_image_generation_configs." + ) + + # Build kwargs for aimage_generation + gen_kwargs: dict[str, Any] = { + "model": model, + "prompt": prompt, + "timeout": timeout, + } + if n is not None: + gen_kwargs["n"] = n + if quality is not None: + gen_kwargs["quality"] = quality + if size is not None: + gen_kwargs["size"] = size + if style is not None: + gen_kwargs["style"] = style + if response_format is not None: + gen_kwargs["response_format"] = response_format + gen_kwargs.update(kwargs) + + return await instance._router.aimage_generation(**gen_kwargs) + + +def is_image_gen_auto_mode(config_id: int | None) -> bool: + """ + Check if the given config ID represents Image Generation Auto mode. + + Args: + config_id: The config ID to check + + Returns: + True if this is Auto mode, False otherwise + """ + return config_id == IMAGE_GEN_AUTO_MODE_ID diff --git a/surfsense_backend/app/utils/signed_image_urls.py b/surfsense_backend/app/utils/signed_image_urls.py new file mode 100644 index 000000000..d8d0bb57e --- /dev/null +++ b/surfsense_backend/app/utils/signed_image_urls.py @@ -0,0 +1,44 @@ +""" +Access token utilities for generated images. + +Provides token generation and verification so that generated images can be +served via tags (which cannot pass auth headers) while still +restricting access to authorised users. + +Each image generation record stores its own random access token. The token +is verified by comparing the incoming query-parameter value against the +stored value in the database. This approach: + +* Survives SECRET_KEY rotation — tokens are random, not derived from a key. +* Allows explicit revocation — just clear the column. +* Is immune to timing attacks — uses ``hmac.compare_digest``. +""" + +import hmac +import secrets + + +def generate_image_token() -> str: + """ + Generate a cryptographically random access token for an image. + + Returns: + A 64-character URL-safe hex string. + """ + return secrets.token_hex(32) + + +def verify_image_token(stored_token: str | None, provided_token: str) -> bool: + """ + Constant-time comparison of a stored token against a user-provided one. + + Args: + stored_token: The token persisted on the ImageGeneration record. + provided_token: The token from the URL query parameter. + + Returns: + True if the tokens match, False otherwise. + """ + if not stored_token or not provided_token: + return False + return hmac.compare_digest(stored_token, provided_token) diff --git a/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx index 1a727f1b6..e6c973ac6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx @@ -7,6 +7,7 @@ import { ChevronRight, FileText, Globe, + ImageIcon, type LucideIcon, Menu, MessageSquare, @@ -19,6 +20,7 @@ import { useTranslations } from "next-intl"; import { useCallback, useEffect, useState } from "react"; import { PublicChatSnapshotsManager } from "@/components/public-chat-snapshots/public-chat-snapshots-manager"; import { GeneralSettingsManager } from "@/components/settings/general-settings-manager"; +import { ImageModelManager } from "@/components/settings/image-model-manager"; import { LLMRoleManager } from "@/components/settings/llm-role-manager"; import { ModelConfigManager } from "@/components/settings/model-config-manager"; import { PromptConfigManager } from "@/components/settings/prompt-config-manager"; @@ -52,6 +54,12 @@ const settingsNavItems: SettingsNavItem[] = [ descriptionKey: "nav_role_assignments_desc", icon: Brain, }, + { + id: "image-models", + labelKey: "nav_image_models", + descriptionKey: "nav_image_models_desc", + icon: ImageIcon, + }, { id: "prompts", labelKey: "nav_system_instructions", @@ -282,8 +290,11 @@ function SettingsContent({ )} {activeSection === "models" && } - {activeSection === "roles" && } - {activeSection === "prompts" && } + {activeSection === "roles" && } + {activeSection === "image-models" && ( + + )} + {activeSection === "prompts" && } {activeSection === "public-links" && ( )} diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts new file mode 100644 index 000000000..dbaf441d0 --- /dev/null +++ b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts @@ -0,0 +1,91 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + CreateImageGenConfigRequest, + GetImageGenConfigsResponse, + UpdateImageGenConfigRequest, + UpdateImageGenConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Mutation atom for creating a new ImageGenerationConfig + */ +export const createImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "create"], + enabled: !!searchSpaceId, + mutationFn: async (request: CreateImageGenConfigRequest) => { + return imageGenConfigApiService.createConfig(request); + }, + onSuccess: () => { + toast.success("Image model configuration created"); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to create image model configuration"); + }, + }; +}); + +/** + * Mutation atom for updating an existing ImageGenerationConfig + */ +export const updateImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "update"], + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateImageGenConfigRequest) => { + return imageGenConfigApiService.updateConfig(request); + }, + onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => { + toast.success("Image model configuration updated"); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.byId(request.id), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update image model configuration"); + }, + }; +}); + +/** + * Mutation atom for deleting an ImageGenerationConfig + */ +export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "delete"], + enabled: !!searchSpaceId, + mutationFn: async (id: number) => { + return imageGenConfigApiService.deleteConfig(id); + }, + onSuccess: (_, id: number) => { + toast.success("Image model configuration deleted"); + queryClient.setQueryData( + cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + (oldData: GetImageGenConfigsResponse | undefined) => { + if (!oldData) return oldData; + return oldData.filter((config) => config.id !== id); + } + ); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to delete image model configuration"); + }, + }; +}); diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts new file mode 100644 index 000000000..a45e69a03 --- /dev/null +++ b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts @@ -0,0 +1,33 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Query atom for fetching user-created image gen configs for the active search space + */ +export const imageGenConfigsAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, // 5 minutes + queryFn: async () => { + return imageGenConfigApiService.getConfigs(Number(searchSpaceId)); + }, + }; +}); + +/** + * Query atom for fetching global image gen configs (from YAML, negative IDs) + */ +export const globalImageGenConfigsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.imageGenConfigs.global(), + staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change + queryFn: async () => { + return imageGenConfigApiService.getGlobalConfigs(); + }, + }; +}); diff --git a/surfsense_web/components/Logo.tsx b/surfsense_web/components/Logo.tsx index 58f8d1c9f..9f5915777 100644 --- a/surfsense_web/components/Logo.tsx +++ b/surfsense_web/components/Logo.tsx @@ -4,16 +4,20 @@ import Image from "next/image"; import Link from "next/link"; import { cn } from "@/lib/utils"; -export const Logo = ({ className }: { className?: string }) => { - return ( - - logo - +export const Logo = ({ className, disableLink = false }: { className?: string; disableLink?: boolean }) => { + const image = ( + logo ); + + if (disableLink) { + return image; + } + + return {image}; }; diff --git a/surfsense_web/components/homepage/navbar.tsx b/surfsense_web/components/homepage/navbar.tsx index 4abd4031b..670e3c810 100644 --- a/surfsense_web/components/homepage/navbar.tsx +++ b/surfsense_web/components/homepage/navbar.tsx @@ -64,7 +64,7 @@ const DesktopNav = ({ navItems, isScrolled }: any) => { href="/" className="flex flex-1 flex-row items-center gap-0.5 hover:opacity-80 transition-opacity" > - + SurfSense
@@ -145,7 +145,7 @@ const MobileNav = ({ navItems, isScrolled }: any) => { href="/" className="flex flex-row items-center gap-2 hover:opacity-80 transition-opacity" > - + SurfSense + ); + } + + return ( + + + + + + + + {totalModels > 3 && ( +
+ +
+ )} + + +
+ +

No image models found

+
+
+ + {/* Global Image Gen Configs */} + {filteredGlobal.length > 0 && ( + +
+ + Global Image Models +
+ {filteredGlobal.map((config) => { + const isSelected = currentConfig?.id === config.id; + const isAuto = "is_auto_mode" in config && config.is_auto_mode; + return ( + handleSelect(config.id)} + className={cn( + "mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50", + isSelected && "bg-accent/80", + isAuto && "border border-violet-200 dark:border-violet-800/50" + )} + > +
+
+ {isAuto ? ( + + ) : ( + + )} +
+
+
+ {config.name} + {isAuto && ( + + Recommended + + )} + {isSelected && } +
+ + {isAuto ? "Auto load balancing" : config.model_name} + +
+
+
+ ); + })} +
+ )} + + {/* User Image Gen Configs */} + {filteredUser.length > 0 && ( + <> + {filteredGlobal.length > 0 && } + +
+ + Your Image Models +
+ {filteredUser.map((config) => { + const isSelected = currentConfig?.id === config.id; + return ( + handleSelect(config.id)} + className={cn( + "mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50", + isSelected && "bg-accent/80" + )} + > +
+
+ +
+
+
+ {config.name} + {isSelected && ( + + )} +
+ + {config.model_name} + +
+
+
+ ); + })} +
+ + )} + + {/* Add New */} + {onAddNew && ( +
+ +
+ )} +
+
+
+
+ ); +} diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index ec1143e04..148028df2 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -392,8 +392,8 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp )} - {/* Add New Config Button */} -
+ {/* Add New Config Button */} +
+
+ + {/* Errors */} + + {errors.map((err) => ( + + + + {err?.message} + + + ))} + + + {/* Global info */} + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( + + + + + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length} global image model(s) + {" "} + available from your administrator. + + + )} + + {/* Active Preference Card */} + {!isLoading && allConfigs.length > 0 && ( + + + +
+
+ +
+
+ Active Image Model + + Select which model to use for image generation + +
+
+
+ + + {hasPrefChanges && ( +
+ + +
+ )} +
+
+
+ )} + + {/* Loading */} + {isLoading && ( + + + + + + )} + + {/* User Configs */} + {!isLoading && ( +
+
+

Your Image Models

+ +
+ + {(userConfigs?.length ?? 0) === 0 ? ( + + +
+ +
+

No Image Models Yet

+

+ Add your own image generation model (DALL-E 3, GPT Image 1, etc.) +

+ +
+
+ ) : ( + + + {userConfigs?.map((config) => ( + + + +
+
+
+
+
+
+ +
+
+
+

{config.name}

+ + {config.provider} + +
+ + {config.model_name} + + {config.description && ( +

{config.description}

+ )} +
+ + {new Date(config.created_at).toLocaleDateString()} +
+
+
+
+ + + + + + Edit + + + + + + + + Delete + + +
+
+
+
+ + + + ))} + + + )} +
+ )} + + {/* Create/Edit Dialog */} + { if (!open) { setIsDialogOpen(false); setEditingConfig(null); resetForm(); } }}> + + + + {editingConfig ? : } + {editingConfig ? "Edit Image Model" : "Add Image Model"} + + + {editingConfig ? "Update your image generation model" : "Configure a new image generation model (DALL-E 3, GPT Image 1, etc.)"} + + + +
+ {/* Name */} +
+ + setFormData((p) => ({ ...p, name: e.target.value }))} + /> +
+ + {/* Description */} +
+ + setFormData((p) => ({ ...p, description: e.target.value }))} + /> +
+ + + + {/* Provider */} +
+ + +
+ + {/* Model Name */} +
+ + {suggestedModels.length > 0 ? ( + + + + + + + setFormData((p) => ({ ...p, model_name: val }))} + /> + + + Type a custom model name + + + {suggestedModels.map((m) => ( + { + setFormData((p) => ({ ...p, model_name: m.value })); + setModelComboboxOpen(false); + }} + > + + {m.value} + {m.label} + + ))} + + + + + + ) : ( + setFormData((p) => ({ ...p, model_name: e.target.value }))} + /> + )} +
+ + {/* API Key */} +
+ + setFormData((p) => ({ ...p, api_key: e.target.value }))} + /> +
+ + {/* API Base (optional) */} +
+ + setFormData((p) => ({ ...p, api_base: e.target.value }))} + /> +
+ + {/* API Version (Azure) */} + {formData.provider === "AZURE_OPENAI" && ( +
+ + setFormData((p) => ({ ...p, api_version: e.target.value }))} + /> +
+ )} + + {/* Actions */} +
+ + +
+
+
+
+ + {/* Delete Confirmation */} + !open && setConfigToDelete(null)}> + + + + + Delete Image Model + + + Are you sure you want to delete {configToDelete?.name}? + + + + Cancel + + {isDeleting ? <>Deleting : <>Delete} + + + + +
+ ); +} diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index dac68a358..22e3d8e08 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -255,15 +255,15 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { )} - {/* Role Assignment Cards */} - {availableConfigs.length > 0 && ( -
- {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { - const IconComponent = role.icon; - const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments]; - const assignedConfig = availableConfigs.find( - (config) => config.id === currentAssignment - ); + {/* Role Assignment Cards */} + {availableConfigs.length > 0 && ( +
+ {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { + const IconComponent = role.icon; + const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments]; + const assignedConfig = availableConfigs.find( + (config) => config.id === currentAssignment + ); return ( -
- - handleRoleAssignment(`${key}_llm_id`, value)} + > + + + + + + Unassigned + - {/* Global Configurations */} - {globalConfigs.length > 0 && ( - <> -
- Global Configurations -
- {globalConfigs.map((config) => { - const isAutoMode = - "is_auto_mode" in config && config.is_auto_mode; - return ( - -
- {isAutoMode ? ( - - - AUTO - - ) : ( - - {config.provider} - - )} - {config.name} - {!isAutoMode && ( - - ({config.model_name}) - - )} - {isAutoMode ? ( - - Recommended - - ) : ( - - 🌐 Global - - )} -
-
- ); - })} - - )} + {/* Global Configurations */} + {globalConfigs.length > 0 && ( + <> +
+ Global Configurations +
+ {globalConfigs.map((config) => { + const isAutoMode = + "is_auto_mode" in config && config.is_auto_mode; + return ( + +
+ {isAutoMode ? ( + + + AUTO + + ) : ( + + {config.provider} + + )} + {config.name} + {!isAutoMode && ( + + ({config.model_name}) + + )} + {isAutoMode ? ( + + Recommended + + ) : ( + + 🌐 Global + + )} +
+
+ ); + })} + + )} - {/* Custom Configurations */} - {newLLMConfigs.length > 0 && ( - <> -
- Your Configurations -
- {newLLMConfigs - .filter( - (config) => config.id && config.id.toString().trim() !== "" - ) - .map((config) => ( - -
- - {config.provider} - - {config.name} - - ({config.model_name}) - -
-
- ))} - - )} -
- -
+ {/* Custom Configurations */} + {newLLMConfigs.length > 0 && ( + <> +
+ Your Configurations +
+ {newLLMConfigs + .filter( + (config) => config.id && config.id.toString().trim() !== "" + ) + .map((config) => ( + +
+ + {config.provider} + + {config.name} + + ({config.model_name}) + +
+
+ ))} + + )} + + +
{assignedConfig && (
; + return ; } /** diff --git a/surfsense_web/components/tool-ui/image/index.tsx b/surfsense_web/components/tool-ui/image/index.tsx index 42725d258..9ecfe4cfa 100644 --- a/surfsense_web/components/tool-ui/image/index.tsx +++ b/surfsense_web/components/tool-ui/image/index.tsx @@ -1,6 +1,6 @@ "use client"; -import { ExternalLinkIcon, ImageIcon } from "lucide-react"; +import { ExternalLinkIcon, ImageIcon, SparklesIcon } from "lucide-react"; import NextImage from "next/image"; import { Component, type ReactNode, useState } from "react"; import { z } from "zod"; @@ -25,7 +25,7 @@ const SerializableImageSchema = z.object({ id: z.string(), assetId: z.string(), src: z.string(), - alt: z.string().nullish(), // Made optional - will use fallback if missing + alt: z.string().nullish(), title: z.string().nullish(), description: z.string().nullish(), href: z.string().nullish(), @@ -49,7 +49,7 @@ export interface ImageProps { id: string; assetId: string; src: string; - alt?: string; // Optional with default fallback + alt?: string; title?: string; description?: string; href?: string; @@ -71,10 +71,8 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a if (!parsed.success) { console.warn("Invalid image data:", parsed.error.issues); - // Try to extract basic info and return a fallback object const obj = (result && typeof result === "object" ? result : {}) as Record; - // If we have at least id, assetId, and src, we can still render the image if ( typeof obj.id === "string" && typeof obj.assetId === "string" && @@ -89,7 +87,7 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a description: typeof obj.description === "string" ? obj.description : undefined, href: typeof obj.href === "string" ? obj.href : undefined, domain: typeof obj.domain === "string" ? obj.domain : undefined, - ratio: undefined, // Use default ratio + ratio: undefined, source: undefined, }; } @@ -97,7 +95,6 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a throw new Error(`Invalid image: ${parsed.error.issues.map((i) => i.message).join(", ")}`); } - // Provide fallback for alt if it's null/undefined return { ...parsed.data, alt: parsed.data.alt ?? "Image", @@ -105,7 +102,7 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a } /** - * Get aspect ratio class based on ratio prop + * Get aspect ratio class based on ratio prop (used for fixed-ratio images only) */ function getAspectRatioClass(ratio?: AspectRatio): string { switch (ratio) { @@ -119,7 +116,6 @@ function getAspectRatioClass(ratio?: AspectRatio): string { return "aspect-[9/16]"; case "21:9": return "aspect-[21/9]"; - case "auto": default: return "aspect-[4/3]"; } @@ -150,7 +146,7 @@ export class ImageErrorBoundary extends Component< if (this.state.hasError) { return ( -
+

Failed to load image

@@ -167,10 +163,10 @@ export class ImageErrorBoundary extends Component< /** * Loading skeleton for Image */ -export function ImageSkeleton({ maxWidth = "420px" }: { maxWidth?: string }) { +export function ImageSkeleton({ maxWidth = "512px" }: { maxWidth?: string }) { return ( -
+
@@ -183,7 +179,7 @@ export function ImageSkeleton({ maxWidth = "420px" }: { maxWidth?: string }) { export function ImageLoading({ title = "Loading image..." }: { title?: string }) { return ( -
+

{title}

@@ -197,7 +193,9 @@ export function ImageLoading({ title = "Loading image..." }: { title?: string }) * Image Component * * Display images with metadata and attribution. - * Features hover overlay with title and source attribution. + * - For "auto" ratio: renders the image at natural dimensions (no cropping) + * - For fixed ratios: uses a fixed aspect container with object-cover + * - Features hover overlay with title, description, and source attribution. */ export function Image({ id, @@ -207,16 +205,18 @@ export function Image({ description, href, domain, - ratio = "4:3", + ratio = "auto", fit = "cover", source, - maxWidth = "420px", + maxWidth = "512px", className, }: ImageProps) { const [isHovered, setIsHovered] = useState(false); const [imageError, setImageError] = useState(false); - const aspectRatioClass = getAspectRatioClass(ratio); + const [imageLoaded, setImageLoaded] = useState(false); const displayDomain = domain || source?.label; + const isGenerated = domain === "ai-generated"; + const isAutoRatio = !ratio || ratio === "auto"; const handleClick = () => { const targetUrl = href || source?.url || src; @@ -228,7 +228,7 @@ export function Image({ if (imageError) { return ( -
+

Image not available

@@ -243,6 +243,7 @@ export function Image({ id={id} className={cn( "group w-full overflow-hidden cursor-pointer transition-shadow duration-200 hover:shadow-lg", + isGenerated && "ring-1 ring-primary/10", className )} style={{ maxWidth }} @@ -258,71 +259,98 @@ export function Image({ role="button" tabIndex={0} > -
- {/* Image */} - setImageError(true)} - /> +
+ {isAutoRatio ? ( + /* Auto ratio: image renders at natural dimensions, no cropping */ + <> + {!imageLoaded && ( +
+ +
+ )} + {/* eslint-disable-next-line @next/next/no-img-element */} + {alt} setImageLoaded(true)} + onError={() => setImageError(true)} + /> + + ) : ( + /* Fixed ratio: constrained aspect container with fill */ +
+ setImageError(true)} + /> +
+ )} - {/* Hover overlay - appears on hover */} + {/* Hover overlay */}
- {/* Content at bottom */} -
- {/* Title */} +
{title && ( -

+

{title}

)} - - {/* Description */} {description && ( -

{description}

+

{description}

)} - - {/* Source attribution */} {displayDomain && (
- {source?.iconUrl ? ( + {isGenerated ? ( + + ) : source?.iconUrl ? ( ) : ( - + )} - {displayDomain} + {displayDomain}
)}
- {/* Always visible domain badge (bottom right, shown when NOT hovered) */} + {/* Badge when not hovered */} {displayDomain && !isHovered && (
+ {isGenerated && } {displayDomain}
diff --git a/surfsense_web/contracts/enums/image-gen-providers.ts b/surfsense_web/contracts/enums/image-gen-providers.ts new file mode 100644 index 000000000..8410aeb4b --- /dev/null +++ b/surfsense_web/contracts/enums/image-gen-providers.ts @@ -0,0 +1,105 @@ +export interface ImageGenProvider { + value: string; + label: string; + example: string; + description: string; + apiBase?: string; +} + +/** + * Image generation providers supported by LiteLLM. + * See: https://docs.litellm.ai/docs/image_generation#supported-providers + */ +export const IMAGE_GEN_PROVIDERS: ImageGenProvider[] = [ + { + value: "OPENAI", + label: "OpenAI", + example: "dall-e-3, gpt-image-1, dall-e-2", + description: "DALL-E and GPT Image models", + }, + { + value: "AZURE_OPENAI", + label: "Azure OpenAI", + example: "azure/dall-e-3, azure/gpt-image-1", + description: "OpenAI image models on Azure", + }, + { + value: "GOOGLE", + label: "Google AI Studio", + example: "gemini/imagen-3.0-generate-002", + description: "Google AI Studio image generation", + }, + { + value: "VERTEX_AI", + label: "Google Vertex AI", + example: "vertex_ai/imagegeneration@006", + description: "Vertex AI image generation models", + }, + { + value: "BEDROCK", + label: "AWS Bedrock", + example: "bedrock/stability.stable-diffusion-xl-v0", + description: "Stable Diffusion on AWS Bedrock", + }, + { + value: "RECRAFT", + label: "Recraft", + example: "recraft/recraftv3", + description: "AI-powered design and image generation", + }, + { + value: "OPENROUTER", + label: "OpenRouter", + example: "openrouter/google/gemini-2.5-flash-image", + description: "Image generation via OpenRouter", + }, + { + value: "XINFERENCE", + label: "Xinference", + example: "xinference/stable-diffusion-xl", + description: "Self-hosted Stable Diffusion models", + }, + { + value: "NSCALE", + label: "Nscale", + example: "nscale/flux.1-schnell", + description: "Nscale image generation", + }, +]; + +/** + * Image generation models organized by provider. + */ +export interface ImageGenModel { + value: string; + label: string; + provider: string; +} + +export const IMAGE_GEN_MODELS: ImageGenModel[] = [ + // OpenAI + { value: "gpt-image-1", label: "GPT Image 1", provider: "OPENAI" }, + { value: "dall-e-3", label: "DALL-E 3", provider: "OPENAI" }, + { value: "dall-e-2", label: "DALL-E 2", provider: "OPENAI" }, + // Azure OpenAI + { value: "azure/dall-e-3", label: "DALL-E 3 (Azure)", provider: "AZURE_OPENAI" }, + { value: "azure/gpt-image-1", label: "GPT Image 1 (Azure)", provider: "AZURE_OPENAI" }, + // Recraft + { value: "recraft/recraftv3", label: "Recraft V3", provider: "RECRAFT" }, + // Bedrock + { + value: "bedrock/stability.stable-diffusion-xl-v0", + label: "Stable Diffusion XL", + provider: "BEDROCK", + }, + // Vertex AI + { + value: "vertex_ai/imagegeneration@006", + label: "Imagen 3", + provider: "VERTEX_AI", + }, +]; + +export function getImageGenModelsByProvider(provider: string): ImageGenModel[] { + return IMAGE_GEN_MODELS.filter((m) => m.provider === provider); +} diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts index f397d4f08..3f0d39e5a 100644 --- a/surfsense_web/contracts/types/new-llm-config.types.ts +++ b/surfsense_web/contracts/types/new-llm-config.types.ts @@ -161,19 +161,105 @@ export const globalNewLLMConfig = z.object({ export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig); +// ============================================================================= +// Image Generation Config (separate table from NewLLMConfig) +// ============================================================================= + +/** + * ImageGenProvider enum - only providers that support image generation + * See: https://docs.litellm.ai/docs/image_generation#supported-providers + */ +export const imageGenProviderEnum = z.enum([ + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", +]); + +export type ImageGenProvider = z.infer; + +/** + * ImageGenerationConfig - user-created image gen model configs + * Separate from NewLLMConfig: no system_instructions, no citations_enabled. + */ +export const imageGenerationConfig = z.object({ + id: z.number(), + name: z.string().max(100), + description: z.string().max(500).nullable().optional(), + provider: imageGenProviderEnum, + custom_provider: z.string().max(100).nullable().optional(), + model_name: z.string().max(100), + api_key: z.string(), + api_base: z.string().max(500).nullable().optional(), + api_version: z.string().max(50).nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + created_at: z.string(), + search_space_id: z.number(), +}); + +export const createImageGenConfigRequest = imageGenerationConfig.omit({ + id: true, + created_at: true, +}); + +export const createImageGenConfigResponse = imageGenerationConfig; + +export const getImageGenConfigsResponse = z.array(imageGenerationConfig); + +export const updateImageGenConfigRequest = z.object({ + id: z.number(), + data: imageGenerationConfig + .omit({ id: true, created_at: true, search_space_id: true }) + .partial(), +}); + +export const updateImageGenConfigResponse = imageGenerationConfig; + +export const deleteImageGenConfigResponse = z.object({ + message: z.string(), + id: z.number(), +}); + +/** + * Global Image Generation Config - from YAML, has negative IDs + * ID 0 is reserved for "Auto" mode (LiteLLM Router load balancing) + */ +export const globalImageGenConfig = z.object({ + id: z.number(), + name: z.string(), + description: z.string().nullable().optional(), + provider: z.string(), + custom_provider: z.string().nullable().optional(), + model_name: z.string(), + api_base: z.string().nullable().optional(), + api_version: z.string().nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + is_global: z.literal(true), + is_auto_mode: z.boolean().optional().default(false), +}); + +export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig); + // ============================================================================= // LLM Preferences (Role Assignments) // ============================================================================= /** * LLM Preferences schemas - for role assignments - * The agent_llm and document_summary_llm fields contain the full NewLLMConfig objects + * image_generation uses image_generation_config_id (not llm_id) */ export const llmPreferences = z.object({ agent_llm_id: z.union([z.number(), z.null()]).optional(), document_summary_llm_id: z.union([z.number(), z.null()]).optional(), + image_generation_config_id: z.union([z.number(), z.null()]).optional(), agent_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(), document_summary_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(), + image_generation_config: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(), }); /** @@ -193,6 +279,7 @@ export const updateLLMPreferencesRequest = z.object({ data: llmPreferences.pick({ agent_llm_id: true, document_summary_llm_id: true, + image_generation_config_id: true, }), }); @@ -219,6 +306,15 @@ export type GetDefaultSystemInstructionsResponse = z.infer< >; export type GlobalNewLLMConfig = z.infer; export type GetGlobalNewLLMConfigsResponse = z.infer; +export type ImageGenerationConfig = z.infer; +export type CreateImageGenConfigRequest = z.infer; +export type CreateImageGenConfigResponse = z.infer; +export type GetImageGenConfigsResponse = z.infer; +export type UpdateImageGenConfigRequest = z.infer; +export type UpdateImageGenConfigResponse = z.infer; +export type DeleteImageGenConfigResponse = z.infer; +export type GlobalImageGenConfig = z.infer; +export type GetGlobalImageGenConfigsResponse = z.infer; export type LLMPreferences = z.infer; export type GetLLMPreferencesRequest = z.infer; export type GetLLMPreferencesResponse = z.infer; diff --git a/surfsense_web/lib/apis/image-gen-config-api.service.ts b/surfsense_web/lib/apis/image-gen-config-api.service.ts new file mode 100644 index 000000000..84aeed3d8 --- /dev/null +++ b/surfsense_web/lib/apis/image-gen-config-api.service.ts @@ -0,0 +1,83 @@ +import { + type CreateImageGenConfigRequest, + createImageGenConfigRequest, + createImageGenConfigResponse, + type UpdateImageGenConfigRequest, + updateImageGenConfigRequest, + updateImageGenConfigResponse, + deleteImageGenConfigResponse, + getImageGenConfigsResponse, + getGlobalImageGenConfigsResponse, +} from "@/contracts/types/new-llm-config.types"; +import { ValidationError } from "../error"; +import { baseApiService } from "./base-api.service"; + +class ImageGenConfigApiService { + /** + * Get all global image generation configs (from YAML, negative IDs) + */ + getGlobalConfigs = async () => { + return baseApiService.get( + `/api/v1/global-image-generation-configs`, + getGlobalImageGenConfigsResponse + ); + }; + + /** + * Create a new image generation config for a search space + */ + createConfig = async (request: CreateImageGenConfigRequest) => { + const parsed = createImageGenConfigRequest.safeParse(request); + if (!parsed.success) { + const msg = parsed.error.issues.map((i) => i.message).join(", "); + throw new ValidationError(`Invalid request: ${msg}`); + } + return baseApiService.post( + `/api/v1/image-generation-configs`, + createImageGenConfigResponse, + { body: parsed.data } + ); + }; + + /** + * Get image generation configs for a search space + */ + getConfigs = async (searchSpaceId: number) => { + const params = new URLSearchParams({ + search_space_id: String(searchSpaceId), + }).toString(); + return baseApiService.get( + `/api/v1/image-generation-configs?${params}`, + getImageGenConfigsResponse + ); + }; + + /** + * Update an existing image generation config + */ + updateConfig = async (request: UpdateImageGenConfigRequest) => { + const parsed = updateImageGenConfigRequest.safeParse(request); + if (!parsed.success) { + const msg = parsed.error.issues.map((i) => i.message).join(", "); + throw new ValidationError(`Invalid request: ${msg}`); + } + const { id, data } = parsed.data; + return baseApiService.put( + `/api/v1/image-generation-configs/${id}`, + updateImageGenConfigResponse, + { body: data } + ); + }; + + /** + * Delete an image generation config + */ + deleteConfig = async (id: number) => { + return baseApiService.delete( + `/api/v1/image-generation-configs/${id}`, + deleteImageGenConfigResponse + ); + }; +} + +export const imageGenConfigApiService = new ImageGenConfigApiService(); diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 4d220a62a..c6981b28a 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -34,6 +34,11 @@ export const cacheKeys = { defaultInstructions: () => ["new-llm-configs", "default-instructions"] as const, global: () => ["new-llm-configs", "global"] as const, }, + imageGenConfigs: { + all: (searchSpaceId: number) => ["image-gen-configs", searchSpaceId] as const, + byId: (configId: number) => ["image-gen-configs", "detail", configId] as const, + global: () => ["image-gen-configs", "global"] as const, + }, auth: { user: ["auth", "user"] as const, }, diff --git a/surfsense_web/messages/en.json b/surfsense_web/messages/en.json index a1ef1f248..5a18f80c3 100644 --- a/surfsense_web/messages/en.json +++ b/surfsense_web/messages/en.json @@ -740,6 +740,8 @@ "nav_agent_configs_desc": "LLM models with prompts & citations", "nav_role_assignments": "Role Assignments", "nav_role_assignments_desc": "Assign configs to agent roles", + "nav_image_models": "Image Models", + "nav_image_models_desc": "Configure image generation models", "nav_system_instructions": "System Instructions", "nav_system_instructions_desc": "SearchSpace-wide AI instructions", "nav_public_links": "Public Chat Links", diff --git a/surfsense_web/messages/zh.json b/surfsense_web/messages/zh.json index 60a0d279f..1046b7296 100644 --- a/surfsense_web/messages/zh.json +++ b/surfsense_web/messages/zh.json @@ -725,6 +725,8 @@ "nav_agent_configs_desc": "LLM 模型配置提示词和引用", "nav_role_assignments": "角色分配", "nav_role_assignments_desc": "为代理角色分配配置", + "nav_image_models": "图像模型", + "nav_image_models_desc": "配置图像生成模型", "nav_system_instructions": "系统指令", "nav_system_instructions_desc": "搜索空间级别的 AI 指令", "nav_public_links": "公开聊天链接",