mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/documents
This commit is contained in:
commit
0fdd194d92
60 changed files with 5086 additions and 248 deletions
|
|
@ -143,6 +143,15 @@ STT_SERVICE=local/base
|
|||
PAGES_LIMIT=500
|
||||
|
||||
|
||||
# Residential Proxy Configuration (anonymous-proxies.net)
|
||||
# Used for web crawling, link previews, and YouTube transcript fetching to avoid IP bans.
|
||||
# Leave commented out to disable proxying.
|
||||
# RESIDENTIAL_PROXY_USERNAME=your_proxy_username
|
||||
# RESIDENTIAL_PROXY_PASSWORD=your_proxy_password
|
||||
# RESIDENTIAL_PROXY_HOSTNAME=rotating.dnsproxifier.com:31230
|
||||
# RESIDENTIAL_PROXY_LOCATION=
|
||||
# RESIDENTIAL_PROXY_TYPE=1
|
||||
|
||||
FIRECRAWL_API_KEY=fcr-01J0000000000000000000000
|
||||
|
||||
# File Parser Service
|
||||
|
|
|
|||
|
|
@ -0,0 +1,299 @@
|
|||
"""Add image generation tables and search space preference
|
||||
|
||||
Revision ID: 93
|
||||
Revises: 92
|
||||
|
||||
Changes:
|
||||
1. Create image_generation_configs table (user-created image model configs)
|
||||
2. Create image_generations table (stores generation requests/results)
|
||||
3. Add image_generation_config_id column to searchspaces table
|
||||
4. Add image generation permissions to existing system roles
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import ENUM as PG_ENUM, JSONB, UUID
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "93"
|
||||
down_revision: str | None = "92"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
connection = op.get_bind()
|
||||
|
||||
# 1. Create imagegenprovider enum type if it doesn't exist
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'imagegenprovider') THEN
|
||||
CREATE TYPE imagegenprovider AS ENUM (
|
||||
'OPENAI', 'AZURE_OPENAI', 'GOOGLE', 'VERTEX_AI', 'BEDROCK',
|
||||
'RECRAFT', 'OPENROUTER', 'XINFERENCE', 'NSCALE'
|
||||
);
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create image_generation_configs table (uses imagegenprovider enum)
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generation_configs')"
|
||||
)
|
||||
)
|
||||
if not result.scalar():
|
||||
op.create_table(
|
||||
"image_generation_configs",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("name", sa.String(100), nullable=False),
|
||||
sa.Column("description", sa.String(500), nullable=True),
|
||||
sa.Column(
|
||||
"provider",
|
||||
PG_ENUM(
|
||||
"OPENAI",
|
||||
"AZURE_OPENAI",
|
||||
"GOOGLE",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"RECRAFT",
|
||||
"OPENROUTER",
|
||||
"XINFERENCE",
|
||||
"NSCALE",
|
||||
name="imagegenprovider",
|
||||
create_type=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("custom_provider", sa.String(100), nullable=True),
|
||||
sa.Column("model_name", sa.String(100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(500), nullable=True),
|
||||
sa.Column("api_version", sa.String(50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generation_configs_name "
|
||||
"ON image_generation_configs (name)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generation_configs_search_space_id "
|
||||
"ON image_generation_configs (search_space_id)"
|
||||
)
|
||||
|
||||
# 3. Create image_generations table
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generations')"
|
||||
)
|
||||
)
|
||||
if not result.scalar():
|
||||
op.create_table(
|
||||
"image_generations",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.Text(), nullable=False),
|
||||
sa.Column("model", sa.String(200), nullable=True),
|
||||
sa.Column("n", sa.Integer(), nullable=True),
|
||||
sa.Column("quality", sa.String(50), nullable=True),
|
||||
sa.Column("size", sa.String(50), nullable=True),
|
||||
sa.Column("style", sa.String(50), nullable=True),
|
||||
sa.Column("response_format", sa.String(50), nullable=True),
|
||||
sa.Column("image_generation_config_id", sa.Integer(), nullable=True),
|
||||
sa.Column("response_data", JSONB(), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_by_id", UUID(as_uuid=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by_id"], ["user.id"], ondelete="SET NULL"
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generations_search_space_id "
|
||||
"ON image_generations (search_space_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generations_created_by_id "
|
||||
"ON image_generations (created_by_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generations_created_at "
|
||||
"ON image_generations (created_at)"
|
||||
)
|
||||
|
||||
# 4. Add image_generation_config_id column to searchspaces
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces'
|
||||
AND column_name = 'image_generation_config_id'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.scalar():
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column(
|
||||
"image_generation_config_id",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
# Drop old column name if it exists (from earlier version of this migration)
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces'
|
||||
AND column_name = 'image_generation_llm_id'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.scalar():
|
||||
op.drop_column("searchspaces", "image_generation_llm_id")
|
||||
|
||||
# Drop old column name on image_generations if it exists
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'image_generations'
|
||||
AND column_name = 'llm_config_id'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.scalar():
|
||||
op.drop_column("image_generations", "llm_config_id")
|
||||
|
||||
# Drop old api_version column on image_generations if it exists
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'image_generations'
|
||||
AND column_name = 'api_version'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.scalar():
|
||||
op.drop_column("image_generations", "api_version")
|
||||
|
||||
# 5. Add image generation permissions to existing system roles
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_cat(
|
||||
permissions,
|
||||
ARRAY['image_generations:create', 'image_generations:read']
|
||||
)
|
||||
WHERE is_system_role = true
|
||||
AND name = 'Editor'
|
||||
AND NOT ('image_generations:create' = ANY(permissions))
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_cat(
|
||||
permissions,
|
||||
ARRAY['image_generations:read']
|
||||
)
|
||||
WHERE is_system_role = true
|
||||
AND name = 'Viewer'
|
||||
AND NOT ('image_generations:read' = ANY(permissions))
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
connection = op.get_bind()
|
||||
|
||||
# Remove permissions
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_remove(
|
||||
array_remove(
|
||||
array_remove(permissions, 'image_generations:create'),
|
||||
'image_generations:read'
|
||||
),
|
||||
'image_generations:delete'
|
||||
)
|
||||
WHERE is_system_role = true
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Remove image_generation_config_id from searchspaces
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces'
|
||||
AND column_name = 'image_generation_config_id'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.scalar():
|
||||
op.drop_column("searchspaces", "image_generation_config_id")
|
||||
|
||||
# Drop indexes and tables
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generations_created_at")
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generations_created_by_id")
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generations_search_space_id")
|
||||
op.execute("DROP TABLE IF EXISTS image_generations")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_search_space_id")
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_name")
|
||||
op.execute("DROP TABLE IF EXISTS image_generation_configs")
|
||||
|
||||
# Drop the imagegenprovider enum type
|
||||
op.execute("DROP TYPE IF EXISTS imagegenprovider")
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Add access_token column to image_generations
|
||||
|
||||
Revision ID: 94
|
||||
Revises: 93
|
||||
|
||||
Adds an indexed access_token column to the image_generations table.
|
||||
This token is stored per-record so that image serving URLs survive
|
||||
SECRET_KEY rotation.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "94"
|
||||
down_revision: str | None = "93"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add access_token column (nullable so existing rows are unaffected)
|
||||
op.add_column(
|
||||
"image_generations",
|
||||
sa.Column("access_token", sa.String(64), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_image_generations_access_token",
|
||||
"image_generations",
|
||||
["access_token"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_image_generations_access_token", table_name="image_generations")
|
||||
op.drop_column("image_generations", "access_token")
|
||||
|
|
@ -133,6 +133,7 @@ async def create_surfsense_deep_agent(
|
|||
The agent comes with built-in tools that can be configured:
|
||||
- search_knowledge_base: Search the user's personal knowledge base
|
||||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- link_preview: Fetch rich previews for URLs
|
||||
- display_image: Display images in chat
|
||||
- scrape_webpage: Extract content from webpages
|
||||
|
|
|
|||
|
|
@ -3,15 +3,25 @@ PostgreSQL-based checkpointer for LangGraph agents.
|
|||
|
||||
This module provides a persistent checkpointer using AsyncPostgresSaver
|
||||
that stores conversation state in the PostgreSQL database.
|
||||
|
||||
Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle
|
||||
connection lifecycle, health checks, and automatic reconnection,
|
||||
preventing 'the connection is closed' errors in long-running deployments.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg.rows import dict_row
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpointer instance (initialized lazily)
|
||||
_checkpointer: AsyncPostgresSaver | None = None
|
||||
_checkpointer_context = None # Store the context manager for cleanup
|
||||
_connection_pool: AsyncConnectionPool | None = None
|
||||
_checkpointer_initialized: bool = False
|
||||
|
||||
|
||||
|
|
@ -38,26 +48,65 @@ def get_postgres_connection_string() -> str:
|
|||
return db_url
|
||||
|
||||
|
||||
async def _create_checkpointer() -> AsyncPostgresSaver:
|
||||
"""
|
||||
Create a new AsyncPostgresSaver backed by a connection pool.
|
||||
|
||||
The connection pool automatically handles:
|
||||
- Connection health checks before use
|
||||
- Reconnection when connections die (idle timeout, DB restart, etc.)
|
||||
- Connection lifecycle management (max_lifetime, max_idle)
|
||||
"""
|
||||
global _connection_pool
|
||||
|
||||
conn_string = get_postgres_connection_string()
|
||||
|
||||
_connection_pool = AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
# Connections are recycled after 30 minutes to avoid stale connections
|
||||
max_lifetime=1800,
|
||||
# Idle connections are closed after 5 minutes
|
||||
max_idle=300,
|
||||
open=False,
|
||||
# Connection kwargs required by AsyncPostgresSaver:
|
||||
# - autocommit: required for .setup() to commit checkpoint tables
|
||||
# - prepare_threshold: disable prepared statements for compatibility
|
||||
# - row_factory: checkpointer accesses rows as dicts (row["column"])
|
||||
kwargs={
|
||||
"autocommit": True,
|
||||
"prepare_threshold": 0,
|
||||
"row_factory": dict_row,
|
||||
},
|
||||
)
|
||||
await _connection_pool.open(wait=True)
|
||||
|
||||
checkpointer = AsyncPostgresSaver(conn=_connection_pool)
|
||||
logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool")
|
||||
return checkpointer
|
||||
|
||||
|
||||
async def get_checkpointer() -> AsyncPostgresSaver:
|
||||
"""
|
||||
Get or create the global AsyncPostgresSaver instance.
|
||||
|
||||
This function:
|
||||
1. Creates the checkpointer if it doesn't exist
|
||||
1. Creates the checkpointer with a connection pool if it doesn't exist
|
||||
2. Sets up the required database tables on first call
|
||||
3. Returns the cached instance on subsequent calls
|
||||
|
||||
The underlying connection pool handles reconnection automatically,
|
||||
so a stale/closed connection will not cause OperationalError.
|
||||
|
||||
Returns:
|
||||
AsyncPostgresSaver: The configured checkpointer instance
|
||||
"""
|
||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
||||
global _checkpointer, _checkpointer_initialized
|
||||
|
||||
if _checkpointer is None:
|
||||
conn_string = get_postgres_connection_string()
|
||||
# from_conn_string returns an async context manager
|
||||
# We need to enter the context to get the actual checkpointer
|
||||
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
|
||||
_checkpointer = await _checkpointer_context.__aenter__()
|
||||
_checkpointer = await _create_checkpointer()
|
||||
_checkpointer_initialized = False
|
||||
|
||||
# Setup tables on first call (idempotent)
|
||||
if not _checkpointer_initialized:
|
||||
|
|
@ -75,20 +124,21 @@ async def setup_checkpointer_tables() -> None:
|
|||
tables exist before any agent calls.
|
||||
"""
|
||||
await get_checkpointer()
|
||||
print("[Checkpointer] PostgreSQL checkpoint tables ready")
|
||||
logger.info("[Checkpointer] PostgreSQL checkpoint tables ready")
|
||||
|
||||
|
||||
async def close_checkpointer() -> None:
|
||||
"""
|
||||
Close the checkpointer connection.
|
||||
Close the checkpointer connection pool.
|
||||
|
||||
This should be called during application shutdown.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
||||
global _checkpointer, _connection_pool, _checkpointer_initialized
|
||||
|
||||
if _checkpointer_context is not None:
|
||||
await _checkpointer_context.__aexit__(None, None, None)
|
||||
_checkpointer = None
|
||||
_checkpointer_context = None
|
||||
_checkpointer_initialized = False
|
||||
print("[Checkpointer] PostgreSQL connection closed")
|
||||
if _connection_pool is not None:
|
||||
await _connection_pool.close()
|
||||
logger.info("[Checkpointer] PostgreSQL connection pool closed")
|
||||
|
||||
_checkpointer = None
|
||||
_connection_pool = None
|
||||
_checkpointer_initialized = False
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ You have access to the following tools:
|
|||
* Showing an image from a URL the user explicitly mentioned in their message
|
||||
* Displaying images found in scraped webpage content (from scrape_webpage tool)
|
||||
* Showing a publicly accessible diagram or chart from a known URL
|
||||
* Displaying an AI-generated image after calling the generate_image tool (ALWAYS required)
|
||||
|
||||
CRITICAL - NEVER USE THIS TOOL FOR USER-UPLOADED ATTACHMENTS:
|
||||
When a user uploads/attaches an image file to their message:
|
||||
|
|
@ -100,7 +101,21 @@ You have access to the following tools:
|
|||
- Returns: An image card with the image, title, and description
|
||||
- The image will automatically be displayed in the chat.
|
||||
|
||||
5. scrape_webpage: Scrape and extract the main content from a webpage.
|
||||
5. generate_image: Generate images from text descriptions using AI image models.
|
||||
- Use this when the user asks you to create, generate, draw, design, or make an image.
|
||||
- Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork"
|
||||
- Args:
|
||||
- prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood.
|
||||
- n: Number of images to generate (1-4, default: 1)
|
||||
- Returns: A dictionary with the generated image URL in the "src" field, along with metadata.
|
||||
- CRITICAL: After calling generate_image, you MUST call `display_image` with the returned "src" URL
|
||||
to actually show the image in the chat. The generate_image tool only generates the image and returns
|
||||
the URL — it does NOT display anything. You must always follow up with display_image.
|
||||
- IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim -
|
||||
expand and improve the prompt with specific details about style, lighting, composition, and mood.
|
||||
- If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details.
|
||||
|
||||
6. scrape_webpage: Scrape and extract the main content from a webpage.
|
||||
- Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage.
|
||||
- IMPORTANT: This is different from link_preview:
|
||||
* link_preview: Only fetches metadata (title, description, thumbnail) for display
|
||||
|
|
@ -123,7 +138,7 @@ You have access to the following tools:
|
|||
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
||||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||
|
||||
6. save_memory: Save facts, preferences, or context about the user for personalized responses.
|
||||
7. save_memory: Save facts, preferences, or context about the user for personalized responses.
|
||||
- Use this when the user explicitly or implicitly shares information worth remembering.
|
||||
- Trigger scenarios:
|
||||
* User says "remember this", "keep this in mind", "note that", or similar
|
||||
|
|
@ -146,7 +161,7 @@ You have access to the following tools:
|
|||
- IMPORTANT: Only save information that would be genuinely useful for future conversations.
|
||||
Don't save trivial or temporary information.
|
||||
|
||||
7. recall_memory: Retrieve relevant memories about the user for personalized responses.
|
||||
8. recall_memory: Retrieve relevant memories about the user for personalized responses.
|
||||
- Use this to access stored information about the user.
|
||||
- Trigger scenarios:
|
||||
* You need user context to give a better, more personalized answer
|
||||
|
|
@ -281,6 +296,22 @@ You have access to the following tools:
|
|||
- Then, if the content contains useful diagrams/images like ``:
|
||||
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
|
||||
- Then provide your explanation, referencing the displayed image
|
||||
|
||||
- User: "Generate an image of a cat"
|
||||
- Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")`
|
||||
- Step 2: Use the returned "src" URL to display it: `display_image(src="<returned_url>", alt="A fluffy orange tabby cat on a windowsill", title="Generated Image")`
|
||||
|
||||
- User: "Create a landscape painting of mountains"
|
||||
- Step 1: `generate_image(prompt="Majestic snow-capped mountain range at sunset, dramatic orange and purple sky, alpine meadow with wildflowers in the foreground, oil painting style with visible brushstrokes, inspired by the Hudson River School art movement")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="Mountain landscape painting", title="Generated Image")`
|
||||
|
||||
- User: "Draw me a logo for a coffee shop called Bean Dream"
|
||||
- Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="Bean Dream coffee shop logo", title="Generated Image")`
|
||||
|
||||
- User: "Make a wide banner image for my blog about AI"
|
||||
- Step 1: `generate_image(prompt="Wide banner illustration for an AI technology blog, featuring abstract neural network patterns, glowing blue and purple connections, modern futuristic aesthetic, digital art style, clean and professional")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="AI blog banner", title="Generated Image")`
|
||||
</tool_call_examples>
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ Available tools:
|
|||
- search_knowledge_base: Search the user's personal knowledge base
|
||||
- search_surfsense_docs: Search Surfsense documentation for usage help
|
||||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- link_preview: Fetch rich previews for URLs
|
||||
- display_image: Display images in chat
|
||||
- scrape_webpage: Extract content from webpages
|
||||
|
|
@ -18,6 +19,7 @@ Available tools:
|
|||
# Registry exports
|
||||
# Tool factory exports (for direct use)
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .knowledge_base import (
|
||||
CONNECTOR_DESCRIPTIONS,
|
||||
create_search_knowledge_base_tool,
|
||||
|
|
@ -47,6 +49,7 @@ __all__ = [
|
|||
"build_tools",
|
||||
# Tool factories
|
||||
"create_display_image_tool",
|
||||
"create_generate_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_link_preview_tool",
|
||||
"create_recall_memory_tool",
|
||||
|
|
|
|||
|
|
@ -82,14 +82,20 @@ def create_display_image_tool():
|
|||
|
||||
domain = extract_domain(src)
|
||||
|
||||
# Determine aspect ratio based on common image sources
|
||||
ratio = "16:9" # Default
|
||||
if "unsplash.com" in src or "pexels.com" in src:
|
||||
# Determine aspect ratio based on image source
|
||||
# AI-generated images should use "auto" to preserve their native ratio
|
||||
is_generated = "/image-generations/" in src
|
||||
if is_generated:
|
||||
ratio = "auto"
|
||||
domain = "ai-generated"
|
||||
elif "unsplash.com" in src or "pexels.com" in src:
|
||||
ratio = "16:9"
|
||||
elif (
|
||||
"imgur.com" in src or "github.com" in src or "githubusercontent.com" in src
|
||||
):
|
||||
ratio = "auto"
|
||||
else:
|
||||
ratio = "auto"
|
||||
|
||||
return {
|
||||
"id": image_id,
|
||||
|
|
|
|||
242
surfsense_backend/app/agents/new_chat/tools/generate_image.py
Normal file
242
surfsense_backend/app/agents/new_chat/tools/generate_image.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""
|
||||
Image generation tool for the SurfSense agent.
|
||||
|
||||
This module provides a tool that generates images using litellm.aimage_generation()
|
||||
and returns the result via the existing display_image tool format so the frontend
|
||||
renders the generated image inline in the chat.
|
||||
|
||||
Config resolution:
|
||||
1. Uses the search space's image_generation_config_id preference
|
||||
2. Falls back to Auto mode (router load balancing) if available
|
||||
3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping (same as routes)
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def create_generate_image_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_image tool.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID (for config resolution)
|
||||
db_session: Async database session
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
n: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate an image from a text description using AI image models.
|
||||
|
||||
Use this tool when the user asks you to create, generate, draw, or make an image.
|
||||
The generated image will be displayed directly in the chat.
|
||||
|
||||
Args:
|
||||
prompt: A detailed text description of the image to generate.
|
||||
Be specific about subject, style, colors, composition, and mood.
|
||||
n: Number of images to generate (1-4). Default: 1
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generated image(s) for display in the chat.
|
||||
"""
|
||||
try:
|
||||
# Resolve the image generation config from the search space preference
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return {"error": "Search space not found"}
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# Build generation kwargs
|
||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
||||
# Different models support different values for these params
|
||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
||||
# differ). Letting the model use its own defaults avoids errors.
|
||||
gen_kwargs: dict[str, Any] = {}
|
||||
if n is not None and n > 1:
|
||||
gen_kwargs["n"] = n
|
||||
|
||||
# Call litellm based on config type
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
return {
|
||||
"error": "No image generation models configured. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
}
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
)
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
return {"error": f"Image generation config {config_id} not found"}
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg["model_name"],
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
cfg_result = await db_session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
return {"error": f"Image generation config {config_id} not found"}
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value,
|
||||
db_cfg.model_name,
|
||||
db_cfg.custom_provider,
|
||||
)
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
|
||||
# Parse the response and store in DB
|
||||
response_dict = (
|
||||
response.model_dump()
|
||||
if hasattr(response, "model_dump")
|
||||
else dict(response)
|
||||
)
|
||||
|
||||
# Generate a random access token for this image
|
||||
access_token = generate_image_token()
|
||||
|
||||
# Save to image_generations table for history
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_generation_config_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
db_session.add(db_image_gen)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(db_image_gen)
|
||||
|
||||
# Extract image URLs from response
|
||||
images = response_dict.get("data", [])
|
||||
if not images:
|
||||
return {"error": "No images were generated"}
|
||||
|
||||
first_image = images[0]
|
||||
revised_prompt = first_image.get("revised_prompt", prompt)
|
||||
|
||||
# Resolve image URL:
|
||||
# - If the API returned a URL, use it directly.
|
||||
# - If the API returned b64_json (e.g. gpt-image-1), serve the
|
||||
# image through our backend endpoint to avoid bloating the
|
||||
# LLM context with megabytes of base64 data.
|
||||
if first_image.get("url"):
|
||||
image_url = first_image["url"]
|
||||
elif first_image.get("b64_json"):
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
image_url = (
|
||||
f"{backend_url}/api/v1/image-generations/"
|
||||
f"{db_image_gen.id}/image?token={access_token}"
|
||||
)
|
||||
else:
|
||||
return {"error": "No displayable image data in the response"}
|
||||
|
||||
return {
|
||||
"src": image_url,
|
||||
"alt": revised_prompt or prompt,
|
||||
"title": "Generated Image",
|
||||
"description": revised_prompt if revised_prompt != prompt else None,
|
||||
"generated": True,
|
||||
"prompt": prompt,
|
||||
"image_count": len(images),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Image generation failed in tool")
|
||||
return {
|
||||
"error": f"Image generation failed: {e!s}",
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
return generate_image
|
||||
|
|
@ -17,6 +17,8 @@ from fake_useragent import UserAgent
|
|||
from langchain_core.tools import tool
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from app.utils.proxy_config import get_playwright_proxy, get_residential_proxy_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -186,9 +188,15 @@ async def fetch_with_chromium(url: str) -> dict[str, Any] | None:
|
|||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
|
||||
# Use residential proxy if configured
|
||||
playwright_proxy = get_playwright_proxy()
|
||||
|
||||
# Use Playwright to fetch the page
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch(headless=True)
|
||||
launch_kwargs: dict = {"headless": True}
|
||||
if playwright_proxy:
|
||||
launch_kwargs["proxy"] = playwright_proxy
|
||||
browser = await p.chromium.launch(**launch_kwargs)
|
||||
context = await browser.new_context(user_agent=user_agent)
|
||||
page = await context.new_page()
|
||||
|
||||
|
|
@ -283,12 +291,16 @@ def create_link_preview_tool():
|
|||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
|
||||
# Use residential proxy if configured
|
||||
proxy_url = get_residential_proxy_url()
|
||||
|
||||
# Use a browser-like User-Agent to fetch Open Graph metadata.
|
||||
# We're only fetching publicly available metadata (title, description, thumbnail)
|
||||
# that websites intentionally expose via OG tags for link preview purposes.
|
||||
async with httpx.AsyncClient(
|
||||
timeout=10.0,
|
||||
follow_redirects=True,
|
||||
proxy=proxy_url,
|
||||
headers={
|
||||
"User-Agent": user_agent,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8",
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from typing import Any
|
|||
from langchain_core.tools import BaseTool
|
||||
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .link_preview import create_link_preview_tool
|
||||
from .mcp_tool import load_mcp_tools
|
||||
|
|
@ -125,6 +126,16 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
factory=lambda deps: create_display_image_tool(),
|
||||
requires=[],
|
||||
),
|
||||
# Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.)
|
||||
ToolDefinition(
|
||||
name="generate_image",
|
||||
description="Generate images from text descriptions using AI image models",
|
||||
factory=lambda deps: create_generate_image_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
),
|
||||
requires=["search_space_id", "db_session"],
|
||||
),
|
||||
# Web scraping tool - extracts content from webpages
|
||||
ToolDefinition(
|
||||
name="scrape_webpage",
|
||||
|
|
|
|||
|
|
@ -2,17 +2,26 @@
|
|||
Web scraping tool for the SurfSense agent.
|
||||
|
||||
This module provides a tool for scraping and extracting content from webpages
|
||||
using the existing WebCrawlerConnector. The scraped content can be used by
|
||||
the agent to answer questions about web pages.
|
||||
using the existing WebCrawlerConnector. For YouTube URLs, it fetches the
|
||||
transcript directly via the YouTubeTranscriptApi instead of crawling the page.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from fake_useragent import UserAgent
|
||||
from langchain_core.tools import tool
|
||||
from requests import Session
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
from app.connectors.webcrawler_connector import WebCrawlerConnector
|
||||
from app.tasks.document_processors.youtube_processor import get_youtube_video_id
|
||||
from app.utils.proxy_config import get_requests_proxies
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_domain(url: str) -> str:
|
||||
|
|
@ -57,6 +66,89 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
|
|||
return truncated + "\n\n[Content truncated...]", True
|
||||
|
||||
|
||||
async def _scrape_youtube_video(
|
||||
url: str, video_id: str, max_length: int
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch YouTube video metadata and transcript via the YouTubeTranscriptApi.
|
||||
|
||||
Returns a result dict in the same shape as the regular scrape_webpage output.
|
||||
"""
|
||||
scrape_id = generate_scrape_id(url)
|
||||
domain = "youtube.com"
|
||||
|
||||
# --- Video metadata via oEmbed ---
|
||||
residential_proxies = get_requests_proxies()
|
||||
|
||||
params = {
|
||||
"format": "json",
|
||||
"url": f"https://www.youtube.com/watch?v={video_id}",
|
||||
}
|
||||
oembed_url = "https://www.youtube.com/oembed"
|
||||
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as http_session,
|
||||
http_session.get(
|
||||
oembed_url,
|
||||
params=params,
|
||||
proxy=residential_proxies["http"] if residential_proxies else None,
|
||||
) as response,
|
||||
):
|
||||
video_data = await response.json()
|
||||
except Exception:
|
||||
video_data = {}
|
||||
|
||||
title = video_data.get("title", "YouTube Video")
|
||||
author = video_data.get("author_name", "Unknown")
|
||||
|
||||
# --- Transcript via YouTubeTranscriptApi ---
|
||||
try:
|
||||
ua = UserAgent()
|
||||
http_client = Session()
|
||||
http_client.headers.update({"User-Agent": ua.random})
|
||||
if residential_proxies:
|
||||
http_client.proxies.update(residential_proxies)
|
||||
ytt_api = YouTubeTranscriptApi(http_client=http_client)
|
||||
captions = ytt_api.fetch(video_id)
|
||||
|
||||
transcript_segments = []
|
||||
for line in captions:
|
||||
start_time = line.start
|
||||
duration = line.duration
|
||||
text = line.text
|
||||
timestamp = f"[{start_time:.2f}s-{start_time + duration:.2f}s]"
|
||||
transcript_segments.append(f"{timestamp} {text}")
|
||||
transcript_text = "\n".join(transcript_segments)
|
||||
except Exception as e:
|
||||
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
|
||||
transcript_text = f"No captions available for this video. Error: {e!s}"
|
||||
|
||||
# Build combined content
|
||||
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
|
||||
|
||||
# Truncate if needed
|
||||
content, was_truncated = truncate_content(content, max_length)
|
||||
word_count = len(content.split())
|
||||
|
||||
description = f"YouTube video by {author}"
|
||||
|
||||
return {
|
||||
"id": scrape_id,
|
||||
"assetId": url,
|
||||
"kind": "article",
|
||||
"href": url,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"content": content,
|
||||
"domain": domain,
|
||||
"word_count": word_count,
|
||||
"was_truncated": was_truncated,
|
||||
"crawler_type": "youtube_transcript",
|
||||
"author": author,
|
||||
}
|
||||
|
||||
|
||||
def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
||||
"""
|
||||
Factory function to create the scrape_webpage tool.
|
||||
|
|
@ -79,7 +171,8 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
|||
|
||||
Use this tool when the user wants you to read, summarize, or answer
|
||||
questions about a specific webpage's content. This tool actually
|
||||
fetches and reads the full page content.
|
||||
fetches and reads the full page content. For YouTube video URLs it
|
||||
fetches the transcript directly instead of crawling the page.
|
||||
|
||||
Common triggers:
|
||||
- "Read this article and summarize it"
|
||||
|
|
@ -114,6 +207,11 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
|||
url = f"https://{url}"
|
||||
|
||||
try:
|
||||
# Check if this is a YouTube URL and use transcript API instead
|
||||
video_id = get_youtube_video_id(url)
|
||||
if video_id:
|
||||
return await _scrape_youtube_video(url, video_id, max_length)
|
||||
|
||||
# Create webcrawler connector
|
||||
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
|
||||
|
||||
|
|
@ -184,7 +282,7 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
|||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
print(f"[scrape_webpage] Error scraping {url}: {error_message}")
|
||||
logger.error(f"[scrape_webpage] Error scraping {url}: {error_message}")
|
||||
return {
|
||||
"id": scrape_id,
|
||||
"assetId": url,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from app.agents.new_chat.checkpointer import (
|
|||
close_checkpointer,
|
||||
setup_checkpointer_tables,
|
||||
)
|
||||
from app.config import config, initialize_llm_router
|
||||
from app.config import config, initialize_image_gen_router, initialize_llm_router
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.routes import router as crud_router
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
|
|
@ -26,6 +26,8 @@ async def lifespan(app: FastAPI):
|
|||
await setup_checkpointer_tables()
|
||||
# Initialize LLM Router for Auto mode load balancing
|
||||
initialize_llm_router()
|
||||
# Initialize Image Generation Router for Auto mode load balancing
|
||||
initialize_image_gen_router()
|
||||
# Seed Surfsense documentation
|
||||
await seed_surfsense_docs()
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -13,14 +13,15 @@ load_dotenv()
|
|||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs):
|
||||
"""Initialize the LLM Router when a Celery worker process starts.
|
||||
"""Initialize the LLM Router and Image Gen Router when a Celery worker process starts.
|
||||
|
||||
This ensures the Auto mode (LiteLLM Router) is available for background tasks
|
||||
like document summarization.
|
||||
like document summarization and image generation.
|
||||
"""
|
||||
from app.config import initialize_llm_router
|
||||
from app.config import initialize_image_gen_router, initialize_llm_router
|
||||
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
|
||||
|
||||
# Get Celery configuration from environment
|
||||
|
|
|
|||
|
|
@ -81,6 +81,56 @@ def load_router_settings():
|
|||
return default_settings
|
||||
|
||||
|
||||
def load_global_image_gen_configs():
|
||||
"""
|
||||
Load global image generation configurations from YAML file.
|
||||
|
||||
Returns:
|
||||
list: List of global image generation config dictionaries, or empty list
|
||||
"""
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_image_generation_configs", [])
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global image generation configs: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def load_image_gen_router_settings():
|
||||
"""
|
||||
Load router settings for image generation Auto mode from YAML file.
|
||||
|
||||
Returns:
|
||||
dict: Router settings dictionary
|
||||
"""
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("image_generation_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load image generation router settings: {e}")
|
||||
return default_settings
|
||||
|
||||
|
||||
def initialize_llm_router():
|
||||
"""
|
||||
Initialize the LLM Router service for Auto mode.
|
||||
|
|
@ -105,6 +155,33 @@ def initialize_llm_router():
|
|||
print(f"Warning: Failed to initialize LLM Router: {e}")
|
||||
|
||||
|
||||
def initialize_image_gen_router():
|
||||
"""
|
||||
Initialize the Image Generation Router service for Auto mode.
|
||||
This should be called during application startup.
|
||||
"""
|
||||
image_gen_configs = load_global_image_gen_configs()
|
||||
router_settings = load_image_gen_router_settings()
|
||||
|
||||
if not image_gen_configs:
|
||||
print(
|
||||
"Info: No global image generation configs found, "
|
||||
"Image Generation Auto mode will not be available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.image_gen_router_service import ImageGenRouterService
|
||||
|
||||
ImageGenRouterService.initialize(image_gen_configs, router_settings)
|
||||
print(
|
||||
f"Info: Image Generation Router initialized with {len(image_gen_configs)} models "
|
||||
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize Image Generation Router: {e}")
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
|
|
@ -216,6 +293,12 @@ class Config:
|
|||
# Router settings for Auto mode (LiteLLM Router load balancing)
|
||||
ROUTER_SETTINGS = load_router_settings()
|
||||
|
||||
# Global Image Generation Configurations (optional)
|
||||
GLOBAL_IMAGE_GEN_CONFIGS = load_global_image_gen_configs()
|
||||
|
||||
# Router settings for Image Generation Auto mode
|
||||
IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings()
|
||||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
# Azure OpenAI credentials from environment variables
|
||||
|
|
@ -277,6 +360,14 @@ class Config:
|
|||
# LlamaCloud API Key
|
||||
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
|
||||
# Residential Proxy Configuration (anonymous-proxies.net)
|
||||
# Used for web crawling and YouTube transcript fetching to avoid IP bans.
|
||||
RESIDENTIAL_PROXY_USERNAME = os.getenv("RESIDENTIAL_PROXY_USERNAME")
|
||||
RESIDENTIAL_PROXY_PASSWORD = os.getenv("RESIDENTIAL_PROXY_PASSWORD")
|
||||
RESIDENTIAL_PROXY_HOSTNAME = os.getenv("RESIDENTIAL_PROXY_HOSTNAME")
|
||||
RESIDENTIAL_PROXY_LOCATION = os.getenv("RESIDENTIAL_PROXY_LOCATION", "")
|
||||
RESIDENTIAL_PROXY_TYPE = int(os.getenv("RESIDENTIAL_PROXY_TYPE", "1"))
|
||||
|
||||
# Litellm TTS Configuration
|
||||
TTS_SERVICE = os.getenv("TTS_SERVICE")
|
||||
TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE")
|
||||
|
|
|
|||
|
|
@ -183,6 +183,69 @@ global_llm_configs:
|
|||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Configuration
|
||||
# =============================================================================
|
||||
# These configurations power the image generation feature using litellm.aimage_generation().
|
||||
# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock,
|
||||
# Recraft, OpenRouter, Xinference, Nscale
|
||||
#
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs.
|
||||
|
||||
# Router Settings for Image Generation Auto Mode
|
||||
image_generation_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
global_image_generation_configs:
|
||||
# Example: OpenAI DALL-E 3
|
||||
- id: -1
|
||||
name: "Global DALL-E 3"
|
||||
description: "OpenAI's DALL-E 3 for high-quality image generation"
|
||||
provider: "OPENAI"
|
||||
model_name: "dall-e-3"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
|
||||
litellm_params: {}
|
||||
|
||||
# Example: OpenAI GPT Image 1
|
||||
- id: -2
|
||||
name: "Global GPT Image 1"
|
||||
description: "OpenAI's GPT Image 1 model"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-image-1"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50
|
||||
litellm_params: {}
|
||||
|
||||
# Example: Azure OpenAI DALL-E 3
|
||||
- id: -3
|
||||
name: "Global Azure DALL-E 3"
|
||||
description: "Azure-hosted DALL-E 3 deployment"
|
||||
provider: "AZURE_OPENAI"
|
||||
model_name: "azure/dall-e-3-deployment"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview"
|
||||
rpm: 50
|
||||
litellm_params:
|
||||
base_model: "dall-e-3"
|
||||
|
||||
# Example: OpenRouter Gemini Image Generation
|
||||
# - id: -4
|
||||
# name: "Global Gemini Image Gen"
|
||||
# description: "Google Gemini image generation via OpenRouter"
|
||||
# provider: "OPENROUTER"
|
||||
# model_name: "google/gemini-2.5-flash-image"
|
||||
# api_key: "your-openrouter-api-key-here"
|
||||
# api_base: ""
|
||||
# rpm: 30
|
||||
# litellm_params: {}
|
||||
|
||||
# Notes:
|
||||
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
|
||||
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
|
||||
|
|
@ -195,10 +258,11 @@ global_llm_configs:
|
|||
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
|
||||
# These help the router distribute load evenly and avoid rate limit errors
|
||||
#
|
||||
# AZURE-SPECIFIC NOTES:
|
||||
# - Always add 'base_model' in litellm_params for Azure deployments
|
||||
# - This fixes "Could not identify azure model 'X'" warnings
|
||||
# - base_model should match the underlying OpenAI model (e.g., gpt-4o, gpt-4-turbo, gpt-3.5-turbo)
|
||||
# - model_name format: "azure/<your-deployment-name>"
|
||||
# - api_version: Use a recent Azure API version (e.g., "2024-02-15-preview")
|
||||
# - See: https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
|
||||
#
|
||||
# IMAGE GENERATION NOTES:
|
||||
# - Image generation configs use the same ID scheme as LLM configs (negative for global)
|
||||
# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure),
|
||||
# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter)
|
||||
# - The router uses litellm.aimage_generation() for async image generation
|
||||
# - Only RPM (requests per minute) is relevant for image generation rate limiting.
|
||||
# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token.
|
||||
|
|
|
|||
|
|
@ -108,7 +108,9 @@ class AirtableHistoryConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Airtable access token is invalid or empty. "
|
||||
"Please reconnect your Airtable account."
|
||||
|
|
|
|||
|
|
@ -128,7 +128,9 @@ class ConfluenceHistoryConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Confluence access token is invalid or empty. "
|
||||
"Please reconnect your Confluence account."
|
||||
|
|
|
|||
|
|
@ -129,7 +129,9 @@ class JiraHistoryConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Jira access token is invalid or empty. "
|
||||
"Please reconnect your Jira account."
|
||||
|
|
|
|||
|
|
@ -153,7 +153,9 @@ class LinearConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Linear access token is invalid or empty. "
|
||||
"Please reconnect your Linear account."
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from fake_useragent import UserAgent
|
|||
from firecrawl import AsyncFirecrawlApp
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from app.utils.proxy_config import get_playwright_proxy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -165,9 +167,15 @@ class WebCrawlerConnector:
|
|||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
|
||||
# Use residential proxy if configured
|
||||
playwright_proxy = get_playwright_proxy()
|
||||
|
||||
# Use Playwright to fetch the page
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch(headless=True)
|
||||
launch_kwargs: dict = {"headless": True}
|
||||
if playwright_proxy:
|
||||
launch_kwargs["proxy"] = playwright_proxy
|
||||
browser = await p.chromium.launch(**launch_kwargs)
|
||||
context = await browser.new_context(user_agent=user_agent)
|
||||
page = await context.new_page()
|
||||
|
||||
|
|
|
|||
|
|
@ -214,6 +214,24 @@ class LiteLLMProvider(str, Enum):
|
|||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class ImageGenProvider(str, Enum):
|
||||
"""
|
||||
Enum for image generation providers supported by LiteLLM.
|
||||
This is a subset of LLM providers — only those that support image generation.
|
||||
See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
GOOGLE = "GOOGLE" # Google AI Studio
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK" # AWS Bedrock
|
||||
RECRAFT = "RECRAFT"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
XINFERENCE = "XINFERENCE"
|
||||
NSCALE = "NSCALE"
|
||||
|
||||
|
||||
class LogLevel(str, Enum):
|
||||
DEBUG = "DEBUG"
|
||||
INFO = "INFO"
|
||||
|
|
@ -314,6 +332,11 @@ class Permission(str, Enum):
|
|||
PODCASTS_UPDATE = "podcasts:update"
|
||||
PODCASTS_DELETE = "podcasts:delete"
|
||||
|
||||
# Image Generations
|
||||
IMAGE_GENERATIONS_CREATE = "image_generations:create"
|
||||
IMAGE_GENERATIONS_READ = "image_generations:read"
|
||||
IMAGE_GENERATIONS_DELETE = "image_generations:delete"
|
||||
|
||||
# Connectors
|
||||
CONNECTORS_CREATE = "connectors:create"
|
||||
CONNECTORS_READ = "connectors:read"
|
||||
|
|
@ -375,6 +398,9 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.PODCASTS_CREATE.value,
|
||||
Permission.PODCASTS_READ.value,
|
||||
Permission.PODCASTS_UPDATE.value,
|
||||
# Image Generations (create and read, no delete)
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
# Connectors (no delete)
|
||||
Permission.CONNECTORS_CREATE.value,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
|
|
@ -404,6 +430,8 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.LLM_CONFIGS_READ.value,
|
||||
# Podcasts (read only)
|
||||
Permission.PODCASTS_READ.value,
|
||||
# Image Generations (read only)
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
# Connectors (read only)
|
||||
Permission.CONNECTORS_READ.value,
|
||||
# Logs (read only)
|
||||
|
|
@ -969,6 +997,103 @@ class Podcast(BaseModel, TimestampMixin):
|
|||
thread = relationship("NewChatThread")
|
||||
|
||||
|
||||
class ImageGenerationConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Dedicated configuration table for image generation models.
|
||||
|
||||
Separate from NewLLMConfig because image generation models don't need
|
||||
system_instructions, citations_enabled, or use_default_system_instructions.
|
||||
They only need provider credentials and model parameters.
|
||||
"""
|
||||
|
||||
__tablename__ = "image_generation_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# Provider & model (uses ImageGenProvider, NOT LiteLLMProvider)
|
||||
provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
# Credentials
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True) # Azure-specific
|
||||
|
||||
# Additional litellm parameters
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# Relationships
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship(
|
||||
"SearchSpace", back_populates="image_generation_configs"
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneration(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Stores image generation requests and results using litellm.aimage_generation().
|
||||
|
||||
Since aimage_generation is a single async call (not a background job),
|
||||
there is no status enum. A row with response_data means success;
|
||||
a row with error_message means failure.
|
||||
|
||||
Response data is stored as JSONB matching the litellm output format:
|
||||
{
|
||||
"created": int,
|
||||
"data": [{"b64_json": str|None, "revised_prompt": str|None, "url": str|None}],
|
||||
"usage": {"prompt_tokens": int, "completion_tokens": int, "total_tokens": int}
|
||||
}
|
||||
"""
|
||||
|
||||
__tablename__ = "image_generations"
|
||||
|
||||
# Request parameters (matching litellm.aimage_generation() params)
|
||||
prompt = Column(Text, nullable=False)
|
||||
model = Column(String(200), nullable=True) # e.g., "dall-e-3", "gpt-image-1"
|
||||
n = Column(Integer, nullable=True, default=1)
|
||||
quality = Column(
|
||||
String(50), nullable=True
|
||||
) # "auto", "high", "medium", "low", "hd", "standard"
|
||||
size = Column(
|
||||
String(50), nullable=True
|
||||
) # "1024x1024", "1536x1024", "1024x1536", etc.
|
||||
style = Column(String(50), nullable=True) # Model-specific style parameter
|
||||
response_format = Column(String(50), nullable=True) # "url" or "b64_json"
|
||||
|
||||
# Image generation config reference
|
||||
# 0 = Auto mode (router), negative IDs = global configs from YAML,
|
||||
# positive IDs = ImageGenerationConfig records in DB
|
||||
image_generation_config_id = Column(Integer, nullable=True)
|
||||
|
||||
# Response data (full litellm response as JSONB) — present on success
|
||||
response_data = Column(JSONB, nullable=True)
|
||||
# Error message — present on failure
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Signed access token for serving images via <img> tags.
|
||||
# Stored in DB so it survives SECRET_KEY rotation.
|
||||
access_token = Column(String(64), nullable=True, index=True)
|
||||
|
||||
# Foreign keys
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
created_by_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
search_space = relationship("SearchSpace", back_populates="image_generations")
|
||||
created_by = relationship("User", back_populates="image_generations")
|
||||
|
||||
|
||||
class SearchSpace(BaseModel, TimestampMixin):
|
||||
__tablename__ = "searchspaces"
|
||||
|
||||
|
|
@ -993,6 +1118,9 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
document_summary_llm_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For document summarization, defaults to Auto mode
|
||||
image_generation_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For image generation, defaults to Auto mode
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
|
|
@ -1017,6 +1145,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="Podcast.id.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
image_generations = relationship(
|
||||
"ImageGeneration",
|
||||
back_populates="search_space",
|
||||
order_by="ImageGeneration.id.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
logs = relationship(
|
||||
"Log",
|
||||
back_populates="search_space",
|
||||
|
|
@ -1041,6 +1175,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="NewLLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="search_space",
|
||||
order_by="ImageGenerationConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# RBAC relationships
|
||||
roles = relationship(
|
||||
|
|
@ -1421,6 +1561,13 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generations created by this user
|
||||
image_generations = relationship(
|
||||
"ImageGeneration",
|
||||
back_populates="created_by",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
|
|
@ -1493,6 +1640,13 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generations created by this user
|
||||
image_generations = relationship(
|
||||
"ImageGeneration",
|
||||
back_populates="created_by",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .google_drive_add_connector_route import (
|
|||
from .google_gmail_add_connector_route import (
|
||||
router as google_gmail_add_connector_router,
|
||||
)
|
||||
from .image_generation_routes import router as image_generation_router
|
||||
from .incentive_tasks_routes import router as incentive_tasks_router
|
||||
from .jira_add_connector_route import router as jira_add_connector_router
|
||||
from .linear_add_connector_route import router as linear_add_connector_router
|
||||
|
|
@ -49,6 +50,7 @@ router.include_router(notes_router)
|
|||
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||
router.include_router(chat_comments_router)
|
||||
router.include_router(podcasts_router) # Podcast task status and audio
|
||||
router.include_router(image_generation_router) # Image generation via litellm
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(google_calendar_add_connector_router)
|
||||
router.include_router(google_gmail_add_connector_router)
|
||||
|
|
|
|||
710
surfsense_backend/app/routes/image_generation_routes.py
Normal file
710
surfsense_backend/app/routes/image_generation_routes.py
Normal file
|
|
@ -0,0 +1,710 @@
|
|||
"""
|
||||
Image Generation routes:
|
||||
- CRUD for ImageGenerationConfig (user-created image model configs)
|
||||
- Global image gen configs endpoint (from YAML)
|
||||
- Image generation execution (calls litellm.aimage_generation())
|
||||
- CRUD for ImageGeneration records (results)
|
||||
- Image serving endpoint (serves b64_json images from DB, protected by signed tokens)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
GlobalImageGenConfigRead,
|
||||
ImageGenerationConfigCreate,
|
||||
ImageGenerationConfigRead,
|
||||
ImageGenerationConfigUpdate,
|
||||
ImageGenerationCreate,
|
||||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping for building litellm model strings.
|
||||
# Only includes providers that support image generation.
|
||||
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image generation configuration by ID (negative IDs)."""
|
||||
if config_id == IMAGE_GEN_AUTO_MODE_ID:
|
||||
return {
|
||||
"id": IMAGE_GEN_AUTO_MODE_ID,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
if config_id > 0:
|
||||
return None
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
"""Build a litellm model string from provider + model_name."""
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
async def _execute_image_generation(
|
||||
session: AsyncSession,
|
||||
image_gen: ImageGeneration,
|
||||
search_space: SearchSpace,
|
||||
) -> None:
|
||||
"""
|
||||
Call litellm.aimage_generation() with the appropriate config.
|
||||
|
||||
Resolution order:
|
||||
1. Explicit image_generation_config_id on the request
|
||||
2. Search space's image_generation_config_id preference
|
||||
3. Falls back to Auto mode if available
|
||||
"""
|
||||
config_id = image_gen.image_generation_config_id
|
||||
if config_id is None:
|
||||
config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
image_gen.image_generation_config_id = config_id
|
||||
|
||||
# Build kwargs
|
||||
gen_kwargs = {}
|
||||
if image_gen.n is not None:
|
||||
gen_kwargs["n"] = image_gen.n
|
||||
if image_gen.quality is not None:
|
||||
gen_kwargs["quality"] = image_gen.quality
|
||||
if image_gen.size is not None:
|
||||
gen_kwargs["size"] = image_gen.size
|
||||
if image_gen.style is not None:
|
||||
gen_kwargs["style"] = image_gen.style
|
||||
if image_gen.response_format is not None:
|
||||
gen_kwargs["response_format"] = image_gen.response_format
|
||||
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
raise ValueError(
|
||||
"Auto mode requested but Image Generation Router not initialized. "
|
||||
"Ensure global_llm_config.yaml has global_image_generation_configs."
|
||||
)
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=image_gen.prompt, model="auto", **gen_kwargs
|
||||
)
|
||||
elif config_id < 0:
|
||||
# Global config from YAML
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
raise ValueError(f"Global image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
|
||||
)
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
# User model override
|
||||
if image_gen.model:
|
||||
model_string = image_gen.model
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=image_gen.prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = DB ImageGenerationConfig
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_cfg = result.scalars().first()
|
||||
if not db_cfg:
|
||||
raise ValueError(f"Image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
|
||||
)
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
|
||||
# User model override
|
||||
if image_gen.model:
|
||||
model_string = image_gen.model
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=image_gen.prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
|
||||
# Store response
|
||||
image_gen.response_data = (
|
||||
response.model_dump() if hasattr(response, "model_dump") else dict(response)
|
||||
)
|
||||
if not image_gen.model and hasattr(response, "_hidden_params"):
|
||||
hidden = response._hidden_params
|
||||
if isinstance(hidden, dict) and hidden.get("model"):
|
||||
image_gen.model = hidden["model"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Image Generation Configs (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/global-image-generation-configs",
|
||||
response_model=list[GlobalImageGenConfigRead],
|
||||
)
|
||||
async def get_global_image_gen_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get all global image generation configs. API keys are hidden."""
|
||||
try:
|
||||
global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
safe_configs = []
|
||||
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"description": "Automatically routes across available image generation providers.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
}
|
||||
)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global image generation configs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ImageGenerationConfig CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead)
|
||||
async def create_image_gen_config(
|
||||
config_data: ImageGenerationConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a new image generation config for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generation configs in this search space",
|
||||
)
|
||||
|
||||
db_config = ImageGenerationConfig(**config_data.model_dump())
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
|
||||
async def list_image_gen_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List image generation configs for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig)
|
||||
.filter(ImageGenerationConfig.search_space_id == search_space_id)
|
||||
.order_by(ImageGenerationConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list ImageGenerationConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def get_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific image generation config by ID."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def update_image_gen_config(
|
||||
config_id: int,
|
||||
update_data: ImageGenerationConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update an existing image generation config."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to update image generation configs in this search space",
|
||||
)
|
||||
|
||||
for key, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(db_config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/image-generation-configs/{config_id}", response_model=dict)
|
||||
async def delete_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete an image generation config."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generation configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {
|
||||
"message": "Image generation config deleted successfully",
|
||||
"id": config_id,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Execution + Results CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/image-generations", response_model=ImageGenerationRead)
|
||||
async def create_image_generation(
|
||||
data: ImageGenerationCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create and execute an image generation request."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generations in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == data.search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error during image generation"
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create image generation")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Image generation failed: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generations", response_model=list[ImageGenerationListRead])
|
||||
async def list_image_generations(
|
||||
search_space_id: int | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List image generations."""
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
|
||||
try:
|
||||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to read image generations in this search space",
|
||||
)
|
||||
result = await session.execute(
|
||||
select(ImageGeneration)
|
||||
.filter(ImageGeneration.search_space_id == search_space_id)
|
||||
.order_by(ImageGeneration.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(ImageGeneration)
|
||||
.join(SearchSpace)
|
||||
.join(SearchSpaceMembership)
|
||||
.filter(SearchSpaceMembership.user_id == user.id)
|
||||
.order_by(ImageGeneration.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
return [
|
||||
ImageGenerationListRead.from_orm_with_count(img)
|
||||
for img in result.scalars().all()
|
||||
]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error fetching image generations"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead)
|
||||
async def get_image_generation(
|
||||
image_gen_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific image generation by ID."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
|
||||
)
|
||||
image_gen = result.scalars().first()
|
||||
if not image_gen:
|
||||
raise HTTPException(status_code=404, detail="Image generation not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
image_gen.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to read image generations in this search space",
|
||||
)
|
||||
return image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error fetching image generation"
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/image-generations/{image_gen_id}", response_model=dict)
|
||||
async def delete_image_generation(
|
||||
image_gen_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete an image generation record."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
|
||||
)
|
||||
db_image_gen = result.scalars().first()
|
||||
if not db_image_gen:
|
||||
raise HTTPException(status_code=404, detail="Image generation not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_image_gen.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generations in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_image_gen)
|
||||
await session.commit()
|
||||
return {"message": "Image generation deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error deleting image generation"
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image Serving (serves generated images from DB, protected by signed tokens)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/image-generations/{image_gen_id}/image")
|
||||
async def serve_generated_image(
|
||||
image_gen_id: int,
|
||||
token: str = Query(..., description="Signed access token"),
|
||||
index: int = 0,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Serve a generated image by ID, protected by a signed token.
|
||||
|
||||
The token is generated when the image URL is created by the generate_image
|
||||
tool and encodes the image_gen_id, search_space_id, and an expiry timestamp.
|
||||
This ensures only users with access to the search space can view images,
|
||||
without requiring auth headers (which <img> tags cannot pass).
|
||||
|
||||
Args:
|
||||
image_gen_id: The image generation record ID
|
||||
token: HMAC-signed access token (included as query parameter)
|
||||
index: Which image to serve if multiple were generated (default: 0)
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
|
||||
)
|
||||
image_gen = result.scalars().first()
|
||||
if not image_gen:
|
||||
raise HTTPException(status_code=404, detail="Image generation not found")
|
||||
|
||||
# Verify the access token against the one stored on the record
|
||||
if not verify_image_token(image_gen.access_token, token):
|
||||
raise HTTPException(status_code=403, detail="Invalid image access token")
|
||||
|
||||
if not image_gen.response_data:
|
||||
raise HTTPException(status_code=404, detail="No image data available")
|
||||
|
||||
images = image_gen.response_data.get("data", [])
|
||||
if not images or index >= len(images):
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Image not found at the specified index"
|
||||
)
|
||||
|
||||
image_entry = images[index]
|
||||
|
||||
# If there's a URL, redirect to it
|
||||
if image_entry.get("url"):
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
return RedirectResponse(url=image_entry["url"])
|
||||
|
||||
# If there's b64_json data, decode and serve it
|
||||
if image_entry.get("b64_json"):
|
||||
image_bytes = base64.b64decode(image_entry["b64_json"])
|
||||
return Response(
|
||||
content=image_bytes,
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Cache-Control": "public, max-age=86400",
|
||||
"Content-Disposition": f'inline; filename="generated-{image_gen_id}-{index}.png"',
|
||||
},
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=404, detail="No displayable image data")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to serve generated image")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to serve image: {e!s}"
|
||||
) from e
|
||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGenerationConfig,
|
||||
NewLLMConfig,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
|
|
@ -387,6 +388,69 @@ async def _get_llm_config_by_id(
|
|||
return None
|
||||
|
||||
|
||||
async def _get_image_gen_config_by_id(
|
||||
session: AsyncSession, config_id: int | None
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get an image generation config by ID as a dictionary.
|
||||
Returns Auto mode for ID 0, global config for negative IDs,
|
||||
DB ImageGenerationConfig for positive IDs, or None.
|
||||
"""
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"description": "Automatically routes requests across available image generation providers",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
}
|
||||
return None
|
||||
|
||||
# Positive ID: query ImageGenerationConfig table
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if db_config:
|
||||
return {
|
||||
"id": db_config.id,
|
||||
"name": db_config.name,
|
||||
"description": db_config.description,
|
||||
"provider": db_config.provider.value if db_config.provider else None,
|
||||
"custom_provider": db_config.custom_provider,
|
||||
"model_name": db_config.model_name,
|
||||
"api_base": db_config.api_base,
|
||||
"api_version": db_config.api_version,
|
||||
"litellm_params": db_config.litellm_params or {},
|
||||
"created_at": db_config.created_at.isoformat()
|
||||
if db_config.created_at
|
||||
else None,
|
||||
"search_space_id": db_config.search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/llm-preferences",
|
||||
response_model=LLMPreferencesRead,
|
||||
|
|
@ -423,12 +487,17 @@ async def get_llm_preferences(
|
|||
document_summary_llm = await _get_llm_config_by_id(
|
||||
session, search_space.document_summary_llm_id
|
||||
)
|
||||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -485,12 +554,17 @@ async def update_llm_preferences(
|
|||
document_summary_llm = await _get_llm_config_by_id(
|
||||
session, search_space.document_summary_llm_id
|
||||
)
|
||||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -21,6 +21,16 @@ from .documents import (
|
|||
PaginatedResponse,
|
||||
)
|
||||
from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest
|
||||
from .image_generation import (
|
||||
GlobalImageGenConfigRead,
|
||||
ImageGenerationConfigCreate,
|
||||
ImageGenerationConfigPublic,
|
||||
ImageGenerationConfigRead,
|
||||
ImageGenerationConfigUpdate,
|
||||
ImageGenerationCreate,
|
||||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||
from .new_chat import (
|
||||
ChatMessage,
|
||||
|
|
@ -105,11 +115,21 @@ __all__ = [
|
|||
"DriveItem",
|
||||
"ExtensionDocumentContent",
|
||||
"ExtensionDocumentMetadata",
|
||||
"GlobalImageGenConfigRead",
|
||||
"GlobalNewLLMConfigRead",
|
||||
"GoogleDriveIndexRequest",
|
||||
"GoogleDriveIndexingOptions",
|
||||
# Base schemas
|
||||
"IDModel",
|
||||
# Image Generation Config schemas
|
||||
"ImageGenerationConfigCreate",
|
||||
"ImageGenerationConfigPublic",
|
||||
"ImageGenerationConfigRead",
|
||||
"ImageGenerationConfigUpdate",
|
||||
# Image Generation schemas
|
||||
"ImageGenerationCreate",
|
||||
"ImageGenerationListRead",
|
||||
"ImageGenerationRead",
|
||||
# RBAC schemas
|
||||
"InviteAcceptRequest",
|
||||
"InviteAcceptResponse",
|
||||
|
|
|
|||
230
surfsense_backend/app/schemas/image_generation.py
Normal file
230
surfsense_backend/app/schemas/image_generation.py
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
"""
|
||||
Pydantic schemas for Image Generation configs and generation requests.
|
||||
|
||||
ImageGenerationConfig: CRUD schemas for user-created image gen model configs.
|
||||
ImageGeneration: Schemas for the actual image generation requests/results.
|
||||
GlobalImageGenConfigRead: Schema for admin-configured YAML configs.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import ImageGenProvider
|
||||
|
||||
# =============================================================================
|
||||
# ImageGenerationConfig CRUD Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ImageGenerationConfigBase(BaseModel):
|
||||
"""Base schema with fields for ImageGenerationConfig."""
|
||||
|
||||
name: str = Field(
|
||||
..., max_length=100, description="User-friendly name for the config"
|
||||
)
|
||||
description: str | None = Field(
|
||||
None, max_length=500, description="Optional description"
|
||||
)
|
||||
provider: ImageGenProvider = Field(
|
||||
...,
|
||||
description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)",
|
||||
)
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name"
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)"
|
||||
)
|
||||
api_key: str = Field(..., description="API key for the provider")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
None,
|
||||
max_length=50,
|
||||
description="Azure-specific API version (e.g., '2024-02-15-preview')",
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfigCreate(ImageGenerationConfigBase):
|
||||
"""Schema for creating a new ImageGenerationConfig."""
|
||||
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the config with"
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfigUpdate(BaseModel):
|
||||
"""Schema for updating an existing ImageGenerationConfig. All fields optional."""
|
||||
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: ImageGenProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
api_key: str | None = None
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ImageGenerationConfigRead(ImageGenerationConfigBase):
|
||||
"""Schema for reading an ImageGenerationConfig (includes id and timestamps)."""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ImageGenerationConfigPublic(BaseModel):
|
||||
"""Public schema that hides the API key (for list views)."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: ImageGenProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ImageGeneration (request/result) Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ImageGenerationCreate(BaseModel):
|
||||
"""Schema for creating an image generation request."""
|
||||
|
||||
prompt: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=4000,
|
||||
description="A text description of the desired image(s)",
|
||||
)
|
||||
model: str | None = Field(
|
||||
None,
|
||||
max_length=200,
|
||||
description="The model to use (e.g., 'dall-e-3', 'gpt-image-1'). Overrides the config model.",
|
||||
)
|
||||
n: int | None = Field(
|
||||
None,
|
||||
ge=1,
|
||||
le=10,
|
||||
description="Number of images to generate (1-10).",
|
||||
)
|
||||
quality: str | None = Field(None, max_length=50)
|
||||
size: str | None = Field(None, max_length=50)
|
||||
style: str | None = Field(None, max_length=50)
|
||||
response_format: str | None = Field(None, max_length=50)
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the generation with"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Image generation config ID. "
|
||||
"0 = Auto mode (router), negative = global YAML config, positive = DB config. "
|
||||
"If not provided, uses the search space's image_generation_config_id preference."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationRead(BaseModel):
|
||||
"""Schema for reading an image generation record."""
|
||||
|
||||
id: int
|
||||
prompt: str
|
||||
model: str | None = None
|
||||
n: int | None = None
|
||||
quality: str | None = None
|
||||
size: str | None = None
|
||||
style: str | None = None
|
||||
response_format: str | None = None
|
||||
image_generation_config_id: int | None = None
|
||||
response_data: dict[str, Any] | None = None
|
||||
error_message: str | None = None
|
||||
search_space_id: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ImageGenerationListRead(BaseModel):
|
||||
"""Lightweight schema for listing image generations (without full response_data)."""
|
||||
|
||||
id: int
|
||||
prompt: str
|
||||
model: str | None = None
|
||||
n: int | None = None
|
||||
quality: str | None = None
|
||||
size: str | None = None
|
||||
search_space_id: int
|
||||
created_at: datetime
|
||||
is_success: bool
|
||||
image_count: int | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@classmethod
|
||||
def from_orm_with_count(cls, obj: Any) -> "ImageGenerationListRead":
|
||||
"""Create ImageGenerationListRead with computed fields."""
|
||||
image_count = None
|
||||
if obj.response_data and isinstance(obj.response_data, dict):
|
||||
data = obj.response_data.get("data")
|
||||
if isinstance(data, list):
|
||||
image_count = len(data)
|
||||
|
||||
return cls(
|
||||
id=obj.id,
|
||||
prompt=obj.prompt,
|
||||
model=obj.model,
|
||||
n=obj.n,
|
||||
quality=obj.quality,
|
||||
size=obj.size,
|
||||
search_space_id=obj.search_space_id,
|
||||
created_at=obj.created_at,
|
||||
is_success=obj.response_data is not None,
|
||||
image_count=image_count,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Image Gen Config (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GlobalImageGenConfigRead(BaseModel):
|
||||
"""
|
||||
Schema for reading global image generation configs from YAML.
|
||||
Global configs have negative IDs. API key is hidden.
|
||||
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
|
||||
"""
|
||||
|
||||
id: int = Field(
|
||||
...,
|
||||
description="Config ID: 0 for Auto mode, negative for global configs",
|
||||
)
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
|
|
@ -176,12 +176,18 @@ class LLMPreferencesRead(BaseModel):
|
|||
document_summary_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for document summarization"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
agent_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for agent LLM"
|
||||
)
|
||||
document_summary_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for document summary LLM"
|
||||
)
|
||||
image_generation_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for image generation"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -195,3 +201,6 @@ class LLMPreferencesUpdate(BaseModel):
|
|||
document_summary_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for document summarization"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
|
|
|
|||
278
surfsense_backend/app/services/image_gen_router_service.py
Normal file
278
surfsense_backend/app/services/image_gen_router_service.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""
|
||||
Image Generation Router Service for Load Balancing
|
||||
|
||||
This module provides a singleton LiteLLM Router for automatic load balancing
|
||||
across multiple image generation deployments. It uses litellm.Router which
|
||||
natively supports aimage_generation() for async image generation.
|
||||
|
||||
The router handles:
|
||||
- Rate limit management with automatic cooldowns
|
||||
- Automatic failover and retries
|
||||
- Usage-based routing to distribute load evenly
|
||||
|
||||
Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI,
|
||||
AWS Bedrock, Recraft, OpenRouter, Xinference, Nscale.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from litellm import Router
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
IMAGE_GEN_AUTO_MODE_ID = 0
|
||||
|
||||
# Provider mapping for LiteLLM model string construction.
|
||||
# Only includes providers that support image generation.
|
||||
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
IMAGE_GEN_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
class ImageGenRouterService:
|
||||
"""
|
||||
Singleton service for managing LiteLLM Router for image generation.
|
||||
|
||||
The router provides automatic load balancing, failover, and rate limit
|
||||
handling across multiple image generation deployments.
|
||||
Uses Router.aimage_generation() for async image generation calls.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_router: Router | None = None
|
||||
_model_list: list[dict] = []
|
||||
_router_settings: dict = {}
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ImageGenRouterService":
|
||||
"""Get the singleton instance of the router service."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(
|
||||
cls,
|
||||
global_configs: list[dict],
|
||||
router_settings: dict | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the router with global image generation configurations.
|
||||
|
||||
Args:
|
||||
global_configs: List of global image gen config dictionaries from YAML
|
||||
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
|
||||
if instance._initialized:
|
||||
logger.debug("Image Generation Router already initialized, skipping")
|
||||
return
|
||||
|
||||
# Build model list from global configs
|
||||
model_list = []
|
||||
for config in global_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
|
||||
if not model_list:
|
||||
logger.warning(
|
||||
"No valid image generation configs found for router initialization"
|
||||
)
|
||||
return
|
||||
|
||||
instance._model_list = model_list
|
||||
instance._router_settings = router_settings or {}
|
||||
|
||||
# Default router settings optimized for rate limit handling
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
"retry_after": 5,
|
||||
}
|
||||
|
||||
# Merge with provided settings
|
||||
final_settings = {**default_settings, **instance._router_settings}
|
||||
|
||||
try:
|
||||
instance._router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=final_settings.get(
|
||||
"routing_strategy", "usage-based-routing"
|
||||
),
|
||||
num_retries=final_settings.get("num_retries", 3),
|
||||
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||
set_verbose=False,
|
||||
)
|
||||
instance._initialized = True
|
||||
logger.info(
|
||||
f"Image Generation Router initialized with {len(model_list)} deployments, "
|
||||
f"strategy: {final_settings.get('routing_strategy')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Image Generation Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||
"""
|
||||
Convert a global image gen config to a router deployment entry.
|
||||
|
||||
Args:
|
||||
config: Global image gen config dictionary
|
||||
|
||||
Returns:
|
||||
Router deployment dictionary or None if invalid
|
||||
"""
|
||||
try:
|
||||
# Skip if essential fields are missing
|
||||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
# Build model string
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Add optional api_base
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
|
||||
# Add api_version (required for Azure)
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
|
||||
# All configs use same alias "auto" for unified routing
|
||||
deployment: dict[str, Any] = {
|
||||
"model_name": "auto",
|
||||
"litellm_params": litellm_params,
|
||||
}
|
||||
|
||||
# Add RPM rate limit from config if available
|
||||
# Note: TPM (tokens per minute) is not applicable for image generation
|
||||
# since image APIs are rate-limited by requests, not tokens.
|
||||
if config.get("rpm"):
|
||||
deployment["rpm"] = config["rpm"]
|
||||
|
||||
return deployment
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert image gen config to deployment: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_router(cls) -> Router | None:
|
||||
"""Get the initialized router instance."""
|
||||
instance = cls.get_instance()
|
||||
return instance._router
|
||||
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
"""Check if the router has been initialized."""
|
||||
instance = cls.get_instance()
|
||||
return instance._initialized and instance._router is not None
|
||||
|
||||
@classmethod
|
||||
def get_model_count(cls) -> int:
|
||||
"""Get the number of models in the router."""
|
||||
instance = cls.get_instance()
|
||||
return len(instance._model_list)
|
||||
|
||||
@classmethod
|
||||
async def aimage_generation(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: str = "auto",
|
||||
n: int | None = None,
|
||||
timeout: int = 600,
|
||||
**kwargs,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Generate images using the router for load balancing.
|
||||
|
||||
Uses Router.aimage_generation() which distributes requests
|
||||
across configured image generation deployments.
|
||||
|
||||
Parameters like size, quality, style, and response_format are intentionally
|
||||
omitted to keep the interface model-agnostic. Providers use their own
|
||||
sensible defaults. If needed, pass them via **kwargs.
|
||||
|
||||
Args:
|
||||
prompt: Text description of the desired image(s)
|
||||
model: Model alias (default "auto" for router routing)
|
||||
n: Number of images to generate
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional provider-specific params (size, quality, etc.)
|
||||
|
||||
Returns:
|
||||
ImageResponse from litellm
|
||||
|
||||
Raises:
|
||||
ValueError: If router is not initialized
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
if not instance._router:
|
||||
raise ValueError(
|
||||
"Image Generation Router not initialized. "
|
||||
"Ensure global_llm_config.yaml has global_image_generation_configs."
|
||||
)
|
||||
|
||||
# Build kwargs for aimage_generation
|
||||
gen_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if n is not None:
|
||||
gen_kwargs["n"] = n
|
||||
gen_kwargs.update(kwargs)
|
||||
|
||||
return await instance._router.aimage_generation(**gen_kwargs)
|
||||
|
||||
|
||||
def is_image_gen_auto_mode(config_id: int | None) -> bool:
|
||||
"""
|
||||
Check if the given config ID represents Image Generation Auto mode.
|
||||
|
||||
Args:
|
||||
config_id: The config ID to check
|
||||
|
||||
Returns:
|
||||
True if this is Auto mode, False otherwise
|
||||
"""
|
||||
return config_id == IMAGE_GEN_AUTO_MODE_ID
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
"""Celery tasks for connector indexing."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
|
@ -11,6 +12,36 @@ from app.config import config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> None:
|
||||
"""
|
||||
Handle greenlet_spawn errors with detailed logging for debugging.
|
||||
|
||||
The 'greenlet_spawn has not been called' error occurs when:
|
||||
1. SQLAlchemy lazy-loads a relationship outside of an async context
|
||||
2. A sync operation is called from an async context (or vice versa)
|
||||
3. Session objects are accessed after the session is closed
|
||||
|
||||
This helper logs detailed context to help identify the root cause.
|
||||
"""
|
||||
error_str = str(e)
|
||||
if "greenlet_spawn has not been called" in error_str:
|
||||
logger.error(
|
||||
f"GREENLET ERROR in {task_name} for connector {connector_id}: {error_str}\n"
|
||||
f"This error typically occurs when SQLAlchemy tries to lazy-load a relationship "
|
||||
f"outside of an async context. Check for:\n"
|
||||
f"1. Accessing relationship attributes (e.g., document.chunks, connector.search_space) "
|
||||
f"without using selectinload() or joinedload()\n"
|
||||
f"2. Accessing model attributes after the session is closed\n"
|
||||
f"3. Passing ORM objects between different async contexts\n"
|
||||
f"Stack trace:\n{traceback.format_exc()}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Error in {task_name} for connector {connector_id}: {error_str}\n"
|
||||
f"Stack trace:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""
|
||||
Create a new async session maker for Celery tasks.
|
||||
|
|
@ -46,6 +77,9 @@ def index_slack_messages_task(
|
|||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_slack_messages", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
|
@ -89,6 +123,9 @@ def index_notion_pages_task(
|
|||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_notion_pages", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
|
@ -347,6 +384,9 @@ def index_google_calendar_events_task(
|
|||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
|
@ -696,6 +736,9 @@ def index_crawled_urls_task(
|
|||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -27,12 +27,12 @@ from app.agents.new_chat.llm_config import (
|
|||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.db import Document, SurfsenseDocsDocument
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.schemas.new_chat import ChatAttachment
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
|
|
@ -1211,9 +1211,10 @@ async def stream_new_chat(
|
|||
|
||||
# Generate LLM title for new chats after first response
|
||||
# Check if this is the first assistant response by counting existing assistant messages
|
||||
from app.db import NewChatMessage, NewChatThread
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db import NewChatMessage, NewChatThread
|
||||
|
||||
assistant_count_result = await session.execute(
|
||||
select(func.count(NewChatMessage.id)).filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
|
|
@ -1231,10 +1232,12 @@ async def stream_new_chat(
|
|||
# Truncate inputs to avoid context length issues
|
||||
truncated_query = user_query[:500]
|
||||
truncated_response = accumulated_text[:1000]
|
||||
title_result = await title_chain.ainvoke({
|
||||
"user_query": truncated_query,
|
||||
"assistant_response": truncated_response,
|
||||
})
|
||||
title_result = await title_chain.ainvoke(
|
||||
{
|
||||
"user_query": truncated_query,
|
||||
"assistant_response": truncated_response,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract and clean the title
|
||||
if title_result and hasattr(title_result, "content"):
|
||||
|
|
@ -1242,7 +1245,7 @@ async def stream_new_chat(
|
|||
# Validate the title (reasonable length)
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
# Remove any quotes or extra formatting
|
||||
generated_title = raw_title.strip('"\'')
|
||||
generated_title = raw_title.strip("\"'")
|
||||
except Exception:
|
||||
generated_title = None
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,34 @@ def safe_set_chunks(document: Document, chunks: list) -> None:
|
|||
set_committed_value(document, "chunks", chunks)
|
||||
|
||||
|
||||
def parse_date_flexible(date_str: str) -> datetime:
|
||||
"""
|
||||
Parse date from multiple common formats.
|
||||
|
||||
Args:
|
||||
date_str: Date string to parse
|
||||
|
||||
Returns:
|
||||
Parsed datetime object
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to parse the date string
|
||||
"""
|
||||
formats = ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
return datetime.strptime(date_str.rstrip("Z"), fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Try ISO format as fallback
|
||||
try:
|
||||
return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||
except ValueError as err:
|
||||
raise ValueError(f"Unable to parse date: {date_str}") from err
|
||||
|
||||
|
||||
async def check_duplicate_document_by_hash(
|
||||
session: AsyncSession, content_hash: str
|
||||
) -> Document | None:
|
||||
|
|
@ -188,6 +216,26 @@ def calculate_date_range(
|
|||
)
|
||||
end_date_str = end_date if end_date else calculated_end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# FIX: Ensure end_date is at least 1 day after start_date to avoid
|
||||
# "start_date must be strictly before end_date" errors when dates are the same
|
||||
# (e.g., when last_indexed_at is today)
|
||||
if start_date_str == end_date_str:
|
||||
logger.info(
|
||||
f"Start date ({start_date_str}) equals end date ({end_date_str}), "
|
||||
"adjusting end date to next day to ensure valid date range"
|
||||
)
|
||||
# Parse end_date and add 1 day
|
||||
try:
|
||||
end_dt = parse_date_flexible(end_date_str)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Could not parse end_date '{end_date_str}', using current date"
|
||||
)
|
||||
end_dt = datetime.now()
|
||||
end_dt = end_dt + timedelta(days=1)
|
||||
end_date_str = end_dt.strftime("%Y-%m-%d")
|
||||
logger.info(f"Adjusted end date to {end_date_str}")
|
||||
|
||||
return start_date_str, end_date_str
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from .base import (
|
|||
get_connector_by_id,
|
||||
get_current_timestamp,
|
||||
logger,
|
||||
parse_date_flexible,
|
||||
safe_set_chunks,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
|
@ -222,6 +223,26 @@ async def index_google_calendar_events(
|
|||
start_date_str = start_date
|
||||
end_date_str = end_date
|
||||
|
||||
# FIX: Ensure end_date is at least 1 day after start_date to avoid
|
||||
# "start_date must be strictly before end_date" errors when dates are the same
|
||||
# (e.g., when last_indexed_at is today)
|
||||
if start_date_str == end_date_str:
|
||||
logger.info(
|
||||
f"Start date ({start_date_str}) equals end date ({end_date_str}), "
|
||||
"adjusting end date to next day to ensure valid date range"
|
||||
)
|
||||
# Parse end_date and add 1 day
|
||||
try:
|
||||
end_dt = parse_date_flexible(end_date_str)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Could not parse end_date '{end_date_str}', using current date"
|
||||
)
|
||||
end_dt = datetime.now()
|
||||
end_dt = end_dt + timedelta(days=1)
|
||||
end_date_str = end_dt.strftime("%Y-%m-%d")
|
||||
logger.info(f"Adjusted end date to {end_date_str}")
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching Google Calendar events from {start_date_str} to {end_date_str}",
|
||||
|
|
|
|||
|
|
@ -202,13 +202,44 @@ async def index_notion_pages(
|
|||
"Recommend reconnecting with OAuth."
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to get Notion pages for connector {connector_id}",
|
||||
str(e),
|
||||
{"error_type": "PageFetchError"},
|
||||
error_str = str(e)
|
||||
# Check if this is an unsupported block type error (transcription, ai_block, etc.)
|
||||
# These are known Notion API limitations and should be logged as warnings, not errors
|
||||
unsupported_block_errors = [
|
||||
"transcription is not supported",
|
||||
"ai_block is not supported",
|
||||
"is not supported via the API",
|
||||
]
|
||||
is_unsupported_block_error = any(
|
||||
err in error_str.lower() for err in unsupported_block_errors
|
||||
)
|
||||
logger.error(f"Error fetching Notion pages: {e!s}", exc_info=True)
|
||||
|
||||
if is_unsupported_block_error:
|
||||
# Log as warning since this is a known Notion API limitation
|
||||
logger.warning(
|
||||
f"Notion API limitation for connector {connector_id}: {error_str}. "
|
||||
"This is a known issue with Notion AI blocks (transcription, ai_block) "
|
||||
"that are not accessible via the Notion API."
|
||||
)
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
"Failed to get Notion pages: Notion API limitation",
|
||||
f"{error_str} - This page contains Notion AI content (transcription/ai_block) that cannot be accessed via the API.",
|
||||
{"error_type": "UnsupportedBlockType", "is_known_limitation": True},
|
||||
)
|
||||
else:
|
||||
# Log as error for other failures
|
||||
logger.error(
|
||||
f"Error fetching Notion pages for connector {connector_id}: {error_str}",
|
||||
exc_info=True,
|
||||
)
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to get Notion pages for connector {connector_id}",
|
||||
str(e),
|
||||
{"error_type": "PageFetchError"},
|
||||
)
|
||||
|
||||
await notion_client.close()
|
||||
return 0, f"Failed to get Notion pages: {e!s}"
|
||||
|
||||
|
|
|
|||
|
|
@ -117,10 +117,15 @@ async def index_crawled_urls(
|
|||
api_key = connector.config.get("FIRECRAWL_API_KEY")
|
||||
|
||||
# Get URLs from connector config
|
||||
urls = parse_webcrawler_urls(connector.config.get("INITIAL_URLS"))
|
||||
raw_initial_urls = connector.config.get("INITIAL_URLS")
|
||||
urls = parse_webcrawler_urls(raw_initial_urls)
|
||||
|
||||
# DEBUG: Log connector config details for troubleshooting empty URL issues
|
||||
logger.info(
|
||||
f"Starting crawled web page indexing for connector {connector_id} with {len(urls)} URLs"
|
||||
f"Starting crawled web page indexing for connector {connector_id} with {len(urls)} URLs. "
|
||||
f"Connector name: {connector.name}, "
|
||||
f"INITIAL_URLS type: {type(raw_initial_urls).__name__}, "
|
||||
f"INITIAL_URLS value: {repr(raw_initial_urls)[:200] if raw_initial_urls else 'None'}"
|
||||
)
|
||||
|
||||
# Initialize webcrawler client
|
||||
|
|
@ -137,11 +142,18 @@ async def index_crawled_urls(
|
|||
|
||||
# Validate URLs
|
||||
if not urls:
|
||||
# DEBUG: Log detailed connector config for troubleshooting
|
||||
logger.error(
|
||||
f"No URLs provided for indexing. Connector ID: {connector_id}, "
|
||||
f"Connector name: {connector.name}, "
|
||||
f"Config keys: {list(connector.config.keys()) if connector.config else 'None'}, "
|
||||
f"INITIAL_URLS raw value: {raw_initial_urls!r}"
|
||||
)
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
"No URLs provided for indexing",
|
||||
"Empty URL list",
|
||||
{"error_type": "ValidationError"},
|
||||
f"Empty URL list. INITIAL_URLS value: {repr(raw_initial_urls)[:100]}",
|
||||
{"error_type": "ValidationError", "connector_name": connector.name},
|
||||
)
|
||||
return 0, "No URLs provided for indexing"
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import logging
|
|||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import aiohttp
|
||||
from fake_useragent import UserAgent
|
||||
from requests import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
|
@ -23,6 +25,7 @@ from app.utils.document_converters import (
|
|||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.proxy_config import get_requests_proxies
|
||||
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -200,9 +203,16 @@ async def add_youtube_video_document(
|
|||
}
|
||||
oembed_url = "https://www.youtube.com/oembed"
|
||||
|
||||
# Build residential proxy URL (if configured)
|
||||
residential_proxies = get_requests_proxies()
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as http_session,
|
||||
http_session.get(oembed_url, params=params) as response,
|
||||
http_session.get(
|
||||
oembed_url,
|
||||
params=params,
|
||||
proxy=residential_proxies["http"] if residential_proxies else None,
|
||||
) as response,
|
||||
):
|
||||
video_data = await response.json()
|
||||
|
||||
|
|
@ -228,7 +238,12 @@ async def add_youtube_video_document(
|
|||
)
|
||||
|
||||
try:
|
||||
ytt_api = YouTubeTranscriptApi()
|
||||
ua = UserAgent()
|
||||
http_client = Session()
|
||||
http_client.headers.update({"User-Agent": ua.random})
|
||||
if residential_proxies:
|
||||
http_client.proxies.update(residential_proxies)
|
||||
ytt_api = YouTubeTranscriptApi(http_client=http_client)
|
||||
captions = ytt_api.fetch(video_id)
|
||||
# Include complete caption information with timestamps
|
||||
transcript_segments = []
|
||||
|
|
|
|||
|
|
@ -219,7 +219,9 @@ class CustomBearerTransport(BearerTransport):
|
|||
|
||||
# Decode JWT to get user_id for refresh token creation
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False})
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
)
|
||||
user_id = uuid.UUID(payload.get("sub"))
|
||||
refresh_token = await create_refresh_token(user_id)
|
||||
except Exception as e:
|
||||
|
|
|
|||
86
surfsense_backend/app/utils/proxy_config.py
Normal file
86
surfsense_backend/app/utils/proxy_config.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
Residential proxy configuration utility.
|
||||
|
||||
Reads proxy credentials from the application Config and provides helper
|
||||
functions that return proxy configs in the format expected by different
|
||||
HTTP libraries (requests, httpx, aiohttp, Playwright).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from app.config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_password_b64() -> str | None:
|
||||
"""
|
||||
Build the base64-encoded password dict required by anonymous-proxies.net.
|
||||
|
||||
Returns ``None`` when the required config values are not set.
|
||||
"""
|
||||
password = Config.RESIDENTIAL_PROXY_PASSWORD
|
||||
if not password:
|
||||
return None
|
||||
|
||||
password_dict = {
|
||||
"p": password,
|
||||
"l": Config.RESIDENTIAL_PROXY_LOCATION,
|
||||
"t": Config.RESIDENTIAL_PROXY_TYPE,
|
||||
}
|
||||
return base64.b64encode(json.dumps(password_dict).encode("utf-8")).decode("utf-8")
|
||||
|
||||
|
||||
def get_residential_proxy_url() -> str | None:
|
||||
"""
|
||||
Return the fully-formed residential proxy URL, or ``None`` when not
|
||||
configured.
|
||||
|
||||
The URL format is::
|
||||
|
||||
http://<username>:<base64_password>@<hostname>/
|
||||
"""
|
||||
username = Config.RESIDENTIAL_PROXY_USERNAME
|
||||
hostname = Config.RESIDENTIAL_PROXY_HOSTNAME
|
||||
password_b64 = _build_password_b64()
|
||||
|
||||
if not all([username, hostname, password_b64]):
|
||||
return None
|
||||
|
||||
return f"http://{username}:{password_b64}@{hostname}/"
|
||||
|
||||
|
||||
def get_requests_proxies() -> dict[str, str] | None:
|
||||
"""
|
||||
Return a ``{"http": …, "https": …}`` dict suitable for
|
||||
``requests.Session.proxies`` and ``aiohttp`` ``proxy=`` kwarg,
|
||||
or ``None`` when not configured.
|
||||
"""
|
||||
proxy_url = get_residential_proxy_url()
|
||||
if proxy_url is None:
|
||||
return None
|
||||
return {"http": proxy_url, "https": proxy_url}
|
||||
|
||||
|
||||
def get_playwright_proxy() -> dict[str, str] | None:
|
||||
"""
|
||||
Return a Playwright-compatible proxy dict::
|
||||
|
||||
{"server": "http://host:port", "username": "…", "password": "…"}
|
||||
|
||||
or ``None`` when not configured.
|
||||
"""
|
||||
username = Config.RESIDENTIAL_PROXY_USERNAME
|
||||
hostname = Config.RESIDENTIAL_PROXY_HOSTNAME
|
||||
password_b64 = _build_password_b64()
|
||||
|
||||
if not all([username, hostname, password_b64]):
|
||||
return None
|
||||
|
||||
return {
|
||||
"server": f"http://{hostname}",
|
||||
"username": username,
|
||||
"password": password_b64,
|
||||
}
|
||||
44
surfsense_backend/app/utils/signed_image_urls.py
Normal file
44
surfsense_backend/app/utils/signed_image_urls.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Access token utilities for generated images.
|
||||
|
||||
Provides token generation and verification so that generated images can be
|
||||
served via <img> tags (which cannot pass auth headers) while still
|
||||
restricting access to authorised users.
|
||||
|
||||
Each image generation record stores its own random access token. The token
|
||||
is verified by comparing the incoming query-parameter value against the
|
||||
stored value in the database. This approach:
|
||||
|
||||
* Survives SECRET_KEY rotation — tokens are random, not derived from a key.
|
||||
* Allows explicit revocation — just clear the column.
|
||||
* Is immune to timing attacks — uses ``hmac.compare_digest``.
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import secrets
|
||||
|
||||
|
||||
def generate_image_token() -> str:
|
||||
"""
|
||||
Generate a cryptographically random access token for an image.
|
||||
|
||||
Returns:
|
||||
A 64-character URL-safe hex string.
|
||||
"""
|
||||
return secrets.token_hex(32)
|
||||
|
||||
|
||||
def verify_image_token(stored_token: str | None, provided_token: str) -> bool:
|
||||
"""
|
||||
Constant-time comparison of a stored token against a user-provided one.
|
||||
|
||||
Args:
|
||||
stored_token: The token persisted on the ImageGeneration record.
|
||||
provided_token: The token from the URL query parameter.
|
||||
|
||||
Returns:
|
||||
True if the tokens match, False otherwise.
|
||||
"""
|
||||
if not stored_token or not provided_token:
|
||||
return False
|
||||
return hmac.compare_digest(stored_token, provided_token)
|
||||
|
|
@ -7,6 +7,7 @@ import {
|
|||
ChevronRight,
|
||||
FileText,
|
||||
Globe,
|
||||
ImageIcon,
|
||||
type LucideIcon,
|
||||
Menu,
|
||||
MessageSquare,
|
||||
|
|
@ -19,6 +20,7 @@ import { useTranslations } from "next-intl";
|
|||
import { useCallback, useEffect, useState } from "react";
|
||||
import { PublicChatSnapshotsManager } from "@/components/public-chat-snapshots/public-chat-snapshots-manager";
|
||||
import { GeneralSettingsManager } from "@/components/settings/general-settings-manager";
|
||||
import { ImageModelManager } from "@/components/settings/image-model-manager";
|
||||
import { LLMRoleManager } from "@/components/settings/llm-role-manager";
|
||||
import { ModelConfigManager } from "@/components/settings/model-config-manager";
|
||||
import { PromptConfigManager } from "@/components/settings/prompt-config-manager";
|
||||
|
|
@ -52,6 +54,12 @@ const settingsNavItems: SettingsNavItem[] = [
|
|||
descriptionKey: "nav_role_assignments_desc",
|
||||
icon: Brain,
|
||||
},
|
||||
{
|
||||
id: "image-models",
|
||||
labelKey: "nav_image_models",
|
||||
descriptionKey: "nav_image_models_desc",
|
||||
icon: ImageIcon,
|
||||
},
|
||||
{
|
||||
id: "prompts",
|
||||
labelKey: "nav_system_instructions",
|
||||
|
|
@ -282,8 +290,11 @@ function SettingsContent({
|
|||
<GeneralSettingsManager searchSpaceId={searchSpaceId} />
|
||||
)}
|
||||
{activeSection === "models" && <ModelConfigManager searchSpaceId={searchSpaceId} />}
|
||||
{activeSection === "roles" && <LLMRoleManager searchSpaceId={searchSpaceId} />}
|
||||
{activeSection === "prompts" && <PromptConfigManager searchSpaceId={searchSpaceId} />}
|
||||
{activeSection === "roles" && <LLMRoleManager searchSpaceId={searchSpaceId} />}
|
||||
{activeSection === "image-models" && (
|
||||
<ImageModelManager searchSpaceId={searchSpaceId} />
|
||||
)}
|
||||
{activeSection === "prompts" && <PromptConfigManager searchSpaceId={searchSpaceId} />}
|
||||
{activeSection === "public-links" && (
|
||||
<PublicChatSnapshotsManager searchSpaceId={searchSpaceId} />
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
import { atomWithMutation } from "jotai-tanstack-query";
|
||||
import { toast } from "sonner";
|
||||
import type {
|
||||
CreateImageGenConfigRequest,
|
||||
GetImageGenConfigsResponse,
|
||||
UpdateImageGenConfigRequest,
|
||||
UpdateImageGenConfigResponse,
|
||||
} from "@/contracts/types/new-llm-config.types";
|
||||
import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
import { queryClient } from "@/lib/query-client/client";
|
||||
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
|
||||
|
||||
/**
|
||||
* Mutation atom for creating a new ImageGenerationConfig
|
||||
*/
|
||||
export const createImageGenConfigMutationAtom = atomWithMutation((get) => {
|
||||
const searchSpaceId = get(activeSearchSpaceIdAtom);
|
||||
|
||||
return {
|
||||
mutationKey: ["image-gen-configs", "create"],
|
||||
enabled: !!searchSpaceId,
|
||||
mutationFn: async (request: CreateImageGenConfigRequest) => {
|
||||
return imageGenConfigApiService.createConfig(request);
|
||||
},
|
||||
onSuccess: () => {
|
||||
toast.success("Image model configuration created");
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
|
||||
});
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(error.message || "Failed to create image model configuration");
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
/**
|
||||
* Mutation atom for updating an existing ImageGenerationConfig
|
||||
*/
|
||||
export const updateImageGenConfigMutationAtom = atomWithMutation((get) => {
|
||||
const searchSpaceId = get(activeSearchSpaceIdAtom);
|
||||
|
||||
return {
|
||||
mutationKey: ["image-gen-configs", "update"],
|
||||
enabled: !!searchSpaceId,
|
||||
mutationFn: async (request: UpdateImageGenConfigRequest) => {
|
||||
return imageGenConfigApiService.updateConfig(request);
|
||||
},
|
||||
onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => {
|
||||
toast.success("Image model configuration updated");
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
|
||||
});
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.imageGenConfigs.byId(request.id),
|
||||
});
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(error.message || "Failed to update image model configuration");
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
/**
|
||||
* Mutation atom for deleting an ImageGenerationConfig
|
||||
*/
|
||||
export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => {
|
||||
const searchSpaceId = get(activeSearchSpaceIdAtom);
|
||||
|
||||
return {
|
||||
mutationKey: ["image-gen-configs", "delete"],
|
||||
enabled: !!searchSpaceId,
|
||||
mutationFn: async (id: number) => {
|
||||
return imageGenConfigApiService.deleteConfig(id);
|
||||
},
|
||||
onSuccess: (_, id: number) => {
|
||||
toast.success("Image model configuration deleted");
|
||||
queryClient.setQueryData(
|
||||
cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
|
||||
(oldData: GetImageGenConfigsResponse | undefined) => {
|
||||
if (!oldData) return oldData;
|
||||
return oldData.filter((config) => config.id !== id);
|
||||
}
|
||||
);
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(error.message || "Failed to delete image model configuration");
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
import { atomWithQuery } from "jotai-tanstack-query";
|
||||
import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
|
||||
|
||||
/**
|
||||
* Query atom for fetching user-created image gen configs for the active search space
|
||||
*/
|
||||
export const imageGenConfigsAtom = atomWithQuery((get) => {
|
||||
const searchSpaceId = get(activeSearchSpaceIdAtom);
|
||||
|
||||
return {
|
||||
queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
|
||||
enabled: !!searchSpaceId,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
queryFn: async () => {
|
||||
return imageGenConfigApiService.getConfigs(Number(searchSpaceId));
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
/**
|
||||
* Query atom for fetching global image gen configs (from YAML, negative IDs)
|
||||
*/
|
||||
export const globalImageGenConfigsAtom = atomWithQuery(() => {
|
||||
return {
|
||||
queryKey: cacheKeys.imageGenConfigs.global(),
|
||||
staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change
|
||||
queryFn: async () => {
|
||||
return imageGenConfigApiService.getGlobalConfigs();
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -4,16 +4,20 @@ import Image from "next/image";
|
|||
import Link from "next/link";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const Logo = ({ className }: { className?: string }) => {
|
||||
return (
|
||||
<Link href="/">
|
||||
<Image
|
||||
src="/icon-128.svg"
|
||||
className={cn("dark:invert", className)}
|
||||
alt="logo"
|
||||
width={128}
|
||||
height={128}
|
||||
/>
|
||||
</Link>
|
||||
export const Logo = ({ className, disableLink = false }: { className?: string; disableLink?: boolean }) => {
|
||||
const image = (
|
||||
<Image
|
||||
src="/icon-128.svg"
|
||||
className={cn("dark:invert", className)}
|
||||
alt="logo"
|
||||
width={128}
|
||||
height={128}
|
||||
/>
|
||||
);
|
||||
|
||||
if (disableLink) {
|
||||
return image;
|
||||
}
|
||||
|
||||
return <Link href="/">{image}</Link>;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -351,14 +351,14 @@ export const ComposerAddAttachment: FC = () => {
|
|||
<PlusIcon className="aui-attachment-add-icon size-5 stroke-[1.5px]" />
|
||||
</TooltipIconButton>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="start" className="w-48 bg-background border-border">
|
||||
<DropdownMenuContent align="start" className="w-72 bg-background border-border">
|
||||
<DropdownMenuItem onSelect={handleChatAttachment} className="cursor-pointer">
|
||||
<Paperclip className="size-4" />
|
||||
<span>Add attachment</span>
|
||||
<span>Add attachment to this chat</span>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handleFileUpload} className="cursor-pointer">
|
||||
<Upload className="size-4" />
|
||||
<span>Upload Documents</span>
|
||||
<span>Upload documents to Search Space</span>
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ const DesktopNav = ({ navItems, isScrolled }: any) => {
|
|||
href="/"
|
||||
className="flex flex-1 flex-row items-center gap-0.5 hover:opacity-80 transition-opacity"
|
||||
>
|
||||
<Logo className="h-8 w-8 rounded-md" />
|
||||
<Logo className="h-8 w-8 rounded-md" disableLink />
|
||||
<span className="dark:text-white/90 text-gray-800 text-lg font-bold">SurfSense</span>
|
||||
</Link>
|
||||
<div className="hidden flex-1 flex-row items-center justify-center space-x-2 text-sm font-medium text-zinc-600 transition duration-200 hover:text-zinc-800 lg:flex lg:space-x-2">
|
||||
|
|
@ -145,7 +145,7 @@ const MobileNav = ({ navItems, isScrolled }: any) => {
|
|||
href="/"
|
||||
className="flex flex-row items-center gap-2 hover:opacity-80 transition-opacity"
|
||||
>
|
||||
<Logo className="h-8 w-8 rounded-md" />
|
||||
<Logo className="h-8 w-8 rounded-md" disableLink />
|
||||
<span className="dark:text-white/90 text-gray-800 text-lg font-bold">SurfSense</span>
|
||||
</Link>
|
||||
<button
|
||||
|
|
|
|||
|
|
@ -2,9 +2,13 @@
|
|||
|
||||
import { useCallback, useState } from "react";
|
||||
import type {
|
||||
GlobalImageGenConfig,
|
||||
GlobalNewLLMConfig,
|
||||
ImageGenerationConfig,
|
||||
NewLLMConfigPublic,
|
||||
} from "@/contracts/types/new-llm-config.types";
|
||||
import { ImageConfigSidebar } from "./image-config-sidebar";
|
||||
import { ImageModelSelector } from "./image-model-selector";
|
||||
import { ModelConfigSidebar } from "./model-config-sidebar";
|
||||
import { ModelSelector } from "./model-selector";
|
||||
|
||||
|
|
@ -13,6 +17,7 @@ interface ChatHeaderProps {
|
|||
}
|
||||
|
||||
export function ChatHeader({ searchSpaceId }: ChatHeaderProps) {
|
||||
// LLM config sidebar state
|
||||
const [sidebarOpen, setSidebarOpen] = useState(false);
|
||||
const [selectedConfig, setSelectedConfig] = useState<
|
||||
NewLLMConfigPublic | GlobalNewLLMConfig | null
|
||||
|
|
@ -20,6 +25,15 @@ export function ChatHeader({ searchSpaceId }: ChatHeaderProps) {
|
|||
const [isGlobal, setIsGlobal] = useState(false);
|
||||
const [sidebarMode, setSidebarMode] = useState<"create" | "edit" | "view">("view");
|
||||
|
||||
// Image config sidebar state
|
||||
const [imageSidebarOpen, setImageSidebarOpen] = useState(false);
|
||||
const [selectedImageConfig, setSelectedImageConfig] = useState<
|
||||
ImageGenerationConfig | GlobalImageGenConfig | null
|
||||
>(null);
|
||||
const [isImageGlobal, setIsImageGlobal] = useState(false);
|
||||
const [imageSidebarMode, setImageSidebarMode] = useState<"create" | "edit" | "view">("view");
|
||||
|
||||
// LLM handlers
|
||||
const handleEditConfig = useCallback(
|
||||
(config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => {
|
||||
setSelectedConfig(config);
|
||||
|
|
@ -39,15 +53,36 @@ export function ChatHeader({ searchSpaceId }: ChatHeaderProps) {
|
|||
|
||||
const handleSidebarClose = useCallback((open: boolean) => {
|
||||
setSidebarOpen(open);
|
||||
if (!open) {
|
||||
// Reset state when closing
|
||||
setSelectedConfig(null);
|
||||
}
|
||||
if (!open) setSelectedConfig(null);
|
||||
}, []);
|
||||
|
||||
// Image model handlers
|
||||
const handleAddImageModel = useCallback(() => {
|
||||
setSelectedImageConfig(null);
|
||||
setIsImageGlobal(false);
|
||||
setImageSidebarMode("create");
|
||||
setImageSidebarOpen(true);
|
||||
}, []);
|
||||
|
||||
const handleEditImageConfig = useCallback(
|
||||
(config: ImageGenerationConfig | GlobalImageGenConfig, global: boolean) => {
|
||||
setSelectedImageConfig(config);
|
||||
setIsImageGlobal(global);
|
||||
setImageSidebarMode(global ? "view" : "edit");
|
||||
setImageSidebarOpen(true);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleImageSidebarClose = useCallback((open: boolean) => {
|
||||
setImageSidebarOpen(open);
|
||||
if (!open) setSelectedImageConfig(null);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<ModelSelector onEdit={handleEditConfig} onAddNew={handleAddNew} />
|
||||
<ImageModelSelector onEdit={handleEditImageConfig} onAddNew={handleAddImageModel} />
|
||||
<ModelConfigSidebar
|
||||
open={sidebarOpen}
|
||||
onOpenChange={handleSidebarClose}
|
||||
|
|
@ -56,6 +91,14 @@ export function ChatHeader({ searchSpaceId }: ChatHeaderProps) {
|
|||
searchSpaceId={searchSpaceId}
|
||||
mode={sidebarMode}
|
||||
/>
|
||||
<ImageConfigSidebar
|
||||
open={imageSidebarOpen}
|
||||
onOpenChange={handleImageSidebarClose}
|
||||
config={selectedImageConfig}
|
||||
isGlobal={isImageGlobal}
|
||||
searchSpaceId={searchSpaceId}
|
||||
mode={imageSidebarMode}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
522
surfsense_web/components/new-chat/image-config-sidebar.tsx
Normal file
522
surfsense_web/components/new-chat/image-config-sidebar.tsx
Normal file
|
|
@ -0,0 +1,522 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import {
|
||||
AlertCircle,
|
||||
Check,
|
||||
ChevronsUpDown,
|
||||
Globe,
|
||||
ImageIcon,
|
||||
Key,
|
||||
Shuffle,
|
||||
X,
|
||||
Zap,
|
||||
} from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
createImageGenConfigMutationAtom,
|
||||
updateImageGenConfigMutationAtom,
|
||||
} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms";
|
||||
import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "@/components/ui/command";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { IMAGE_GEN_MODELS, IMAGE_GEN_PROVIDERS } from "@/contracts/enums/image-gen-providers";
|
||||
import type {
|
||||
GlobalImageGenConfig,
|
||||
ImageGenerationConfig,
|
||||
} from "@/contracts/types/new-llm-config.types";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface ImageConfigSidebarProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
config: ImageGenerationConfig | GlobalImageGenConfig | null;
|
||||
isGlobal: boolean;
|
||||
searchSpaceId: number;
|
||||
mode: "create" | "edit" | "view";
|
||||
}
|
||||
|
||||
const INITIAL_FORM = {
|
||||
name: "",
|
||||
description: "",
|
||||
provider: "",
|
||||
model_name: "",
|
||||
api_key: "",
|
||||
api_base: "",
|
||||
api_version: "",
|
||||
};
|
||||
|
||||
export function ImageConfigSidebar({
|
||||
open,
|
||||
onOpenChange,
|
||||
config,
|
||||
isGlobal,
|
||||
searchSpaceId,
|
||||
mode,
|
||||
}: ImageConfigSidebarProps) {
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [mounted, setMounted] = useState(false);
|
||||
const [formData, setFormData] = useState(INITIAL_FORM);
|
||||
const [modelComboboxOpen, setModelComboboxOpen] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
setMounted(true);
|
||||
}, []);
|
||||
|
||||
// Reset form when opening
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
if (mode === "edit" && config && !isGlobal) {
|
||||
setFormData({
|
||||
name: config.name || "",
|
||||
description: config.description || "",
|
||||
provider: config.provider || "",
|
||||
model_name: config.model_name || "",
|
||||
api_key: (config as ImageGenerationConfig).api_key || "",
|
||||
api_base: config.api_base || "",
|
||||
api_version: config.api_version || "",
|
||||
});
|
||||
} else if (mode === "create") {
|
||||
setFormData(INITIAL_FORM);
|
||||
}
|
||||
}
|
||||
}, [open, mode, config, isGlobal]);
|
||||
|
||||
// Mutations
|
||||
const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom);
|
||||
const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom);
|
||||
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
|
||||
|
||||
// Escape key
|
||||
useEffect(() => {
|
||||
const handleEscape = (e: KeyboardEvent) => {
|
||||
if (e.key === "Escape" && open) onOpenChange(false);
|
||||
};
|
||||
window.addEventListener("keydown", handleEscape);
|
||||
return () => window.removeEventListener("keydown", handleEscape);
|
||||
}, [open, onOpenChange]);
|
||||
|
||||
const isAutoMode = config && "is_auto_mode" in config && config.is_auto_mode;
|
||||
|
||||
const suggestedModels = useMemo(() => {
|
||||
if (!formData.provider) return [];
|
||||
return IMAGE_GEN_MODELS.filter((m) => m.provider === formData.provider);
|
||||
}, [formData.provider]);
|
||||
|
||||
const getTitle = () => {
|
||||
if (mode === "create") return "Add Image Model";
|
||||
if (isAutoMode) return "Auto Mode (Load Balanced)";
|
||||
if (isGlobal) return "View Global Image Model";
|
||||
return "Edit Image Model";
|
||||
};
|
||||
|
||||
const handleSubmit = useCallback(async () => {
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
if (mode === "create") {
|
||||
const result = await createConfig({
|
||||
name: formData.name,
|
||||
provider: formData.provider,
|
||||
model_name: formData.model_name,
|
||||
api_key: formData.api_key,
|
||||
api_base: formData.api_base || undefined,
|
||||
api_version: formData.api_version || undefined,
|
||||
description: formData.description || undefined,
|
||||
search_space_id: searchSpaceId,
|
||||
});
|
||||
// Set as active image model
|
||||
if (result?.id) {
|
||||
await updatePreferences({
|
||||
search_space_id: searchSpaceId,
|
||||
data: { image_generation_config_id: result.id },
|
||||
});
|
||||
}
|
||||
toast.success("Image model created and assigned!");
|
||||
onOpenChange(false);
|
||||
} else if (!isGlobal && config) {
|
||||
await updateConfig({
|
||||
id: config.id,
|
||||
data: {
|
||||
name: formData.name,
|
||||
description: formData.description || undefined,
|
||||
provider: formData.provider,
|
||||
model_name: formData.model_name,
|
||||
api_key: formData.api_key,
|
||||
api_base: formData.api_base || undefined,
|
||||
api_version: formData.api_version || undefined,
|
||||
},
|
||||
});
|
||||
toast.success("Image model updated!");
|
||||
onOpenChange(false);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to save image config:", error);
|
||||
toast.error("Failed to save image model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}, [mode, isGlobal, config, formData, searchSpaceId, createConfig, updateConfig, updatePreferences, onOpenChange]);
|
||||
|
||||
const handleUseGlobalConfig = useCallback(async () => {
|
||||
if (!config || !isGlobal) return;
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
await updatePreferences({
|
||||
search_space_id: searchSpaceId,
|
||||
data: { image_generation_config_id: config.id },
|
||||
});
|
||||
toast.success(`Now using ${config.name}`);
|
||||
onOpenChange(false);
|
||||
} catch (error) {
|
||||
console.error("Failed to set image model:", error);
|
||||
toast.error("Failed to set image model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]);
|
||||
|
||||
const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key;
|
||||
const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider);
|
||||
|
||||
if (!mounted) return null;
|
||||
|
||||
const sidebarContent = (
|
||||
<AnimatePresence>
|
||||
{open && (
|
||||
<>
|
||||
{/* Backdrop */}
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="fixed inset-0 z-50 bg-black/20 backdrop-blur-sm"
|
||||
onClick={() => onOpenChange(false)}
|
||||
/>
|
||||
|
||||
{/* Sidebar */}
|
||||
<motion.div
|
||||
initial={{ x: "100%", opacity: 0 }}
|
||||
animate={{ x: 0, opacity: 1 }}
|
||||
exit={{ x: "100%", opacity: 0 }}
|
||||
transition={{ type: "spring", damping: 30, stiffness: 300 }}
|
||||
className={cn(
|
||||
"fixed right-0 top-0 z-50 h-full w-full sm:w-[480px] lg:w-[540px]",
|
||||
"bg-background border-l border-border/50 shadow-2xl",
|
||||
"flex flex-col"
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-between px-6 py-4 border-b border-border/50",
|
||||
isAutoMode
|
||||
? "bg-gradient-to-r from-violet-500/10 to-purple-500/10"
|
||||
: "bg-gradient-to-r from-teal-500/10 to-cyan-500/10"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center size-10 rounded-xl",
|
||||
isAutoMode
|
||||
? "bg-gradient-to-br from-violet-500 to-purple-600"
|
||||
: "bg-gradient-to-br from-teal-500 to-cyan-600"
|
||||
)}
|
||||
>
|
||||
{isAutoMode ? (
|
||||
<Shuffle className="size-5 text-white" />
|
||||
) : (
|
||||
<ImageIcon className="size-5 text-white" />
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<h2 className="text-base sm:text-lg font-semibold">{getTitle()}</h2>
|
||||
<div className="flex items-center gap-2 mt-0.5">
|
||||
{isAutoMode ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="gap-1 text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||
>
|
||||
<Zap className="size-3" />
|
||||
Recommended
|
||||
</Badge>
|
||||
) : isGlobal ? (
|
||||
<Badge variant="secondary" className="gap-1 text-xs">
|
||||
<Globe className="size-3" />
|
||||
Global
|
||||
</Badge>
|
||||
) : null}
|
||||
{config && !isAutoMode && (
|
||||
<span className="text-xs text-muted-foreground">{config.model_name}</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => onOpenChange(false)}
|
||||
className="h-8 w-8 rounded-full"
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
<span className="sr-only">Close</span>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
<div className="p-6">
|
||||
{/* Auto mode */}
|
||||
{isAutoMode && (
|
||||
<>
|
||||
<Alert className="mb-6 border-violet-500/30 bg-violet-500/5">
|
||||
<Shuffle className="size-4 text-violet-500" />
|
||||
<AlertDescription className="text-sm text-violet-700 dark:text-violet-400">
|
||||
Auto mode distributes image generation requests across all configured providers for optimal performance and rate limit protection.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
<div className="flex gap-3 pt-4 border-t border-border/50">
|
||||
<Button variant="outline" className="flex-1" onClick={() => onOpenChange(false)}>
|
||||
Close
|
||||
</Button>
|
||||
<Button
|
||||
className="flex-1 gap-2 bg-gradient-to-r from-violet-500 to-purple-600 hover:from-violet-600 hover:to-purple-700"
|
||||
onClick={handleUseGlobalConfig}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Loading..." : "Use Auto Mode"}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Global config (read-only) */}
|
||||
{isGlobal && !isAutoMode && config && (
|
||||
<>
|
||||
<Alert className="mb-6 border-amber-500/30 bg-amber-500/5">
|
||||
<AlertCircle className="size-4 text-amber-500" />
|
||||
<AlertDescription className="text-sm text-amber-700 dark:text-amber-400">
|
||||
Global configurations are read-only. To customize, create a new model.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
<div className="space-y-4">
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-1.5">
|
||||
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">Name</div>
|
||||
<p className="text-sm font-medium">{config.name}</p>
|
||||
</div>
|
||||
{config.description && (
|
||||
<div className="space-y-1.5">
|
||||
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">Description</div>
|
||||
<p className="text-sm text-muted-foreground">{config.description}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<Separator />
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-1.5">
|
||||
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">Provider</div>
|
||||
<p className="text-sm font-medium">{config.provider}</p>
|
||||
</div>
|
||||
<div className="space-y-1.5">
|
||||
<div className="text-xs font-medium text-muted-foreground uppercase tracking-wider">Model</div>
|
||||
<p className="text-sm font-medium font-mono">{config.model_name}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-3 pt-6 border-t border-border/50 mt-6">
|
||||
<Button variant="outline" className="flex-1" onClick={() => onOpenChange(false)}>
|
||||
Close
|
||||
</Button>
|
||||
<Button className="flex-1 gap-2" onClick={handleUseGlobalConfig} disabled={isSubmitting}>
|
||||
{isSubmitting ? "Loading..." : "Use This Model"}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Create / Edit form */}
|
||||
{(mode === "create" || (mode === "edit" && !isGlobal)) && (
|
||||
<div className="space-y-4">
|
||||
{/* Name */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Name *</Label>
|
||||
<Input
|
||||
placeholder="e.g., My DALL-E 3"
|
||||
value={formData.name}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, name: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Description */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Description</Label>
|
||||
<Input
|
||||
placeholder="Optional description"
|
||||
value={formData.description}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, description: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Provider */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Provider *</Label>
|
||||
<Select
|
||||
value={formData.provider}
|
||||
onValueChange={(val) => setFormData((p) => ({ ...p, provider: val, model_name: "" }))}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select a provider" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{IMAGE_GEN_PROVIDERS.map((p) => (
|
||||
<SelectItem key={p.value} value={p.value}>
|
||||
<div className="flex flex-col">
|
||||
<span className="font-medium">{p.label}</span>
|
||||
<span className="text-xs text-muted-foreground">{p.example}</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{/* Model Name */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Model Name *</Label>
|
||||
{suggestedModels.length > 0 ? (
|
||||
<Popover open={modelComboboxOpen} onOpenChange={setModelComboboxOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="outline" role="combobox" className="w-full justify-between font-normal">
|
||||
{formData.model_name || "Select or type a model..."}
|
||||
<ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-full p-0" align="start">
|
||||
<Command>
|
||||
<CommandInput
|
||||
placeholder="Search or type model..."
|
||||
value={formData.model_name}
|
||||
onValueChange={(val) => setFormData((p) => ({ ...p, model_name: val }))}
|
||||
/>
|
||||
<CommandList>
|
||||
<CommandEmpty>
|
||||
<span className="text-xs text-muted-foreground">Type a custom model name</span>
|
||||
</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{suggestedModels.map((m) => (
|
||||
<CommandItem
|
||||
key={m.value}
|
||||
value={m.value}
|
||||
onSelect={() => {
|
||||
setFormData((p) => ({ ...p, model_name: m.value }));
|
||||
setModelComboboxOpen(false);
|
||||
}}
|
||||
>
|
||||
<Check className={cn("mr-2 h-4 w-4", formData.model_name === m.value ? "opacity-100" : "opacity-0")} />
|
||||
<span className="font-mono text-sm">{m.value}</span>
|
||||
<span className="ml-2 text-xs text-muted-foreground">{m.label}</span>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
) : (
|
||||
<Input
|
||||
placeholder="e.g., dall-e-3"
|
||||
value={formData.model_name}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, model_name: e.target.value }))}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* API Key */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium flex items-center gap-1.5">
|
||||
<Key className="h-3.5 w-3.5" /> API Key *
|
||||
</Label>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="sk-..."
|
||||
value={formData.api_key}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, api_key: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* API Base */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">API Base URL</Label>
|
||||
<Input
|
||||
placeholder={selectedProvider?.apiBase || "Optional"}
|
||||
value={formData.api_base}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, api_base: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Azure API Version */}
|
||||
{formData.provider === "AZURE_OPENAI" && (
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">API Version (Azure)</Label>
|
||||
<Input
|
||||
placeholder="2024-02-15-preview"
|
||||
value={formData.api_version}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, api_version: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Actions */}
|
||||
<div className="flex gap-3 pt-4 border-t">
|
||||
<Button variant="outline" className="flex-1" onClick={() => onOpenChange(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
className="flex-1 gap-2 bg-gradient-to-r from-teal-500 to-cyan-600 hover:from-teal-600 hover:to-cyan-700"
|
||||
onClick={handleSubmit}
|
||||
disabled={isSubmitting || !isFormValid}
|
||||
>
|
||||
{isSubmitting ? <Spinner size="sm" className="mr-2" /> : null}
|
||||
{mode === "edit" ? "Save Changes" : "Create & Use"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
</>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
);
|
||||
|
||||
return typeof document !== "undefined" ? createPortal(sidebarContent, document.body) : null;
|
||||
}
|
||||
364
surfsense_web/components/new-chat/image-model-selector.tsx
Normal file
364
surfsense_web/components/new-chat/image-model-selector.tsx
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import {
|
||||
Check,
|
||||
ChevronDown,
|
||||
ChevronRight,
|
||||
Edit3,
|
||||
Globe,
|
||||
ImageIcon,
|
||||
Plus,
|
||||
Shuffle,
|
||||
User,
|
||||
} from "lucide-react";
|
||||
import { useCallback, useMemo, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
createImageGenConfigMutationAtom,
|
||||
updateImageGenConfigMutationAtom,
|
||||
} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms";
|
||||
import {
|
||||
globalImageGenConfigsAtom,
|
||||
imageGenConfigsAtom,
|
||||
} from "@/atoms/image-gen-config/image-gen-config-query.atoms";
|
||||
import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms";
|
||||
import { llmPreferencesAtom } from "@/atoms/new-llm-config/new-llm-config-query.atoms";
|
||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
CommandSeparator,
|
||||
} from "@/components/ui/command";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import type {
|
||||
GlobalImageGenConfig,
|
||||
ImageGenerationConfig,
|
||||
} from "@/contracts/types/new-llm-config.types";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface ImageModelSelectorProps {
|
||||
className?: string;
|
||||
onAddNew?: () => void;
|
||||
onEdit?: (config: ImageGenerationConfig | GlobalImageGenConfig, isGlobal: boolean) => void;
|
||||
}
|
||||
|
||||
export function ImageModelSelector({ className, onAddNew, onEdit }: ImageModelSelectorProps) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
|
||||
const { data: globalConfigs, isLoading: globalLoading } =
|
||||
useAtomValue(globalImageGenConfigsAtom);
|
||||
const { data: userConfigs, isLoading: userLoading } = useAtomValue(imageGenConfigsAtom);
|
||||
const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom);
|
||||
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
|
||||
|
||||
const isLoading = globalLoading || userLoading || prefsLoading;
|
||||
|
||||
const currentConfig = useMemo(() => {
|
||||
if (!preferences) return null;
|
||||
const id = preferences.image_generation_config_id;
|
||||
if (id === null || id === undefined) return null;
|
||||
const globalMatch = globalConfigs?.find((c) => c.id === id);
|
||||
if (globalMatch) return globalMatch;
|
||||
return userConfigs?.find((c) => c.id === id) ?? null;
|
||||
}, [preferences, globalConfigs, userConfigs]);
|
||||
|
||||
const isCurrentAutoMode = useMemo(() => {
|
||||
return currentConfig && "is_auto_mode" in currentConfig && currentConfig.is_auto_mode;
|
||||
}, [currentConfig]);
|
||||
|
||||
const filteredGlobal = useMemo(() => {
|
||||
if (!globalConfigs) return [];
|
||||
if (!searchQuery) return globalConfigs;
|
||||
const q = searchQuery.toLowerCase();
|
||||
return globalConfigs.filter(
|
||||
(c) =>
|
||||
c.name.toLowerCase().includes(q) ||
|
||||
c.model_name.toLowerCase().includes(q) ||
|
||||
c.provider.toLowerCase().includes(q)
|
||||
);
|
||||
}, [globalConfigs, searchQuery]);
|
||||
|
||||
const filteredUser = useMemo(() => {
|
||||
if (!userConfigs) return [];
|
||||
if (!searchQuery) return userConfigs;
|
||||
const q = searchQuery.toLowerCase();
|
||||
return userConfigs.filter(
|
||||
(c) =>
|
||||
c.name.toLowerCase().includes(q) ||
|
||||
c.model_name.toLowerCase().includes(q) ||
|
||||
c.provider.toLowerCase().includes(q)
|
||||
);
|
||||
}, [userConfigs, searchQuery]);
|
||||
|
||||
const totalModels = (globalConfigs?.length ?? 0) + (userConfigs?.length ?? 0);
|
||||
|
||||
const handleSelect = useCallback(
|
||||
async (configId: number) => {
|
||||
if (currentConfig?.id === configId) {
|
||||
setOpen(false);
|
||||
return;
|
||||
}
|
||||
if (!searchSpaceId) {
|
||||
toast.error("No search space selected");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await updatePreferences({
|
||||
search_space_id: Number(searchSpaceId),
|
||||
data: { image_generation_config_id: configId },
|
||||
});
|
||||
toast.success("Image model updated");
|
||||
setOpen(false);
|
||||
} catch {
|
||||
toast.error("Failed to switch image model");
|
||||
}
|
||||
},
|
||||
[currentConfig, searchSpaceId, updatePreferences]
|
||||
);
|
||||
|
||||
// Don't render if no configs at all
|
||||
if (!isLoading && totalModels === 0) {
|
||||
return (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={onAddNew}
|
||||
className={cn("h-8 gap-2 px-3 text-sm border-border/60", className)}
|
||||
>
|
||||
<Plus className="size-4 text-teal-600" />
|
||||
<span className="hidden md:inline">Add Image Model</span>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
role="combobox"
|
||||
aria-expanded={open}
|
||||
className={cn("h-8 gap-2 px-3 text-sm border-border/60", className)}
|
||||
>
|
||||
{isLoading ? (
|
||||
<Spinner size="sm" className="text-muted-foreground" />
|
||||
) : currentConfig ? (
|
||||
<>
|
||||
{isCurrentAutoMode ? (
|
||||
<Shuffle className="size-4 text-violet-500" />
|
||||
) : (
|
||||
<ImageIcon className="size-4 text-teal-500" />
|
||||
)}
|
||||
<span className="max-w-[100px] md:max-w-[120px] truncate hidden md:inline">
|
||||
{currentConfig.name}
|
||||
</span>
|
||||
{isCurrentAutoMode ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||
>
|
||||
Auto
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="ml-1 text-[10px] px-1.5 py-0 h-4 bg-teal-50 text-teal-700 dark:bg-teal-900/30 dark:text-teal-300"
|
||||
>
|
||||
Image
|
||||
</Badge>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ImageIcon className="h-4 w-4 text-muted-foreground" />
|
||||
<span className="text-muted-foreground hidden md:inline">Image Model</span>
|
||||
</>
|
||||
)}
|
||||
<ChevronDown
|
||||
className={cn(
|
||||
"h-3.5 w-3.5 text-muted-foreground ml-1 shrink-0 transition-transform duration-200",
|
||||
open && "rotate-180"
|
||||
)}
|
||||
/>
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
|
||||
<PopoverContent
|
||||
className="w-[280px] md:w-[360px] p-0 rounded-lg shadow-lg border-border/60"
|
||||
align="start"
|
||||
sideOffset={8}
|
||||
>
|
||||
<Command shouldFilter={false} className="rounded-lg">
|
||||
{totalModels > 3 && (
|
||||
<div className="flex items-center gap-1 md:gap-2 px-2 md:px-3 py-1.5 md:py-2">
|
||||
<CommandInput
|
||||
placeholder="Search image models..."
|
||||
value={searchQuery}
|
||||
onValueChange={setSearchQuery}
|
||||
className="h-7 md:h-8 text-xs md:text-sm border-0 bg-transparent focus:ring-0 placeholder:text-muted-foreground/60"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<CommandList className="max-h-[300px] md:max-h-[400px] overflow-y-auto">
|
||||
<CommandEmpty className="py-8 text-center">
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ImageIcon className="size-8 text-muted-foreground" />
|
||||
<p className="text-sm text-muted-foreground">No image models found</p>
|
||||
</div>
|
||||
</CommandEmpty>
|
||||
|
||||
{/* Global Image Gen Configs */}
|
||||
{filteredGlobal.length > 0 && (
|
||||
<CommandGroup>
|
||||
<div className="flex items-center gap-2 px-3 py-2 text-xs font-semibold text-muted-foreground tracking-wider">
|
||||
<Globe className="size-3.5" />
|
||||
Global Image Models
|
||||
</div>
|
||||
{filteredGlobal.map((config) => {
|
||||
const isSelected = currentConfig?.id === config.id;
|
||||
const isAuto = "is_auto_mode" in config && config.is_auto_mode;
|
||||
return (
|
||||
<CommandItem
|
||||
key={`g-${config.id}`}
|
||||
value={`g-${config.id}`}
|
||||
onSelect={() => handleSelect(config.id)}
|
||||
className={cn(
|
||||
"mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50",
|
||||
isSelected && "bg-accent/80",
|
||||
isAuto && "border border-violet-200 dark:border-violet-800/50"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-3 min-w-0 flex-1">
|
||||
<div className="shrink-0">
|
||||
{isAuto ? (
|
||||
<Shuffle className="size-4 text-violet-500" />
|
||||
) : (
|
||||
<ImageIcon className="size-4 text-teal-500" />
|
||||
)}
|
||||
</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium truncate">{config.name}</span>
|
||||
{isAuto && (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-[9px] px-1 py-0 h-3.5 bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-0"
|
||||
>
|
||||
Recommended
|
||||
</Badge>
|
||||
)}
|
||||
{isSelected && <Check className="size-3.5 text-primary shrink-0" />}
|
||||
</div>
|
||||
<span className="text-xs text-muted-foreground truncate block">
|
||||
{isAuto ? "Auto load balancing" : config.model_name}
|
||||
</span>
|
||||
</div>
|
||||
{onEdit && (
|
||||
<ChevronRight
|
||||
className="size-3.5 text-muted-foreground shrink-0 opacity-0 group-hover:opacity-100 transition-opacity cursor-pointer"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setOpen(false);
|
||||
onEdit(config, true);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</CommandItem>
|
||||
);
|
||||
})}
|
||||
</CommandGroup>
|
||||
)}
|
||||
|
||||
{/* User Image Gen Configs */}
|
||||
{filteredUser.length > 0 && (
|
||||
<>
|
||||
{filteredGlobal.length > 0 && <CommandSeparator className="my-1 bg-border/30" />}
|
||||
<CommandGroup>
|
||||
<div className="flex items-center gap-2 px-3 py-2 text-xs font-semibold text-muted-foreground tracking-wider">
|
||||
<User className="size-3.5" />
|
||||
Your Image Models
|
||||
</div>
|
||||
{filteredUser.map((config) => {
|
||||
const isSelected = currentConfig?.id === config.id;
|
||||
return (
|
||||
<CommandItem
|
||||
key={`u-${config.id}`}
|
||||
value={`u-${config.id}`}
|
||||
onSelect={() => handleSelect(config.id)}
|
||||
className={cn(
|
||||
"mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50",
|
||||
isSelected && "bg-accent/80"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-3 min-w-0 flex-1">
|
||||
<div className="shrink-0">
|
||||
<ImageIcon className="size-4 text-teal-500" />
|
||||
</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium truncate">{config.name}</span>
|
||||
{isSelected && (
|
||||
<Check className="size-3.5 text-primary shrink-0" />
|
||||
)}
|
||||
</div>
|
||||
<span className="text-xs text-muted-foreground truncate block">
|
||||
{config.model_name}
|
||||
</span>
|
||||
</div>
|
||||
{onEdit && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-7 w-7 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setOpen(false);
|
||||
onEdit(config, false);
|
||||
}}
|
||||
>
|
||||
<Edit3 className="size-3.5 text-muted-foreground" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</CommandItem>
|
||||
);
|
||||
})}
|
||||
</CommandGroup>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Add New */}
|
||||
{onAddNew && (
|
||||
<div className="p-2 bg-muted/20">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="w-full justify-start gap-2 h-9 rounded-lg hover:bg-accent/50"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
onAddNew();
|
||||
}}
|
||||
>
|
||||
<Plus className="size-4 text-teal-600" />
|
||||
<span className="text-sm font-medium">Add Image Model</span>
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
|
@ -392,8 +392,8 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp
|
|||
</CommandGroup>
|
||||
)}
|
||||
|
||||
{/* Add New Config Button */}
|
||||
<div className="p-2 bg-muted/20">
|
||||
{/* Add New Config Button */}
|
||||
<div className="p-2 bg-muted/20">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@ const demoPlans = [
|
|||
features: [
|
||||
"Open source on GitHub",
|
||||
"Upload and chat with 300+ pages of content",
|
||||
"Connects with 8 popular sources, like Drive and Notion.",
|
||||
"Connects with 8 popular sources, like Drive and Notion",
|
||||
"Includes limited access to ChatGPT, Claude, and DeepSeek models",
|
||||
"Supports 100+ more LLMs, including Gemini, Llama and many more.",
|
||||
"50+ File extensions supported.",
|
||||
"Generate podcasts in seconds.",
|
||||
"Supports 100+ more LLMs, including Gemini, Llama and many more",
|
||||
"50+ File extensions supported",
|
||||
"Generate podcasts in seconds",
|
||||
"Cross-Browser Extension for dynamic webpages including authenticated content",
|
||||
"Community support on Discord",
|
||||
],
|
||||
|
|
@ -33,8 +33,8 @@ const demoPlans = [
|
|||
billingText: "billed annually",
|
||||
features: [
|
||||
"Everything in Free",
|
||||
"Upload and chat with 5,000+ pages of content",
|
||||
"Connects with 15+ external sources, like Slack and Airtable.",
|
||||
"Upload and chat with 5,000+ pages of content per user",
|
||||
"Connects with 15+ external sources, like Slack and Airtable",
|
||||
"Includes extended access to ChatGPT, Claude, and DeepSeek models",
|
||||
"Collaboration and commenting features",
|
||||
"Shared BYOK (Bring Your Own Key)",
|
||||
|
|
@ -42,7 +42,7 @@ const demoPlans = [
|
|||
"Planned: Centralized billing",
|
||||
"Priority support",
|
||||
],
|
||||
description: "The AIknowledge base for individuals and teams",
|
||||
description: "The AI knowledge base for individuals and teams",
|
||||
buttonText: "Upgrade",
|
||||
href: "/contact",
|
||||
isPopular: true,
|
||||
|
|
|
|||
692
surfsense_web/components/settings/image-model-manager.tsx
Normal file
692
surfsense_web/components/settings/image-model-manager.tsx
Normal file
|
|
@ -0,0 +1,692 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import {
|
||||
AlertCircle,
|
||||
Check,
|
||||
ChevronsUpDown,
|
||||
Clock,
|
||||
Edit3,
|
||||
ImageIcon,
|
||||
Key,
|
||||
Plus,
|
||||
RefreshCw,
|
||||
Shuffle,
|
||||
Sparkles,
|
||||
Trash2,
|
||||
Wand2,
|
||||
} from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
createImageGenConfigMutationAtom,
|
||||
deleteImageGenConfigMutationAtom,
|
||||
updateImageGenConfigMutationAtom,
|
||||
} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms";
|
||||
import {
|
||||
globalImageGenConfigsAtom,
|
||||
imageGenConfigsAtom,
|
||||
} from "@/atoms/image-gen-config/image-gen-config-query.atoms";
|
||||
import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms";
|
||||
import { llmPreferencesAtom } from "@/atoms/new-llm-config/new-llm-config-query.atoms";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "@/components/ui/command";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import {
|
||||
IMAGE_GEN_PROVIDERS,
|
||||
getImageGenModelsByProvider,
|
||||
} from "@/contracts/enums/image-gen-providers";
|
||||
import type { ImageGenerationConfig } from "@/contracts/types/new-llm-config.types";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface ImageModelManagerProps {
|
||||
searchSpaceId: number;
|
||||
}
|
||||
|
||||
const container = {
|
||||
hidden: { opacity: 0 },
|
||||
show: { opacity: 1, transition: { staggerChildren: 0.05 } },
|
||||
};
|
||||
|
||||
const item = {
|
||||
hidden: { opacity: 0, y: 20 },
|
||||
show: { opacity: 1, y: 0 },
|
||||
};
|
||||
|
||||
export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
||||
// Image gen config atoms
|
||||
const { mutateAsync: createConfig, isPending: isCreating, error: createError } =
|
||||
useAtomValue(createImageGenConfigMutationAtom);
|
||||
const { mutateAsync: updateConfig, isPending: isUpdating, error: updateError } =
|
||||
useAtomValue(updateImageGenConfigMutationAtom);
|
||||
const { mutateAsync: deleteConfig, isPending: isDeleting, error: deleteError } =
|
||||
useAtomValue(deleteImageGenConfigMutationAtom);
|
||||
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
|
||||
|
||||
const { data: userConfigs, isFetching: configsLoading, error: fetchError, refetch: refreshConfigs } =
|
||||
useAtomValue(imageGenConfigsAtom);
|
||||
const { data: globalConfigs = [], isFetching: globalLoading } =
|
||||
useAtomValue(globalImageGenConfigsAtom);
|
||||
const { data: preferences = {}, isFetching: prefsLoading } = useAtomValue(llmPreferencesAtom);
|
||||
|
||||
// Local state
|
||||
const [isDialogOpen, setIsDialogOpen] = useState(false);
|
||||
const [editingConfig, setEditingConfig] = useState<ImageGenerationConfig | null>(null);
|
||||
const [configToDelete, setConfigToDelete] = useState<ImageGenerationConfig | null>(null);
|
||||
|
||||
// Preference state
|
||||
const [selectedPrefId, setSelectedPrefId] = useState<string | number>(
|
||||
preferences.image_generation_config_id ?? ""
|
||||
);
|
||||
const [hasPrefChanges, setHasPrefChanges] = useState(false);
|
||||
const [isSavingPref, setIsSavingPref] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
setSelectedPrefId(preferences.image_generation_config_id ?? "");
|
||||
setHasPrefChanges(false);
|
||||
}, [preferences]);
|
||||
|
||||
const isSubmitting = isCreating || isUpdating;
|
||||
const isLoading = configsLoading || globalLoading || prefsLoading;
|
||||
const errors = [createError, updateError, deleteError, fetchError].filter(Boolean) as Error[];
|
||||
|
||||
// Form state for create/edit dialog
|
||||
const [formData, setFormData] = useState({
|
||||
name: "",
|
||||
description: "",
|
||||
provider: "",
|
||||
custom_provider: "",
|
||||
model_name: "",
|
||||
api_key: "",
|
||||
api_base: "",
|
||||
api_version: "",
|
||||
});
|
||||
const [modelComboboxOpen, setModelComboboxOpen] = useState(false);
|
||||
|
||||
const resetForm = () => {
|
||||
setFormData({
|
||||
name: "",
|
||||
description: "",
|
||||
provider: "",
|
||||
custom_provider: "",
|
||||
model_name: "",
|
||||
api_key: "",
|
||||
api_base: "",
|
||||
api_version: "",
|
||||
});
|
||||
};
|
||||
|
||||
const handleFormSubmit = useCallback(async () => {
|
||||
if (!formData.name || !formData.provider || !formData.model_name || !formData.api_key) {
|
||||
toast.error("Please fill in all required fields");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (editingConfig) {
|
||||
await updateConfig({
|
||||
id: editingConfig.id,
|
||||
data: {
|
||||
name: formData.name,
|
||||
description: formData.description || undefined,
|
||||
provider: formData.provider as any,
|
||||
custom_provider: formData.custom_provider || undefined,
|
||||
model_name: formData.model_name,
|
||||
api_key: formData.api_key,
|
||||
api_base: formData.api_base || undefined,
|
||||
api_version: formData.api_version || undefined,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
const result = await createConfig({
|
||||
name: formData.name,
|
||||
description: formData.description || undefined,
|
||||
provider: formData.provider as any,
|
||||
custom_provider: formData.custom_provider || undefined,
|
||||
model_name: formData.model_name,
|
||||
api_key: formData.api_key,
|
||||
api_base: formData.api_base || undefined,
|
||||
api_version: formData.api_version || undefined,
|
||||
search_space_id: searchSpaceId,
|
||||
});
|
||||
// Auto-assign newly created config
|
||||
if (result?.id) {
|
||||
await updatePreferences({
|
||||
search_space_id: searchSpaceId,
|
||||
data: { image_generation_config_id: result.id },
|
||||
});
|
||||
}
|
||||
}
|
||||
setIsDialogOpen(false);
|
||||
setEditingConfig(null);
|
||||
resetForm();
|
||||
} catch {
|
||||
// Error handled by mutation
|
||||
}
|
||||
}, [editingConfig, formData, searchSpaceId, createConfig, updateConfig, updatePreferences]);
|
||||
|
||||
const handleDelete = async () => {
|
||||
if (!configToDelete) return;
|
||||
try {
|
||||
await deleteConfig(configToDelete.id);
|
||||
setConfigToDelete(null);
|
||||
} catch {
|
||||
// Error handled by mutation
|
||||
}
|
||||
};
|
||||
|
||||
const openEditDialog = (config: ImageGenerationConfig) => {
|
||||
setEditingConfig(config);
|
||||
setFormData({
|
||||
name: config.name,
|
||||
description: config.description || "",
|
||||
provider: config.provider,
|
||||
custom_provider: config.custom_provider || "",
|
||||
model_name: config.model_name,
|
||||
api_key: config.api_key,
|
||||
api_base: config.api_base || "",
|
||||
api_version: config.api_version || "",
|
||||
});
|
||||
setIsDialogOpen(true);
|
||||
};
|
||||
|
||||
const openNewDialog = () => {
|
||||
setEditingConfig(null);
|
||||
resetForm();
|
||||
setIsDialogOpen(true);
|
||||
};
|
||||
|
||||
const handlePrefChange = (value: string) => {
|
||||
const newVal = value === "unassigned" ? "" : parseInt(value);
|
||||
setSelectedPrefId(newVal);
|
||||
setHasPrefChanges(newVal !== (preferences.image_generation_config_id ?? ""));
|
||||
};
|
||||
|
||||
const handleSavePref = async () => {
|
||||
setIsSavingPref(true);
|
||||
try {
|
||||
await updatePreferences({
|
||||
search_space_id: searchSpaceId,
|
||||
data: {
|
||||
image_generation_config_id:
|
||||
typeof selectedPrefId === "string"
|
||||
? selectedPrefId ? parseInt(selectedPrefId) : undefined
|
||||
: selectedPrefId,
|
||||
},
|
||||
});
|
||||
setHasPrefChanges(false);
|
||||
toast.success("Image generation model preference saved!");
|
||||
} catch {
|
||||
toast.error("Failed to save preference");
|
||||
} finally {
|
||||
setIsSavingPref(false);
|
||||
}
|
||||
};
|
||||
|
||||
const allConfigs = [
|
||||
...globalConfigs.map((c) => ({ ...c, _source: "global" as const })),
|
||||
...(userConfigs ?? []).map((c) => ({ ...c, _source: "user" as const })),
|
||||
];
|
||||
|
||||
const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider);
|
||||
const suggestedModels = getImageGenModelsByProvider(formData.provider);
|
||||
|
||||
return (
|
||||
<div className="space-y-4 md:space-y-6">
|
||||
{/* Header */}
|
||||
<div className="flex flex-col space-y-4 sm:flex-row sm:items-center sm:justify-between sm:space-y-0">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => refreshConfigs()}
|
||||
disabled={isLoading}
|
||||
className="flex items-center gap-2 text-xs md:text-sm h-8 md:h-9"
|
||||
>
|
||||
<RefreshCw className={cn("h-3 w-3 md:h-4 md:w-4", configsLoading && "animate-spin")} />
|
||||
Refresh
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Errors */}
|
||||
<AnimatePresence>
|
||||
{errors.map((err) => (
|
||||
<motion.div key={err?.message} initial={{ opacity: 0, y: -10 }} animate={{ opacity: 1, y: 0 }} exit={{ opacity: 0, y: -10 }}>
|
||||
<Alert variant="destructive" className="py-3">
|
||||
<AlertCircle className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
|
||||
<AlertDescription className="text-xs md:text-sm">{err?.message}</AlertDescription>
|
||||
</Alert>
|
||||
</motion.div>
|
||||
))}
|
||||
</AnimatePresence>
|
||||
|
||||
{/* Global info */}
|
||||
{globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && (
|
||||
<Alert className="border-teal-500/30 bg-teal-500/5 py-3">
|
||||
<Sparkles className="h-3 w-3 md:h-4 md:w-4 text-teal-600 dark:text-teal-400 shrink-0" />
|
||||
<AlertDescription className="text-teal-800 dark:text-teal-200 text-xs md:text-sm">
|
||||
<span className="font-medium">
|
||||
{globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length} global image model(s)
|
||||
</span>{" "}
|
||||
available from your administrator.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Active Preference Card */}
|
||||
{!isLoading && allConfigs.length > 0 && (
|
||||
<motion.div initial={{ opacity: 0, y: 10 }} animate={{ opacity: 1, y: 0 }}>
|
||||
<Card className="border-l-4 border-l-teal-500">
|
||||
<CardHeader className="pb-2 px-3 md:px-6 pt-3 md:pt-6">
|
||||
<div className="flex items-center gap-2 md:gap-3">
|
||||
<div className="p-1.5 md:p-2 rounded-lg bg-teal-100 text-teal-800">
|
||||
<ImageIcon className="w-4 h-4 md:w-5 md:h-5" />
|
||||
</div>
|
||||
<div>
|
||||
<CardTitle className="text-base md:text-lg">Active Image Model</CardTitle>
|
||||
<CardDescription className="text-xs md:text-sm">
|
||||
Select which model to use for image generation
|
||||
</CardDescription>
|
||||
</div>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-3 px-3 md:px-6 pb-3 md:pb-6">
|
||||
<Select
|
||||
value={selectedPrefId?.toString() || "unassigned"}
|
||||
onValueChange={handlePrefChange}
|
||||
>
|
||||
<SelectTrigger className="h-9 md:h-10 text-xs md:text-sm">
|
||||
<SelectValue placeholder="Select an image model" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="unassigned">
|
||||
<span className="text-muted-foreground">Unassigned</span>
|
||||
</SelectItem>
|
||||
{globalConfigs.length > 0 && (
|
||||
<>
|
||||
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">Global</div>
|
||||
{globalConfigs.map((c) => {
|
||||
const isAuto = "is_auto_mode" in c && c.is_auto_mode;
|
||||
return (
|
||||
<SelectItem key={`g-${c.id}`} value={c.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
{isAuto ? (
|
||||
<Badge variant="outline" className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200">
|
||||
<Shuffle className="size-3 mr-1" />AUTO
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="outline" className="text-xs bg-teal-50 text-teal-700 dark:bg-teal-900/30 dark:text-teal-300 border-teal-200">
|
||||
{c.provider}
|
||||
</Badge>
|
||||
)}
|
||||
<span>{c.name}</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
{(userConfigs?.length ?? 0) > 0 && (
|
||||
<>
|
||||
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">Your Models</div>
|
||||
{userConfigs?.map((c) => (
|
||||
<SelectItem key={`u-${c.id}`} value={c.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Badge variant="outline" className="text-xs">{c.provider}</Badge>
|
||||
<span>{c.name}</span>
|
||||
<span className="text-muted-foreground">({c.model_name})</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
{hasPrefChanges && (
|
||||
<div className="flex gap-2 pt-1">
|
||||
<Button size="sm" onClick={handleSavePref} disabled={isSavingPref} className="text-xs h-8">
|
||||
{isSavingPref ? "Saving..." : "Save"}
|
||||
</Button>
|
||||
<Button size="sm" variant="outline" onClick={() => { setSelectedPrefId(preferences.image_generation_config_id ?? ""); setHasPrefChanges(false); }} className="text-xs h-8">
|
||||
Reset
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</motion.div>
|
||||
)}
|
||||
|
||||
{/* Loading */}
|
||||
{isLoading && (
|
||||
<Card>
|
||||
<CardContent className="flex items-center justify-center py-10">
|
||||
<Spinner size="md" className="text-muted-foreground" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{/* User Configs */}
|
||||
{!isLoading && (
|
||||
<div className="space-y-4 md:space-y-6">
|
||||
<div className="flex flex-col space-y-4 sm:flex-row sm:items-center sm:justify-between sm:space-y-0">
|
||||
<h3 className="text-lg md:text-xl font-semibold tracking-tight">Your Image Models</h3>
|
||||
<Button onClick={openNewDialog} className="flex items-center gap-2 text-xs md:text-sm h-8 md:h-9">
|
||||
<Plus className="h-3 w-3 md:h-4 md:w-4" />
|
||||
Add Image Model
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{(userConfigs?.length ?? 0) === 0 ? (
|
||||
<Card className="border-dashed border-2 border-muted-foreground/25">
|
||||
<CardContent className="flex flex-col items-center justify-center py-10 md:py-16 text-center">
|
||||
<div className="rounded-full bg-gradient-to-br from-teal-500/10 to-cyan-500/10 p-4 md:p-6 mb-4">
|
||||
<Wand2 className="h-8 w-8 md:h-12 md:w-12 text-teal-600 dark:text-teal-400" />
|
||||
</div>
|
||||
<h3 className="text-lg font-semibold mb-2">No Image Models Yet</h3>
|
||||
<p className="text-xs md:text-sm text-muted-foreground max-w-sm mb-4">
|
||||
Add your own image generation model (DALL-E 3, GPT Image 1, etc.)
|
||||
</p>
|
||||
<Button onClick={openNewDialog} size="lg" className="gap-2 text-xs md:text-sm">
|
||||
<Plus className="h-3 w-3 md:h-4 md:w-4" />
|
||||
Add First Image Model
|
||||
</Button>
|
||||
</CardContent>
|
||||
</Card>
|
||||
) : (
|
||||
<motion.div variants={container} initial="hidden" animate="show" className="grid gap-4">
|
||||
<AnimatePresence mode="popLayout">
|
||||
{userConfigs?.map((config) => (
|
||||
<motion.div key={config.id} variants={item} layout exit={{ opacity: 0, scale: 0.95 }}>
|
||||
<Card className="group overflow-hidden hover:shadow-lg transition-all duration-300 border-muted-foreground/10 hover:border-teal-500/30">
|
||||
<CardContent className="p-0">
|
||||
<div className="flex">
|
||||
<div className="w-1 md:w-1.5 bg-gradient-to-b from-teal-500/50 to-cyan-500/50 group-hover:from-teal-500 group-hover:to-cyan-500 transition-colors" />
|
||||
<div className="flex-1 p-3 md:p-5">
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="flex items-start gap-2 md:gap-4 flex-1 min-w-0">
|
||||
<div className="flex h-10 w-10 md:h-12 md:w-12 items-center justify-center rounded-lg md:rounded-xl bg-gradient-to-br from-teal-500/10 to-cyan-500/10 shrink-0">
|
||||
<ImageIcon className="h-5 w-5 md:h-6 md:w-6 text-teal-600 dark:text-teal-400" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0 space-y-2">
|
||||
<div className="flex items-center gap-1.5 flex-wrap">
|
||||
<h4 className="text-sm md:text-base font-semibold truncate">{config.name}</h4>
|
||||
<Badge variant="secondary" className="text-[9px] md:text-[10px] px-1.5 py-0.5 bg-teal-500/10 text-teal-700 dark:text-teal-300 border-teal-500/20">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
</div>
|
||||
<code className="text-[10px] md:text-xs font-mono text-muted-foreground bg-muted/50 px-1.5 py-0.5 rounded-md inline-block">
|
||||
{config.model_name}
|
||||
</code>
|
||||
{config.description && (
|
||||
<p className="text-[10px] md:text-xs text-muted-foreground line-clamp-1">{config.description}</p>
|
||||
)}
|
||||
<div className="flex items-center gap-1 text-[10px] md:text-xs text-muted-foreground pt-1">
|
||||
<Clock className="h-2.5 w-2.5 md:h-3 md:w-3" />
|
||||
{new Date(config.created_at).toLocaleDateString()}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-0.5 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button variant="ghost" size="sm" onClick={() => openEditDialog(config)} className="h-7 w-7 p-0 text-muted-foreground hover:text-foreground">
|
||||
<Edit3 className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Edit</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button variant="ghost" size="sm" onClick={() => setConfigToDelete(config)} className="h-7 w-7 p-0 text-muted-foreground hover:text-destructive">
|
||||
<Trash2 className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Delete</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</motion.div>
|
||||
))}
|
||||
</AnimatePresence>
|
||||
</motion.div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Create/Edit Dialog */}
|
||||
<Dialog open={isDialogOpen} onOpenChange={(open) => { if (!open) { setIsDialogOpen(false); setEditingConfig(null); resetForm(); } }}>
|
||||
<DialogContent className="max-w-lg max-h-[90vh] overflow-y-auto">
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
{editingConfig ? <Edit3 className="w-5 h-5 text-teal-600" /> : <Plus className="w-5 h-5 text-teal-600" />}
|
||||
{editingConfig ? "Edit Image Model" : "Add Image Model"}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{editingConfig ? "Update your image generation model" : "Configure a new image generation model (DALL-E 3, GPT Image 1, etc.)"}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-4 pt-2">
|
||||
{/* Name */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Name *</Label>
|
||||
<Input
|
||||
placeholder="e.g., My DALL-E 3"
|
||||
value={formData.name}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, name: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Description */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Description</Label>
|
||||
<Input
|
||||
placeholder="Optional description"
|
||||
value={formData.description}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, description: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Provider */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Provider *</Label>
|
||||
<Select
|
||||
value={formData.provider}
|
||||
onValueChange={(val) => setFormData((p) => ({ ...p, provider: val, model_name: "" }))}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select a provider" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{IMAGE_GEN_PROVIDERS.map((p) => (
|
||||
<SelectItem key={p.value} value={p.value}>
|
||||
<div className="flex flex-col">
|
||||
<span className="font-medium">{p.label}</span>
|
||||
<span className="text-xs text-muted-foreground">{p.example}</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{/* Model Name */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Model Name *</Label>
|
||||
{suggestedModels.length > 0 ? (
|
||||
<Popover open={modelComboboxOpen} onOpenChange={setModelComboboxOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="outline" role="combobox" className="w-full justify-between font-normal">
|
||||
{formData.model_name || "Select or type a model..."}
|
||||
<ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-full p-0" align="start">
|
||||
<Command>
|
||||
<CommandInput
|
||||
placeholder="Search or type model name..."
|
||||
value={formData.model_name}
|
||||
onValueChange={(val) => setFormData((p) => ({ ...p, model_name: val }))}
|
||||
/>
|
||||
<CommandList>
|
||||
<CommandEmpty>
|
||||
<span className="text-xs text-muted-foreground">Type a custom model name</span>
|
||||
</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{suggestedModels.map((m) => (
|
||||
<CommandItem
|
||||
key={m.value}
|
||||
value={m.value}
|
||||
onSelect={() => {
|
||||
setFormData((p) => ({ ...p, model_name: m.value }));
|
||||
setModelComboboxOpen(false);
|
||||
}}
|
||||
>
|
||||
<Check className={cn("mr-2 h-4 w-4", formData.model_name === m.value ? "opacity-100" : "opacity-0")} />
|
||||
<span className="font-mono text-sm">{m.value}</span>
|
||||
<span className="ml-2 text-xs text-muted-foreground">{m.label}</span>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
) : (
|
||||
<Input
|
||||
placeholder="e.g., dall-e-3"
|
||||
value={formData.model_name}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, model_name: e.target.value }))}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* API Key */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium flex items-center gap-1.5">
|
||||
<Key className="h-3.5 w-3.5" /> API Key *
|
||||
</Label>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="sk-..."
|
||||
value={formData.api_key}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, api_key: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* API Base (optional) */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">API Base URL</Label>
|
||||
<Input
|
||||
placeholder={selectedProvider?.apiBase || "Optional"}
|
||||
value={formData.api_base}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, api_base: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* API Version (Azure) */}
|
||||
{formData.provider === "AZURE_OPENAI" && (
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">API Version (Azure)</Label>
|
||||
<Input
|
||||
placeholder="2024-02-15-preview"
|
||||
value={formData.api_version}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, api_version: e.target.value }))}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Actions */}
|
||||
<div className="flex gap-3 pt-4 border-t">
|
||||
<Button
|
||||
variant="outline"
|
||||
className="flex-1"
|
||||
onClick={() => { setIsDialogOpen(false); setEditingConfig(null); resetForm(); }}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
className="flex-1"
|
||||
onClick={handleFormSubmit}
|
||||
disabled={isSubmitting || !formData.name || !formData.provider || !formData.model_name || !formData.api_key}
|
||||
>
|
||||
{isSubmitting ? <Spinner size="sm" className="mr-2" /> : null}
|
||||
{editingConfig ? "Save Changes" : "Create & Use"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
{/* Delete Confirmation */}
|
||||
<AlertDialog open={!!configToDelete} onOpenChange={(open) => !open && setConfigToDelete(null)}>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle className="flex items-center gap-2">
|
||||
<Trash2 className="h-5 w-5 text-destructive" />
|
||||
Delete Image Model
|
||||
</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
Are you sure you want to delete <span className="font-semibold text-foreground">{configToDelete?.name}</span>?
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel disabled={isDeleting}>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction onClick={handleDelete} disabled={isDeleting} className="bg-destructive text-destructive-foreground hover:bg-destructive/90">
|
||||
{isDeleting ? <><Spinner size="sm" className="mr-2" />Deleting</> : <><Trash2 className="mr-2 h-4 w-4" />Delete</>}
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -255,15 +255,15 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Role Assignment Cards */}
|
||||
{availableConfigs.length > 0 && (
|
||||
<div className="grid gap-4 md:gap-6">
|
||||
{Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => {
|
||||
const IconComponent = role.icon;
|
||||
const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments];
|
||||
const assignedConfig = availableConfigs.find(
|
||||
(config) => config.id === currentAssignment
|
||||
);
|
||||
{/* Role Assignment Cards */}
|
||||
{availableConfigs.length > 0 && (
|
||||
<div className="grid gap-4 md:gap-6">
|
||||
{Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => {
|
||||
const IconComponent = role.icon;
|
||||
const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments];
|
||||
const assignedConfig = availableConfigs.find(
|
||||
(config) => config.id === currentAssignment
|
||||
);
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
|
|
@ -294,100 +294,100 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
</div>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-3 md:space-y-4 px-3 md:px-6 pb-3 md:pb-6">
|
||||
<div className="space-y-1.5 md:space-y-2">
|
||||
<Label className="text-xs md:text-sm font-medium">
|
||||
Assign LLM Configuration:
|
||||
</Label>
|
||||
<Select
|
||||
value={currentAssignment?.toString() || "unassigned"}
|
||||
onValueChange={(value) => handleRoleAssignment(`${key}_llm_id`, value)}
|
||||
>
|
||||
<SelectTrigger className="h-9 md:h-10 text-xs md:text-sm">
|
||||
<SelectValue placeholder="Select an LLM configuration" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="unassigned">
|
||||
<span className="text-muted-foreground">Unassigned</span>
|
||||
</SelectItem>
|
||||
<div className="space-y-1.5 md:space-y-2">
|
||||
<Label className="text-xs md:text-sm font-medium">
|
||||
Assign LLM Configuration:
|
||||
</Label>
|
||||
<Select
|
||||
value={currentAssignment?.toString() || "unassigned"}
|
||||
onValueChange={(value) => handleRoleAssignment(`${key}_llm_id`, value)}
|
||||
>
|
||||
<SelectTrigger className="h-9 md:h-10 text-xs md:text-sm">
|
||||
<SelectValue placeholder="Select an LLM configuration" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="unassigned">
|
||||
<span className="text-muted-foreground">Unassigned</span>
|
||||
</SelectItem>
|
||||
|
||||
{/* Global Configurations */}
|
||||
{globalConfigs.length > 0 && (
|
||||
<>
|
||||
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
|
||||
Global Configurations
|
||||
</div>
|
||||
{globalConfigs.map((config) => {
|
||||
const isAutoMode =
|
||||
"is_auto_mode" in config && config.is_auto_mode;
|
||||
return (
|
||||
<SelectItem key={config.id} value={config.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
{isAutoMode ? (
|
||||
<Badge
|
||||
variant="outline"
|
||||
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200 dark:border-violet-700"
|
||||
>
|
||||
<Shuffle className="size-3 mr-1" />
|
||||
AUTO
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
)}
|
||||
<span>{config.name}</span>
|
||||
{!isAutoMode && (
|
||||
<span className="text-muted-foreground">
|
||||
({config.model_name})
|
||||
</span>
|
||||
)}
|
||||
{isAutoMode ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||
>
|
||||
Recommended
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="secondary" className="text-xs">
|
||||
🌐 Global
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</SelectItem>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
{/* Global Configurations */}
|
||||
{globalConfigs.length > 0 && (
|
||||
<>
|
||||
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
|
||||
Global Configurations
|
||||
</div>
|
||||
{globalConfigs.map((config) => {
|
||||
const isAutoMode =
|
||||
"is_auto_mode" in config && config.is_auto_mode;
|
||||
return (
|
||||
<SelectItem key={config.id} value={config.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
{isAutoMode ? (
|
||||
<Badge
|
||||
variant="outline"
|
||||
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300 border-violet-200 dark:border-violet-700"
|
||||
>
|
||||
<Shuffle className="size-3 mr-1" />
|
||||
AUTO
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
)}
|
||||
<span>{config.name}</span>
|
||||
{!isAutoMode && (
|
||||
<span className="text-muted-foreground">
|
||||
({config.model_name})
|
||||
</span>
|
||||
)}
|
||||
{isAutoMode ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-xs bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-300"
|
||||
>
|
||||
Recommended
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="secondary" className="text-xs">
|
||||
🌐 Global
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</SelectItem>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Custom Configurations */}
|
||||
{newLLMConfigs.length > 0 && (
|
||||
<>
|
||||
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
|
||||
Your Configurations
|
||||
</div>
|
||||
{newLLMConfigs
|
||||
.filter(
|
||||
(config) => config.id && config.id.toString().trim() !== ""
|
||||
)
|
||||
.map((config) => (
|
||||
<SelectItem key={config.id} value={config.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
<span>{config.name}</span>
|
||||
<span className="text-muted-foreground">
|
||||
({config.model_name})
|
||||
</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
{/* Custom Configurations */}
|
||||
{newLLMConfigs.length > 0 && (
|
||||
<>
|
||||
<div className="px-2 py-1.5 text-xs font-semibold text-muted-foreground">
|
||||
Your Configurations
|
||||
</div>
|
||||
{newLLMConfigs
|
||||
.filter(
|
||||
(config) => config.id && config.id.toString().trim() !== ""
|
||||
)
|
||||
.map((config) => (
|
||||
<SelectItem key={config.id} value={config.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
<span>{config.name}</span>
|
||||
<span className="text-muted-foreground">
|
||||
({config.model_name})
|
||||
</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{assignedConfig && (
|
||||
<div
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ function ImageCancelledState({ src }: { src: string }) {
|
|||
function ParsedImage({ result }: { result: unknown }) {
|
||||
const image = parseSerializableImage(result);
|
||||
|
||||
return <Image {...image} maxWidth="420px" />;
|
||||
return <Image {...image} maxWidth="512px" />;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { ExternalLinkIcon, ImageIcon } from "lucide-react";
|
||||
import { ExternalLinkIcon, ImageIcon, SparklesIcon } from "lucide-react";
|
||||
import NextImage from "next/image";
|
||||
import { Component, type ReactNode, useState } from "react";
|
||||
import { z } from "zod";
|
||||
|
|
@ -25,7 +25,7 @@ const SerializableImageSchema = z.object({
|
|||
id: z.string(),
|
||||
assetId: z.string(),
|
||||
src: z.string(),
|
||||
alt: z.string().nullish(), // Made optional - will use fallback if missing
|
||||
alt: z.string().nullish(),
|
||||
title: z.string().nullish(),
|
||||
description: z.string().nullish(),
|
||||
href: z.string().nullish(),
|
||||
|
|
@ -49,7 +49,7 @@ export interface ImageProps {
|
|||
id: string;
|
||||
assetId: string;
|
||||
src: string;
|
||||
alt?: string; // Optional with default fallback
|
||||
alt?: string;
|
||||
title?: string;
|
||||
description?: string;
|
||||
href?: string;
|
||||
|
|
@ -71,10 +71,8 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a
|
|||
if (!parsed.success) {
|
||||
console.warn("Invalid image data:", parsed.error.issues);
|
||||
|
||||
// Try to extract basic info and return a fallback object
|
||||
const obj = (result && typeof result === "object" ? result : {}) as Record<string, unknown>;
|
||||
|
||||
// If we have at least id, assetId, and src, we can still render the image
|
||||
if (
|
||||
typeof obj.id === "string" &&
|
||||
typeof obj.assetId === "string" &&
|
||||
|
|
@ -89,7 +87,7 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a
|
|||
description: typeof obj.description === "string" ? obj.description : undefined,
|
||||
href: typeof obj.href === "string" ? obj.href : undefined,
|
||||
domain: typeof obj.domain === "string" ? obj.domain : undefined,
|
||||
ratio: undefined, // Use default ratio
|
||||
ratio: undefined,
|
||||
source: undefined,
|
||||
};
|
||||
}
|
||||
|
|
@ -97,7 +95,6 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a
|
|||
throw new Error(`Invalid image: ${parsed.error.issues.map((i) => i.message).join(", ")}`);
|
||||
}
|
||||
|
||||
// Provide fallback for alt if it's null/undefined
|
||||
return {
|
||||
...parsed.data,
|
||||
alt: parsed.data.alt ?? "Image",
|
||||
|
|
@ -105,7 +102,7 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a
|
|||
}
|
||||
|
||||
/**
|
||||
* Get aspect ratio class based on ratio prop
|
||||
* Get aspect ratio class based on ratio prop (used for fixed-ratio images only)
|
||||
*/
|
||||
function getAspectRatioClass(ratio?: AspectRatio): string {
|
||||
switch (ratio) {
|
||||
|
|
@ -119,7 +116,6 @@ function getAspectRatioClass(ratio?: AspectRatio): string {
|
|||
return "aspect-[9/16]";
|
||||
case "21:9":
|
||||
return "aspect-[21/9]";
|
||||
case "auto":
|
||||
default:
|
||||
return "aspect-[4/3]";
|
||||
}
|
||||
|
|
@ -150,7 +146,7 @@ export class ImageErrorBoundary extends Component<
|
|||
if (this.state.hasError) {
|
||||
return (
|
||||
<Card className="w-full max-w-md overflow-hidden">
|
||||
<div className="aspect-[4/3] bg-muted flex items-center justify-center">
|
||||
<div className="aspect-square bg-muted flex items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-2 text-muted-foreground">
|
||||
<ImageIcon className="size-8" />
|
||||
<p className="text-sm">Failed to load image</p>
|
||||
|
|
@ -167,10 +163,10 @@ export class ImageErrorBoundary extends Component<
|
|||
/**
|
||||
* Loading skeleton for Image
|
||||
*/
|
||||
export function ImageSkeleton({ maxWidth = "420px" }: { maxWidth?: string }) {
|
||||
export function ImageSkeleton({ maxWidth = "512px" }: { maxWidth?: string }) {
|
||||
return (
|
||||
<Card className="w-full overflow-hidden animate-pulse" style={{ maxWidth }}>
|
||||
<div className="aspect-[4/3] bg-muted flex items-center justify-center">
|
||||
<div className="aspect-square bg-muted flex items-center justify-center">
|
||||
<ImageIcon className="size-12 text-muted-foreground/30" />
|
||||
</div>
|
||||
</Card>
|
||||
|
|
@ -183,7 +179,7 @@ export function ImageSkeleton({ maxWidth = "420px" }: { maxWidth?: string }) {
|
|||
export function ImageLoading({ title = "Loading image..." }: { title?: string }) {
|
||||
return (
|
||||
<Card className="w-full max-w-md overflow-hidden">
|
||||
<div className="aspect-[4/3] bg-muted flex items-center justify-center">
|
||||
<div className="aspect-square bg-muted flex items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<Spinner size="lg" className="text-muted-foreground" />
|
||||
<p className="text-muted-foreground text-sm">{title}</p>
|
||||
|
|
@ -197,7 +193,9 @@ export function ImageLoading({ title = "Loading image..." }: { title?: string })
|
|||
* Image Component
|
||||
*
|
||||
* Display images with metadata and attribution.
|
||||
* Features hover overlay with title and source attribution.
|
||||
* - For "auto" ratio: renders the image at natural dimensions (no cropping)
|
||||
* - For fixed ratios: uses a fixed aspect container with object-cover
|
||||
* - Features hover overlay with title, description, and source attribution.
|
||||
*/
|
||||
export function Image({
|
||||
id,
|
||||
|
|
@ -207,16 +205,18 @@ export function Image({
|
|||
description,
|
||||
href,
|
||||
domain,
|
||||
ratio = "4:3",
|
||||
ratio = "auto",
|
||||
fit = "cover",
|
||||
source,
|
||||
maxWidth = "420px",
|
||||
maxWidth = "512px",
|
||||
className,
|
||||
}: ImageProps) {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const [imageError, setImageError] = useState(false);
|
||||
const aspectRatioClass = getAspectRatioClass(ratio);
|
||||
const [imageLoaded, setImageLoaded] = useState(false);
|
||||
const displayDomain = domain || source?.label;
|
||||
const isGenerated = domain === "ai-generated";
|
||||
const isAutoRatio = !ratio || ratio === "auto";
|
||||
|
||||
const handleClick = () => {
|
||||
const targetUrl = href || source?.url || src;
|
||||
|
|
@ -228,7 +228,7 @@ export function Image({
|
|||
if (imageError) {
|
||||
return (
|
||||
<Card id={id} className={cn("w-full overflow-hidden", className)} style={{ maxWidth }}>
|
||||
<div className={cn("bg-muted flex items-center justify-center", aspectRatioClass)}>
|
||||
<div className="aspect-square bg-muted flex items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-2 text-muted-foreground">
|
||||
<ImageIcon className="size-8" />
|
||||
<p className="text-sm">Image not available</p>
|
||||
|
|
@ -243,6 +243,7 @@ export function Image({
|
|||
id={id}
|
||||
className={cn(
|
||||
"group w-full overflow-hidden cursor-pointer transition-shadow duration-200 hover:shadow-lg",
|
||||
isGenerated && "ring-1 ring-primary/10",
|
||||
className
|
||||
)}
|
||||
style={{ maxWidth }}
|
||||
|
|
@ -258,71 +259,98 @@ export function Image({
|
|||
role="button"
|
||||
tabIndex={0}
|
||||
>
|
||||
<div className={cn("relative w-full overflow-hidden bg-muted", aspectRatioClass)}>
|
||||
{/* Image */}
|
||||
<NextImage
|
||||
src={src}
|
||||
alt={alt}
|
||||
fill
|
||||
className={cn(
|
||||
"transition-transform duration-300",
|
||||
fit === "cover" ? "object-cover" : "object-contain",
|
||||
isHovered && "scale-105"
|
||||
)}
|
||||
unoptimized
|
||||
onError={() => setImageError(true)}
|
||||
/>
|
||||
<div className="relative w-full overflow-hidden bg-muted">
|
||||
{isAutoRatio ? (
|
||||
/* Auto ratio: image renders at natural dimensions, no cropping */
|
||||
<>
|
||||
{!imageLoaded && (
|
||||
<div className="aspect-square flex items-center justify-center">
|
||||
<Spinner size="lg" className="text-muted-foreground" />
|
||||
</div>
|
||||
)}
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={src}
|
||||
alt={alt}
|
||||
className={cn(
|
||||
"w-full h-auto transition-transform duration-300",
|
||||
isHovered && "scale-[1.02]",
|
||||
!imageLoaded && "hidden"
|
||||
)}
|
||||
onLoad={() => setImageLoaded(true)}
|
||||
onError={() => setImageError(true)}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
/* Fixed ratio: constrained aspect container with fill */
|
||||
<div className={getAspectRatioClass(ratio)}>
|
||||
<NextImage
|
||||
src={src}
|
||||
alt={alt}
|
||||
fill
|
||||
className={cn(
|
||||
"transition-transform duration-300",
|
||||
fit === "cover" ? "object-cover" : "object-contain",
|
||||
isHovered && "scale-105"
|
||||
)}
|
||||
unoptimized
|
||||
onError={() => setImageError(true)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Hover overlay - appears on hover */}
|
||||
{/* Hover overlay */}
|
||||
<div
|
||||
className={cn(
|
||||
"absolute inset-0 bg-gradient-to-t from-black/80 via-black/20 to-transparent",
|
||||
"absolute inset-0 bg-gradient-to-t from-black/70 via-transparent to-transparent",
|
||||
"transition-opacity duration-200",
|
||||
isHovered ? "opacity-100" : "opacity-0"
|
||||
)}
|
||||
>
|
||||
{/* Content at bottom */}
|
||||
<div className="absolute bottom-0 left-0 right-0 p-4">
|
||||
{/* Title */}
|
||||
<div className="absolute bottom-0 left-0 right-0 p-3">
|
||||
{title && (
|
||||
<h3 className="font-semibold text-white text-base leading-tight line-clamp-2 mb-1">
|
||||
<h3 className="font-semibold text-white text-sm leading-tight line-clamp-2 mb-0.5">
|
||||
{title}
|
||||
</h3>
|
||||
)}
|
||||
|
||||
{/* Description */}
|
||||
{description && (
|
||||
<p className="text-white/80 text-sm line-clamp-2 mb-2">{description}</p>
|
||||
<p className="text-white/80 text-xs line-clamp-2 mb-1.5">{description}</p>
|
||||
)}
|
||||
|
||||
{/* Source attribution */}
|
||||
{displayDomain && (
|
||||
<div className="flex items-center gap-1.5">
|
||||
{source?.iconUrl ? (
|
||||
{isGenerated ? (
|
||||
<SparklesIcon className="size-3.5 text-white/70" />
|
||||
) : source?.iconUrl ? (
|
||||
<NextImage
|
||||
src={source.iconUrl}
|
||||
alt={source.label}
|
||||
width={16}
|
||||
height={16}
|
||||
width={14}
|
||||
height={14}
|
||||
className="rounded"
|
||||
unoptimized
|
||||
/>
|
||||
) : (
|
||||
<ExternalLinkIcon className="size-4 text-white/70" />
|
||||
<ExternalLinkIcon className="size-3.5 text-white/70" />
|
||||
)}
|
||||
<span className="text-white/70 text-sm">{displayDomain}</span>
|
||||
<span className="text-white/70 text-xs">{displayDomain}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Always visible domain badge (bottom right, shown when NOT hovered) */}
|
||||
{/* Badge when not hovered */}
|
||||
{displayDomain && !isHovered && (
|
||||
<div className="absolute bottom-2 right-2">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="bg-black/60 text-white border-0 text-xs backdrop-blur-sm"
|
||||
className={cn(
|
||||
"border-0 text-xs backdrop-blur-sm",
|
||||
isGenerated
|
||||
? "bg-primary/80 text-primary-foreground"
|
||||
: "bg-black/60 text-white"
|
||||
)}
|
||||
>
|
||||
{isGenerated && <SparklesIcon className="size-3 mr-1" />}
|
||||
{displayDomain}
|
||||
</Badge>
|
||||
</div>
|
||||
|
|
|
|||
105
surfsense_web/contracts/enums/image-gen-providers.ts
Normal file
105
surfsense_web/contracts/enums/image-gen-providers.ts
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
export interface ImageGenProvider {
|
||||
value: string;
|
||||
label: string;
|
||||
example: string;
|
||||
description: string;
|
||||
apiBase?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image generation providers supported by LiteLLM.
|
||||
* See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
*/
|
||||
export const IMAGE_GEN_PROVIDERS: ImageGenProvider[] = [
|
||||
{
|
||||
value: "OPENAI",
|
||||
label: "OpenAI",
|
||||
example: "dall-e-3, gpt-image-1, dall-e-2",
|
||||
description: "DALL-E and GPT Image models",
|
||||
},
|
||||
{
|
||||
value: "AZURE_OPENAI",
|
||||
label: "Azure OpenAI",
|
||||
example: "azure/dall-e-3, azure/gpt-image-1",
|
||||
description: "OpenAI image models on Azure",
|
||||
},
|
||||
{
|
||||
value: "GOOGLE",
|
||||
label: "Google AI Studio",
|
||||
example: "gemini/imagen-3.0-generate-002",
|
||||
description: "Google AI Studio image generation",
|
||||
},
|
||||
{
|
||||
value: "VERTEX_AI",
|
||||
label: "Google Vertex AI",
|
||||
example: "vertex_ai/imagegeneration@006",
|
||||
description: "Vertex AI image generation models",
|
||||
},
|
||||
{
|
||||
value: "BEDROCK",
|
||||
label: "AWS Bedrock",
|
||||
example: "bedrock/stability.stable-diffusion-xl-v0",
|
||||
description: "Stable Diffusion on AWS Bedrock",
|
||||
},
|
||||
{
|
||||
value: "RECRAFT",
|
||||
label: "Recraft",
|
||||
example: "recraft/recraftv3",
|
||||
description: "AI-powered design and image generation",
|
||||
},
|
||||
{
|
||||
value: "OPENROUTER",
|
||||
label: "OpenRouter",
|
||||
example: "openrouter/google/gemini-2.5-flash-image",
|
||||
description: "Image generation via OpenRouter",
|
||||
},
|
||||
{
|
||||
value: "XINFERENCE",
|
||||
label: "Xinference",
|
||||
example: "xinference/stable-diffusion-xl",
|
||||
description: "Self-hosted Stable Diffusion models",
|
||||
},
|
||||
{
|
||||
value: "NSCALE",
|
||||
label: "Nscale",
|
||||
example: "nscale/flux.1-schnell",
|
||||
description: "Nscale image generation",
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
* Image generation models organized by provider.
|
||||
*/
|
||||
export interface ImageGenModel {
|
||||
value: string;
|
||||
label: string;
|
||||
provider: string;
|
||||
}
|
||||
|
||||
export const IMAGE_GEN_MODELS: ImageGenModel[] = [
|
||||
// OpenAI
|
||||
{ value: "gpt-image-1", label: "GPT Image 1", provider: "OPENAI" },
|
||||
{ value: "dall-e-3", label: "DALL-E 3", provider: "OPENAI" },
|
||||
{ value: "dall-e-2", label: "DALL-E 2", provider: "OPENAI" },
|
||||
// Azure OpenAI
|
||||
{ value: "azure/dall-e-3", label: "DALL-E 3 (Azure)", provider: "AZURE_OPENAI" },
|
||||
{ value: "azure/gpt-image-1", label: "GPT Image 1 (Azure)", provider: "AZURE_OPENAI" },
|
||||
// Recraft
|
||||
{ value: "recraft/recraftv3", label: "Recraft V3", provider: "RECRAFT" },
|
||||
// Bedrock
|
||||
{
|
||||
value: "bedrock/stability.stable-diffusion-xl-v0",
|
||||
label: "Stable Diffusion XL",
|
||||
provider: "BEDROCK",
|
||||
},
|
||||
// Vertex AI
|
||||
{
|
||||
value: "vertex_ai/imagegeneration@006",
|
||||
label: "Imagen 3",
|
||||
provider: "VERTEX_AI",
|
||||
},
|
||||
];
|
||||
|
||||
export function getImageGenModelsByProvider(provider: string): ImageGenModel[] {
|
||||
return IMAGE_GEN_MODELS.filter((m) => m.provider === provider);
|
||||
}
|
||||
|
|
@ -161,19 +161,105 @@ export const globalNewLLMConfig = z.object({
|
|||
|
||||
export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig);
|
||||
|
||||
// =============================================================================
|
||||
// Image Generation Config (separate table from NewLLMConfig)
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* ImageGenProvider enum - only providers that support image generation
|
||||
* See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
*/
|
||||
export const imageGenProviderEnum = z.enum([
|
||||
"OPENAI",
|
||||
"AZURE_OPENAI",
|
||||
"GOOGLE",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"RECRAFT",
|
||||
"OPENROUTER",
|
||||
"XINFERENCE",
|
||||
"NSCALE",
|
||||
]);
|
||||
|
||||
export type ImageGenProvider = z.infer<typeof imageGenProviderEnum>;
|
||||
|
||||
/**
|
||||
* ImageGenerationConfig - user-created image gen model configs
|
||||
* Separate from NewLLMConfig: no system_instructions, no citations_enabled.
|
||||
*/
|
||||
export const imageGenerationConfig = z.object({
|
||||
id: z.number(),
|
||||
name: z.string().max(100),
|
||||
description: z.string().max(500).nullable().optional(),
|
||||
provider: imageGenProviderEnum,
|
||||
custom_provider: z.string().max(100).nullable().optional(),
|
||||
model_name: z.string().max(100),
|
||||
api_key: z.string(),
|
||||
api_base: z.string().max(500).nullable().optional(),
|
||||
api_version: z.string().max(50).nullable().optional(),
|
||||
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
|
||||
created_at: z.string(),
|
||||
search_space_id: z.number(),
|
||||
});
|
||||
|
||||
export const createImageGenConfigRequest = imageGenerationConfig.omit({
|
||||
id: true,
|
||||
created_at: true,
|
||||
});
|
||||
|
||||
export const createImageGenConfigResponse = imageGenerationConfig;
|
||||
|
||||
export const getImageGenConfigsResponse = z.array(imageGenerationConfig);
|
||||
|
||||
export const updateImageGenConfigRequest = z.object({
|
||||
id: z.number(),
|
||||
data: imageGenerationConfig
|
||||
.omit({ id: true, created_at: true, search_space_id: true })
|
||||
.partial(),
|
||||
});
|
||||
|
||||
export const updateImageGenConfigResponse = imageGenerationConfig;
|
||||
|
||||
export const deleteImageGenConfigResponse = z.object({
|
||||
message: z.string(),
|
||||
id: z.number(),
|
||||
});
|
||||
|
||||
/**
|
||||
* Global Image Generation Config - from YAML, has negative IDs
|
||||
* ID 0 is reserved for "Auto" mode (LiteLLM Router load balancing)
|
||||
*/
|
||||
export const globalImageGenConfig = z.object({
|
||||
id: z.number(),
|
||||
name: z.string(),
|
||||
description: z.string().nullable().optional(),
|
||||
provider: z.string(),
|
||||
custom_provider: z.string().nullable().optional(),
|
||||
model_name: z.string(),
|
||||
api_base: z.string().nullable().optional(),
|
||||
api_version: z.string().nullable().optional(),
|
||||
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
|
||||
is_global: z.literal(true),
|
||||
is_auto_mode: z.boolean().optional().default(false),
|
||||
});
|
||||
|
||||
export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig);
|
||||
|
||||
// =============================================================================
|
||||
// LLM Preferences (Role Assignments)
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* LLM Preferences schemas - for role assignments
|
||||
* The agent_llm and document_summary_llm fields contain the full NewLLMConfig objects
|
||||
* image_generation uses image_generation_config_id (not llm_id)
|
||||
*/
|
||||
export const llmPreferences = z.object({
|
||||
agent_llm_id: z.union([z.number(), z.null()]).optional(),
|
||||
document_summary_llm_id: z.union([z.number(), z.null()]).optional(),
|
||||
image_generation_config_id: z.union([z.number(), z.null()]).optional(),
|
||||
agent_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
|
||||
document_summary_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
|
||||
image_generation_config: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
|
|
@ -193,6 +279,7 @@ export const updateLLMPreferencesRequest = z.object({
|
|||
data: llmPreferences.pick({
|
||||
agent_llm_id: true,
|
||||
document_summary_llm_id: true,
|
||||
image_generation_config_id: true,
|
||||
}),
|
||||
});
|
||||
|
||||
|
|
@ -219,6 +306,15 @@ export type GetDefaultSystemInstructionsResponse = z.infer<
|
|||
>;
|
||||
export type GlobalNewLLMConfig = z.infer<typeof globalNewLLMConfig>;
|
||||
export type GetGlobalNewLLMConfigsResponse = z.infer<typeof getGlobalNewLLMConfigsResponse>;
|
||||
export type ImageGenerationConfig = z.infer<typeof imageGenerationConfig>;
|
||||
export type CreateImageGenConfigRequest = z.infer<typeof createImageGenConfigRequest>;
|
||||
export type CreateImageGenConfigResponse = z.infer<typeof createImageGenConfigResponse>;
|
||||
export type GetImageGenConfigsResponse = z.infer<typeof getImageGenConfigsResponse>;
|
||||
export type UpdateImageGenConfigRequest = z.infer<typeof updateImageGenConfigRequest>;
|
||||
export type UpdateImageGenConfigResponse = z.infer<typeof updateImageGenConfigResponse>;
|
||||
export type DeleteImageGenConfigResponse = z.infer<typeof deleteImageGenConfigResponse>;
|
||||
export type GlobalImageGenConfig = z.infer<typeof globalImageGenConfig>;
|
||||
export type GetGlobalImageGenConfigsResponse = z.infer<typeof getGlobalImageGenConfigsResponse>;
|
||||
export type LLMPreferences = z.infer<typeof llmPreferences>;
|
||||
export type GetLLMPreferencesRequest = z.infer<typeof getLLMPreferencesRequest>;
|
||||
export type GetLLMPreferencesResponse = z.infer<typeof getLLMPreferencesResponse>;
|
||||
|
|
|
|||
83
surfsense_web/lib/apis/image-gen-config-api.service.ts
Normal file
83
surfsense_web/lib/apis/image-gen-config-api.service.ts
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import {
|
||||
type CreateImageGenConfigRequest,
|
||||
createImageGenConfigRequest,
|
||||
createImageGenConfigResponse,
|
||||
type UpdateImageGenConfigRequest,
|
||||
updateImageGenConfigRequest,
|
||||
updateImageGenConfigResponse,
|
||||
deleteImageGenConfigResponse,
|
||||
getImageGenConfigsResponse,
|
||||
getGlobalImageGenConfigsResponse,
|
||||
} from "@/contracts/types/new-llm-config.types";
|
||||
import { ValidationError } from "../error";
|
||||
import { baseApiService } from "./base-api.service";
|
||||
|
||||
class ImageGenConfigApiService {
|
||||
/**
|
||||
* Get all global image generation configs (from YAML, negative IDs)
|
||||
*/
|
||||
getGlobalConfigs = async () => {
|
||||
return baseApiService.get(
|
||||
`/api/v1/global-image-generation-configs`,
|
||||
getGlobalImageGenConfigsResponse
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Create a new image generation config for a search space
|
||||
*/
|
||||
createConfig = async (request: CreateImageGenConfigRequest) => {
|
||||
const parsed = createImageGenConfigRequest.safeParse(request);
|
||||
if (!parsed.success) {
|
||||
const msg = parsed.error.issues.map((i) => i.message).join(", ");
|
||||
throw new ValidationError(`Invalid request: ${msg}`);
|
||||
}
|
||||
return baseApiService.post(
|
||||
`/api/v1/image-generation-configs`,
|
||||
createImageGenConfigResponse,
|
||||
{ body: parsed.data }
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Get image generation configs for a search space
|
||||
*/
|
||||
getConfigs = async (searchSpaceId: number) => {
|
||||
const params = new URLSearchParams({
|
||||
search_space_id: String(searchSpaceId),
|
||||
}).toString();
|
||||
return baseApiService.get(
|
||||
`/api/v1/image-generation-configs?${params}`,
|
||||
getImageGenConfigsResponse
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Update an existing image generation config
|
||||
*/
|
||||
updateConfig = async (request: UpdateImageGenConfigRequest) => {
|
||||
const parsed = updateImageGenConfigRequest.safeParse(request);
|
||||
if (!parsed.success) {
|
||||
const msg = parsed.error.issues.map((i) => i.message).join(", ");
|
||||
throw new ValidationError(`Invalid request: ${msg}`);
|
||||
}
|
||||
const { id, data } = parsed.data;
|
||||
return baseApiService.put(
|
||||
`/api/v1/image-generation-configs/${id}`,
|
||||
updateImageGenConfigResponse,
|
||||
{ body: data }
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Delete an image generation config
|
||||
*/
|
||||
deleteConfig = async (id: number) => {
|
||||
return baseApiService.delete(
|
||||
`/api/v1/image-generation-configs/${id}`,
|
||||
deleteImageGenConfigResponse
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
export const imageGenConfigApiService = new ImageGenConfigApiService();
|
||||
|
|
@ -34,6 +34,11 @@ export const cacheKeys = {
|
|||
defaultInstructions: () => ["new-llm-configs", "default-instructions"] as const,
|
||||
global: () => ["new-llm-configs", "global"] as const,
|
||||
},
|
||||
imageGenConfigs: {
|
||||
all: (searchSpaceId: number) => ["image-gen-configs", searchSpaceId] as const,
|
||||
byId: (configId: number) => ["image-gen-configs", "detail", configId] as const,
|
||||
global: () => ["image-gen-configs", "global"] as const,
|
||||
},
|
||||
auth: {
|
||||
user: ["auth", "user"] as const,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -739,6 +739,8 @@
|
|||
"nav_agent_configs_desc": "LLM models with prompts & citations",
|
||||
"nav_role_assignments": "Role Assignments",
|
||||
"nav_role_assignments_desc": "Assign configs to agent roles",
|
||||
"nav_image_models": "Image Models",
|
||||
"nav_image_models_desc": "Configure image generation models",
|
||||
"nav_system_instructions": "System Instructions",
|
||||
"nav_system_instructions_desc": "SearchSpace-wide AI instructions",
|
||||
"nav_public_links": "Public Chat Links",
|
||||
|
|
|
|||
|
|
@ -724,6 +724,8 @@
|
|||
"nav_agent_configs_desc": "LLM 模型配置提示词和引用",
|
||||
"nav_role_assignments": "角色分配",
|
||||
"nav_role_assignments_desc": "为代理角色分配配置",
|
||||
"nav_image_models": "图像模型",
|
||||
"nav_image_models_desc": "配置图像生成模型",
|
||||
"nav_system_instructions": "系统指令",
|
||||
"nav_system_instructions_desc": "搜索空间级别的 AI 指令",
|
||||
"nav_public_links": "公开聊天链接",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue