Merge remote-tracking branch 'upstream/dev' into fix/documents

This commit is contained in:
Anish Sarkar 2026-02-06 12:13:26 +05:30
commit 0fdd194d92
60 changed files with 5086 additions and 248 deletions

View file

@ -143,6 +143,15 @@ STT_SERVICE=local/base
PAGES_LIMIT=500
# Residential Proxy Configuration (anonymous-proxies.net)
# Used for web crawling, link previews, and YouTube transcript fetching to avoid IP bans.
# Leave commented out to disable proxying.
# RESIDENTIAL_PROXY_USERNAME=your_proxy_username
# RESIDENTIAL_PROXY_PASSWORD=your_proxy_password
# RESIDENTIAL_PROXY_HOSTNAME=rotating.dnsproxifier.com:31230
# RESIDENTIAL_PROXY_LOCATION=
# RESIDENTIAL_PROXY_TYPE=1
FIRECRAWL_API_KEY=fcr-01J0000000000000000000000
# File Parser Service

View file

@ -0,0 +1,299 @@
"""Add image generation tables and search space preference
Revision ID: 93
Revises: 92
Changes:
1. Create image_generation_configs table (user-created image model configs)
2. Create image_generations table (stores generation requests/results)
3. Add image_generation_config_id column to searchspaces table
4. Add image generation permissions to existing system roles
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import ENUM as PG_ENUM, JSONB, UUID
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "93"
down_revision: str | None = "92"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
connection = op.get_bind()
# 1. Create imagegenprovider enum type if it doesn't exist
connection.execute(
sa.text(
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'imagegenprovider') THEN
CREATE TYPE imagegenprovider AS ENUM (
'OPENAI', 'AZURE_OPENAI', 'GOOGLE', 'VERTEX_AI', 'BEDROCK',
'RECRAFT', 'OPENROUTER', 'XINFERENCE', 'NSCALE'
);
END IF;
END
$$;
"""
)
)
# 2. Create image_generation_configs table (uses imagegenprovider enum)
result = connection.execute(
sa.text(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generation_configs')"
)
)
if not result.scalar():
op.create_table(
"image_generation_configs",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("name", sa.String(100), nullable=False),
sa.Column("description", sa.String(500), nullable=True),
sa.Column(
"provider",
PG_ENUM(
"OPENAI",
"AZURE_OPENAI",
"GOOGLE",
"VERTEX_AI",
"BEDROCK",
"RECRAFT",
"OPENROUTER",
"XINFERENCE",
"NSCALE",
name="imagegenprovider",
create_type=False,
),
nullable=False,
),
sa.Column("custom_provider", sa.String(100), nullable=True),
sa.Column("model_name", sa.String(100), nullable=False),
sa.Column("api_key", sa.String(), nullable=False),
sa.Column("api_base", sa.String(500), nullable=True),
sa.Column("api_version", sa.String(50), nullable=True),
sa.Column("litellm_params", sa.JSON(), nullable=True),
sa.Column("search_space_id", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
),
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_image_generation_configs_name "
"ON image_generation_configs (name)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_image_generation_configs_search_space_id "
"ON image_generation_configs (search_space_id)"
)
# 3. Create image_generations table
result = connection.execute(
sa.text(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generations')"
)
)
if not result.scalar():
op.create_table(
"image_generations",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.Text(), nullable=False),
sa.Column("model", sa.String(200), nullable=True),
sa.Column("n", sa.Integer(), nullable=True),
sa.Column("quality", sa.String(50), nullable=True),
sa.Column("size", sa.String(50), nullable=True),
sa.Column("style", sa.String(50), nullable=True),
sa.Column("response_format", sa.String(50), nullable=True),
sa.Column("image_generation_config_id", sa.Integer(), nullable=True),
sa.Column("response_data", JSONB(), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column("search_space_id", sa.Integer(), nullable=False),
sa.Column("created_by_id", UUID(as_uuid=True), nullable=True),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["created_by_id"], ["user.id"], ondelete="SET NULL"
),
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_image_generations_search_space_id "
"ON image_generations (search_space_id)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_image_generations_created_by_id "
"ON image_generations (created_by_id)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_image_generations_created_at "
"ON image_generations (created_at)"
)
# 4. Add image_generation_config_id column to searchspaces
result = connection.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces'
AND column_name = 'image_generation_config_id'
)
"""
)
)
if not result.scalar():
op.add_column(
"searchspaces",
sa.Column(
"image_generation_config_id",
sa.Integer(),
nullable=True,
server_default="0",
),
)
# Drop old column name if it exists (from earlier version of this migration)
result = connection.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces'
AND column_name = 'image_generation_llm_id'
)
"""
)
)
if result.scalar():
op.drop_column("searchspaces", "image_generation_llm_id")
# Drop old column name on image_generations if it exists
result = connection.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'image_generations'
AND column_name = 'llm_config_id'
)
"""
)
)
if result.scalar():
op.drop_column("image_generations", "llm_config_id")
# Drop old api_version column on image_generations if it exists
result = connection.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'image_generations'
AND column_name = 'api_version'
)
"""
)
)
if result.scalar():
op.drop_column("image_generations", "api_version")
# 5. Add image generation permissions to existing system roles
connection.execute(
sa.text(
"""
UPDATE search_space_roles
SET permissions = array_cat(
permissions,
ARRAY['image_generations:create', 'image_generations:read']
)
WHERE is_system_role = true
AND name = 'Editor'
AND NOT ('image_generations:create' = ANY(permissions))
"""
)
)
connection.execute(
sa.text(
"""
UPDATE search_space_roles
SET permissions = array_cat(
permissions,
ARRAY['image_generations:read']
)
WHERE is_system_role = true
AND name = 'Viewer'
AND NOT ('image_generations:read' = ANY(permissions))
"""
)
)
def downgrade() -> None:
connection = op.get_bind()
# Remove permissions
connection.execute(
sa.text(
"""
UPDATE search_space_roles
SET permissions = array_remove(
array_remove(
array_remove(permissions, 'image_generations:create'),
'image_generations:read'
),
'image_generations:delete'
)
WHERE is_system_role = true
"""
)
)
# Remove image_generation_config_id from searchspaces
result = connection.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces'
AND column_name = 'image_generation_config_id'
)
"""
)
)
if result.scalar():
op.drop_column("searchspaces", "image_generation_config_id")
# Drop indexes and tables
op.execute("DROP INDEX IF EXISTS ix_image_generations_created_at")
op.execute("DROP INDEX IF EXISTS ix_image_generations_created_by_id")
op.execute("DROP INDEX IF EXISTS ix_image_generations_search_space_id")
op.execute("DROP TABLE IF EXISTS image_generations")
op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_search_space_id")
op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_name")
op.execute("DROP TABLE IF EXISTS image_generation_configs")
# Drop the imagegenprovider enum type
op.execute("DROP TYPE IF EXISTS imagegenprovider")

View file

@ -0,0 +1,39 @@
"""Add access_token column to image_generations
Revision ID: 94
Revises: 93
Adds an indexed access_token column to the image_generations table.
This token is stored per-record so that image serving URLs survive
SECRET_KEY rotation.
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "94"
down_revision: str | None = "93"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Add access_token column (nullable so existing rows are unaffected)
op.add_column(
"image_generations",
sa.Column("access_token", sa.String(64), nullable=True),
)
op.create_index(
"ix_image_generations_access_token",
"image_generations",
["access_token"],
)
def downgrade() -> None:
op.drop_index("ix_image_generations_access_token", table_name="image_generations")
op.drop_column("image_generations", "access_token")

View file

@ -133,6 +133,7 @@ async def create_surfsense_deep_agent(
The agent comes with built-in tools that can be configured:
- search_knowledge_base: Search the user's personal knowledge base
- generate_podcast: Generate audio podcasts from content
- generate_image: Generate images from text descriptions using AI models
- link_preview: Fetch rich previews for URLs
- display_image: Display images in chat
- scrape_webpage: Extract content from webpages

View file

@ -3,15 +3,25 @@ PostgreSQL-based checkpointer for LangGraph agents.
This module provides a persistent checkpointer using AsyncPostgresSaver
that stores conversation state in the PostgreSQL database.
Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle
connection lifecycle, health checks, and automatic reconnection,
preventing 'the connection is closed' errors in long-running deployments.
"""
import logging
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
from app.config import config
logger = logging.getLogger(__name__)
# Global checkpointer instance (initialized lazily)
_checkpointer: AsyncPostgresSaver | None = None
_checkpointer_context = None # Store the context manager for cleanup
_connection_pool: AsyncConnectionPool | None = None
_checkpointer_initialized: bool = False
@ -38,26 +48,65 @@ def get_postgres_connection_string() -> str:
return db_url
async def _create_checkpointer() -> AsyncPostgresSaver:
"""
Create a new AsyncPostgresSaver backed by a connection pool.
The connection pool automatically handles:
- Connection health checks before use
- Reconnection when connections die (idle timeout, DB restart, etc.)
- Connection lifecycle management (max_lifetime, max_idle)
"""
global _connection_pool
conn_string = get_postgres_connection_string()
_connection_pool = AsyncConnectionPool(
conninfo=conn_string,
min_size=2,
max_size=10,
# Connections are recycled after 30 minutes to avoid stale connections
max_lifetime=1800,
# Idle connections are closed after 5 minutes
max_idle=300,
open=False,
# Connection kwargs required by AsyncPostgresSaver:
# - autocommit: required for .setup() to commit checkpoint tables
# - prepare_threshold: disable prepared statements for compatibility
# - row_factory: checkpointer accesses rows as dicts (row["column"])
kwargs={
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
},
)
await _connection_pool.open(wait=True)
checkpointer = AsyncPostgresSaver(conn=_connection_pool)
logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool")
return checkpointer
async def get_checkpointer() -> AsyncPostgresSaver:
"""
Get or create the global AsyncPostgresSaver instance.
This function:
1. Creates the checkpointer if it doesn't exist
1. Creates the checkpointer with a connection pool if it doesn't exist
2. Sets up the required database tables on first call
3. Returns the cached instance on subsequent calls
The underlying connection pool handles reconnection automatically,
so a stale/closed connection will not cause OperationalError.
Returns:
AsyncPostgresSaver: The configured checkpointer instance
"""
global _checkpointer, _checkpointer_context, _checkpointer_initialized
global _checkpointer, _checkpointer_initialized
if _checkpointer is None:
conn_string = get_postgres_connection_string()
# from_conn_string returns an async context manager
# We need to enter the context to get the actual checkpointer
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
_checkpointer = await _checkpointer_context.__aenter__()
_checkpointer = await _create_checkpointer()
_checkpointer_initialized = False
# Setup tables on first call (idempotent)
if not _checkpointer_initialized:
@ -75,20 +124,21 @@ async def setup_checkpointer_tables() -> None:
tables exist before any agent calls.
"""
await get_checkpointer()
print("[Checkpointer] PostgreSQL checkpoint tables ready")
logger.info("[Checkpointer] PostgreSQL checkpoint tables ready")
async def close_checkpointer() -> None:
"""
Close the checkpointer connection.
Close the checkpointer connection pool.
This should be called during application shutdown.
"""
global _checkpointer, _checkpointer_context, _checkpointer_initialized
global _checkpointer, _connection_pool, _checkpointer_initialized
if _checkpointer_context is not None:
await _checkpointer_context.__aexit__(None, None, None)
_checkpointer = None
_checkpointer_context = None
_checkpointer_initialized = False
print("[Checkpointer] PostgreSQL connection closed")
if _connection_pool is not None:
await _connection_pool.close()
logger.info("[Checkpointer] PostgreSQL connection pool closed")
_checkpointer = None
_connection_pool = None
_checkpointer_initialized = False

View file

@ -83,6 +83,7 @@ You have access to the following tools:
* Showing an image from a URL the user explicitly mentioned in their message
* Displaying images found in scraped webpage content (from scrape_webpage tool)
* Showing a publicly accessible diagram or chart from a known URL
* Displaying an AI-generated image after calling the generate_image tool (ALWAYS required)
CRITICAL - NEVER USE THIS TOOL FOR USER-UPLOADED ATTACHMENTS:
When a user uploads/attaches an image file to their message:
@ -100,7 +101,21 @@ You have access to the following tools:
- Returns: An image card with the image, title, and description
- The image will automatically be displayed in the chat.
5. scrape_webpage: Scrape and extract the main content from a webpage.
5. generate_image: Generate images from text descriptions using AI image models.
- Use this when the user asks you to create, generate, draw, design, or make an image.
- Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork"
- Args:
- prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood.
- n: Number of images to generate (1-4, default: 1)
- Returns: A dictionary with the generated image URL in the "src" field, along with metadata.
- CRITICAL: After calling generate_image, you MUST call `display_image` with the returned "src" URL
to actually show the image in the chat. The generate_image tool only generates the image and returns
the URL it does NOT display anything. You must always follow up with display_image.
- IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim -
expand and improve the prompt with specific details about style, lighting, composition, and mood.
- If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details.
6. scrape_webpage: Scrape and extract the main content from a webpage.
- Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage.
- IMPORTANT: This is different from link_preview:
* link_preview: Only fetches metadata (title, description, thumbnail) for display
@ -123,7 +138,7 @@ You have access to the following tools:
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
6. save_memory: Save facts, preferences, or context about the user for personalized responses.
7. save_memory: Save facts, preferences, or context about the user for personalized responses.
- Use this when the user explicitly or implicitly shares information worth remembering.
- Trigger scenarios:
* User says "remember this", "keep this in mind", "note that", or similar
@ -146,7 +161,7 @@ You have access to the following tools:
- IMPORTANT: Only save information that would be genuinely useful for future conversations.
Don't save trivial or temporary information.
7. recall_memory: Retrieve relevant memories about the user for personalized responses.
8. recall_memory: Retrieve relevant memories about the user for personalized responses.
- Use this to access stored information about the user.
- Trigger scenarios:
* You need user context to give a better, more personalized answer
@ -281,6 +296,22 @@ You have access to the following tools:
- Then, if the content contains useful diagrams/images like `![Neural Network Diagram](https://example.com/nn-diagram.png)`:
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
- Then provide your explanation, referencing the displayed image
- User: "Generate an image of a cat"
- Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")`
- Step 2: Use the returned "src" URL to display it: `display_image(src="<returned_url>", alt="A fluffy orange tabby cat on a windowsill", title="Generated Image")`
- User: "Create a landscape painting of mountains"
- Step 1: `generate_image(prompt="Majestic snow-capped mountain range at sunset, dramatic orange and purple sky, alpine meadow with wildflowers in the foreground, oil painting style with visible brushstrokes, inspired by the Hudson River School art movement")`
- Step 2: `display_image(src="<returned_url>", alt="Mountain landscape painting", title="Generated Image")`
- User: "Draw me a logo for a coffee shop called Bean Dream"
- Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")`
- Step 2: `display_image(src="<returned_url>", alt="Bean Dream coffee shop logo", title="Generated Image")`
- User: "Make a wide banner image for my blog about AI"
- Step 1: `generate_image(prompt="Wide banner illustration for an AI technology blog, featuring abstract neural network patterns, glowing blue and purple connections, modern futuristic aesthetic, digital art style, clean and professional")`
- Step 2: `display_image(src="<returned_url>", alt="AI blog banner", title="Generated Image")`
</tool_call_examples>
"""

View file

@ -8,6 +8,7 @@ Available tools:
- search_knowledge_base: Search the user's personal knowledge base
- search_surfsense_docs: Search Surfsense documentation for usage help
- generate_podcast: Generate audio podcasts from content
- generate_image: Generate images from text descriptions using AI models
- link_preview: Fetch rich previews for URLs
- display_image: Display images in chat
- scrape_webpage: Extract content from webpages
@ -18,6 +19,7 @@ Available tools:
# Registry exports
# Tool factory exports (for direct use)
from .display_image import create_display_image_tool
from .generate_image import create_generate_image_tool
from .knowledge_base import (
CONNECTOR_DESCRIPTIONS,
create_search_knowledge_base_tool,
@ -47,6 +49,7 @@ __all__ = [
"build_tools",
# Tool factories
"create_display_image_tool",
"create_generate_image_tool",
"create_generate_podcast_tool",
"create_link_preview_tool",
"create_recall_memory_tool",

View file

@ -82,14 +82,20 @@ def create_display_image_tool():
domain = extract_domain(src)
# Determine aspect ratio based on common image sources
ratio = "16:9" # Default
if "unsplash.com" in src or "pexels.com" in src:
# Determine aspect ratio based on image source
# AI-generated images should use "auto" to preserve their native ratio
is_generated = "/image-generations/" in src
if is_generated:
ratio = "auto"
domain = "ai-generated"
elif "unsplash.com" in src or "pexels.com" in src:
ratio = "16:9"
elif (
"imgur.com" in src or "github.com" in src or "githubusercontent.com" in src
):
ratio = "auto"
else:
ratio = "auto"
return {
"id": image_id,

View file

@ -0,0 +1,242 @@
"""
Image generation tool for the SurfSense agent.
This module provides a tool that generates images using litellm.aimage_generation()
and returns the result via the existing display_image tool format so the frontend
renders the generated image inline in the chat.
Config resolution:
1. Uses the search space's image_generation_config_id preference
2. Falls back to Auto mode (router load balancing) if available
3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs)
"""
import logging
from typing import Any
from langchain_core.tools import tool
from litellm import aimage_generation
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__)
# Provider mapping (same as routes)
_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini",
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock",
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
if custom_provider:
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}"
def _get_global_image_gen_config(config_id: int) -> dict | None:
"""Get a global image gen config by negative ID."""
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if cfg.get("id") == config_id:
return cfg
return None
def create_generate_image_tool(
search_space_id: int,
db_session: AsyncSession,
):
"""
Factory function to create the generate_image tool.
Args:
search_space_id: The search space ID (for config resolution)
db_session: Async database session
"""
@tool
async def generate_image(
prompt: str,
n: int = 1,
) -> dict[str, Any]:
"""
Generate an image from a text description using AI image models.
Use this tool when the user asks you to create, generate, draw, or make an image.
The generated image will be displayed directly in the chat.
Args:
prompt: A detailed text description of the image to generate.
Be specific about subject, style, colors, composition, and mood.
n: Number of images to generate (1-4). Default: 1
Returns:
A dictionary containing the generated image(s) for display in the chat.
"""
try:
# Resolve the image generation config from the search space preference
result = await db_session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
return {"error": "Search space not found"}
config_id = (
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
)
# Build generation kwargs
# NOTE: size, quality, and style are intentionally NOT passed.
# Different models support different values for these params
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
# gpt-image-1 wants "high"/"medium"/"low"; size options also
# differ). Letting the model use its own defaults avoids errors.
gen_kwargs: dict[str, Any] = {}
if n is not None and n > 1:
gen_kwargs["n"] = n
# Call litellm based on config type
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
return {
"error": "No image generation models configured. "
"Please add an image model in Settings > Image Models."
}
response = await ImageGenRouterService.aimage_generation(
prompt=prompt, model="auto", **gen_kwargs
)
elif config_id < 0:
cfg = _get_global_image_gen_config(config_id)
if not cfg:
return {"error": f"Image generation config {config_id} not found"}
model_string = _build_model_string(
cfg.get("provider", ""),
cfg["model_name"],
cfg.get("custom_provider"),
)
gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"):
gen_kwargs["api_base"] = cfg["api_base"]
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
gen_kwargs.update(cfg["litellm_params"])
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = user-created ImageGenerationConfig
cfg_result = await db_session.execute(
select(ImageGenerationConfig).filter(
ImageGenerationConfig.id == config_id
)
)
db_cfg = cfg_result.scalars().first()
if not db_cfg:
return {"error": f"Image generation config {config_id} not found"}
model_string = _build_model_string(
db_cfg.provider.value,
db_cfg.model_name,
db_cfg.custom_provider,
)
gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base:
gen_kwargs["api_base"] = db_cfg.api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
gen_kwargs.update(db_cfg.litellm_params)
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
)
# Parse the response and store in DB
response_dict = (
response.model_dump()
if hasattr(response, "model_dump")
else dict(response)
)
# Generate a random access token for this image
access_token = generate_image_token()
# Save to image_generations table for history
db_image_gen = ImageGeneration(
prompt=prompt,
model=getattr(response, "_hidden_params", {}).get("model"),
n=n,
image_generation_config_id=config_id,
response_data=response_dict,
search_space_id=search_space_id,
access_token=access_token,
)
db_session.add(db_image_gen)
await db_session.commit()
await db_session.refresh(db_image_gen)
# Extract image URLs from response
images = response_dict.get("data", [])
if not images:
return {"error": "No images were generated"}
first_image = images[0]
revised_prompt = first_image.get("revised_prompt", prompt)
# Resolve image URL:
# - If the API returned a URL, use it directly.
# - If the API returned b64_json (e.g. gpt-image-1), serve the
# image through our backend endpoint to avoid bloating the
# LLM context with megabytes of base64 data.
if first_image.get("url"):
image_url = first_image["url"]
elif first_image.get("b64_json"):
backend_url = config.BACKEND_URL or "http://localhost:8000"
image_url = (
f"{backend_url}/api/v1/image-generations/"
f"{db_image_gen.id}/image?token={access_token}"
)
else:
return {"error": "No displayable image data in the response"}
return {
"src": image_url,
"alt": revised_prompt or prompt,
"title": "Generated Image",
"description": revised_prompt if revised_prompt != prompt else None,
"generated": True,
"prompt": prompt,
"image_count": len(images),
}
except Exception as e:
logger.exception("Image generation failed in tool")
return {
"error": f"Image generation failed: {e!s}",
"prompt": prompt,
}
return generate_image

View file

@ -17,6 +17,8 @@ from fake_useragent import UserAgent
from langchain_core.tools import tool
from playwright.async_api import async_playwright
from app.utils.proxy_config import get_playwright_proxy, get_residential_proxy_url
logger = logging.getLogger(__name__)
@ -186,9 +188,15 @@ async def fetch_with_chromium(url: str) -> dict[str, Any] | None:
ua = UserAgent()
user_agent = ua.random
# Use residential proxy if configured
playwright_proxy = get_playwright_proxy()
# Use Playwright to fetch the page
async with async_playwright() as p:
browser = await p.chromium.launch(headless=True)
launch_kwargs: dict = {"headless": True}
if playwright_proxy:
launch_kwargs["proxy"] = playwright_proxy
browser = await p.chromium.launch(**launch_kwargs)
context = await browser.new_context(user_agent=user_agent)
page = await context.new_page()
@ -283,12 +291,16 @@ def create_link_preview_tool():
ua = UserAgent()
user_agent = ua.random
# Use residential proxy if configured
proxy_url = get_residential_proxy_url()
# Use a browser-like User-Agent to fetch Open Graph metadata.
# We're only fetching publicly available metadata (title, description, thumbnail)
# that websites intentionally expose via OG tags for link preview purposes.
async with httpx.AsyncClient(
timeout=10.0,
follow_redirects=True,
proxy=proxy_url,
headers={
"User-Agent": user_agent,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8",

View file

@ -44,6 +44,7 @@ from typing import Any
from langchain_core.tools import BaseTool
from .display_image import create_display_image_tool
from .generate_image import create_generate_image_tool
from .knowledge_base import create_search_knowledge_base_tool
from .link_preview import create_link_preview_tool
from .mcp_tool import load_mcp_tools
@ -125,6 +126,16 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
factory=lambda deps: create_display_image_tool(),
requires=[],
),
# Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.)
ToolDefinition(
name="generate_image",
description="Generate images from text descriptions using AI image models",
factory=lambda deps: create_generate_image_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
),
requires=["search_space_id", "db_session"],
),
# Web scraping tool - extracts content from webpages
ToolDefinition(
name="scrape_webpage",

View file

@ -2,17 +2,26 @@
Web scraping tool for the SurfSense agent.
This module provides a tool for scraping and extracting content from webpages
using the existing WebCrawlerConnector. The scraped content can be used by
the agent to answer questions about web pages.
using the existing WebCrawlerConnector. For YouTube URLs, it fetches the
transcript directly via the YouTubeTranscriptApi instead of crawling the page.
"""
import hashlib
import logging
from typing import Any
from urllib.parse import urlparse
import aiohttp
from fake_useragent import UserAgent
from langchain_core.tools import tool
from requests import Session
from youtube_transcript_api import YouTubeTranscriptApi
from app.connectors.webcrawler_connector import WebCrawlerConnector
from app.tasks.document_processors.youtube_processor import get_youtube_video_id
from app.utils.proxy_config import get_requests_proxies
logger = logging.getLogger(__name__)
def extract_domain(url: str) -> str:
@ -57,6 +66,89 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
return truncated + "\n\n[Content truncated...]", True
async def _scrape_youtube_video(
url: str, video_id: str, max_length: int
) -> dict[str, Any]:
"""
Fetch YouTube video metadata and transcript via the YouTubeTranscriptApi.
Returns a result dict in the same shape as the regular scrape_webpage output.
"""
scrape_id = generate_scrape_id(url)
domain = "youtube.com"
# --- Video metadata via oEmbed ---
residential_proxies = get_requests_proxies()
params = {
"format": "json",
"url": f"https://www.youtube.com/watch?v={video_id}",
}
oembed_url = "https://www.youtube.com/oembed"
try:
async with (
aiohttp.ClientSession() as http_session,
http_session.get(
oembed_url,
params=params,
proxy=residential_proxies["http"] if residential_proxies else None,
) as response,
):
video_data = await response.json()
except Exception:
video_data = {}
title = video_data.get("title", "YouTube Video")
author = video_data.get("author_name", "Unknown")
# --- Transcript via YouTubeTranscriptApi ---
try:
ua = UserAgent()
http_client = Session()
http_client.headers.update({"User-Agent": ua.random})
if residential_proxies:
http_client.proxies.update(residential_proxies)
ytt_api = YouTubeTranscriptApi(http_client=http_client)
captions = ytt_api.fetch(video_id)
transcript_segments = []
for line in captions:
start_time = line.start
duration = line.duration
text = line.text
timestamp = f"[{start_time:.2f}s-{start_time + duration:.2f}s]"
transcript_segments.append(f"{timestamp} {text}")
transcript_text = "\n".join(transcript_segments)
except Exception as e:
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
transcript_text = f"No captions available for this video. Error: {e!s}"
# Build combined content
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
# Truncate if needed
content, was_truncated = truncate_content(content, max_length)
word_count = len(content.split())
description = f"YouTube video by {author}"
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": title,
"description": description,
"content": content,
"domain": domain,
"word_count": word_count,
"was_truncated": was_truncated,
"crawler_type": "youtube_transcript",
"author": author,
}
def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
"""
Factory function to create the scrape_webpage tool.
@ -79,7 +171,8 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
Use this tool when the user wants you to read, summarize, or answer
questions about a specific webpage's content. This tool actually
fetches and reads the full page content.
fetches and reads the full page content. For YouTube video URLs it
fetches the transcript directly instead of crawling the page.
Common triggers:
- "Read this article and summarize it"
@ -114,6 +207,11 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
url = f"https://{url}"
try:
# Check if this is a YouTube URL and use transcript API instead
video_id = get_youtube_video_id(url)
if video_id:
return await _scrape_youtube_video(url, video_id, max_length)
# Create webcrawler connector
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
@ -184,7 +282,7 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
except Exception as e:
error_message = str(e)
print(f"[scrape_webpage] Error scraping {url}: {error_message}")
logger.error(f"[scrape_webpage] Error scraping {url}: {error_message}")
return {
"id": scrape_id,
"assetId": url,

View file

@ -9,7 +9,7 @@ from app.agents.new_chat.checkpointer import (
close_checkpointer,
setup_checkpointer_tables,
)
from app.config import config, initialize_llm_router
from app.config import config, initialize_image_gen_router, initialize_llm_router
from app.db import User, create_db_and_tables, get_async_session
from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
@ -26,6 +26,8 @@ async def lifespan(app: FastAPI):
await setup_checkpointer_tables()
# Initialize LLM Router for Auto mode load balancing
initialize_llm_router()
# Initialize Image Generation Router for Auto mode load balancing
initialize_image_gen_router()
# Seed Surfsense documentation
await seed_surfsense_docs()
yield

View file

@ -13,14 +13,15 @@ load_dotenv()
@worker_process_init.connect
def init_worker(**kwargs):
"""Initialize the LLM Router when a Celery worker process starts.
"""Initialize the LLM Router and Image Gen Router when a Celery worker process starts.
This ensures the Auto mode (LiteLLM Router) is available for background tasks
like document summarization.
like document summarization and image generation.
"""
from app.config import initialize_llm_router
from app.config import initialize_image_gen_router, initialize_llm_router
initialize_llm_router()
initialize_image_gen_router()
# Get Celery configuration from environment

View file

@ -81,6 +81,56 @@ def load_router_settings():
return default_settings
def load_global_image_gen_configs():
"""
Load global image generation configurations from YAML file.
Returns:
list: List of global image generation config dictionaries, or empty list
"""
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
if not global_config_file.exists():
return []
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
return data.get("global_image_generation_configs", [])
except Exception as e:
print(f"Warning: Failed to load global image generation configs: {e}")
return []
def load_image_gen_router_settings():
"""
Load router settings for image generation Auto mode from YAML file.
Returns:
dict: Router settings dictionary
"""
default_settings = {
"routing_strategy": "usage-based-routing",
"num_retries": 3,
"allowed_fails": 3,
"cooldown_time": 60,
}
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
if not global_config_file.exists():
return default_settings
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
settings = data.get("image_generation_router_settings", {})
return {**default_settings, **settings}
except Exception as e:
print(f"Warning: Failed to load image generation router settings: {e}")
return default_settings
def initialize_llm_router():
"""
Initialize the LLM Router service for Auto mode.
@ -105,6 +155,33 @@ def initialize_llm_router():
print(f"Warning: Failed to initialize LLM Router: {e}")
def initialize_image_gen_router():
"""
Initialize the Image Generation Router service for Auto mode.
This should be called during application startup.
"""
image_gen_configs = load_global_image_gen_configs()
router_settings = load_image_gen_router_settings()
if not image_gen_configs:
print(
"Info: No global image generation configs found, "
"Image Generation Auto mode will not be available"
)
return
try:
from app.services.image_gen_router_service import ImageGenRouterService
ImageGenRouterService.initialize(image_gen_configs, router_settings)
print(
f"Info: Image Generation Router initialized with {len(image_gen_configs)} models "
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
)
except Exception as e:
print(f"Warning: Failed to initialize Image Generation Router: {e}")
class Config:
# Check if ffmpeg is installed
if not is_ffmpeg_installed():
@ -216,6 +293,12 @@ class Config:
# Router settings for Auto mode (LiteLLM Router load balancing)
ROUTER_SETTINGS = load_router_settings()
# Global Image Generation Configurations (optional)
GLOBAL_IMAGE_GEN_CONFIGS = load_global_image_gen_configs()
# Router settings for Image Generation Auto mode
IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings()
# Chonkie Configuration | Edit this to your needs
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
# Azure OpenAI credentials from environment variables
@ -277,6 +360,14 @@ class Config:
# LlamaCloud API Key
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
# Residential Proxy Configuration (anonymous-proxies.net)
# Used for web crawling and YouTube transcript fetching to avoid IP bans.
RESIDENTIAL_PROXY_USERNAME = os.getenv("RESIDENTIAL_PROXY_USERNAME")
RESIDENTIAL_PROXY_PASSWORD = os.getenv("RESIDENTIAL_PROXY_PASSWORD")
RESIDENTIAL_PROXY_HOSTNAME = os.getenv("RESIDENTIAL_PROXY_HOSTNAME")
RESIDENTIAL_PROXY_LOCATION = os.getenv("RESIDENTIAL_PROXY_LOCATION", "")
RESIDENTIAL_PROXY_TYPE = int(os.getenv("RESIDENTIAL_PROXY_TYPE", "1"))
# Litellm TTS Configuration
TTS_SERVICE = os.getenv("TTS_SERVICE")
TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE")

View file

@ -183,6 +183,69 @@ global_llm_configs:
use_default_system_instructions: true
citations_enabled: true
# =============================================================================
# Image Generation Configuration
# =============================================================================
# These configurations power the image generation feature using litellm.aimage_generation().
# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock,
# Recraft, OpenRouter, Xinference, Nscale
#
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs.
# Router Settings for Image Generation Auto Mode
image_generation_router_settings:
routing_strategy: "usage-based-routing"
num_retries: 3
allowed_fails: 3
cooldown_time: 60
global_image_generation_configs:
# Example: OpenAI DALL-E 3
- id: -1
name: "Global DALL-E 3"
description: "OpenAI's DALL-E 3 for high-quality image generation"
provider: "OPENAI"
model_name: "dall-e-3"
api_key: "sk-your-openai-api-key-here"
api_base: ""
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
litellm_params: {}
# Example: OpenAI GPT Image 1
- id: -2
name: "Global GPT Image 1"
description: "OpenAI's GPT Image 1 model"
provider: "OPENAI"
model_name: "gpt-image-1"
api_key: "sk-your-openai-api-key-here"
api_base: ""
rpm: 50
litellm_params: {}
# Example: Azure OpenAI DALL-E 3
- id: -3
name: "Global Azure DALL-E 3"
description: "Azure-hosted DALL-E 3 deployment"
provider: "AZURE_OPENAI"
model_name: "azure/dall-e-3-deployment"
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
api_version: "2024-02-15-preview"
rpm: 50
litellm_params:
base_model: "dall-e-3"
# Example: OpenRouter Gemini Image Generation
# - id: -4
# name: "Global Gemini Image Gen"
# description: "Google Gemini image generation via OpenRouter"
# provider: "OPENROUTER"
# model_name: "google/gemini-2.5-flash-image"
# api_key: "your-openrouter-api-key-here"
# api_base: ""
# rpm: 30
# litellm_params: {}
# Notes:
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
@ -195,10 +258,11 @@ global_llm_configs:
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
# These help the router distribute load evenly and avoid rate limit errors
#
# AZURE-SPECIFIC NOTES:
# - Always add 'base_model' in litellm_params for Azure deployments
# - This fixes "Could not identify azure model 'X'" warnings
# - base_model should match the underlying OpenAI model (e.g., gpt-4o, gpt-4-turbo, gpt-3.5-turbo)
# - model_name format: "azure/<your-deployment-name>"
# - api_version: Use a recent Azure API version (e.g., "2024-02-15-preview")
# - See: https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
#
# IMAGE GENERATION NOTES:
# - Image generation configs use the same ID scheme as LLM configs (negative for global)
# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure),
# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter)
# - The router uses litellm.aimage_generation() for async image generation
# - Only RPM (requests per minute) is relevant for image generation rate limiting.
# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token.

View file

@ -108,7 +108,9 @@ class AirtableHistoryConnector:
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
if not final_token or (
isinstance(final_token, str) and not final_token.strip()
):
raise ValueError(
"Airtable access token is invalid or empty. "
"Please reconnect your Airtable account."

View file

@ -128,7 +128,9 @@ class ConfluenceHistoryConnector:
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
if not final_token or (
isinstance(final_token, str) and not final_token.strip()
):
raise ValueError(
"Confluence access token is invalid or empty. "
"Please reconnect your Confluence account."

View file

@ -129,7 +129,9 @@ class JiraHistoryConnector:
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
if not final_token or (
isinstance(final_token, str) and not final_token.strip()
):
raise ValueError(
"Jira access token is invalid or empty. "
"Please reconnect your Jira account."

View file

@ -153,7 +153,9 @@ class LinearConnector:
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
if not final_token or (
isinstance(final_token, str) and not final_token.strip()
):
raise ValueError(
"Linear access token is invalid or empty. "
"Please reconnect your Linear account."

View file

@ -14,6 +14,8 @@ from fake_useragent import UserAgent
from firecrawl import AsyncFirecrawlApp
from playwright.async_api import async_playwright
from app.utils.proxy_config import get_playwright_proxy
logger = logging.getLogger(__name__)
@ -165,9 +167,15 @@ class WebCrawlerConnector:
ua = UserAgent()
user_agent = ua.random
# Use residential proxy if configured
playwright_proxy = get_playwright_proxy()
# Use Playwright to fetch the page
async with async_playwright() as p:
browser = await p.chromium.launch(headless=True)
launch_kwargs: dict = {"headless": True}
if playwright_proxy:
launch_kwargs["proxy"] = playwright_proxy
browser = await p.chromium.launch(**launch_kwargs)
context = await browser.new_context(user_agent=user_agent)
page = await context.new_page()

View file

@ -214,6 +214,24 @@ class LiteLLMProvider(str, Enum):
CUSTOM = "CUSTOM"
class ImageGenProvider(str, Enum):
"""
Enum for image generation providers supported by LiteLLM.
This is a subset of LLM providers only those that support image generation.
See: https://docs.litellm.ai/docs/image_generation#supported-providers
"""
OPENAI = "OPENAI"
AZURE_OPENAI = "AZURE_OPENAI"
GOOGLE = "GOOGLE" # Google AI Studio
VERTEX_AI = "VERTEX_AI"
BEDROCK = "BEDROCK" # AWS Bedrock
RECRAFT = "RECRAFT"
OPENROUTER = "OPENROUTER"
XINFERENCE = "XINFERENCE"
NSCALE = "NSCALE"
class LogLevel(str, Enum):
DEBUG = "DEBUG"
INFO = "INFO"
@ -314,6 +332,11 @@ class Permission(str, Enum):
PODCASTS_UPDATE = "podcasts:update"
PODCASTS_DELETE = "podcasts:delete"
# Image Generations
IMAGE_GENERATIONS_CREATE = "image_generations:create"
IMAGE_GENERATIONS_READ = "image_generations:read"
IMAGE_GENERATIONS_DELETE = "image_generations:delete"
# Connectors
CONNECTORS_CREATE = "connectors:create"
CONNECTORS_READ = "connectors:read"
@ -375,6 +398,9 @@ DEFAULT_ROLE_PERMISSIONS = {
Permission.PODCASTS_CREATE.value,
Permission.PODCASTS_READ.value,
Permission.PODCASTS_UPDATE.value,
# Image Generations (create and read, no delete)
Permission.IMAGE_GENERATIONS_CREATE.value,
Permission.IMAGE_GENERATIONS_READ.value,
# Connectors (no delete)
Permission.CONNECTORS_CREATE.value,
Permission.CONNECTORS_READ.value,
@ -404,6 +430,8 @@ DEFAULT_ROLE_PERMISSIONS = {
Permission.LLM_CONFIGS_READ.value,
# Podcasts (read only)
Permission.PODCASTS_READ.value,
# Image Generations (read only)
Permission.IMAGE_GENERATIONS_READ.value,
# Connectors (read only)
Permission.CONNECTORS_READ.value,
# Logs (read only)
@ -969,6 +997,103 @@ class Podcast(BaseModel, TimestampMixin):
thread = relationship("NewChatThread")
class ImageGenerationConfig(BaseModel, TimestampMixin):
"""
Dedicated configuration table for image generation models.
Separate from NewLLMConfig because image generation models don't need
system_instructions, citations_enabled, or use_default_system_instructions.
They only need provider credentials and model parameters.
"""
__tablename__ = "image_generation_configs"
name = Column(String(100), nullable=False, index=True)
description = Column(String(500), nullable=True)
# Provider & model (uses ImageGenProvider, NOT LiteLLMProvider)
provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False)
custom_provider = Column(String(100), nullable=True)
model_name = Column(String(100), nullable=False)
# Credentials
api_key = Column(String, nullable=False)
api_base = Column(String(500), nullable=True)
api_version = Column(String(50), nullable=True) # Azure-specific
# Additional litellm parameters
litellm_params = Column(JSON, nullable=True, default={})
# Relationships
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship(
"SearchSpace", back_populates="image_generation_configs"
)
class ImageGeneration(BaseModel, TimestampMixin):
"""
Stores image generation requests and results using litellm.aimage_generation().
Since aimage_generation is a single async call (not a background job),
there is no status enum. A row with response_data means success;
a row with error_message means failure.
Response data is stored as JSONB matching the litellm output format:
{
"created": int,
"data": [{"b64_json": str|None, "revised_prompt": str|None, "url": str|None}],
"usage": {"prompt_tokens": int, "completion_tokens": int, "total_tokens": int}
}
"""
__tablename__ = "image_generations"
# Request parameters (matching litellm.aimage_generation() params)
prompt = Column(Text, nullable=False)
model = Column(String(200), nullable=True) # e.g., "dall-e-3", "gpt-image-1"
n = Column(Integer, nullable=True, default=1)
quality = Column(
String(50), nullable=True
) # "auto", "high", "medium", "low", "hd", "standard"
size = Column(
String(50), nullable=True
) # "1024x1024", "1536x1024", "1024x1536", etc.
style = Column(String(50), nullable=True) # Model-specific style parameter
response_format = Column(String(50), nullable=True) # "url" or "b64_json"
# Image generation config reference
# 0 = Auto mode (router), negative IDs = global configs from YAML,
# positive IDs = ImageGenerationConfig records in DB
image_generation_config_id = Column(Integer, nullable=True)
# Response data (full litellm response as JSONB) — present on success
response_data = Column(JSONB, nullable=True)
# Error message — present on failure
error_message = Column(Text, nullable=True)
# Signed access token for serving images via <img> tags.
# Stored in DB so it survives SECRET_KEY rotation.
access_token = Column(String(64), nullable=True, index=True)
# Foreign keys
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
created_by_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Relationships
search_space = relationship("SearchSpace", back_populates="image_generations")
created_by = relationship("User", back_populates="image_generations")
class SearchSpace(BaseModel, TimestampMixin):
__tablename__ = "searchspaces"
@ -993,6 +1118,9 @@ class SearchSpace(BaseModel, TimestampMixin):
document_summary_llm_id = Column(
Integer, nullable=True, default=0
) # For document summarization, defaults to Auto mode
image_generation_config_id = Column(
Integer, nullable=True, default=0
) # For image generation, defaults to Auto mode
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
@ -1017,6 +1145,12 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="Podcast.id.desc()",
cascade="all, delete-orphan",
)
image_generations = relationship(
"ImageGeneration",
back_populates="search_space",
order_by="ImageGeneration.id.desc()",
cascade="all, delete-orphan",
)
logs = relationship(
"Log",
back_populates="search_space",
@ -1041,6 +1175,12 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="NewLLMConfig.id",
cascade="all, delete-orphan",
)
image_generation_configs = relationship(
"ImageGenerationConfig",
back_populates="search_space",
order_by="ImageGenerationConfig.id",
cascade="all, delete-orphan",
)
# RBAC relationships
roles = relationship(
@ -1421,6 +1561,13 @@ if config.AUTH_TYPE == "GOOGLE":
passive_deletes=True,
)
# Image generations created by this user
image_generations = relationship(
"ImageGeneration",
back_populates="created_by",
passive_deletes=True,
)
# User memories for personalized AI responses
memories = relationship(
"UserMemory",
@ -1493,6 +1640,13 @@ else:
passive_deletes=True,
)
# Image generations created by this user
image_generations = relationship(
"ImageGeneration",
back_populates="created_by",
passive_deletes=True,
)
# User memories for personalized AI responses
memories = relationship(
"UserMemory",

View file

@ -20,6 +20,7 @@ from .google_drive_add_connector_route import (
from .google_gmail_add_connector_route import (
router as google_gmail_add_connector_router,
)
from .image_generation_routes import router as image_generation_router
from .incentive_tasks_routes import router as incentive_tasks_router
from .jira_add_connector_route import router as jira_add_connector_router
from .linear_add_connector_route import router as linear_add_connector_router
@ -49,6 +50,7 @@ router.include_router(notes_router)
router.include_router(new_chat_router) # Chat with assistant-ui persistence
router.include_router(chat_comments_router)
router.include_router(podcasts_router) # Podcast task status and audio
router.include_router(image_generation_router) # Image generation via litellm
router.include_router(search_source_connectors_router)
router.include_router(google_calendar_add_connector_router)
router.include_router(google_gmail_add_connector_router)

View file

@ -0,0 +1,710 @@
"""
Image Generation routes:
- CRUD for ImageGenerationConfig (user-created image model configs)
- Global image gen configs endpoint (from YAML)
- Image generation execution (calls litellm.aimage_generation())
- CRUD for ImageGeneration records (results)
- Image serving endpoint (serves b64_json images from DB, protected by signed tokens)
"""
import base64
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from litellm import aimage_generation
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import (
ImageGeneration,
ImageGenerationConfig,
Permission,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.schemas import (
GlobalImageGenConfigRead,
ImageGenerationConfigCreate,
ImageGenerationConfigRead,
ImageGenerationConfigUpdate,
ImageGenerationCreate,
ImageGenerationListRead,
ImageGenerationRead,
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
router = APIRouter()
logger = logging.getLogger(__name__)
# Provider mapping for building litellm model strings.
# Only includes providers that support image generation.
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini", # Google AI Studio
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock", # AWS Bedrock
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
def _get_global_image_gen_config(config_id: int) -> dict | None:
"""Get a global image generation configuration by ID (negative IDs)."""
if config_id == IMAGE_GEN_AUTO_MODE_ID:
return {
"id": IMAGE_GEN_AUTO_MODE_ID,
"name": "Auto (Load Balanced)",
"provider": "AUTO",
"model_name": "auto",
"is_auto_mode": True,
}
if config_id > 0:
return None
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if cfg.get("id") == config_id:
return cfg
return None
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
"""Build a litellm model string from provider + model_name."""
if custom_provider:
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}"
async def _execute_image_generation(
session: AsyncSession,
image_gen: ImageGeneration,
search_space: SearchSpace,
) -> None:
"""
Call litellm.aimage_generation() with the appropriate config.
Resolution order:
1. Explicit image_generation_config_id on the request
2. Search space's image_generation_config_id preference
3. Falls back to Auto mode if available
"""
config_id = image_gen.image_generation_config_id
if config_id is None:
config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
image_gen.image_generation_config_id = config_id
# Build kwargs
gen_kwargs = {}
if image_gen.n is not None:
gen_kwargs["n"] = image_gen.n
if image_gen.quality is not None:
gen_kwargs["quality"] = image_gen.quality
if image_gen.size is not None:
gen_kwargs["size"] = image_gen.size
if image_gen.style is not None:
gen_kwargs["style"] = image_gen.style
if image_gen.response_format is not None:
gen_kwargs["response_format"] = image_gen.response_format
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
raise ValueError(
"Auto mode requested but Image Generation Router not initialized. "
"Ensure global_llm_config.yaml has global_image_generation_configs."
)
response = await ImageGenRouterService.aimage_generation(
prompt=image_gen.prompt, model="auto", **gen_kwargs
)
elif config_id < 0:
# Global config from YAML
cfg = _get_global_image_gen_config(config_id)
if not cfg:
raise ValueError(f"Global image generation config {config_id} not found")
model_string = _build_model_string(
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
)
gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"):
gen_kwargs["api_base"] = cfg["api_base"]
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
gen_kwargs.update(cfg["litellm_params"])
# User model override
if image_gen.model:
model_string = image_gen.model
response = await aimage_generation(
prompt=image_gen.prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = DB ImageGenerationConfig
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_cfg = result.scalars().first()
if not db_cfg:
raise ValueError(f"Image generation config {config_id} not found")
model_string = _build_model_string(
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
)
gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base:
gen_kwargs["api_base"] = db_cfg.api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
gen_kwargs.update(db_cfg.litellm_params)
# User model override
if image_gen.model:
model_string = image_gen.model
response = await aimage_generation(
prompt=image_gen.prompt, model=model_string, **gen_kwargs
)
# Store response
image_gen.response_data = (
response.model_dump() if hasattr(response, "model_dump") else dict(response)
)
if not image_gen.model and hasattr(response, "_hidden_params"):
hidden = response._hidden_params
if isinstance(hidden, dict) and hidden.get("model"):
image_gen.model = hidden["model"]
# =============================================================================
# Global Image Generation Configs (from YAML)
# =============================================================================
@router.get(
"/global-image-generation-configs",
response_model=list[GlobalImageGenConfigRead],
)
async def get_global_image_gen_configs(
user: User = Depends(current_active_user),
):
"""Get all global image generation configs. API keys are hidden."""
try:
global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS
safe_configs = []
if global_configs and len(global_configs) > 0:
safe_configs.append(
{
"id": 0,
"name": "Auto (Load Balanced)",
"description": "Automatically routes across available image generation providers.",
"provider": "AUTO",
"custom_provider": None,
"model_name": "auto",
"api_base": None,
"api_version": None,
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
}
)
for cfg in global_configs:
safe_configs.append(
{
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
}
)
return safe_configs
except Exception as e:
logger.exception("Failed to fetch global image generation configs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configs: {e!s}"
) from e
# =============================================================================
# ImageGenerationConfig CRUD
# =============================================================================
@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead)
async def create_image_gen_config(
config_data: ImageGenerationConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create a new image generation config for a search space."""
try:
await check_permission(
session,
user,
config_data.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to create image generation configs in this search space",
)
db_config = ImageGenerationConfig(**config_data.model_dump())
session.add(db_config)
await session.commit()
await session.refresh(db_config)
return db_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to create ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to create config: {e!s}"
) from e
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
async def list_image_gen_configs(
search_space_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""List image generation configs for a search space."""
try:
await check_permission(
session,
user,
search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to view image generation configs in this search space",
)
result = await session.execute(
select(ImageGenerationConfig)
.filter(ImageGenerationConfig.search_space_id == search_space_id)
.order_by(ImageGenerationConfig.created_at.desc())
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to list ImageGenerationConfigs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configs: {e!s}"
) from e
@router.get(
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
)
async def get_image_gen_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get a specific image generation config by ID."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to view image generation configs in this search space",
)
return db_config
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to get ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to fetch config: {e!s}"
) from e
@router.put(
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
)
async def update_image_gen_config(
config_id: int,
update_data: ImageGenerationConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Update an existing image generation config."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to update image generation configs in this search space",
)
for key, value in update_data.model_dump(exclude_unset=True).items():
setattr(db_config, key, value)
await session.commit()
await session.refresh(db_config)
return db_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to update ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to update config: {e!s}"
) from e
@router.delete("/image-generation-configs/{config_id}", response_model=dict)
async def delete_image_gen_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Delete an image generation config."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_DELETE.value,
"You don't have permission to delete image generation configs in this search space",
)
await session.delete(db_config)
await session.commit()
return {
"message": "Image generation config deleted successfully",
"id": config_id,
}
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to delete ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to delete config: {e!s}"
) from e
# =============================================================================
# Image Generation Execution + Results CRUD
# =============================================================================
@router.post("/image-generations", response_model=ImageGenerationRead)
async def create_image_generation(
data: ImageGenerationCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create and execute an image generation request."""
try:
await check_permission(
session,
user,
data.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to create image generations in this search space",
)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == data.search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
)
session.add(db_image_gen)
await session.flush()
try:
await _execute_image_generation(session, db_image_gen, search_space)
except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
except HTTPException:
raise
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error during image generation"
) from None
except Exception as e:
await session.rollback()
logger.exception("Failed to create image generation")
raise HTTPException(
status_code=500, detail=f"Image generation failed: {e!s}"
) from e
@router.get("/image-generations", response_model=list[ImageGenerationListRead])
async def list_image_generations(
search_space_id: int | None = None,
skip: int = 0,
limit: int = 50,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""List image generations."""
if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
if limit > 100:
limit = 100
try:
if search_space_id is not None:
await check_permission(
session,
user,
search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space",
)
result = await session.execute(
select(ImageGeneration)
.filter(ImageGeneration.search_space_id == search_space_id)
.order_by(ImageGeneration.created_at.desc())
.offset(skip)
.limit(limit)
)
else:
result = await session.execute(
select(ImageGeneration)
.join(SearchSpace)
.join(SearchSpaceMembership)
.filter(SearchSpaceMembership.user_id == user.id)
.order_by(ImageGeneration.created_at.desc())
.offset(skip)
.limit(limit)
)
return [
ImageGenerationListRead.from_orm_with_count(img)
for img in result.scalars().all()
]
except HTTPException:
raise
except SQLAlchemyError:
raise HTTPException(
status_code=500, detail="Database error fetching image generations"
) from None
@router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead)
async def get_image_generation(
image_gen_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get a specific image generation by ID."""
try:
result = await session.execute(
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
)
image_gen = result.scalars().first()
if not image_gen:
raise HTTPException(status_code=404, detail="Image generation not found")
await check_permission(
session,
user,
image_gen.search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space",
)
return image_gen
except HTTPException:
raise
except SQLAlchemyError:
raise HTTPException(
status_code=500, detail="Database error fetching image generation"
) from None
@router.delete("/image-generations/{image_gen_id}", response_model=dict)
async def delete_image_generation(
image_gen_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Delete an image generation record."""
try:
result = await session.execute(
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
)
db_image_gen = result.scalars().first()
if not db_image_gen:
raise HTTPException(status_code=404, detail="Image generation not found")
await check_permission(
session,
user,
db_image_gen.search_space_id,
Permission.IMAGE_GENERATIONS_DELETE.value,
"You don't have permission to delete image generations in this search space",
)
await session.delete(db_image_gen)
await session.commit()
return {"message": "Image generation deleted successfully"}
except HTTPException:
raise
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error deleting image generation"
) from None
# =============================================================================
# Image Serving (serves generated images from DB, protected by signed tokens)
# =============================================================================
@router.get("/image-generations/{image_gen_id}/image")
async def serve_generated_image(
image_gen_id: int,
token: str = Query(..., description="Signed access token"),
index: int = 0,
session: AsyncSession = Depends(get_async_session),
):
"""
Serve a generated image by ID, protected by a signed token.
The token is generated when the image URL is created by the generate_image
tool and encodes the image_gen_id, search_space_id, and an expiry timestamp.
This ensures only users with access to the search space can view images,
without requiring auth headers (which <img> tags cannot pass).
Args:
image_gen_id: The image generation record ID
token: HMAC-signed access token (included as query parameter)
index: Which image to serve if multiple were generated (default: 0)
"""
try:
result = await session.execute(
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
)
image_gen = result.scalars().first()
if not image_gen:
raise HTTPException(status_code=404, detail="Image generation not found")
# Verify the access token against the one stored on the record
if not verify_image_token(image_gen.access_token, token):
raise HTTPException(status_code=403, detail="Invalid image access token")
if not image_gen.response_data:
raise HTTPException(status_code=404, detail="No image data available")
images = image_gen.response_data.get("data", [])
if not images or index >= len(images):
raise HTTPException(
status_code=404, detail="Image not found at the specified index"
)
image_entry = images[index]
# If there's a URL, redirect to it
if image_entry.get("url"):
from fastapi.responses import RedirectResponse
return RedirectResponse(url=image_entry["url"])
# If there's b64_json data, decode and serve it
if image_entry.get("b64_json"):
image_bytes = base64.b64decode(image_entry["b64_json"])
return Response(
content=image_bytes,
media_type="image/png",
headers={
"Cache-Control": "public, max-age=86400",
"Content-Disposition": f'inline; filename="generated-{image_gen_id}-{index}.png"',
},
)
raise HTTPException(status_code=404, detail="No displayable image data")
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to serve generated image")
raise HTTPException(
status_code=500, detail=f"Failed to serve image: {e!s}"
) from e

View file

@ -7,6 +7,7 @@ from sqlalchemy.future import select
from app.config import config
from app.db import (
ImageGenerationConfig,
NewLLMConfig,
Permission,
SearchSpace,
@ -387,6 +388,69 @@ async def _get_llm_config_by_id(
return None
async def _get_image_gen_config_by_id(
session: AsyncSession, config_id: int | None
) -> dict | None:
"""
Get an image generation config by ID as a dictionary.
Returns Auto mode for ID 0, global config for negative IDs,
DB ImageGenerationConfig for positive IDs, or None.
"""
if config_id is None:
return None
if config_id == 0:
return {
"id": 0,
"name": "Auto (Load Balanced)",
"description": "Automatically routes requests across available image generation providers",
"provider": "AUTO",
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
}
if config_id < 0:
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if cfg.get("id") == config_id:
return {
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
}
return None
# Positive ID: query ImageGenerationConfig table
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if db_config:
return {
"id": db_config.id,
"name": db_config.name,
"description": db_config.description,
"provider": db_config.provider.value if db_config.provider else None,
"custom_provider": db_config.custom_provider,
"model_name": db_config.model_name,
"api_base": db_config.api_base,
"api_version": db_config.api_version,
"litellm_params": db_config.litellm_params or {},
"created_at": db_config.created_at.isoformat()
if db_config.created_at
else None,
"search_space_id": db_config.search_space_id,
}
return None
@router.get(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
@ -423,12 +487,17 @@ async def get_llm_preferences(
document_summary_llm = await _get_llm_config_by_id(
session, search_space.document_summary_llm_id
)
image_generation_config = await _get_image_gen_config_by_id(
session, search_space.image_generation_config_id
)
return LLMPreferencesRead(
agent_llm_id=search_space.agent_llm_id,
document_summary_llm_id=search_space.document_summary_llm_id,
image_generation_config_id=search_space.image_generation_config_id,
agent_llm=agent_llm,
document_summary_llm=document_summary_llm,
image_generation_config=image_generation_config,
)
except HTTPException:
@ -485,12 +554,17 @@ async def update_llm_preferences(
document_summary_llm = await _get_llm_config_by_id(
session, search_space.document_summary_llm_id
)
image_generation_config = await _get_image_gen_config_by_id(
session, search_space.image_generation_config_id
)
return LLMPreferencesRead(
agent_llm_id=search_space.agent_llm_id,
document_summary_llm_id=search_space.document_summary_llm_id,
image_generation_config_id=search_space.image_generation_config_id,
agent_llm=agent_llm,
document_summary_llm=document_summary_llm,
image_generation_config=image_generation_config,
)
except HTTPException:

View file

@ -21,6 +21,16 @@ from .documents import (
PaginatedResponse,
)
from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest
from .image_generation import (
GlobalImageGenConfigRead,
ImageGenerationConfigCreate,
ImageGenerationConfigPublic,
ImageGenerationConfigRead,
ImageGenerationConfigUpdate,
ImageGenerationCreate,
ImageGenerationListRead,
ImageGenerationRead,
)
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
from .new_chat import (
ChatMessage,
@ -105,11 +115,21 @@ __all__ = [
"DriveItem",
"ExtensionDocumentContent",
"ExtensionDocumentMetadata",
"GlobalImageGenConfigRead",
"GlobalNewLLMConfigRead",
"GoogleDriveIndexRequest",
"GoogleDriveIndexingOptions",
# Base schemas
"IDModel",
# Image Generation Config schemas
"ImageGenerationConfigCreate",
"ImageGenerationConfigPublic",
"ImageGenerationConfigRead",
"ImageGenerationConfigUpdate",
# Image Generation schemas
"ImageGenerationCreate",
"ImageGenerationListRead",
"ImageGenerationRead",
# RBAC schemas
"InviteAcceptRequest",
"InviteAcceptResponse",

View file

@ -0,0 +1,230 @@
"""
Pydantic schemas for Image Generation configs and generation requests.
ImageGenerationConfig: CRUD schemas for user-created image gen model configs.
ImageGeneration: Schemas for the actual image generation requests/results.
GlobalImageGenConfigRead: Schema for admin-configured YAML configs.
"""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import ImageGenProvider
# =============================================================================
# ImageGenerationConfig CRUD Schemas
# =============================================================================
class ImageGenerationConfigBase(BaseModel):
"""Base schema with fields for ImageGenerationConfig."""
name: str = Field(
..., max_length=100, description="User-friendly name for the config"
)
description: str | None = Field(
None, max_length=500, description="Optional description"
)
provider: ImageGenProvider = Field(
...,
description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)",
)
custom_provider: str | None = Field(
None, max_length=100, description="Custom provider name"
)
model_name: str = Field(
..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)"
)
api_key: str = Field(..., description="API key for the provider")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
api_version: str | None = Field(
None,
max_length=50,
description="Azure-specific API version (e.g., '2024-02-15-preview')",
)
litellm_params: dict[str, Any] | None = Field(
default=None, description="Additional LiteLLM parameters"
)
class ImageGenerationConfigCreate(ImageGenerationConfigBase):
"""Schema for creating a new ImageGenerationConfig."""
search_space_id: int = Field(
..., description="Search space ID to associate the config with"
)
class ImageGenerationConfigUpdate(BaseModel):
"""Schema for updating an existing ImageGenerationConfig. All fields optional."""
name: str | None = Field(None, max_length=100)
description: str | None = Field(None, max_length=500)
provider: ImageGenProvider | None = None
custom_provider: str | None = Field(None, max_length=100)
model_name: str | None = Field(None, max_length=100)
api_key: str | None = None
api_base: str | None = Field(None, max_length=500)
api_version: str | None = Field(None, max_length=50)
litellm_params: dict[str, Any] | None = None
class ImageGenerationConfigRead(ImageGenerationConfigBase):
"""Schema for reading an ImageGenerationConfig (includes id and timestamps)."""
id: int
created_at: datetime
search_space_id: int
model_config = ConfigDict(from_attributes=True)
class ImageGenerationConfigPublic(BaseModel):
"""Public schema that hides the API key (for list views)."""
id: int
name: str
description: str | None = None
provider: ImageGenProvider
custom_provider: str | None = None
model_name: str
api_base: str | None = None
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
created_at: datetime
search_space_id: int
model_config = ConfigDict(from_attributes=True)
# =============================================================================
# ImageGeneration (request/result) Schemas
# =============================================================================
class ImageGenerationCreate(BaseModel):
"""Schema for creating an image generation request."""
prompt: str = Field(
...,
min_length=1,
max_length=4000,
description="A text description of the desired image(s)",
)
model: str | None = Field(
None,
max_length=200,
description="The model to use (e.g., 'dall-e-3', 'gpt-image-1'). Overrides the config model.",
)
n: int | None = Field(
None,
ge=1,
le=10,
description="Number of images to generate (1-10).",
)
quality: str | None = Field(None, max_length=50)
size: str | None = Field(None, max_length=50)
style: str | None = Field(None, max_length=50)
response_format: str | None = Field(None, max_length=50)
search_space_id: int = Field(
..., description="Search space ID to associate the generation with"
)
image_generation_config_id: int | None = Field(
None,
description=(
"Image generation config ID. "
"0 = Auto mode (router), negative = global YAML config, positive = DB config. "
"If not provided, uses the search space's image_generation_config_id preference."
),
)
class ImageGenerationRead(BaseModel):
"""Schema for reading an image generation record."""
id: int
prompt: str
model: str | None = None
n: int | None = None
quality: str | None = None
size: str | None = None
style: str | None = None
response_format: str | None = None
image_generation_config_id: int | None = None
response_data: dict[str, Any] | None = None
error_message: str | None = None
search_space_id: int
created_at: datetime
model_config = ConfigDict(from_attributes=True)
class ImageGenerationListRead(BaseModel):
"""Lightweight schema for listing image generations (without full response_data)."""
id: int
prompt: str
model: str | None = None
n: int | None = None
quality: str | None = None
size: str | None = None
search_space_id: int
created_at: datetime
is_success: bool
image_count: int | None = None
model_config = ConfigDict(from_attributes=True)
@classmethod
def from_orm_with_count(cls, obj: Any) -> "ImageGenerationListRead":
"""Create ImageGenerationListRead with computed fields."""
image_count = None
if obj.response_data and isinstance(obj.response_data, dict):
data = obj.response_data.get("data")
if isinstance(data, list):
image_count = len(data)
return cls(
id=obj.id,
prompt=obj.prompt,
model=obj.model,
n=obj.n,
quality=obj.quality,
size=obj.size,
search_space_id=obj.search_space_id,
created_at=obj.created_at,
is_success=obj.response_data is not None,
image_count=image_count,
)
# =============================================================================
# Global Image Gen Config (from YAML)
# =============================================================================
class GlobalImageGenConfigRead(BaseModel):
"""
Schema for reading global image generation configs from YAML.
Global configs have negative IDs. API key is hidden.
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
"""
id: int = Field(
...,
description="Config ID: 0 for Auto mode, negative for global configs",
)
name: str
description: str | None = None
provider: str
custom_provider: str | None = None
model_name: str
api_base: str | None = None
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False

View file

@ -176,12 +176,18 @@ class LLMPreferencesRead(BaseModel):
document_summary_llm_id: int | None = Field(
None, description="ID of the LLM config to use for document summarization"
)
image_generation_config_id: int | None = Field(
None, description="ID of the image generation config to use"
)
agent_llm: dict[str, Any] | None = Field(
None, description="Full config for agent LLM"
)
document_summary_llm: dict[str, Any] | None = Field(
None, description="Full config for document summary LLM"
)
image_generation_config: dict[str, Any] | None = Field(
None, description="Full config for image generation"
)
model_config = ConfigDict(from_attributes=True)
@ -195,3 +201,6 @@ class LLMPreferencesUpdate(BaseModel):
document_summary_llm_id: int | None = Field(
None, description="ID of the LLM config to use for document summarization"
)
image_generation_config_id: int | None = Field(
None, description="ID of the image generation config to use"
)

View file

@ -0,0 +1,278 @@
"""
Image Generation Router Service for Load Balancing
This module provides a singleton LiteLLM Router for automatic load balancing
across multiple image generation deployments. It uses litellm.Router which
natively supports aimage_generation() for async image generation.
The router handles:
- Rate limit management with automatic cooldowns
- Automatic failover and retries
- Usage-based routing to distribute load evenly
Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI,
AWS Bedrock, Recraft, OpenRouter, Xinference, Nscale.
"""
import logging
from typing import Any
from litellm import Router
from litellm.utils import ImageResponse
logger = logging.getLogger(__name__)
# Special ID for Auto mode - uses router for load balancing
IMAGE_GEN_AUTO_MODE_ID = 0
# Provider mapping for LiteLLM model string construction.
# Only includes providers that support image generation.
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
IMAGE_GEN_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini", # Google AI Studio
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock", # AWS Bedrock
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
class ImageGenRouterService:
"""
Singleton service for managing LiteLLM Router for image generation.
The router provides automatic load balancing, failover, and rate limit
handling across multiple image generation deployments.
Uses Router.aimage_generation() for async image generation calls.
"""
_instance = None
_router: Router | None = None
_model_list: list[dict] = []
_router_settings: dict = {}
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def get_instance(cls) -> "ImageGenRouterService":
"""Get the singleton instance of the router service."""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def initialize(
cls,
global_configs: list[dict],
router_settings: dict | None = None,
) -> None:
"""
Initialize the router with global image generation configurations.
Args:
global_configs: List of global image gen config dictionaries from YAML
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
"""
instance = cls.get_instance()
if instance._initialized:
logger.debug("Image Generation Router already initialized, skipping")
return
# Build model list from global configs
model_list = []
for config in global_configs:
deployment = cls._config_to_deployment(config)
if deployment:
model_list.append(deployment)
if not model_list:
logger.warning(
"No valid image generation configs found for router initialization"
)
return
instance._model_list = model_list
instance._router_settings = router_settings or {}
# Default router settings optimized for rate limit handling
default_settings = {
"routing_strategy": "usage-based-routing",
"num_retries": 3,
"allowed_fails": 3,
"cooldown_time": 60,
"retry_after": 5,
}
# Merge with provided settings
final_settings = {**default_settings, **instance._router_settings}
try:
instance._router = Router(
model_list=model_list,
routing_strategy=final_settings.get(
"routing_strategy", "usage-based-routing"
),
num_retries=final_settings.get("num_retries", 3),
allowed_fails=final_settings.get("allowed_fails", 3),
cooldown_time=final_settings.get("cooldown_time", 60),
set_verbose=False,
)
instance._initialized = True
logger.info(
f"Image Generation Router initialized with {len(model_list)} deployments, "
f"strategy: {final_settings.get('routing_strategy')}"
)
except Exception as e:
logger.error(f"Failed to initialize Image Generation Router: {e}")
instance._router = None
@classmethod
def _config_to_deployment(cls, config: dict) -> dict | None:
"""
Convert a global image gen config to a router deployment entry.
Args:
config: Global image gen config dictionary
Returns:
Router deployment dictionary or None if invalid
"""
try:
# Skip if essential fields are missing
if not config.get("model_name") or not config.get("api_key"):
return None
# Build model string
if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}"
else:
provider = config.get("provider", "").upper()
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params: dict[str, Any] = {
"model": model_string,
"api_key": config.get("api_key"),
}
# Add optional api_base
if config.get("api_base"):
litellm_params["api_base"] = config["api_base"]
# Add api_version (required for Azure)
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]
# Add any additional litellm parameters
if config.get("litellm_params"):
litellm_params.update(config["litellm_params"])
# All configs use same alias "auto" for unified routing
deployment: dict[str, Any] = {
"model_name": "auto",
"litellm_params": litellm_params,
}
# Add RPM rate limit from config if available
# Note: TPM (tokens per minute) is not applicable for image generation
# since image APIs are rate-limited by requests, not tokens.
if config.get("rpm"):
deployment["rpm"] = config["rpm"]
return deployment
except Exception as e:
logger.warning(f"Failed to convert image gen config to deployment: {e}")
return None
@classmethod
def get_router(cls) -> Router | None:
"""Get the initialized router instance."""
instance = cls.get_instance()
return instance._router
@classmethod
def is_initialized(cls) -> bool:
"""Check if the router has been initialized."""
instance = cls.get_instance()
return instance._initialized and instance._router is not None
@classmethod
def get_model_count(cls) -> int:
"""Get the number of models in the router."""
instance = cls.get_instance()
return len(instance._model_list)
@classmethod
async def aimage_generation(
cls,
prompt: str,
model: str = "auto",
n: int | None = None,
timeout: int = 600,
**kwargs,
) -> ImageResponse:
"""
Generate images using the router for load balancing.
Uses Router.aimage_generation() which distributes requests
across configured image generation deployments.
Parameters like size, quality, style, and response_format are intentionally
omitted to keep the interface model-agnostic. Providers use their own
sensible defaults. If needed, pass them via **kwargs.
Args:
prompt: Text description of the desired image(s)
model: Model alias (default "auto" for router routing)
n: Number of images to generate
timeout: Request timeout in seconds
**kwargs: Additional provider-specific params (size, quality, etc.)
Returns:
ImageResponse from litellm
Raises:
ValueError: If router is not initialized
"""
instance = cls.get_instance()
if not instance._router:
raise ValueError(
"Image Generation Router not initialized. "
"Ensure global_llm_config.yaml has global_image_generation_configs."
)
# Build kwargs for aimage_generation
gen_kwargs: dict[str, Any] = {
"model": model,
"prompt": prompt,
"timeout": timeout,
}
if n is not None:
gen_kwargs["n"] = n
gen_kwargs.update(kwargs)
return await instance._router.aimage_generation(**gen_kwargs)
def is_image_gen_auto_mode(config_id: int | None) -> bool:
"""
Check if the given config ID represents Image Generation Auto mode.
Args:
config_id: The config ID to check
Returns:
True if this is Auto mode, False otherwise
"""
return config_id == IMAGE_GEN_AUTO_MODE_ID

View file

@ -1,6 +1,7 @@
"""Celery tasks for connector indexing."""
import logging
import traceback
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
@ -11,6 +12,36 @@ from app.config import config
logger = logging.getLogger(__name__)
def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> None:
"""
Handle greenlet_spawn errors with detailed logging for debugging.
The 'greenlet_spawn has not been called' error occurs when:
1. SQLAlchemy lazy-loads a relationship outside of an async context
2. A sync operation is called from an async context (or vice versa)
3. Session objects are accessed after the session is closed
This helper logs detailed context to help identify the root cause.
"""
error_str = str(e)
if "greenlet_spawn has not been called" in error_str:
logger.error(
f"GREENLET ERROR in {task_name} for connector {connector_id}: {error_str}\n"
f"This error typically occurs when SQLAlchemy tries to lazy-load a relationship "
f"outside of an async context. Check for:\n"
f"1. Accessing relationship attributes (e.g., document.chunks, connector.search_space) "
f"without using selectinload() or joinedload()\n"
f"2. Accessing model attributes after the session is closed\n"
f"3. Passing ORM objects between different async contexts\n"
f"Stack trace:\n{traceback.format_exc()}"
)
else:
logger.error(
f"Error in {task_name} for connector {connector_id}: {error_str}\n"
f"Stack trace:\n{traceback.format_exc()}"
)
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
@ -46,6 +77,9 @@ def index_slack_messages_task(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_slack_messages", connector_id)
raise
finally:
loop.close()
@ -89,6 +123,9 @@ def index_notion_pages_task(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_notion_pages", connector_id)
raise
finally:
loop.close()
@ -347,6 +384,9 @@ def index_google_calendar_events_task(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
raise
finally:
loop.close()
@ -696,6 +736,9 @@ def index_crawled_urls_task(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
raise
finally:
loop.close()

View file

@ -27,12 +27,12 @@ from app.agents.new_chat.llm_config import (
load_llm_config_from_yaml,
)
from app.db import Document, SurfsenseDocsDocument
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.schemas.new_chat import ChatAttachment
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
)
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
@ -1211,9 +1211,10 @@ async def stream_new_chat(
# Generate LLM title for new chats after first response
# Check if this is the first assistant response by counting existing assistant messages
from app.db import NewChatMessage, NewChatThread
from sqlalchemy import func
from app.db import NewChatMessage, NewChatThread
assistant_count_result = await session.execute(
select(func.count(NewChatMessage.id)).filter(
NewChatMessage.thread_id == chat_id,
@ -1231,10 +1232,12 @@ async def stream_new_chat(
# Truncate inputs to avoid context length issues
truncated_query = user_query[:500]
truncated_response = accumulated_text[:1000]
title_result = await title_chain.ainvoke({
"user_query": truncated_query,
"assistant_response": truncated_response,
})
title_result = await title_chain.ainvoke(
{
"user_query": truncated_query,
"assistant_response": truncated_response,
}
)
# Extract and clean the title
if title_result and hasattr(title_result, "content"):
@ -1242,7 +1245,7 @@ async def stream_new_chat(
# Validate the title (reasonable length)
if raw_title and len(raw_title) <= 100:
# Remove any quotes or extra formatting
generated_title = raw_title.strip('"\'')
generated_title = raw_title.strip("\"'")
except Exception:
generated_title = None

View file

@ -57,6 +57,34 @@ def safe_set_chunks(document: Document, chunks: list) -> None:
set_committed_value(document, "chunks", chunks)
def parse_date_flexible(date_str: str) -> datetime:
"""
Parse date from multiple common formats.
Args:
date_str: Date string to parse
Returns:
Parsed datetime object
Raises:
ValueError: If unable to parse the date string
"""
formats = ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]
for fmt in formats:
try:
return datetime.strptime(date_str.rstrip("Z"), fmt)
except ValueError:
continue
# Try ISO format as fallback
try:
return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except ValueError as err:
raise ValueError(f"Unable to parse date: {date_str}") from err
async def check_duplicate_document_by_hash(
session: AsyncSession, content_hash: str
) -> Document | None:
@ -188,6 +216,26 @@ def calculate_date_range(
)
end_date_str = end_date if end_date else calculated_end_date.strftime("%Y-%m-%d")
# FIX: Ensure end_date is at least 1 day after start_date to avoid
# "start_date must be strictly before end_date" errors when dates are the same
# (e.g., when last_indexed_at is today)
if start_date_str == end_date_str:
logger.info(
f"Start date ({start_date_str}) equals end date ({end_date_str}), "
"adjusting end date to next day to ensure valid date range"
)
# Parse end_date and add 1 day
try:
end_dt = parse_date_flexible(end_date_str)
except ValueError:
logger.warning(
f"Could not parse end_date '{end_date_str}', using current date"
)
end_dt = datetime.now()
end_dt = end_dt + timedelta(days=1)
end_date_str = end_dt.strftime("%Y-%m-%d")
logger.info(f"Adjusted end date to {end_date_str}")
return start_date_str, end_date_str

View file

@ -31,6 +31,7 @@ from .base import (
get_connector_by_id,
get_current_timestamp,
logger,
parse_date_flexible,
safe_set_chunks,
update_connector_last_indexed,
)
@ -222,6 +223,26 @@ async def index_google_calendar_events(
start_date_str = start_date
end_date_str = end_date
# FIX: Ensure end_date is at least 1 day after start_date to avoid
# "start_date must be strictly before end_date" errors when dates are the same
# (e.g., when last_indexed_at is today)
if start_date_str == end_date_str:
logger.info(
f"Start date ({start_date_str}) equals end date ({end_date_str}), "
"adjusting end date to next day to ensure valid date range"
)
# Parse end_date and add 1 day
try:
end_dt = parse_date_flexible(end_date_str)
except ValueError:
logger.warning(
f"Could not parse end_date '{end_date_str}', using current date"
)
end_dt = datetime.now()
end_dt = end_dt + timedelta(days=1)
end_date_str = end_dt.strftime("%Y-%m-%d")
logger.info(f"Adjusted end date to {end_date_str}")
await task_logger.log_task_progress(
log_entry,
f"Fetching Google Calendar events from {start_date_str} to {end_date_str}",

View file

@ -202,13 +202,44 @@ async def index_notion_pages(
"Recommend reconnecting with OAuth."
)
except Exception as e:
await task_logger.log_task_failure(
log_entry,
f"Failed to get Notion pages for connector {connector_id}",
str(e),
{"error_type": "PageFetchError"},
error_str = str(e)
# Check if this is an unsupported block type error (transcription, ai_block, etc.)
# These are known Notion API limitations and should be logged as warnings, not errors
unsupported_block_errors = [
"transcription is not supported",
"ai_block is not supported",
"is not supported via the API",
]
is_unsupported_block_error = any(
err in error_str.lower() for err in unsupported_block_errors
)
logger.error(f"Error fetching Notion pages: {e!s}", exc_info=True)
if is_unsupported_block_error:
# Log as warning since this is a known Notion API limitation
logger.warning(
f"Notion API limitation for connector {connector_id}: {error_str}. "
"This is a known issue with Notion AI blocks (transcription, ai_block) "
"that are not accessible via the Notion API."
)
await task_logger.log_task_failure(
log_entry,
"Failed to get Notion pages: Notion API limitation",
f"{error_str} - This page contains Notion AI content (transcription/ai_block) that cannot be accessed via the API.",
{"error_type": "UnsupportedBlockType", "is_known_limitation": True},
)
else:
# Log as error for other failures
logger.error(
f"Error fetching Notion pages for connector {connector_id}: {error_str}",
exc_info=True,
)
await task_logger.log_task_failure(
log_entry,
f"Failed to get Notion pages for connector {connector_id}",
str(e),
{"error_type": "PageFetchError"},
)
await notion_client.close()
return 0, f"Failed to get Notion pages: {e!s}"

View file

@ -117,10 +117,15 @@ async def index_crawled_urls(
api_key = connector.config.get("FIRECRAWL_API_KEY")
# Get URLs from connector config
urls = parse_webcrawler_urls(connector.config.get("INITIAL_URLS"))
raw_initial_urls = connector.config.get("INITIAL_URLS")
urls = parse_webcrawler_urls(raw_initial_urls)
# DEBUG: Log connector config details for troubleshooting empty URL issues
logger.info(
f"Starting crawled web page indexing for connector {connector_id} with {len(urls)} URLs"
f"Starting crawled web page indexing for connector {connector_id} with {len(urls)} URLs. "
f"Connector name: {connector.name}, "
f"INITIAL_URLS type: {type(raw_initial_urls).__name__}, "
f"INITIAL_URLS value: {repr(raw_initial_urls)[:200] if raw_initial_urls else 'None'}"
)
# Initialize webcrawler client
@ -137,11 +142,18 @@ async def index_crawled_urls(
# Validate URLs
if not urls:
# DEBUG: Log detailed connector config for troubleshooting
logger.error(
f"No URLs provided for indexing. Connector ID: {connector_id}, "
f"Connector name: {connector.name}, "
f"Config keys: {list(connector.config.keys()) if connector.config else 'None'}, "
f"INITIAL_URLS raw value: {raw_initial_urls!r}"
)
await task_logger.log_task_failure(
log_entry,
"No URLs provided for indexing",
"Empty URL list",
{"error_type": "ValidationError"},
f"Empty URL list. INITIAL_URLS value: {repr(raw_initial_urls)[:100]}",
{"error_type": "ValidationError", "connector_name": connector.name},
)
return 0, "No URLs provided for indexing"

View file

@ -10,6 +10,8 @@ import logging
from urllib.parse import parse_qs, urlparse
import aiohttp
from fake_useragent import UserAgent
from requests import Session
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from youtube_transcript_api import YouTubeTranscriptApi
@ -23,6 +25,7 @@ from app.utils.document_converters import (
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.proxy_config import get_requests_proxies
from .base import (
check_document_by_unique_identifier,
@ -200,9 +203,16 @@ async def add_youtube_video_document(
}
oembed_url = "https://www.youtube.com/oembed"
# Build residential proxy URL (if configured)
residential_proxies = get_requests_proxies()
async with (
aiohttp.ClientSession() as http_session,
http_session.get(oembed_url, params=params) as response,
http_session.get(
oembed_url,
params=params,
proxy=residential_proxies["http"] if residential_proxies else None,
) as response,
):
video_data = await response.json()
@ -228,7 +238,12 @@ async def add_youtube_video_document(
)
try:
ytt_api = YouTubeTranscriptApi()
ua = UserAgent()
http_client = Session()
http_client.headers.update({"User-Agent": ua.random})
if residential_proxies:
http_client.proxies.update(residential_proxies)
ytt_api = YouTubeTranscriptApi(http_client=http_client)
captions = ytt_api.fetch(video_id)
# Include complete caption information with timestamps
transcript_segments = []

View file

@ -219,7 +219,9 @@ class CustomBearerTransport(BearerTransport):
# Decode JWT to get user_id for refresh token creation
try:
payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False})
payload = jwt.decode(
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
)
user_id = uuid.UUID(payload.get("sub"))
refresh_token = await create_refresh_token(user_id)
except Exception as e:

View file

@ -0,0 +1,86 @@
"""
Residential proxy configuration utility.
Reads proxy credentials from the application Config and provides helper
functions that return proxy configs in the format expected by different
HTTP libraries (requests, httpx, aiohttp, Playwright).
"""
import base64
import json
import logging
from app.config import Config
logger = logging.getLogger(__name__)
def _build_password_b64() -> str | None:
"""
Build the base64-encoded password dict required by anonymous-proxies.net.
Returns ``None`` when the required config values are not set.
"""
password = Config.RESIDENTIAL_PROXY_PASSWORD
if not password:
return None
password_dict = {
"p": password,
"l": Config.RESIDENTIAL_PROXY_LOCATION,
"t": Config.RESIDENTIAL_PROXY_TYPE,
}
return base64.b64encode(json.dumps(password_dict).encode("utf-8")).decode("utf-8")
def get_residential_proxy_url() -> str | None:
"""
Return the fully-formed residential proxy URL, or ``None`` when not
configured.
The URL format is::
http://<username>:<base64_password>@<hostname>/
"""
username = Config.RESIDENTIAL_PROXY_USERNAME
hostname = Config.RESIDENTIAL_PROXY_HOSTNAME
password_b64 = _build_password_b64()
if not all([username, hostname, password_b64]):
return None
return f"http://{username}:{password_b64}@{hostname}/"
def get_requests_proxies() -> dict[str, str] | None:
"""
Return a ``{"http": , "https": }`` dict suitable for
``requests.Session.proxies`` and ``aiohttp`` ``proxy=`` kwarg,
or ``None`` when not configured.
"""
proxy_url = get_residential_proxy_url()
if proxy_url is None:
return None
return {"http": proxy_url, "https": proxy_url}
def get_playwright_proxy() -> dict[str, str] | None:
"""
Return a Playwright-compatible proxy dict::
{"server": "http://host:port", "username": "", "password": ""}
or ``None`` when not configured.
"""
username = Config.RESIDENTIAL_PROXY_USERNAME
hostname = Config.RESIDENTIAL_PROXY_HOSTNAME
password_b64 = _build_password_b64()
if not all([username, hostname, password_b64]):
return None
return {
"server": f"http://{hostname}",
"username": username,
"password": password_b64,
}

View file

@ -0,0 +1,44 @@
"""
Access token utilities for generated images.
Provides token generation and verification so that generated images can be
served via <img> tags (which cannot pass auth headers) while still
restricting access to authorised users.
Each image generation record stores its own random access token. The token
is verified by comparing the incoming query-parameter value against the
stored value in the database. This approach:
* Survives SECRET_KEY rotation tokens are random, not derived from a key.
* Allows explicit revocation just clear the column.
* Is immune to timing attacks uses ``hmac.compare_digest``.
"""
import hmac
import secrets
def generate_image_token() -> str:
"""
Generate a cryptographically random access token for an image.
Returns:
A 64-character URL-safe hex string.
"""
return secrets.token_hex(32)
def verify_image_token(stored_token: str | None, provided_token: str) -> bool:
"""
Constant-time comparison of a stored token against a user-provided one.
Args:
stored_token: The token persisted on the ImageGeneration record.
provided_token: The token from the URL query parameter.
Returns:
True if the tokens match, False otherwise.
"""
if not stored_token or not provided_token:
return False
return hmac.compare_digest(stored_token, provided_token)

View file

@ -7,6 +7,7 @@ import {
ChevronRight,
FileText,
Globe,
ImageIcon,
type LucideIcon,
Menu,
MessageSquare,
@ -19,6 +20,7 @@ import { useTranslations } from "next-intl";
import { useCallback, useEffect, useState } from "react";
import { PublicChatSnapshotsManager } from "@/components/public-chat-snapshots/public-chat-snapshots-manager";
import { GeneralSettingsManager } from "@/components/settings/general-settings-manager";
import { ImageModelManager } from "@/components/settings/image-model-manager";
import { LLMRoleManager } from "@/components/settings/llm-role-manager";
import { ModelConfigManager } from "@/components/settings/model-config-manager";
import { PromptConfigManager } from "@/components/settings/prompt-config-manager";
@ -52,6 +54,12 @@ const settingsNavItems: SettingsNavItem[] = [
descriptionKey: "nav_role_assignments_desc",
icon: Brain,
},
{
id: "image-models",
labelKey: "nav_image_models",
descriptionKey: "nav_image_models_desc",
icon: ImageIcon,
},
{
id: "prompts",
labelKey: "nav_system_instructions",
@ -282,8 +290,11 @@ function SettingsContent({
<GeneralSettingsManager searchSpaceId={searchSpaceId} />
)}
{activeSection === "models" && <ModelConfigManager searchSpaceId={searchSpaceId} />}
{activeSection === "roles" && <LLMRoleManager searchSpaceId={searchSpaceId} />}
{activeSection === "prompts" && <PromptConfigManager searchSpaceId={searchSpaceId} />}
{activeSection === "roles" && <LLMRoleManager searchSpaceId={searchSpaceId} />}
{activeSection === "image-models" && (
<ImageModelManager searchSpaceId={searchSpaceId} />
)}
{activeSection === "prompts" && <PromptConfigManager searchSpaceId={searchSpaceId} />}
{activeSection === "public-links" && (
<PublicChatSnapshotsManager searchSpaceId={searchSpaceId} />
)}

View file

@ -0,0 +1,91 @@
import { atomWithMutation } from "jotai-tanstack-query";
import { toast } from "sonner";
import type {
CreateImageGenConfigRequest,
GetImageGenConfigsResponse,
UpdateImageGenConfigRequest,
UpdateImageGenConfigResponse,
} from "@/contracts/types/new-llm-config.types";
import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { queryClient } from "@/lib/query-client/client";
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
/**
* Mutation atom for creating a new ImageGenerationConfig
*/
export const createImageGenConfigMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: ["image-gen-configs", "create"],
enabled: !!searchSpaceId,
mutationFn: async (request: CreateImageGenConfigRequest) => {
return imageGenConfigApiService.createConfig(request);
},
onSuccess: () => {
toast.success("Image model configuration created");
queryClient.invalidateQueries({
queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
});
},
onError: (error: Error) => {
toast.error(error.message || "Failed to create image model configuration");
},
};
});
/**
* Mutation atom for updating an existing ImageGenerationConfig
*/
export const updateImageGenConfigMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: ["image-gen-configs", "update"],
enabled: !!searchSpaceId,
mutationFn: async (request: UpdateImageGenConfigRequest) => {
return imageGenConfigApiService.updateConfig(request);
},
onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => {
toast.success("Image model configuration updated");
queryClient.invalidateQueries({
queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
});
queryClient.invalidateQueries({
queryKey: cacheKeys.imageGenConfigs.byId(request.id),
});
},
onError: (error: Error) => {
toast.error(error.message || "Failed to update image model configuration");
},
};
});
/**
* Mutation atom for deleting an ImageGenerationConfig
*/
export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: ["image-gen-configs", "delete"],
enabled: !!searchSpaceId,
mutationFn: async (id: number) => {
return imageGenConfigApiService.deleteConfig(id);
},
onSuccess: (_, id: number) => {
toast.success("Image model configuration deleted");
queryClient.setQueryData(
cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
(oldData: GetImageGenConfigsResponse | undefined) => {
if (!oldData) return oldData;
return oldData.filter((config) => config.id !== id);
}
);
},
onError: (error: Error) => {
toast.error(error.message || "Failed to delete image model configuration");
},
};
});

View file

@ -0,0 +1,33 @@
import { atomWithQuery } from "jotai-tanstack-query";
import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
/**
* Query atom for fetching user-created image gen configs for the active search space
*/
export const imageGenConfigsAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
enabled: !!searchSpaceId,
staleTime: 5 * 60 * 1000, // 5 minutes
queryFn: async () => {
return imageGenConfigApiService.getConfigs(Number(searchSpaceId));
},
};
});
/**
* Query atom for fetching global image gen configs (from YAML, negative IDs)
*/
export const globalImageGenConfigsAtom = atomWithQuery(() => {
return {
queryKey: cacheKeys.imageGenConfigs.global(),
staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change
queryFn: async () => {
return imageGenConfigApiService.getGlobalConfigs();
},
};
});

View file

@ -4,16 +4,20 @@ import Image from "next/image";
import Link from "next/link";
import { cn } from "@/lib/utils";
export const Logo = ({ className }: { className?: string }) => {
return (
<Link href="/">
<Image
src="/icon-128.svg"
className={cn("dark:invert", className)}
alt="logo"
width={128}
height={128}
/>
</Link>
export const Logo = ({ className, disableLink = false }: { className?: string; disableLink?: boolean }) => {
const image = (
<Image
src="/icon-128.svg"
className={cn("dark:invert", className)}
alt="logo"
width={128}
height={128}
/>
);
if (disableLink) {
return image;
}
return <Link href="/">{image}</Link>;
};

View file

@ -351,14 +351,14 @@ export const ComposerAddAttachment: FC = () => {
<PlusIcon className="aui-attachment-add-icon size-5 stroke-[1.5px]" />
</TooltipIconButton>
</DropdownMenuTrigger>
<DropdownMenuContent align="start" className="w-48 bg-background border-border">
<DropdownMenuContent align="start" className="w-72 bg-background border-border">
<DropdownMenuItem onSelect={handleChatAttachment} className="cursor-pointer">
<Paperclip className="size-4" />
<span>Add attachment</span>
<span>Add attachment to this chat</span>
</DropdownMenuItem>
<DropdownMenuItem onClick={handleFileUpload} className="cursor-pointer">
<Upload className="size-4" />
<span>Upload Documents</span>
<span>Upload documents to Search Space</span>
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>

View file

@ -64,7 +64,7 @@ const DesktopNav = ({ navItems, isScrolled }: any) => {
href="/"
className="flex flex-1 flex-row items-center gap-0.5 hover:opacity-80 transition-opacity"
>
<Logo className="h-8 w-8 rounded-md" />
<Logo className="h-8 w-8 rounded-md" disableLink />
<span className="dark:text-white/90 text-gray-800 text-lg font-bold">SurfSense</span>
</Link>
<div className="hidden flex-1 flex-row items-center justify-center space-x-2 text-sm font-medium text-zinc-600 transition duration-200 hover:text-zinc-800 lg:flex lg:space-x-2">
@ -145,7 +145,7 @@ const MobileNav = ({ navItems, isScrolled }: any) => {
href="/"
className="flex flex-row items-center gap-2 hover:opacity-80 transition-opacity"
>
<Logo className="h-8 w-8 rounded-md" />
<Logo className="h-8 w-8 rounded-md" disableLink />
<span className="dark:text-white/90 text-gray-800 text-lg font-bold">SurfSense</span>
</Link>
<button

View file

@ -2,9 +2,13 @@
import { useCallback, useState } from "react";
import type {
GlobalImageGenConfig,
GlobalNewLLMConfig,
ImageGenerationConfig,
NewLLMConfigPublic,
} from "@/contracts/types/new-llm-config.types";
import { ImageConfigSidebar } from "./image-config-sidebar";
import { ImageModelSelector } from "./image-model-selector";
import { ModelConfigSidebar } from "./model-config-sidebar";
import { ModelSelector } from "./model-selector";
@ -13,6 +17,7 @@ interface ChatHeaderProps {
}
export function ChatHeader({ searchSpaceId }: ChatHeaderProps) {
// LLM config sidebar state
const [sidebarOpen, setSidebarOpen] = useState(false);
const [selectedConfig, setSelectedConfig] = useState<
NewLLMConfigPublic | GlobalNewLLMConfig | null
@ -20,6 +25,15 @@ export function ChatHeader({ searchSpaceId }: ChatHeaderProps) {
const [isGlobal, setIsGlobal] = useState(false);
const [sidebarMode, setSidebarMode] = useState<"create" | "edit" | "view">("view");
// Image config sidebar state
const [imageSidebarOpen, setImageSidebarOpen] = useState(false);
const [selectedImageConfig, setSelectedImageConfig] = useState<
ImageGenerationConfig | GlobalImageGenConfig | null
>(null);
const [isImageGlobal, setIsImageGlobal] = useState(false);
const [imageSidebarMode, setImageSidebarMode] = useState<"create" | "edit" | "view">("view");
// LLM handlers
const handleEditConfig = useCallback(
(config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => {
setSelectedConfig(config);
@ -39,15 +53,36 @@ export function ChatHeader({ searchSpaceId }: ChatHeaderProps) {
const handleSidebarClose = useCallback((open: boolean) => {
setSidebarOpen(open);
if (!open) {
// Reset state when closing
setSelectedConfig(null);
}
if (!open) setSelectedConfig(null);
}, []);
// Image model handlers
const handleAddImageModel = useCallback(() => {
setSelectedImageConfig(null);
setIsImageGlobal(false);
setImageSidebarMode("create");
setImageSidebarOpen(true);
}, []);
const handleEditImageConfig = useCallback(
(config: ImageGenerationConfig | GlobalImageGenConfig, global: boolean) => {
setSelectedImageConfig(config);
setIsImageGlobal(global);
setImageSidebarMode(global ? "view" : "edit");
setImageSidebarOpen(true);
},
[]
);
const handleImageSidebarClose = useCallback((open: boolean) => {
setImageSidebarOpen(open);
if (!open) setSelectedImageConfig(null);
}, []);
return (
<div className="flex items-center gap-2">
<ModelSelector onEdit={handleEditConfig} onAddNew={handleAddNew} />
<ImageModelSelector onEdit={handleEditImageConfig} onAddNew={handleAddImageModel} />
<ModelConfigSidebar
open={sidebarOpen}
onOpenChange={handleSidebarClose}
@ -56,6 +91,14 @@ export function ChatHeader({ searchSpaceId }: ChatHeaderProps) {
searchSpaceId={searchSpaceId}
mode={sidebarMode}
/>
<ImageConfigSidebar
open={imageSidebarOpen}
onOpenChange={handleImageSidebarClose}
config={selectedImageConfig}
isGlobal={isImageGlobal}
searchSpaceId={searchSpaceId}
mode={imageSidebarMode}
/>
</div>
);
}

View file

@ -0,0 +1,522 @@
"use client";
import { useAtomValue } from "jotai";
import {
AlertCircle,
Check,
ChevronsUpDown,
Globe,
ImageIcon,
Key,
Shuffle,
X,
Zap,
} from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { createPortal } from "react-dom";
import { toast } from "sonner";
import {
createImageGenConfigMutationAtom,
updateImageGenConfigMutationAtom,
} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms";
import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms";
import { Alert, AlertDescription } from "@/components/ui/alert";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import {
Command,
CommandEmpty,
CommandGroup,
CommandInput,
CommandItem,
CommandList,
} from "@/components/ui/command";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Separator } from "@/components/ui/separator";
import { Spinner } from "@/components/ui/spinner";
import { IMAGE_GEN_MODELS, IMAGE_GEN_PROVIDERS } from "@/contracts/enums/image-gen-providers";
import type {
GlobalImageGenConfig,
ImageGenerationConfig,
} from "@/contracts/types/new-llm-config.types";
import { cn } from "@/lib/utils";
interface ImageConfigSidebarProps {
open: boolean;
onOpenChange: (open: boolean) => void;
config: ImageGenerationConfig | GlobalImageGenConfig | null;
isGlobal: boolean;
searchSpaceId: number;
mode: "create" | "edit" | "view";
}
const INITIAL_FORM = {
name: "",
description: "",
provider: "",
model_name: "",
api_key: "",
api_base: "",
api_version: "",
};
export function ImageConfigSidebar({
open,
onOpenChange,
config,
isGlobal,
searchSpaceId,
mode,
}: ImageConfigSidebarProps) {
const [isSubmitting, setIsSubmitting] = useState(false);
const [mounted, setMounted] = useState(false);
const [formData, setFormData] = useState(INITIAL_FORM);
const [modelComboboxOpen, setModelComboboxOpen] = useState(false);
useEffect(() => {
setMounted(true);
}, []);
// Reset form when opening
useEffect(() => {
if (open) {
if (mode === "edit" && config && !isGlobal) {
setFormData({
name: config.name || "",
description: config.description || "",
provider: config.provider || "",
model_name: config.model_name || "",
api_key: (config as ImageGenerationConfig).api_key || "",
api_base: config.api_base || "",
api_version: config.api_version || "",
});
} else if (mode === "create") {
setFormData(INITIAL_FORM);
}
}
}, [open, mode, config, isGlobal]);
// Mutations
const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom);
const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom);
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
// Escape key
useEffect(() => {
const handleEscape = (e: KeyboardEvent) => {
if (e.key === "Escape" && open) onOpenChange(false);
};
window.addEventListener("keydown", handleEscape);
return () => window.removeEventListener("keydown", handleEscape);
}, [open, onOpenChange]);
const isAutoMode = config && "is_auto_mode" in config && config.is_auto_mode;
const suggestedModels = useMemo(() => {
if (!formData.provider) return [];
return IMAGE_GEN_MODELS.filter((m) => m.provider === formData.provider);
}, [formData.provider]);
const getTitle = () => {
if (mode === "create") return "Add Image Model";
if (isAutoMode) return "Auto Mode (Load Balanced)";
if (isGlobal) return "View Global Image Model";
return "Edit Image Model";
};
const handleSubmit = useCallback(async () => {
setIsSubmitting(true);
try {
if (mode === "create") {
const result = await createConfig({
name: formData.name,
provider: formData.provider,
model_name: formData.model_name,
api_key: formData.api_key,
api_base: formData.api_base || undefined,
api_version: formData.api_version || undefined,
description: formData.description || undefined,
search_space_id: searchSpaceId,
});
// Set as active image model
if (result?.id) {
await updatePreferences({
search_space_id: searchSpaceId,
data: { image_generation_config_id: result.id },
});
}
toast.success("Image model created and assigned!");
onOpenChange(false);
} else if (!isGlobal && config) {
await updateConfig({
id: config.id,
data: {
name: formData.name,
description: formData.description || undefined,
provider: formData.provider,
model_name: formData.model_name,
api_key: formData.api_key,
api_base: formData.api_base || undefined,
api_version: formData.api_version || undefined,
},
});
toast.success("Image model updated!");
onOpenChange(false);
}
} catch (error) {
console.error("Failed to save image config:", error);
toast.error("Failed to save image model");
} finally {
setIsSubmitting(false);
}
}, [mode, isGlobal, config, formData, searchSpaceId, createConfig, updateConfig, updatePreferences, onOpenChange]);
const handleUseGlobalConfig = useCallback(async () => {
if (!config || !isGlobal) return;
setIsSubmitting(true);
try {
await updatePreferences({
search_space_id: searchSpaceId,
data: { image_generation_config_id: config.id },
});
toast.success(`Now using ${config.name}`);
onOpenChange(false);
} catch (error) {
console.error("Failed to set image model:", error);
toast.error("Failed to set image model");
} finally {
setIsSubmitting(false);
}
}, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]);
const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key;
const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider);
if (!mounted) return null;
const sidebarContent = (
<AnimatePresence>
{open && (
<>
{/* Backdrop */}
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 z-50 bg-black/20 backdrop-blur-sm"
onClick={() => onOpenChange(false)}
/>
{/* Sidebar */}
<motion.div
initial={{ x: "100%", opacity: 0 }}
animate={{ x: 0, opacity: 1 }}
exit={{ x: "100%", opacity: 0 }}
transition={{ type: "spring", damping: 30, stiffness: 300 }}
className={cn(
"fixed right-0 top-0 z-50 h-full w-full sm:w-[480px] lg:w-[540px]",
"bg-background border-l border-border/50 shadow-2xl",
"flex flex-col"
)}
>
{/* Header */}
<div
className={cn(
"flex items-center justify-between px-6 py-4 border-b border-border/50",
isAutoMode
? "bg-gradient-to-r from-violet-500/10 to-purple-500/10"
: "bg-gradient-to-r from-teal-500/10 to-cyan-500/10"
)}
>
<div className="flex items-center gap-3">
<div
className={cn(
"flex items-center justify-center size-10 rounded-xl",
isAutoMode
? "bg-gradient-to-br from-violet-500 to-purple-600"
: "bg-gradient-to-br from-teal-500 to-cyan-600"
)}
>
{isAutoMode ? (
<Shuffle className="size-5 text-white" />
) : (
<ImageIcon className="size-5 text-white" />
)}
</div>
<div>
<h2 className="text-base sm:text-lg font-semibold">{getTitle()}</h2>
<div className="flex items-center gap-2 mt-0.5">
{isAutoMode ? (
<Badge
variant="secondary"
className="gap-1 text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
>
<Zap className="size-3" />
Recommended
</Badge>
) : isGlobal ? (
<Badge variant="secondary" className="gap-1 text-xs">
<Globe className="size-3" />
Global
</Badge>
) : null}
{config && !isAutoMode && (
<span className="text-xs text-muted-foreground">{config.model_name}</span>
)}
</div>
</div>
</div>
<Button
variant="ghost"
size="icon"
onClick={() => onOpenChange(false)}
className="h-8 w-8 rounded-full"
>
<X className="h-4 w-4" />
<span className="sr-only">Close</span>
</Button>
</div>
{/* Content */}
<div className="flex-1 overflow-y-auto">
<div className="p-6">
{/* Auto mode */}
{isAutoMode && (
<>
<Alert className="mb-6 border-violet-500/30 bg-violet-500/5">
<Shuffle className="size-4 text-violet-500" />
<AlertDescription className="text-sm text-violet-700 dark:text-violet-400">
Auto mode distributes image generation requests across all configured providers for optimal performance and rate limit protection.
</AlertDescription>
</Alert>
<div className="flex gap-3 pt-4 border-t border-border/50">
<Button variant="outline" className="flex-1" onClick={() => onOpenChange(false)}>
Close
</Button>
<Button
className="flex-1 gap-2 bg-gradient-to-r from-violet-500 to-purple-600 hover:from-violet-600 hover:to-purple-700"
onClick={handleUseGlobalConfig}
disabled={isSubmitting}
>
{isSubmitting ? "Loading..." : "Use Auto Mode"}
</Button>
</div>
</>
)}
{/* Global config (read-only) */}
{isGlobal && !isAutoMode && config && (
<>
<Alert className="mb-6 border-amber-500/30 bg-amber-500/5">
<AlertCircle className="size-4 text-amber-500" />
<AlertDescription className="text-sm text-amber-700 dark:text-amber-400">
Global configurations are read-only. To customize, create a new model.
</AlertDescription>
</Alert>
<div className="space-y-4">
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-1.5">
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">Name</div>
<p className="text-sm font-medium">{config.name}</p>
</div>
{config.description && (
<div className="space-y-1.5">
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">Description</div>
<p className="text-sm text-muted-foreground">{config.description}</p>
</div>
)}
</div>
<Separator />
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-1.5">
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">Provider</div>
<p className="text-sm font-medium">{config.provider}</p>
</div>
<div className="space-y-1.5">
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">Model</div>
<p className="text-sm font-medium font-mono">{config.model_name}</p>
</div>
</div>
</div>
<div className="flex gap-3 pt-6 border-t border-border/50 mt-6">
<Button variant="outline" className="flex-1" onClick={() => onOpenChange(false)}>
Close
</Button>
<Button className="flex-1 gap-2" onClick={handleUseGlobalConfig} disabled={isSubmitting}>
{isSubmitting ? "Loading..." : "Use This Model"}
</Button>
</div>
</>
)}
{/* Create / Edit form */}
{(mode === "create" || (mode === "edit" && !isGlobal)) && (
<div className="space-y-4">
{/* Name */}
<div className="space-y-2">
<Label className="text-sm font-medium">Name *</Label>
<Input
placeholder="e.g., My DALL-E 3"
value={formData.name}
onChange={(e) => setFormData((p) => ({ ...p, name: e.target.value }))}
/>
</div>
{/* Description */}
<div className="space-y-2">
<Label className="text-sm font-medium">Description</Label>
<Input
placeholder="Optional description"
value={formData.description}
onChange={(e) => setFormData((p) => ({ ...p, description: e.target.value }))}
/>
</div>
<Separator />
{/* Provider */}
<div className="space-y-2">
<Label className="text-sm font-medium">Provider *</Label>
<Select
value={formData.provider}
onValueChange={(val) => setFormData((p) => ({ ...p, provider: val, model_name: "" }))}
>
<SelectTrigger>
<SelectValue placeholder="Select a provider" />
</SelectTrigger>
<SelectContent>
{IMAGE_GEN_PROVIDERS.map((p) => (
<SelectItem key={p.value} value={p.value}>
<div className="flex flex-col">
<span className="font-medium">{p.label}</span>
<span className="text-xs text-muted-foreground">{p.example}</span>
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
{/* Model Name */}
<div className="space-y-2">
<Label className="text-sm font-medium">Model Name *</Label>
{suggestedModels.length > 0 ? (
<Popover open={modelComboboxOpen} onOpenChange={setModelComboboxOpen}>
<PopoverTrigger asChild>
<Button variant="outline" role="combobox" className="w-full justify-between font-normal">
{formData.model_name || "Select or type a model..."}
<ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" />
</Button>
</PopoverTrigger>
<PopoverContent className="w-full p-0" align="start">
<Command>
<CommandInput
placeholder="Search or type model..."
value={formData.model_name}
onValueChange={(val) => setFormData((p) => ({ ...p, model_name: val }))}
/>
<CommandList>
<CommandEmpty>
<span className="text-xs text-muted-foreground">Type a custom model name</span>
</CommandEmpty>
<CommandGroup>
{suggestedModels.map((m) => (
<CommandItem
key={m.value}
value={m.value}
onSelect={() => {
setFormData((p) => ({ ...p, model_name: m.value }));
setModelComboboxOpen(false);
}}
>
<Check className={cn("mr-2 h-4 w-4", formData.model_name === m.value ? "opacity-100" : "opacity-0")} />
<span className="font-mono text-sm">{m.value}</span>
<span className="ml-2 text-xs text-muted-foreground">{m.label}</span>
</CommandItem>
))}
</CommandGroup>
</CommandList>
</Command>
</PopoverContent>
</Popover>
) : (
<Input
placeholder="e.g., dall-e-3"
value={formData.model_name}
onChange={(e) => setFormData((p) => ({ ...p, model_name: e.target.value }))}
/>
)}
</div>
{/* API Key */}
<div className="space-y-2">
<Label className="text-sm font-medium flex items-center gap-1.5">
<Key className="h-3.5 w-3.5" /> API Key *
</Label>
<Input
type="password"
placeholder="sk-..."
value={formData.api_key}
onChange={(e) => setFormData((p) => ({ ...p, api_key: e.target.value }))}
/>
</div>
{/* API Base */}
<div className="space-y-2">
<Label className="text-sm font-medium">API Base URL</Label>
<Input
placeholder={selectedProvider?.apiBase || "Optional"}
value={formData.api_base}
onChange={(e) => setFormData((p) => ({ ...p, api_base: e.target.value }))}
/>
</div>
{/* Azure API Version */}
{formData.provider === "AZURE_OPENAI" && (
<div className="space-y-2">
<Label className="text-sm font-medium">API Version (Azure)</Label>
<Input
placeholder="2024-02-15-preview"
value={formData.api_version}
onChange={(e) => setFormData((p) => ({ ...p, api_version: e.target.value }))}
/>
</div>
)}
{/* Actions */}
<div className="flex gap-3 pt-4 border-t">
<Button variant="outline" className="flex-1" onClick={() => onOpenChange(false)}>
Cancel
</Button>
<Button
className="flex-1 gap-2 bg-gradient-to-r from-teal-500 to-cyan-600 hover:from-teal-600 hover:to-cyan-700"
onClick={handleSubmit}
disabled={isSubmitting || !isFormValid}
>
{isSubmitting ? <Spinner size="sm" className="mr-2" /> : null}
{mode === "edit" ? "Save Changes" : "Create & Use"}
</Button>
</div>
</div>
)}
</div>
</div>
</motion.div>
</>
)}
</AnimatePresence>
);
return typeof document !== "undefined" ? createPortal(sidebarContent, document.body) : null;
}

View file

@ -0,0 +1,364 @@
"use client";
import { useAtomValue } from "jotai";
import {
Check,
ChevronDown,
ChevronRight,
Edit3,
Globe,
ImageIcon,
Plus,
Shuffle,
User,
} from "lucide-react";
import { useCallback, useMemo, useState } from "react";
import { toast } from "sonner";
import {
createImageGenConfigMutationAtom,
updateImageGenConfigMutationAtom,
} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms";
import {
globalImageGenConfigsAtom,
imageGenConfigsAtom,
} from "@/atoms/image-gen-config/image-gen-config-query.atoms";
import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms";
import { llmPreferencesAtom } from "@/atoms/new-llm-config/new-llm-config-query.atoms";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import {
Command,
CommandEmpty,
CommandGroup,
CommandInput,
CommandItem,
CommandList,
CommandSeparator,
} from "@/components/ui/command";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import { Spinner } from "@/components/ui/spinner";
import type {
GlobalImageGenConfig,
ImageGenerationConfig,
} from "@/contracts/types/new-llm-config.types";
import { cn } from "@/lib/utils";
interface ImageModelSelectorProps {
className?: string;
onAddNew?: () => void;
onEdit?: (config: ImageGenerationConfig | GlobalImageGenConfig, isGlobal: boolean) => void;
}
export function ImageModelSelector({ className, onAddNew, onEdit }: ImageModelSelectorProps) {
const [open, setOpen] = useState(false);
const [searchQuery, setSearchQuery] = useState("");
const { data: globalConfigs, isLoading: globalLoading } =
useAtomValue(globalImageGenConfigsAtom);
const { data: userConfigs, isLoading: userLoading } = useAtomValue(imageGenConfigsAtom);
const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom);
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
const isLoading = globalLoading || userLoading || prefsLoading;
const currentConfig = useMemo(() => {
if (!preferences) return null;
const id = preferences.image_generation_config_id;
if (id === null || id === undefined) return null;
const globalMatch = globalConfigs?.find((c) => c.id === id);
if (globalMatch) return globalMatch;
return userConfigs?.find((c) => c.id === id) ?? null;
}, [preferences, globalConfigs, userConfigs]);
const isCurrentAutoMode = useMemo(() => {
return currentConfig && "is_auto_mode" in currentConfig && currentConfig.is_auto_mode;
}, [currentConfig]);
const filteredGlobal = useMemo(() => {
if (!globalConfigs) return [];
if (!searchQuery) return globalConfigs;
const q = searchQuery.toLowerCase();
return globalConfigs.filter(
(c) =>
c.name.toLowerCase().includes(q) ||
c.model_name.toLowerCase().includes(q) ||
c.provider.toLowerCase().includes(q)
);
}, [globalConfigs, searchQuery]);
const filteredUser = useMemo(() => {
if (!userConfigs) return [];
if (!searchQuery) return userConfigs;
const q = searchQuery.toLowerCase();
return userConfigs.filter(
(c) =>
c.name.toLowerCase().includes(q) ||
c.model_name.toLowerCase().includes(q) ||
c.provider.toLowerCase().includes(q)
);
}, [userConfigs, searchQuery]);
const totalModels = (globalConfigs?.length ?? 0) + (userConfigs?.length ?? 0);
const handleSelect = useCallback(
async (configId: number) => {
if (currentConfig?.id === configId) {
setOpen(false);
return;
}
if (!searchSpaceId) {
toast.error("No search space selected");
return;
}
try {
await updatePreferences({
search_space_id: Number(searchSpaceId),
data: { image_generation_config_id: configId },
});
toast.success("Image model updated");
setOpen(false);
} catch {
toast.error("Failed to switch image model");
}
},
[currentConfig, searchSpaceId, updatePreferences]
);
// Don't render if no configs at all
if (!isLoading && totalModels === 0) {
return (
<Button
variant="outline"
size="sm"
onClick={onAddNew}
className={cn("h-8 gap-2 px-3 text-sm border-border/60", className)}
>
<Plus className="size-4 text-teal-600" />
<span className="hidden md:inline">Add Image Model</span>
</Button>
);
}
return (
<Popover open={open} onOpenChange={setOpen}>
<PopoverTrigger asChild>
<Button
variant="outline"
size="sm"
role="combobox"
aria-expanded={open}
className={cn("h-8 gap-2 px-3 text-sm border-border/60", className)}
>
{isLoading ? (
<Spinner size="sm" className="text-muted-foreground" />
) : currentConfig ? (
<>
{isCurrentAutoMode ? (
<Shuffle className="size-4 text-violet-500" />
) : (
<ImageIcon className="size-4 text-teal-500" />
)}
<span className="max-w-[100px] md:max-w-[120px] truncate hidden md:inline">
{currentConfig.name}
</span>
{isCurrentAutoMode ? (
<Badge
variant="secondary"
className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
>
Auto
</Badge>
) : (
<Badge
variant="secondary"
className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-teal-50 text-teal-700 dark:bg-teal-900/30 dark:text-teal-300"
>
Image
</Badge>
)}
</>
) : (
<>
<ImageIcon className="h-4 w-4 text-muted-foreground" />
<span className="text-muted-foreground hidden md:inline">Image Model</span>
</>
)}
<ChevronDown
className={cn(
"h-3.5 w-3.5 text-muted-foreground ml-1 shrink-0 transition-transform duration-200",
open && "rotate-180"
)}
/>
</Button>
</PopoverTrigger>
<PopoverContent
className="w-[280px] md:w-[360px] p-0 rounded-lg shadow-lg border-border/60"
align="start"
sideOffset={8}
>
<Command shouldFilter={false} className="rounded-lg">
{totalModels > 3 && (
<div className="flex items-center gap-1 md:gap-2 px-2 md:px-3 py-1.5 md:py-2">
<CommandInput
placeholder="Search image models..."
value={searchQuery}
onValueChange={setSearchQuery}
className="h-7 md:h-8 text-xs md:text-sm border-0 bg-transparent focus:ring-0 placeholder:text-muted-foreground/60"
/>
</div>
)}
<CommandList className="max-h-[300px] md:max-h-[400px] overflow-y-auto">
<CommandEmpty className="py-8 text-center">
<div className="flex flex-col items-center gap-2">
<ImageIcon className="size-8 text-muted-foreground" />
<p className="text-sm text-muted-foreground">No image models found</p>
</div>
</CommandEmpty>
{/* Global Image Gen Configs */}
{filteredGlobal.length > 0 && (
<CommandGroup>
<div className="flex items-center gap-2 px-3 py-2 text-xs font-semibold text-muted-foreground tracking-wider">
<Globe className="size-3.5" />
Global Image Models
</div>
{filteredGlobal.map((config) => {
const isSelected = currentConfig?.id === config.id;
const isAuto = "is_auto_mode" in config && config.is_auto_mode;
return (
<CommandItem
key={`g-${config.id}`}
value={`g-${config.id}`}
onSelect={() => handleSelect(config.id)}
className={cn(
"mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50",
isSelected && "bg-accent/80",
isAuto && "border border-violet-200 dark:border-violet-800/50"
)}
>
<div className="flex items-center gap-3 min-w-0 flex-1">
<div className="shrink-0">
{isAuto ? (
<Shuffle className="size-4 text-violet-500" />
) : (
<ImageIcon className="size-4 text-teal-500" />
)}
</div>
<div className="min-w-0 flex-1">
<div className="flex items-center gap-2">
<span className="font-medium truncate">{config.name}</span>
{isAuto && (
<Badge
variant="secondary"
className="text-[9px] px-1 py-0 h-3.5 bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-0"
>
Recommended
</Badge>
)}
{isSelected && <Check className="size-3.5 text-primary shrink-0" />}
</div>
<span className="text-xs text-muted-foreground truncate block">
{isAuto ? "Auto load balancing" : config.model_name}
</span>
</div>
{onEdit && (
<ChevronRight
className="size-3.5 text-muted-foreground shrink-0 opacity-0 group-hover:opacity-100 transition-opacity cursor-pointer"
onClick={(e) => {
e.stopPropagation();
setOpen(false);
onEdit(config, true);
}}
/>
)}
</div>
</CommandItem>
);
})}
</CommandGroup>
)}
{/* User Image Gen Configs */}
{filteredUser.length > 0 && (
<>
{filteredGlobal.length > 0 && <CommandSeparator className="my-1 bg-border/30" />}
<CommandGroup>
<div className="flex items-center gap-2 px-3 py-2 text-xs font-semibold text-muted-foreground tracking-wider">
<User className="size-3.5" />
Your Image Models
</div>
{filteredUser.map((config) => {
const isSelected = currentConfig?.id === config.id;
return (
<CommandItem
key={`u-${config.id}`}
value={`u-${config.id}`}
onSelect={() => handleSelect(config.id)}
className={cn(
"mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50",
isSelected && "bg-accent/80"
)}
>
<div className="flex items-center gap-3 min-w-0 flex-1">
<div className="shrink-0">
<ImageIcon className="size-4 text-teal-500" />
</div>
<div className="min-w-0 flex-1">
<div className="flex items-center gap-2">
<span className="font-medium truncate">{config.name}</span>
{isSelected && (
<Check className="size-3.5 text-primary shrink-0" />
)}
</div>
<span className="text-xs text-muted-foreground truncate block">
{config.model_name}
</span>
</div>
{onEdit && (
<Button
variant="ghost"
size="icon"
className="h-7 w-7 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={(e) => {
e.stopPropagation();
setOpen(false);
onEdit(config, false);
}}
>
<Edit3 className="size-3.5 text-muted-foreground" />
</Button>
)}
</div>
</CommandItem>
);
})}
</CommandGroup>
</>
)}
{/* Add New */}
{onAddNew && (
<div className="p-2 bg-muted/20">
<Button
variant="ghost"
size="sm"
className="w-full justify-start gap-2 h-9 rounded-lg hover:bg-accent/50"
onClick={() => {
setOpen(false);
onAddNew();
}}
>
<Plus className="size-4 text-teal-600" />
<span className="text-sm font-medium">Add Image Model</span>
</Button>
</div>
)}
</CommandList>
</Command>
</PopoverContent>
</Popover>
);
}

View file

@ -392,8 +392,8 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
</CommandGroup>
)}
{/* Add New Config Button */}
<div className="p-2 bg-muted/20">
{/* Add New Config Button */}
<div className="p-2 bg-muted/20">
<Button
variant="ghost"
size="sm"

View file

@ -12,11 +12,11 @@ const demoPlans = [
features: [
"Open source on GitHub",
"Upload and chat with 300+ pages of content",
"Connects with 8 popular sources, like Drive and Notion.",
"Connects with 8 popular sources, like Drive and Notion",
"Includes limited access to ChatGPT, Claude, and DeepSeek models",
"Supports 100+ more LLMs, including Gemini, Llama and many more.",
"50+ File extensions supported.",
"Generate podcasts in seconds.",
"Supports 100+ more LLMs, including Gemini, Llama and many more",
"50+ File extensions supported",
"Generate podcasts in seconds",
"Cross-Browser Extension for dynamic webpages including authenticated content",
"Community support on Discord",
],
@ -33,8 +33,8 @@ const demoPlans = [
billingText: "billed annually",
features: [
"Everything in Free",
"Upload and chat with 5,000+ pages of content",
"Connects with 15+ external sources, like Slack and Airtable.",
"Upload and chat with 5,000+ pages of content per user",
"Connects with 15+ external sources, like Slack and Airtable",
"Includes extended access to ChatGPT, Claude, and DeepSeek models",
"Collaboration and commenting features",
"Shared BYOK (Bring Your Own Key)",
@ -42,7 +42,7 @@ const demoPlans = [
"Planned: Centralized billing",
"Priority support",
],
description: "The AIknowledge base for individuals and teams",
description: "The AI knowledge base for individuals and teams",
buttonText: "Upgrade",
href: "/contact",
isPopular: true,

View file

@ -0,0 +1,692 @@
"use client";
import { useAtomValue } from "jotai";
import {
AlertCircle,
Check,
ChevronsUpDown,
Clock,
Edit3,
ImageIcon,
Key,
Plus,
RefreshCw,
Shuffle,
Sparkles,
Trash2,
Wand2,
} from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useCallback, useEffect, useState } from "react";
import { toast } from "sonner";
import {
createImageGenConfigMutationAtom,
deleteImageGenConfigMutationAtom,
updateImageGenConfigMutationAtom,
} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms";
import {
globalImageGenConfigsAtom,
imageGenConfigsAtom,
} from "@/atoms/image-gen-config/image-gen-config-query.atoms";
import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms";
import { llmPreferencesAtom } from "@/atoms/new-llm-config/new-llm-config-query.atoms";
import { Alert, AlertDescription } from "@/components/ui/alert";
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import {
Command,
CommandEmpty,
CommandGroup,
CommandInput,
CommandItem,
CommandList,
} from "@/components/ui/command";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Separator } from "@/components/ui/separator";
import { Spinner } from "@/components/ui/spinner";
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
import {
IMAGE_GEN_PROVIDERS,
getImageGenModelsByProvider,
} from "@/contracts/enums/image-gen-providers";
import type { ImageGenerationConfig } from "@/contracts/types/new-llm-config.types";
import { cn } from "@/lib/utils";
interface ImageModelManagerProps {
searchSpaceId: number;
}
const container = {
hidden: { opacity: 0 },
show: { opacity: 1, transition: { staggerChildren: 0.05 } },
};
const item = {
hidden: { opacity: 0, y: 20 },
show: { opacity: 1, y: 0 },
};
export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
// Image gen config atoms
const { mutateAsync: createConfig, isPending: isCreating, error: createError } =
useAtomValue(createImageGenConfigMutationAtom);
const { mutateAsync: updateConfig, isPending: isUpdating, error: updateError } =
useAtomValue(updateImageGenConfigMutationAtom);
const { mutateAsync: deleteConfig, isPending: isDeleting, error: deleteError } =
useAtomValue(deleteImageGenConfigMutationAtom);
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
const { data: userConfigs, isFetching: configsLoading, error: fetchError, refetch: refreshConfigs } =
useAtomValue(imageGenConfigsAtom);
const { data: globalConfigs = [], isFetching: globalLoading } =
useAtomValue(globalImageGenConfigsAtom);
const { data: preferences = {}, isFetching: prefsLoading } = useAtomValue(llmPreferencesAtom);
// Local state
const [isDialogOpen, setIsDialogOpen] = useState(false);
const [editingConfig, setEditingConfig] = useState<ImageGenerationConfig | null>(null);
const [configToDelete, setConfigToDelete] = useState<ImageGenerationConfig | null>(null);
// Preference state
const [selectedPrefId, setSelectedPrefId] = useState<string | number>(
preferences.image_generation_config_id ?? ""
);
const [hasPrefChanges, setHasPrefChanges] = useState(false);
const [isSavingPref, setIsSavingPref] = useState(false);
useEffect(() => {
setSelectedPrefId(preferences.image_generation_config_id ?? "");
setHasPrefChanges(false);
}, [preferences]);
const isSubmitting = isCreating || isUpdating;
const isLoading = configsLoading || globalLoading || prefsLoading;
const errors = [createError, updateError, deleteError, fetchError].filter(Boolean) as Error[];
// Form state for create/edit dialog
const [formData, setFormData] = useState({
name: "",
description: "",
provider: "",
custom_provider: "",
model_name: "",
api_key: "",
api_base: "",
api_version: "",
});
const [modelComboboxOpen, setModelComboboxOpen] = useState(false);
const resetForm = () => {
setFormData({
name: "",
description: "",
provider: "",
custom_provider: "",
model_name: "",
api_key: "",
api_base: "",
api_version: "",
});
};
const handleFormSubmit = useCallback(async () => {
if (!formData.name || !formData.provider || !formData.model_name || !formData.api_key) {
toast.error("Please fill in all required fields");
return;
}
try {
if (editingConfig) {
await updateConfig({
id: editingConfig.id,
data: {
name: formData.name,
description: formData.description || undefined,
provider: formData.provider as any,
custom_provider: formData.custom_provider || undefined,
model_name: formData.model_name,
api_key: formData.api_key,
api_base: formData.api_base || undefined,
api_version: formData.api_version || undefined,
},
});
} else {
const result = await createConfig({
name: formData.name,
description: formData.description || undefined,
provider: formData.provider as any,
custom_provider: formData.custom_provider || undefined,
model_name: formData.model_name,
api_key: formData.api_key,
api_base: formData.api_base || undefined,
api_version: formData.api_version || undefined,
search_space_id: searchSpaceId,
});
// Auto-assign newly created config
if (result?.id) {
await updatePreferences({
search_space_id: searchSpaceId,
data: { image_generation_config_id: result.id },
});
}
}
setIsDialogOpen(false);
setEditingConfig(null);
resetForm();
} catch {
// Error handled by mutation
}
}, [editingConfig, formData, searchSpaceId, createConfig, updateConfig, updatePreferences]);
const handleDelete = async () => {
if (!configToDelete) return;
try {
await deleteConfig(configToDelete.id);
setConfigToDelete(null);
} catch {
// Error handled by mutation
}
};
const openEditDialog = (config: ImageGenerationConfig) => {
setEditingConfig(config);
setFormData({
name: config.name,
description: config.description || "",
provider: config.provider,
custom_provider: config.custom_provider || "",
model_name: config.model_name,
api_key: config.api_key,
api_base: config.api_base || "",
api_version: config.api_version || "",
});
setIsDialogOpen(true);
};
const openNewDialog = () => {
setEditingConfig(null);
resetForm();
setIsDialogOpen(true);
};
const handlePrefChange = (value: string) => {
const newVal = value === "unassigned" ? "" : parseInt(value);
setSelectedPrefId(newVal);
setHasPrefChanges(newVal !== (preferences.image_generation_config_id ?? ""));
};
const handleSavePref = async () => {
setIsSavingPref(true);
try {
await updatePreferences({
search_space_id: searchSpaceId,
data: {
image_generation_config_id:
typeof selectedPrefId === "string"
? selectedPrefId ? parseInt(selectedPrefId) : undefined
: selectedPrefId,
},
});
setHasPrefChanges(false);
toast.success("Image generation model preference saved!");
} catch {
toast.error("Failed to save preference");
} finally {
setIsSavingPref(false);
}
};
const allConfigs = [
...globalConfigs.map((c) => ({ ...c, _source: "global" as const })),
...(userConfigs ?? []).map((c) => ({ ...c, _source: "user" as const })),
];
const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider);
const suggestedModels = getImageGenModelsByProvider(formData.provider);
return (
<div className="space-y-4 md:space-y-6">
{/* Header */}
<div className="flex flex-col space-y-4 sm:flex-row sm:items-center sm:justify-between sm:space-y-0">
<Button
variant="outline"
size="sm"
onClick={() => refreshConfigs()}
disabled={isLoading}
className="flex items-center gap-2 text-xs md:text-sm h-8 md:h-9"
>
<RefreshCw className={cn("h-3 w-3 md:h-4 md:w-4", configsLoading && "animate-spin")} />
Refresh
</Button>
</div>
{/* Errors */}
<AnimatePresence>
{errors.map((err) => (
<motion.div key={err?.message} initial={{ opacity: 0, y: -10 }} animate={{ opacity: 1, y: 0 }} exit={{ opacity: 0, y: -10 }}>
<Alert variant="destructive" className="py-3">
<AlertCircle className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
<AlertDescription className="text-xs md:text-sm">{err?.message}</AlertDescription>
</Alert>
</motion.div>
))}
</AnimatePresence>
{/* Global info */}
{globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && (
<Alert className="border-teal-500/30 bg-teal-500/5 py-3">
<Sparkles className="h-3 w-3 md:h-4 md:w-4 text-teal-600 dark:text-teal-400 shrink-0" />
<AlertDescription className="text-teal-800 dark:text-teal-200 text-xs md:text-sm">
<span className="font-medium">
{globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length} global image model(s)
</span>{" "}
available from your administrator.
</AlertDescription>
</Alert>
)}
{/* Active Preference Card */}
{!isLoading && allConfigs.length > 0 && (
<motion.div initial={{ opacity: 0, y: 10 }} animate={{ opacity: 1, y: 0 }}>
<Card className="border-l-4 border-l-teal-500">
<CardHeader className="pb-2 px-3 md:px-6 pt-3 md:pt-6">
<div className="flex items-center gap-2 md:gap-3">
<div className="p-1.5 md:p-2 rounded-lg bg-teal-100 text-teal-800">
<ImageIcon className="w-4 h-4 md:w-5 md:h-5" />
</div>
<div>
<CardTitle className="text-base md:text-lg">Active Image Model</CardTitle>
<CardDescription className="text-xs md:text-sm">
Select which model to use for image generation
</CardDescription>
</div>
</div>
</CardHeader>
<CardContent className="space-y-3 px-3 md:px-6 pb-3 md:pb-6">
<Select
value={selectedPrefId?.toString() || "unassigned"}
onValueChange={handlePrefChange}
>
<SelectTrigger className="h-9 md:h-10 text-xs md:text-sm">
<SelectValue placeholder="Select an image model" />
</SelectTrigger>
<SelectContent>
<SelectItem value="unassigned">
<span className="text-muted-foreground">Unassigned</span>
</SelectItem>
{globalConfigs.length > 0 && (
<>
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">Global</div>
{globalConfigs.map((c) => {
const isAuto = "is_auto_mode" in c && c.is_auto_mode;
return (
<SelectItem key={`g-${c.id}`} value={c.id.toString()}>
<div className="flex items-center gap-2">
{isAuto ? (
<Badge variant="outline" className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200">
<Shuffle className="size-3 mr-1" />AUTO
</Badge>
) : (
<Badge variant="outline" className="text-xs bg-teal-50 text-teal-700 dark:bg-teal-900/30 dark:text-teal-300 border-teal-200">
{c.provider}
</Badge>
)}
<span>{c.name}</span>
</div>
</SelectItem>
);
})}
</>
)}
{(userConfigs?.length ?? 0) > 0 && (
<>
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">Your Models</div>
{userConfigs?.map((c) => (
<SelectItem key={`u-${c.id}`} value={c.id.toString()}>
<div className="flex items-center gap-2">
<Badge variant="outline" className="text-xs">{c.provider}</Badge>
<span>{c.name}</span>
<span className="text-muted-foreground">({c.model_name})</span>
</div>
</SelectItem>
))}
</>
)}
</SelectContent>
</Select>
{hasPrefChanges && (
<div className="flex gap-2 pt-1">
<Button size="sm" onClick={handleSavePref} disabled={isSavingPref} className="text-xs h-8">
{isSavingPref ? "Saving..." : "Save"}
</Button>
<Button size="sm" variant="outline" onClick={() => { setSelectedPrefId(preferences.image_generation_config_id ?? ""); setHasPrefChanges(false); }} className="text-xs h-8">
Reset
</Button>
</div>
)}
</CardContent>
</Card>
</motion.div>
)}
{/* Loading */}
{isLoading && (
<Card>
<CardContent className="flex items-center justify-center py-10">
<Spinner size="md" className="text-muted-foreground" />
</CardContent>
</Card>
)}
{/* User Configs */}
{!isLoading && (
<div className="space-y-4 md:space-y-6">
<div className="flex flex-col space-y-4 sm:flex-row sm:items-center sm:justify-between sm:space-y-0">
<h3 className="text-lg md:text-xl font-semibold tracking-tight">Your Image Models</h3>
<Button onClick={openNewDialog} className="flex items-center gap-2 text-xs md:text-sm h-8 md:h-9">
<Plus className="h-3 w-3 md:h-4 md:w-4" />
Add Image Model
</Button>
</div>
{(userConfigs?.length ?? 0) === 0 ? (
<Card className="border-dashed border-2 border-muted-foreground/25">
<CardContent className="flex flex-col items-center justify-center py-10 md:py-16 text-center">
<div className="rounded-full bg-gradient-to-br from-teal-500/10 to-cyan-500/10 p-4 md:p-6 mb-4">
<Wand2 className="h-8 w-8 md:h-12 md:w-12 text-teal-600 dark:text-teal-400" />
</div>
<h3 className="text-lg font-semibold mb-2">No Image Models Yet</h3>
<p className="text-xs md:text-sm text-muted-foreground max-w-sm mb-4">
Add your own image generation model (DALL-E 3, GPT Image 1, etc.)
</p>
<Button onClick={openNewDialog} size="lg" className="gap-2 text-xs md:text-sm">
<Plus className="h-3 w-3 md:h-4 md:w-4" />
Add First Image Model
</Button>
</CardContent>
</Card>
) : (
<motion.div variants={container} initial="hidden" animate="show" className="grid gap-4">
<AnimatePresence mode="popLayout">
{userConfigs?.map((config) => (
<motion.div key={config.id} variants={item} layout exit={{ opacity: 0, scale: 0.95 }}>
<Card className="group overflow-hidden hover:shadow-lg transition-all duration-300 border-muted-foreground/10 hover:border-teal-500/30">
<CardContent className="p-0">
<div className="flex">
<div className="w-1 md:w-1.5 bg-gradient-to-b from-teal-500/50 to-cyan-500/50 group-hover:from-teal-500 group-hover:to-cyan-500 transition-colors" />
<div className="flex-1 p-3 md:p-5">
<div className="flex items-start justify-between gap-2">
<div className="flex items-start gap-2 md:gap-4 flex-1 min-w-0">
<div className="flex h-10 w-10 md:h-12 md:w-12 items-center justify-center rounded-lg md:rounded-xl bg-gradient-to-br from-teal-500/10 to-cyan-500/10 shrink-0">
<ImageIcon className="h-5 w-5 md:h-6 md:w-6 text-teal-600 dark:text-teal-400" />
</div>
<div className="flex-1 min-w-0 space-y-2">
<div className="flex items-center gap-1.5 flex-wrap">
<h4 className="text-sm md:text-base font-semibold truncate">{config.name}</h4>
<Badge variant="secondary" className="text-[9px] md:text-[10px] px-1.5 py-0.5 bg-teal-500/10 text-teal-700 dark:text-teal-300 border-teal-500/20">
{config.provider}
</Badge>
</div>
<code className="text-[10px] md:text-xs font-mono text-muted-foreground bg-muted/50 px-1.5 py-0.5 rounded-md inline-block">
{config.model_name}
</code>
{config.description && (
<p className="text-[10px] md:text-xs text-muted-foreground line-clamp-1">{config.description}</p>
)}
<div className="flex items-center gap-1 text-[10px] md:text-xs text-muted-foreground pt-1">
<Clock className="h-2.5 w-2.5 md:h-3 md:w-3" />
{new Date(config.created_at).toLocaleDateString()}
</div>
</div>
</div>
<div className="flex items-center gap-0.5 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button variant="ghost" size="sm" onClick={() => openEditDialog(config)} className="h-7 w-7 p-0 text-muted-foreground hover:text-foreground">
<Edit3 className="h-3.5 w-3.5" />
</Button>
</TooltipTrigger>
<TooltipContent>Edit</TooltipContent>
</Tooltip>
</TooltipProvider>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button variant="ghost" size="sm" onClick={() => setConfigToDelete(config)} className="h-7 w-7 p-0 text-muted-foreground hover:text-destructive">
<Trash2 className="h-3.5 w-3.5" />
</Button>
</TooltipTrigger>
<TooltipContent>Delete</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
</div>
</div>
</div>
</CardContent>
</Card>
</motion.div>
))}
</AnimatePresence>
</motion.div>
)}
</div>
)}
{/* Create/Edit Dialog */}
<Dialog open={isDialogOpen} onOpenChange={(open) => { if (!open) { setIsDialogOpen(false); setEditingConfig(null); resetForm(); } }}>
<DialogContent className="max-w-lg max-h-[90vh] overflow-y-auto">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
{editingConfig ? <Edit3 className="w-5 h-5 text-teal-600" /> : <Plus className="w-5 h-5 text-teal-600" />}
{editingConfig ? "Edit Image Model" : "Add Image Model"}
</DialogTitle>
<DialogDescription>
{editingConfig ? "Update your image generation model" : "Configure a new image generation model (DALL-E 3, GPT Image 1, etc.)"}
</DialogDescription>
</DialogHeader>
<div className="space-y-4 pt-2">
{/* Name */}
<div className="space-y-2">
<Label className="text-sm font-medium">Name *</Label>
<Input
placeholder="e.g., My DALL-E 3"
value={formData.name}
onChange={(e) => setFormData((p) => ({ ...p, name: e.target.value }))}
/>
</div>
{/* Description */}
<div className="space-y-2">
<Label className="text-sm font-medium">Description</Label>
<Input
placeholder="Optional description"
value={formData.description}
onChange={(e) => setFormData((p) => ({ ...p, description: e.target.value }))}
/>
</div>
<Separator />
{/* Provider */}
<div className="space-y-2">
<Label className="text-sm font-medium">Provider *</Label>
<Select
value={formData.provider}
onValueChange={(val) => setFormData((p) => ({ ...p, provider: val, model_name: "" }))}
>
<SelectTrigger>
<SelectValue placeholder="Select a provider" />
</SelectTrigger>
<SelectContent>
{IMAGE_GEN_PROVIDERS.map((p) => (
<SelectItem key={p.value} value={p.value}>
<div className="flex flex-col">
<span className="font-medium">{p.label}</span>
<span className="text-xs text-muted-foreground">{p.example}</span>
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
{/* Model Name */}
<div className="space-y-2">
<Label className="text-sm font-medium">Model Name *</Label>
{suggestedModels.length > 0 ? (
<Popover open={modelComboboxOpen} onOpenChange={setModelComboboxOpen}>
<PopoverTrigger asChild>
<Button variant="outline" role="combobox" className="w-full justify-between font-normal">
{formData.model_name || "Select or type a model..."}
<ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" />
</Button>
</PopoverTrigger>
<PopoverContent className="w-full p-0" align="start">
<Command>
<CommandInput
placeholder="Search or type model name..."
value={formData.model_name}
onValueChange={(val) => setFormData((p) => ({ ...p, model_name: val }))}
/>
<CommandList>
<CommandEmpty>
<span className="text-xs text-muted-foreground">Type a custom model name</span>
</CommandEmpty>
<CommandGroup>
{suggestedModels.map((m) => (
<CommandItem
key={m.value}
value={m.value}
onSelect={() => {
setFormData((p) => ({ ...p, model_name: m.value }));
setModelComboboxOpen(false);
}}
>
<Check className={cn("mr-2 h-4 w-4", formData.model_name === m.value ? "opacity-100" : "opacity-0")} />
<span className="font-mono text-sm">{m.value}</span>
<span className="ml-2 text-xs text-muted-foreground">{m.label}</span>
</CommandItem>
))}
</CommandGroup>
</CommandList>
</Command>
</PopoverContent>
</Popover>
) : (
<Input
placeholder="e.g., dall-e-3"
value={formData.model_name}
onChange={(e) => setFormData((p) => ({ ...p, model_name: e.target.value }))}
/>
)}
</div>
{/* API Key */}
<div className="space-y-2">
<Label className="text-sm font-medium flex items-center gap-1.5">
<Key className="h-3.5 w-3.5" /> API Key *
</Label>
<Input
type="password"
placeholder="sk-..."
value={formData.api_key}
onChange={(e) => setFormData((p) => ({ ...p, api_key: e.target.value }))}
/>
</div>
{/* API Base (optional) */}
<div className="space-y-2">
<Label className="text-sm font-medium">API Base URL</Label>
<Input
placeholder={selectedProvider?.apiBase || "Optional"}
value={formData.api_base}
onChange={(e) => setFormData((p) => ({ ...p, api_base: e.target.value }))}
/>
</div>
{/* API Version (Azure) */}
{formData.provider === "AZURE_OPENAI" && (
<div className="space-y-2">
<Label className="text-sm font-medium">API Version (Azure)</Label>
<Input
placeholder="2024-02-15-preview"
value={formData.api_version}
onChange={(e) => setFormData((p) => ({ ...p, api_version: e.target.value }))}
/>
</div>
)}
{/* Actions */}
<div className="flex gap-3 pt-4 border-t">
<Button
variant="outline"
className="flex-1"
onClick={() => { setIsDialogOpen(false); setEditingConfig(null); resetForm(); }}
>
Cancel
</Button>
<Button
className="flex-1"
onClick={handleFormSubmit}
disabled={isSubmitting || !formData.name || !formData.provider || !formData.model_name || !formData.api_key}
>
{isSubmitting ? <Spinner size="sm" className="mr-2" /> : null}
{editingConfig ? "Save Changes" : "Create & Use"}
</Button>
</div>
</div>
</DialogContent>
</Dialog>
{/* Delete Confirmation */}
<AlertDialog open={!!configToDelete} onOpenChange={(open) => !open && setConfigToDelete(null)}>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle className="flex items-center gap-2">
<Trash2 className="h-5 w-5 text-destructive" />
Delete Image Model
</AlertDialogTitle>
<AlertDialogDescription>
Are you sure you want to delete <span className="font-semibold text-foreground">{configToDelete?.name}</span>?
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel disabled={isDeleting}>Cancel</AlertDialogCancel>
<AlertDialogAction onClick={handleDelete} disabled={isDeleting} className="bg-destructive text-destructive-foreground hover:bg-destructive/90">
{isDeleting ? <><Spinner size="sm" className="mr-2" />Deleting</> : <><Trash2 className="mr-2 h-4 w-4" />Delete</>}
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</div>
);
}

View file

@ -255,15 +255,15 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
</Alert>
)}
{/* Role Assignment Cards */}
{availableConfigs.length > 0 && (
<div className="grid gap-4 md:gap-6">
{Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => {
const IconComponent = role.icon;
const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments];
const assignedConfig = availableConfigs.find(
(config) => config.id === currentAssignment
);
{/* Role Assignment Cards */}
{availableConfigs.length > 0 && (
<div className="grid gap-4 md:gap-6">
{Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => {
const IconComponent = role.icon;
const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments];
const assignedConfig = availableConfigs.find(
(config) => config.id === currentAssignment
);
return (
<motion.div
@ -294,100 +294,100 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
</div>
</CardHeader>
<CardContent className="space-y-3 md:space-y-4 px-3 md:px-6 pb-3 md:pb-6">
<div className="space-y-1.5 md:space-y-2">
<Label className="text-xs md:text-sm font-medium">
Assign LLM Configuration:
</Label>
<Select
value={currentAssignment?.toString() || "unassigned"}
onValueChange={(value) => handleRoleAssignment(`${key}_llm_id`, value)}
>
<SelectTrigger className="h-9 md:h-10 text-xs md:text-sm">
<SelectValue placeholder="Select an LLM configuration" />
</SelectTrigger>
<SelectContent>
<SelectItem value="unassigned">
<span className="text-muted-foreground">Unassigned</span>
</SelectItem>
<div className="space-y-1.5 md:space-y-2">
<Label className="text-xs md:text-sm font-medium">
Assign LLM Configuration:
</Label>
<Select
value={currentAssignment?.toString() || "unassigned"}
onValueChange={(value) => handleRoleAssignment(`${key}_llm_id`, value)}
>
<SelectTrigger className="h-9 md:h-10 text-xs md:text-sm">
<SelectValue placeholder="Select an LLM configuration" />
</SelectTrigger>
<SelectContent>
<SelectItem value="unassigned">
<span className="text-muted-foreground">Unassigned</span>
</SelectItem>
{/* Global Configurations */}
{globalConfigs.length > 0 && (
<>
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
Global Configurations
</div>
{globalConfigs.map((config) => {
const isAutoMode =
"is_auto_mode" in config && config.is_auto_mode;
return (
<SelectItem key={config.id} value={config.id.toString()}>
<div className="flex items-center gap-2">
{isAutoMode ? (
<Badge
variant="outline"
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200 dark:border-violet-700"
>
<Shuffle className="size-3 mr-1" />
AUTO
</Badge>
) : (
<Badge variant="outline" className="text-xs">
{config.provider}
</Badge>
)}
<span>{config.name}</span>
{!isAutoMode && (
<span className="text-muted-foreground">
({config.model_name})
</span>
)}
{isAutoMode ? (
<Badge
variant="secondary"
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
>
Recommended
</Badge>
) : (
<Badge variant="secondary" className="text-xs">
🌐 Global
</Badge>
)}
</div>
</SelectItem>
);
})}
</>
)}
{/* Global Configurations */}
{globalConfigs.length > 0 && (
<>
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
Global Configurations
</div>
{globalConfigs.map((config) => {
const isAutoMode =
"is_auto_mode" in config && config.is_auto_mode;
return (
<SelectItem key={config.id} value={config.id.toString()}>
<div className="flex items-center gap-2">
{isAutoMode ? (
<Badge
variant="outline"
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200 dark:border-violet-700"
>
<Shuffle className="size-3 mr-1" />
AUTO
</Badge>
) : (
<Badge variant="outline" className="text-xs">
{config.provider}
</Badge>
)}
<span>{config.name}</span>
{!isAutoMode && (
<span className="text-muted-foreground">
({config.model_name})
</span>
)}
{isAutoMode ? (
<Badge
variant="secondary"
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
>
Recommended
</Badge>
) : (
<Badge variant="secondary" className="text-xs">
🌐 Global
</Badge>
)}
</div>
</SelectItem>
);
})}
</>
)}
{/* Custom Configurations */}
{newLLMConfigs.length > 0 && (
<>
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
Your Configurations
</div>
{newLLMConfigs
.filter(
(config) => config.id && config.id.toString().trim() !== ""
)
.map((config) => (
<SelectItem key={config.id} value={config.id.toString()}>
<div className="flex items-center gap-2">
<Badge variant="outline" className="text-xs">
{config.provider}
</Badge>
<span>{config.name}</span>
<span className="text-muted-foreground">
({config.model_name})
</span>
</div>
</SelectItem>
))}
</>
)}
</SelectContent>
</Select>
</div>
{/* Custom Configurations */}
{newLLMConfigs.length > 0 && (
<>
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
Your Configurations
</div>
{newLLMConfigs
.filter(
(config) => config.id && config.id.toString().trim() !== ""
)
.map((config) => (
<SelectItem key={config.id} value={config.id.toString()}>
<div className="flex items-center gap-2">
<Badge variant="outline" className="text-xs">
{config.provider}
</Badge>
<span>{config.name}</span>
<span className="text-muted-foreground">
({config.model_name})
</span>
</div>
</SelectItem>
))}
</>
)}
</SelectContent>
</Select>
</div>
{assignedConfig && (
<div

View file

@ -88,7 +88,7 @@ function ImageCancelledState({ src }: { src: string }) {
function ParsedImage({ result }: { result: unknown }) {
const image = parseSerializableImage(result);
return <Image {...image} maxWidth="420px" />;
return <Image {...image} maxWidth="512px" />;
}
/**

View file

@ -1,6 +1,6 @@
"use client";
import { ExternalLinkIcon, ImageIcon } from "lucide-react";
import { ExternalLinkIcon, ImageIcon, SparklesIcon } from "lucide-react";
import NextImage from "next/image";
import { Component, type ReactNode, useState } from "react";
import { z } from "zod";
@ -25,7 +25,7 @@ const SerializableImageSchema = z.object({
id: z.string(),
assetId: z.string(),
src: z.string(),
alt: z.string().nullish(), // Made optional - will use fallback if missing
alt: z.string().nullish(),
title: z.string().nullish(),
description: z.string().nullish(),
href: z.string().nullish(),
@ -49,7 +49,7 @@ export interface ImageProps {
id: string;
assetId: string;
src: string;
alt?: string; // Optional with default fallback
alt?: string;
title?: string;
description?: string;
href?: string;
@ -71,10 +71,8 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a
if (!parsed.success) {
console.warn("Invalid image data:", parsed.error.issues);
// Try to extract basic info and return a fallback object
const obj = (result && typeof result === "object" ? result : {}) as Record<string, unknown>;
// If we have at least id, assetId, and src, we can still render the image
if (
typeof obj.id === "string" &&
typeof obj.assetId === "string" &&
@ -89,7 +87,7 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a
description: typeof obj.description === "string" ? obj.description : undefined,
href: typeof obj.href === "string" ? obj.href : undefined,
domain: typeof obj.domain === "string" ? obj.domain : undefined,
ratio: undefined, // Use default ratio
ratio: undefined,
source: undefined,
};
}
@ -97,7 +95,6 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a
throw new Error(`Invalid image: ${parsed.error.issues.map((i) => i.message).join(", ")}`);
}
// Provide fallback for alt if it's null/undefined
return {
...parsed.data,
alt: parsed.data.alt ?? "Image",
@ -105,7 +102,7 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a
}
/**
* Get aspect ratio class based on ratio prop
* Get aspect ratio class based on ratio prop (used for fixed-ratio images only)
*/
function getAspectRatioClass(ratio?: AspectRatio): string {
switch (ratio) {
@ -119,7 +116,6 @@ function getAspectRatioClass(ratio?: AspectRatio): string {
return "aspect-[9/16]";
case "21:9":
return "aspect-[21/9]";
case "auto":
default:
return "aspect-[4/3]";
}
@ -150,7 +146,7 @@ export class ImageErrorBoundary extends Component<
if (this.state.hasError) {
return (
<Card className="w-full max-w-md overflow-hidden">
<div className="aspect-[4/3] bg-muted flex items-center justify-center">
<div className="aspect-square bg-muted flex items-center justify-center">
<div className="flex flex-col items-center gap-2 text-muted-foreground">
<ImageIcon className="size-8" />
<p className="text-sm">Failed to load image</p>
@ -167,10 +163,10 @@ export class ImageErrorBoundary extends Component<
/**
* Loading skeleton for Image
*/
export function ImageSkeleton({ maxWidth = "420px" }: { maxWidth?: string }) {
export function ImageSkeleton({ maxWidth = "512px" }: { maxWidth?: string }) {
return (
<Card className="w-full overflow-hidden animate-pulse" style={{ maxWidth }}>
<div className="aspect-[4/3] bg-muted flex items-center justify-center">
<div className="aspect-square bg-muted flex items-center justify-center">
<ImageIcon className="size-12 text-muted-foreground/30" />
</div>
</Card>
@ -183,7 +179,7 @@ export function ImageSkeleton({ maxWidth = "420px" }: { maxWidth?: string }) {
export function ImageLoading({ title = "Loading image..." }: { title?: string }) {
return (
<Card className="w-full max-w-md overflow-hidden">
<div className="aspect-[4/3] bg-muted flex items-center justify-center">
<div className="aspect-square bg-muted flex items-center justify-center">
<div className="flex flex-col items-center gap-3">
<Spinner size="lg" className="text-muted-foreground" />
<p className="text-muted-foreground text-sm">{title}</p>
@ -197,7 +193,9 @@ export function ImageLoading({ title = "Loading image..." }: { title?: string })
* Image Component
*
* Display images with metadata and attribution.
* Features hover overlay with title and source attribution.
* - For "auto" ratio: renders the image at natural dimensions (no cropping)
* - For fixed ratios: uses a fixed aspect container with object-cover
* - Features hover overlay with title, description, and source attribution.
*/
export function Image({
id,
@ -207,16 +205,18 @@ export function Image({
description,
href,
domain,
ratio = "4:3",
ratio = "auto",
fit = "cover",
source,
maxWidth = "420px",
maxWidth = "512px",
className,
}: ImageProps) {
const [isHovered, setIsHovered] = useState(false);
const [imageError, setImageError] = useState(false);
const aspectRatioClass = getAspectRatioClass(ratio);
const [imageLoaded, setImageLoaded] = useState(false);
const displayDomain = domain || source?.label;
const isGenerated = domain === "ai-generated";
const isAutoRatio = !ratio || ratio === "auto";
const handleClick = () => {
const targetUrl = href || source?.url || src;
@ -228,7 +228,7 @@ export function Image({
if (imageError) {
return (
<Card id={id} className={cn("w-full overflow-hidden", className)} style={{ maxWidth }}>
<div className={cn("bg-muted flex items-center justify-center", aspectRatioClass)}>
<div className="aspect-square bg-muted flex items-center justify-center">
<div className="flex flex-col items-center gap-2 text-muted-foreground">
<ImageIcon className="size-8" />
<p className="text-sm">Image not available</p>
@ -243,6 +243,7 @@ export function Image({
id={id}
className={cn(
"group w-full overflow-hidden cursor-pointer transition-shadow duration-200 hover:shadow-lg",
isGenerated && "ring-1 ring-primary/10",
className
)}
style={{ maxWidth }}
@ -258,71 +259,98 @@ export function Image({
role="button"
tabIndex={0}
>
<div className={cn("relative w-full overflow-hidden bg-muted", aspectRatioClass)}>
{/* Image */}
<NextImage
src={src}
alt={alt}
fill
className={cn(
"transition-transform duration-300",
fit === "cover" ? "object-cover" : "object-contain",
isHovered && "scale-105"
)}
unoptimized
onError={() => setImageError(true)}
/>
<div className="relative w-full overflow-hidden bg-muted">
{isAutoRatio ? (
/* Auto ratio: image renders at natural dimensions, no cropping */
<>
{!imageLoaded && (
<div className="aspect-square flex items-center justify-center">
<Spinner size="lg" className="text-muted-foreground" />
</div>
)}
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={src}
alt={alt}
className={cn(
"w-full h-auto transition-transform duration-300",
isHovered && "scale-[1.02]",
!imageLoaded && "hidden"
)}
onLoad={() => setImageLoaded(true)}
onError={() => setImageError(true)}
/>
</>
) : (
/* Fixed ratio: constrained aspect container with fill */
<div className={getAspectRatioClass(ratio)}>
<NextImage
src={src}
alt={alt}
fill
className={cn(
"transition-transform duration-300",
fit === "cover" ? "object-cover" : "object-contain",
isHovered && "scale-105"
)}
unoptimized
onError={() => setImageError(true)}
/>
</div>
)}
{/* Hover overlay - appears on hover */}
{/* Hover overlay */}
<div
className={cn(
"absolute inset-0 bg-gradient-to-t from-black/80 via-black/20 to-transparent",
"absolute inset-0 bg-gradient-to-t from-black/70 via-transparent to-transparent",
"transition-opacity duration-200",
isHovered ? "opacity-100" : "opacity-0"
)}
>
{/* Content at bottom */}
<div className="absolute bottom-0 left-0 right-0 p-4">
{/* Title */}
<div className="absolute bottom-0 left-0 right-0 p-3">
{title && (
<h3 className="font-semibold text-white text-base leading-tight line-clamp-2 mb-1">
<h3 className="font-semibold text-white text-sm leading-tight line-clamp-2 mb-0.5">
{title}
</h3>
)}
{/* Description */}
{description && (
<p className="text-white/80 text-sm line-clamp-2 mb-2">{description}</p>
<p className="text-white/80 text-xs line-clamp-2 mb-1.5">{description}</p>
)}
{/* Source attribution */}
{displayDomain && (
<div className="flex items-center gap-1.5">
{source?.iconUrl ? (
{isGenerated ? (
<SparklesIcon className="size-3.5 text-white/70" />
) : source?.iconUrl ? (
<NextImage
src={source.iconUrl}
alt={source.label}
width={16}
height={16}
width={14}
height={14}
className="rounded"
unoptimized
/>
) : (
<ExternalLinkIcon className="size-4 text-white/70" />
<ExternalLinkIcon className="size-3.5 text-white/70" />
)}
<span className="text-white/70 text-sm">{displayDomain}</span>
<span className="text-white/70 text-xs">{displayDomain}</span>
</div>
)}
</div>
</div>
{/* Always visible domain badge (bottom right, shown when NOT hovered) */}
{/* Badge when not hovered */}
{displayDomain && !isHovered && (
<div className="absolute bottom-2 right-2">
<Badge
variant="secondary"
className="bg-black/60 text-white border-0 text-xs backdrop-blur-sm"
className={cn(
"border-0 text-xs backdrop-blur-sm",
isGenerated
? "bg-primary/80 text-primary-foreground"
: "bg-black/60 text-white"
)}
>
{isGenerated && <SparklesIcon className="size-3 mr-1" />}
{displayDomain}
</Badge>
</div>

View file

@ -0,0 +1,105 @@
export interface ImageGenProvider {
value: string;
label: string;
example: string;
description: string;
apiBase?: string;
}
/**
* Image generation providers supported by LiteLLM.
* See: https://docs.litellm.ai/docs/image_generation#supported-providers
*/
export const IMAGE_GEN_PROVIDERS: ImageGenProvider[] = [
{
value: "OPENAI",
label: "OpenAI",
example: "dall-e-3, gpt-image-1, dall-e-2",
description: "DALL-E and GPT Image models",
},
{
value: "AZURE_OPENAI",
label: "Azure OpenAI",
example: "azure/dall-e-3, azure/gpt-image-1",
description: "OpenAI image models on Azure",
},
{
value: "GOOGLE",
label: "Google AI Studio",
example: "gemini/imagen-3.0-generate-002",
description: "Google AI Studio image generation",
},
{
value: "VERTEX_AI",
label: "Google Vertex AI",
example: "vertex_ai/imagegeneration@006",
description: "Vertex AI image generation models",
},
{
value: "BEDROCK",
label: "AWS Bedrock",
example: "bedrock/stability.stable-diffusion-xl-v0",
description: "Stable Diffusion on AWS Bedrock",
},
{
value: "RECRAFT",
label: "Recraft",
example: "recraft/recraftv3",
description: "AI-powered design and image generation",
},
{
value: "OPENROUTER",
label: "OpenRouter",
example: "openrouter/google/gemini-2.5-flash-image",
description: "Image generation via OpenRouter",
},
{
value: "XINFERENCE",
label: "Xinference",
example: "xinference/stable-diffusion-xl",
description: "Self-hosted Stable Diffusion models",
},
{
value: "NSCALE",
label: "Nscale",
example: "nscale/flux.1-schnell",
description: "Nscale image generation",
},
];
/**
* Image generation models organized by provider.
*/
export interface ImageGenModel {
value: string;
label: string;
provider: string;
}
export const IMAGE_GEN_MODELS: ImageGenModel[] = [
// OpenAI
{ value: "gpt-image-1", label: "GPT Image 1", provider: "OPENAI" },
{ value: "dall-e-3", label: "DALL-E 3", provider: "OPENAI" },
{ value: "dall-e-2", label: "DALL-E 2", provider: "OPENAI" },
// Azure OpenAI
{ value: "azure/dall-e-3", label: "DALL-E 3 (Azure)", provider: "AZURE_OPENAI" },
{ value: "azure/gpt-image-1", label: "GPT Image 1 (Azure)", provider: "AZURE_OPENAI" },
// Recraft
{ value: "recraft/recraftv3", label: "Recraft V3", provider: "RECRAFT" },
// Bedrock
{
value: "bedrock/stability.stable-diffusion-xl-v0",
label: "Stable Diffusion XL",
provider: "BEDROCK",
},
// Vertex AI
{
value: "vertex_ai/imagegeneration@006",
label: "Imagen 3",
provider: "VERTEX_AI",
},
];
export function getImageGenModelsByProvider(provider: string): ImageGenModel[] {
return IMAGE_GEN_MODELS.filter((m) => m.provider === provider);
}

View file

@ -161,19 +161,105 @@ export const globalNewLLMConfig = z.object({
export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig);
// =============================================================================
// Image Generation Config (separate table from NewLLMConfig)
// =============================================================================
/**
* ImageGenProvider enum - only providers that support image generation
* See: https://docs.litellm.ai/docs/image_generation#supported-providers
*/
export const imageGenProviderEnum = z.enum([
"OPENAI",
"AZURE_OPENAI",
"GOOGLE",
"VERTEX_AI",
"BEDROCK",
"RECRAFT",
"OPENROUTER",
"XINFERENCE",
"NSCALE",
]);
export type ImageGenProvider = z.infer<typeof imageGenProviderEnum>;
/**
* ImageGenerationConfig - user-created image gen model configs
* Separate from NewLLMConfig: no system_instructions, no citations_enabled.
*/
export const imageGenerationConfig = z.object({
id: z.number(),
name: z.string().max(100),
description: z.string().max(500).nullable().optional(),
provider: imageGenProviderEnum,
custom_provider: z.string().max(100).nullable().optional(),
model_name: z.string().max(100),
api_key: z.string(),
api_base: z.string().max(500).nullable().optional(),
api_version: z.string().max(50).nullable().optional(),
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
created_at: z.string(),
search_space_id: z.number(),
});
export const createImageGenConfigRequest = imageGenerationConfig.omit({
id: true,
created_at: true,
});
export const createImageGenConfigResponse = imageGenerationConfig;
export const getImageGenConfigsResponse = z.array(imageGenerationConfig);
export const updateImageGenConfigRequest = z.object({
id: z.number(),
data: imageGenerationConfig
.omit({ id: true, created_at: true, search_space_id: true })
.partial(),
});
export const updateImageGenConfigResponse = imageGenerationConfig;
export const deleteImageGenConfigResponse = z.object({
message: z.string(),
id: z.number(),
});
/**
* Global Image Generation Config - from YAML, has negative IDs
* ID 0 is reserved for "Auto" mode (LiteLLM Router load balancing)
*/
export const globalImageGenConfig = z.object({
id: z.number(),
name: z.string(),
description: z.string().nullable().optional(),
provider: z.string(),
custom_provider: z.string().nullable().optional(),
model_name: z.string(),
api_base: z.string().nullable().optional(),
api_version: z.string().nullable().optional(),
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
is_global: z.literal(true),
is_auto_mode: z.boolean().optional().default(false),
});
export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig);
// =============================================================================
// LLM Preferences (Role Assignments)
// =============================================================================
/**
* LLM Preferences schemas - for role assignments
* The agent_llm and document_summary_llm fields contain the full NewLLMConfig objects
* image_generation uses image_generation_config_id (not llm_id)
*/
export const llmPreferences = z.object({
agent_llm_id: z.union([z.number(), z.null()]).optional(),
document_summary_llm_id: z.union([z.number(), z.null()]).optional(),
image_generation_config_id: z.union([z.number(), z.null()]).optional(),
agent_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
document_summary_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
image_generation_config: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
});
/**
@ -193,6 +279,7 @@ export const updateLLMPreferencesRequest = z.object({
data: llmPreferences.pick({
agent_llm_id: true,
document_summary_llm_id: true,
image_generation_config_id: true,
}),
});
@ -219,6 +306,15 @@ export type GetDefaultSystemInstructionsResponse = z.infer<
>;
export type GlobalNewLLMConfig = z.infer<typeof globalNewLLMConfig>;
export type GetGlobalNewLLMConfigsResponse = z.infer<typeof getGlobalNewLLMConfigsResponse>;
export type ImageGenerationConfig = z.infer<typeof imageGenerationConfig>;
export type CreateImageGenConfigRequest = z.infer<typeof createImageGenConfigRequest>;
export type CreateImageGenConfigResponse = z.infer<typeof createImageGenConfigResponse>;
export type GetImageGenConfigsResponse = z.infer<typeof getImageGenConfigsResponse>;
export type UpdateImageGenConfigRequest = z.infer<typeof updateImageGenConfigRequest>;
export type UpdateImageGenConfigResponse = z.infer<typeof updateImageGenConfigResponse>;
export type DeleteImageGenConfigResponse = z.infer<typeof deleteImageGenConfigResponse>;
export type GlobalImageGenConfig = z.infer<typeof globalImageGenConfig>;
export type GetGlobalImageGenConfigsResponse = z.infer<typeof getGlobalImageGenConfigsResponse>;
export type LLMPreferences = z.infer<typeof llmPreferences>;
export type GetLLMPreferencesRequest = z.infer<typeof getLLMPreferencesRequest>;
export type GetLLMPreferencesResponse = z.infer<typeof getLLMPreferencesResponse>;

View file

@ -0,0 +1,83 @@
import {
type CreateImageGenConfigRequest,
createImageGenConfigRequest,
createImageGenConfigResponse,
type UpdateImageGenConfigRequest,
updateImageGenConfigRequest,
updateImageGenConfigResponse,
deleteImageGenConfigResponse,
getImageGenConfigsResponse,
getGlobalImageGenConfigsResponse,
} from "@/contracts/types/new-llm-config.types";
import { ValidationError } from "../error";
import { baseApiService } from "./base-api.service";
class ImageGenConfigApiService {
/**
* Get all global image generation configs (from YAML, negative IDs)
*/
getGlobalConfigs = async () => {
return baseApiService.get(
`/api/v1/global-image-generation-configs`,
getGlobalImageGenConfigsResponse
);
};
/**
* Create a new image generation config for a search space
*/
createConfig = async (request: CreateImageGenConfigRequest) => {
const parsed = createImageGenConfigRequest.safeParse(request);
if (!parsed.success) {
const msg = parsed.error.issues.map((i) => i.message).join(", ");
throw new ValidationError(`Invalid request: ${msg}`);
}
return baseApiService.post(
`/api/v1/image-generation-configs`,
createImageGenConfigResponse,
{ body: parsed.data }
);
};
/**
* Get image generation configs for a search space
*/
getConfigs = async (searchSpaceId: number) => {
const params = new URLSearchParams({
search_space_id: String(searchSpaceId),
}).toString();
return baseApiService.get(
`/api/v1/image-generation-configs?${params}`,
getImageGenConfigsResponse
);
};
/**
* Update an existing image generation config
*/
updateConfig = async (request: UpdateImageGenConfigRequest) => {
const parsed = updateImageGenConfigRequest.safeParse(request);
if (!parsed.success) {
const msg = parsed.error.issues.map((i) => i.message).join(", ");
throw new ValidationError(`Invalid request: ${msg}`);
}
const { id, data } = parsed.data;
return baseApiService.put(
`/api/v1/image-generation-configs/${id}`,
updateImageGenConfigResponse,
{ body: data }
);
};
/**
* Delete an image generation config
*/
deleteConfig = async (id: number) => {
return baseApiService.delete(
`/api/v1/image-generation-configs/${id}`,
deleteImageGenConfigResponse
);
};
}
export const imageGenConfigApiService = new ImageGenConfigApiService();

View file

@ -34,6 +34,11 @@ export const cacheKeys = {
defaultInstructions: () => ["new-llm-configs", "default-instructions"] as const,
global: () => ["new-llm-configs", "global"] as const,
},
imageGenConfigs: {
all: (searchSpaceId: number) => ["image-gen-configs", searchSpaceId] as const,
byId: (configId: number) => ["image-gen-configs", "detail", configId] as const,
global: () => ["image-gen-configs", "global"] as const,
},
auth: {
user: ["auth", "user"] as const,
},

View file

@ -739,6 +739,8 @@
"nav_agent_configs_desc": "LLM models with prompts & citations",
"nav_role_assignments": "Role Assignments",
"nav_role_assignments_desc": "Assign configs to agent roles",
"nav_image_models": "Image Models",
"nav_image_models_desc": "Configure image generation models",
"nav_system_instructions": "System Instructions",
"nav_system_instructions_desc": "SearchSpace-wide AI instructions",
"nav_public_links": "Public Chat Links",

View file

@ -724,6 +724,8 @@
"nav_agent_configs_desc": "LLM 模型配置提示词和引用",
"nav_role_assignments": "角色分配",
"nav_role_assignments_desc": "为代理角色分配配置",
"nav_image_models": "图像模型",
"nav_image_models_desc": "配置图像生成模型",
"nav_system_instructions": "系统指令",
"nav_system_instructions_desc": "搜索空间级别的 AI 指令",
"nav_public_links": "公开聊天链接",