diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 2e10f4e36..628329917 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -32,6 +32,11 @@ ELECTRIC_DB_PASSWORD=electric_password SCHEDULE_CHECKER_INTERVAL=5m SECRET_KEY=SECRET + +# JWT Token Lifetimes (optional, defaults shown) +# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day +# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks + NEXT_FRONTEND_URL=http://localhost:3000 # Backend URL for OAuth callbacks (optional, set when behind reverse proxy with HTTPS) diff --git a/surfsense_backend/alembic/versions/92_add_refresh_tokens_table.py b/surfsense_backend/alembic/versions/92_add_refresh_tokens_table.py new file mode 100644 index 000000000..c7e133ae9 --- /dev/null +++ b/surfsense_backend/alembic/versions/92_add_refresh_tokens_table.py @@ -0,0 +1,92 @@ +"""Add refresh_tokens table for user session management + +Revision ID: 92 +Revises: 91 + +Changes: +1. Create refresh_tokens table with columns: + - id (primary key) + - user_id (foreign key to user) + - token_hash (unique, indexed) + - expires_at (indexed) + - is_revoked + - family_id (indexed, for token rotation tracking) + - created_at, updated_at (timestamps) +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "92" +down_revision: str | None = "91" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Create refresh_tokens table (idempotent).""" + # Check if table already exists + connection = op.get_bind() + result = connection.execute( + sa.text( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'refresh_tokens')" + ) + ) + table_exists = result.scalar() + + if not table_exists: + op.create_table( + "refresh_tokens", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("user_id", UUID(as_uuid=True), nullable=False), + sa.Column("token_hash", sa.String(256), nullable=False), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("is_revoked", sa.Boolean(), nullable=False, default=False), + sa.Column("family_id", UUID(as_uuid=True), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ondelete="CASCADE", + ), + ) + + # Create indexes if they don't exist + op.execute( + "CREATE INDEX IF NOT EXISTS ix_refresh_tokens_user_id ON refresh_tokens (user_id)" + ) + op.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS ix_refresh_tokens_token_hash ON refresh_tokens (token_hash)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_refresh_tokens_expires_at ON refresh_tokens (expires_at)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_refresh_tokens_family_id ON refresh_tokens (family_id)" + ) + + +def downgrade() -> None: + """Drop refresh_tokens table (idempotent).""" + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_family_id") + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_expires_at") + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_token_hash") + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_user_id") + op.execute("DROP TABLE IF EXISTS refresh_tokens") diff --git a/surfsense_backend/alembic/versions/93_add_image_generations_table.py b/surfsense_backend/alembic/versions/93_add_image_generations_table.py new file mode 100644 index 000000000..eba9d7c86 --- /dev/null +++ b/surfsense_backend/alembic/versions/93_add_image_generations_table.py @@ -0,0 +1,300 @@ +"""Add image generation tables and search space preference + +Revision ID: 93 +Revises: 92 + +Changes: +1. Create image_generation_configs table (user-created image model configs) +2. Create image_generations table (stores generation requests/results) +3. Add image_generation_config_id column to searchspaces table +4. Add image generation permissions to existing system roles +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ENUM as PG_ENUM +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "93" +down_revision: str | None = "92" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + connection = op.get_bind() + + # 1. Create imagegenprovider enum type if it doesn't exist + connection.execute( + sa.text( + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'imagegenprovider') THEN + CREATE TYPE imagegenprovider AS ENUM ( + 'OPENAI', 'AZURE_OPENAI', 'GOOGLE', 'VERTEX_AI', 'BEDROCK', + 'RECRAFT', 'OPENROUTER', 'XINFERENCE', 'NSCALE' + ); + END IF; + END + $$; + """ + ) + ) + + # 2. Create image_generation_configs table (uses imagegenprovider enum) + result = connection.execute( + sa.text( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generation_configs')" + ) + ) + if not result.scalar(): + op.create_table( + "image_generation_configs", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("name", sa.String(100), nullable=False), + sa.Column("description", sa.String(500), nullable=True), + sa.Column( + "provider", + PG_ENUM( + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", + name="imagegenprovider", + create_type=False, + ), + nullable=False, + ), + sa.Column("custom_provider", sa.String(100), nullable=True), + sa.Column("model_name", sa.String(100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(500), nullable=True), + sa.Column("api_version", sa.String(50), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generation_configs_name " + "ON image_generation_configs (name)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generation_configs_search_space_id " + "ON image_generation_configs (search_space_id)" + ) + + # 3. Create image_generations table + result = connection.execute( + sa.text( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generations')" + ) + ) + if not result.scalar(): + op.create_table( + "image_generations", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("prompt", sa.Text(), nullable=False), + sa.Column("model", sa.String(200), nullable=True), + sa.Column("n", sa.Integer(), nullable=True), + sa.Column("quality", sa.String(50), nullable=True), + sa.Column("size", sa.String(50), nullable=True), + sa.Column("style", sa.String(50), nullable=True), + sa.Column("response_format", sa.String(50), nullable=True), + sa.Column("image_generation_config_id", sa.Integer(), nullable=True), + sa.Column("response_data", JSONB(), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("created_by_id", UUID(as_uuid=True), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["created_by_id"], ["user.id"], ondelete="SET NULL" + ), + ) + + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generations_search_space_id " + "ON image_generations (search_space_id)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generations_created_by_id " + "ON image_generations (created_by_id)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_image_generations_created_at " + "ON image_generations (created_at)" + ) + + # 4. Add image_generation_config_id column to searchspaces + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' + AND column_name = 'image_generation_config_id' + ) + """ + ) + ) + if not result.scalar(): + op.add_column( + "searchspaces", + sa.Column( + "image_generation_config_id", + sa.Integer(), + nullable=True, + server_default="0", + ), + ) + + # Drop old column name if it exists (from earlier version of this migration) + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' + AND column_name = 'image_generation_llm_id' + ) + """ + ) + ) + if result.scalar(): + op.drop_column("searchspaces", "image_generation_llm_id") + + # Drop old column name on image_generations if it exists + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'image_generations' + AND column_name = 'llm_config_id' + ) + """ + ) + ) + if result.scalar(): + op.drop_column("image_generations", "llm_config_id") + + # Drop old api_version column on image_generations if it exists + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'image_generations' + AND column_name = 'api_version' + ) + """ + ) + ) + if result.scalar(): + op.drop_column("image_generations", "api_version") + + # 5. Add image generation permissions to existing system roles + connection.execute( + sa.text( + """ + UPDATE search_space_roles + SET permissions = array_cat( + permissions, + ARRAY['image_generations:create', 'image_generations:read'] + ) + WHERE is_system_role = true + AND name = 'Editor' + AND NOT ('image_generations:create' = ANY(permissions)) + """ + ) + ) + connection.execute( + sa.text( + """ + UPDATE search_space_roles + SET permissions = array_cat( + permissions, + ARRAY['image_generations:read'] + ) + WHERE is_system_role = true + AND name = 'Viewer' + AND NOT ('image_generations:read' = ANY(permissions)) + """ + ) + ) + + +def downgrade() -> None: + connection = op.get_bind() + + # Remove permissions + connection.execute( + sa.text( + """ + UPDATE search_space_roles + SET permissions = array_remove( + array_remove( + array_remove(permissions, 'image_generations:create'), + 'image_generations:read' + ), + 'image_generations:delete' + ) + WHERE is_system_role = true + """ + ) + ) + + # Remove image_generation_config_id from searchspaces + result = connection.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' + AND column_name = 'image_generation_config_id' + ) + """ + ) + ) + if result.scalar(): + op.drop_column("searchspaces", "image_generation_config_id") + + # Drop indexes and tables + op.execute("DROP INDEX IF EXISTS ix_image_generations_created_at") + op.execute("DROP INDEX IF EXISTS ix_image_generations_created_by_id") + op.execute("DROP INDEX IF EXISTS ix_image_generations_search_space_id") + op.execute("DROP TABLE IF EXISTS image_generations") + + op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_search_space_id") + op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_name") + op.execute("DROP TABLE IF EXISTS image_generation_configs") + + # Drop the imagegenprovider enum type + op.execute("DROP TYPE IF EXISTS imagegenprovider") diff --git a/surfsense_backend/alembic/versions/94_add_access_token_to_image_generations.py b/surfsense_backend/alembic/versions/94_add_access_token_to_image_generations.py new file mode 100644 index 000000000..09bea2c19 --- /dev/null +++ b/surfsense_backend/alembic/versions/94_add_access_token_to_image_generations.py @@ -0,0 +1,39 @@ +"""Add access_token column to image_generations + +Revision ID: 94 +Revises: 93 + +Adds an indexed access_token column to the image_generations table. +This token is stored per-record so that image serving URLs survive +SECRET_KEY rotation. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "94" +down_revision: str | None = "93" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # Add access_token column (nullable so existing rows are unaffected) + op.add_column( + "image_generations", + sa.Column("access_token", sa.String(64), nullable=True), + ) + op.create_index( + "ix_image_generations_access_token", + "image_generations", + ["access_token"], + ) + + +def downgrade() -> None: + op.drop_index("ix_image_generations_access_token", table_name="image_generations") + op.drop_column("image_generations", "access_token") diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 9c383c308..9da6ea3c2 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -133,6 +133,7 @@ async def create_surfsense_deep_agent( The agent comes with built-in tools that can be configured: - search_knowledge_base: Search the user's personal knowledge base - generate_podcast: Generate audio podcasts from content + - generate_image: Generate images from text descriptions using AI models - link_preview: Fetch rich previews for URLs - display_image: Display images in chat - scrape_webpage: Extract content from webpages diff --git a/surfsense_backend/app/agents/new_chat/checkpointer.py b/surfsense_backend/app/agents/new_chat/checkpointer.py index 637b2926f..04ecfbdea 100644 --- a/surfsense_backend/app/agents/new_chat/checkpointer.py +++ b/surfsense_backend/app/agents/new_chat/checkpointer.py @@ -3,15 +3,25 @@ PostgreSQL-based checkpointer for LangGraph agents. This module provides a persistent checkpointer using AsyncPostgresSaver that stores conversation state in the PostgreSQL database. + +Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle +connection lifecycle, health checks, and automatic reconnection, +preventing 'the connection is closed' errors in long-running deployments. """ +import logging + from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from psycopg.rows import dict_row +from psycopg_pool import AsyncConnectionPool from app.config import config +logger = logging.getLogger(__name__) + # Global checkpointer instance (initialized lazily) _checkpointer: AsyncPostgresSaver | None = None -_checkpointer_context = None # Store the context manager for cleanup +_connection_pool: AsyncConnectionPool | None = None _checkpointer_initialized: bool = False @@ -38,26 +48,65 @@ def get_postgres_connection_string() -> str: return db_url +async def _create_checkpointer() -> AsyncPostgresSaver: + """ + Create a new AsyncPostgresSaver backed by a connection pool. + + The connection pool automatically handles: + - Connection health checks before use + - Reconnection when connections die (idle timeout, DB restart, etc.) + - Connection lifecycle management (max_lifetime, max_idle) + """ + global _connection_pool + + conn_string = get_postgres_connection_string() + + _connection_pool = AsyncConnectionPool( + conninfo=conn_string, + min_size=2, + max_size=10, + # Connections are recycled after 30 minutes to avoid stale connections + max_lifetime=1800, + # Idle connections are closed after 5 minutes + max_idle=300, + open=False, + # Connection kwargs required by AsyncPostgresSaver: + # - autocommit: required for .setup() to commit checkpoint tables + # - prepare_threshold: disable prepared statements for compatibility + # - row_factory: checkpointer accesses rows as dicts (row["column"]) + kwargs={ + "autocommit": True, + "prepare_threshold": 0, + "row_factory": dict_row, + }, + ) + await _connection_pool.open(wait=True) + + checkpointer = AsyncPostgresSaver(conn=_connection_pool) + logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool") + return checkpointer + + async def get_checkpointer() -> AsyncPostgresSaver: """ Get or create the global AsyncPostgresSaver instance. This function: - 1. Creates the checkpointer if it doesn't exist + 1. Creates the checkpointer with a connection pool if it doesn't exist 2. Sets up the required database tables on first call 3. Returns the cached instance on subsequent calls + The underlying connection pool handles reconnection automatically, + so a stale/closed connection will not cause OperationalError. + Returns: AsyncPostgresSaver: The configured checkpointer instance """ - global _checkpointer, _checkpointer_context, _checkpointer_initialized + global _checkpointer, _checkpointer_initialized if _checkpointer is None: - conn_string = get_postgres_connection_string() - # from_conn_string returns an async context manager - # We need to enter the context to get the actual checkpointer - _checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string) - _checkpointer = await _checkpointer_context.__aenter__() + _checkpointer = await _create_checkpointer() + _checkpointer_initialized = False # Setup tables on first call (idempotent) if not _checkpointer_initialized: @@ -75,20 +124,21 @@ async def setup_checkpointer_tables() -> None: tables exist before any agent calls. """ await get_checkpointer() - print("[Checkpointer] PostgreSQL checkpoint tables ready") + logger.info("[Checkpointer] PostgreSQL checkpoint tables ready") async def close_checkpointer() -> None: """ - Close the checkpointer connection. + Close the checkpointer connection pool. This should be called during application shutdown. """ - global _checkpointer, _checkpointer_context, _checkpointer_initialized + global _checkpointer, _connection_pool, _checkpointer_initialized - if _checkpointer_context is not None: - await _checkpointer_context.__aexit__(None, None, None) - _checkpointer = None - _checkpointer_context = None - _checkpointer_initialized = False - print("[Checkpointer] PostgreSQL connection closed") + if _connection_pool is not None: + await _connection_pool.close() + logger.info("[Checkpointer] PostgreSQL connection pool closed") + + _checkpointer = None + _connection_pool = None + _checkpointer_initialized = False diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index fdb80acb9..01c762197 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -83,6 +83,7 @@ You have access to the following tools: * Showing an image from a URL the user explicitly mentioned in their message * Displaying images found in scraped webpage content (from scrape_webpage tool) * Showing a publicly accessible diagram or chart from a known URL + * Displaying an AI-generated image after calling the generate_image tool (ALWAYS required) CRITICAL - NEVER USE THIS TOOL FOR USER-UPLOADED ATTACHMENTS: When a user uploads/attaches an image file to their message: @@ -100,7 +101,21 @@ You have access to the following tools: - Returns: An image card with the image, title, and description - The image will automatically be displayed in the chat. -5. scrape_webpage: Scrape and extract the main content from a webpage. +5. generate_image: Generate images from text descriptions using AI image models. + - Use this when the user asks you to create, generate, draw, design, or make an image. + - Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork" + - Args: + - prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood. + - n: Number of images to generate (1-4, default: 1) + - Returns: A dictionary with the generated image URL in the "src" field, along with metadata. + - CRITICAL: After calling generate_image, you MUST call `display_image` with the returned "src" URL + to actually show the image in the chat. The generate_image tool only generates the image and returns + the URL — it does NOT display anything. You must always follow up with display_image. + - IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim - + expand and improve the prompt with specific details about style, lighting, composition, and mood. + - If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details. + +6. scrape_webpage: Scrape and extract the main content from a webpage. - Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage. - IMPORTANT: This is different from link_preview: * link_preview: Only fetches metadata (title, description, thumbnail) for display @@ -123,7 +138,7 @@ You have access to the following tools: * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. * Don't show every image - just the most relevant 1-3 images that enhance understanding. -6. save_memory: Save facts, preferences, or context about the user for personalized responses. +7. save_memory: Save facts, preferences, or context about the user for personalized responses. - Use this when the user explicitly or implicitly shares information worth remembering. - Trigger scenarios: * User says "remember this", "keep this in mind", "note that", or similar @@ -146,7 +161,7 @@ You have access to the following tools: - IMPORTANT: Only save information that would be genuinely useful for future conversations. Don't save trivial or temporary information. -7. recall_memory: Retrieve relevant memories about the user for personalized responses. +8. recall_memory: Retrieve relevant memories about the user for personalized responses. - Use this to access stored information about the user. - Trigger scenarios: * You need user context to give a better, more personalized answer @@ -281,6 +296,22 @@ You have access to the following tools: - Then, if the content contains useful diagrams/images like `![Neural Network Diagram](https://example.com/nn-diagram.png)`: - Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")` - Then provide your explanation, referencing the displayed image + +- User: "Generate an image of a cat" + - Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")` + - Step 2: Use the returned "src" URL to display it: `display_image(src="", alt="A fluffy orange tabby cat on a windowsill", title="Generated Image")` + +- User: "Create a landscape painting of mountains" + - Step 1: `generate_image(prompt="Majestic snow-capped mountain range at sunset, dramatic orange and purple sky, alpine meadow with wildflowers in the foreground, oil painting style with visible brushstrokes, inspired by the Hudson River School art movement")` + - Step 2: `display_image(src="", alt="Mountain landscape painting", title="Generated Image")` + +- User: "Draw me a logo for a coffee shop called Bean Dream" + - Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")` + - Step 2: `display_image(src="", alt="Bean Dream coffee shop logo", title="Generated Image")` + +- User: "Make a wide banner image for my blog about AI" + - Step 1: `generate_image(prompt="Wide banner illustration for an AI technology blog, featuring abstract neural network patterns, glowing blue and purple connections, modern futuristic aesthetic, digital art style, clean and professional")` + - Step 2: `display_image(src="", alt="AI blog banner", title="Generated Image")` """ diff --git a/surfsense_backend/app/agents/new_chat/tools/__init__.py b/surfsense_backend/app/agents/new_chat/tools/__init__.py index 9e1a4f19c..0a11951f0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/__init__.py +++ b/surfsense_backend/app/agents/new_chat/tools/__init__.py @@ -8,6 +8,7 @@ Available tools: - search_knowledge_base: Search the user's personal knowledge base - search_surfsense_docs: Search Surfsense documentation for usage help - generate_podcast: Generate audio podcasts from content +- generate_image: Generate images from text descriptions using AI models - link_preview: Fetch rich previews for URLs - display_image: Display images in chat - scrape_webpage: Extract content from webpages @@ -18,6 +19,7 @@ Available tools: # Registry exports # Tool factory exports (for direct use) from .display_image import create_display_image_tool +from .generate_image import create_generate_image_tool from .knowledge_base import ( CONNECTOR_DESCRIPTIONS, create_search_knowledge_base_tool, @@ -47,6 +49,7 @@ __all__ = [ "build_tools", # Tool factories "create_display_image_tool", + "create_generate_image_tool", "create_generate_podcast_tool", "create_link_preview_tool", "create_recall_memory_tool", diff --git a/surfsense_backend/app/agents/new_chat/tools/display_image.py b/surfsense_backend/app/agents/new_chat/tools/display_image.py index 5eb846063..4424cc0d3 100644 --- a/surfsense_backend/app/agents/new_chat/tools/display_image.py +++ b/surfsense_backend/app/agents/new_chat/tools/display_image.py @@ -82,14 +82,20 @@ def create_display_image_tool(): domain = extract_domain(src) - # Determine aspect ratio based on common image sources - ratio = "16:9" # Default - if "unsplash.com" in src or "pexels.com" in src: + # Determine aspect ratio based on image source + # AI-generated images should use "auto" to preserve their native ratio + is_generated = "/image-generations/" in src + if is_generated: + ratio = "auto" + domain = "ai-generated" + elif "unsplash.com" in src or "pexels.com" in src: ratio = "16:9" elif ( "imgur.com" in src or "github.com" in src or "githubusercontent.com" in src ): ratio = "auto" + else: + ratio = "auto" return { "id": image_id, diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py new file mode 100644 index 000000000..8ffa4ecde --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py @@ -0,0 +1,242 @@ +""" +Image generation tool for the SurfSense agent. + +This module provides a tool that generates images using litellm.aimage_generation() +and returns the result via the existing display_image tool format so the frontend +renders the generated image inline in the chat. + +Config resolution: +1. Uses the search space's image_generation_config_id preference +2. Falls back to Auto mode (router load balancing) if available +3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs) +""" + +import logging +from typing import Any + +from langchain_core.tools import tool +from litellm import aimage_generation +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace +from app.services.image_gen_router_service import ( + IMAGE_GEN_AUTO_MODE_ID, + ImageGenRouterService, + is_image_gen_auto_mode, +) +from app.utils.signed_image_urls import generate_image_token + +logger = logging.getLogger(__name__) + +# Provider mapping (same as routes) +_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} + + +def _build_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: + if custom_provider: + return f"{custom_provider}/{model_name}" + prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) + return f"{prefix}/{model_name}" + + +def _get_global_image_gen_config(config_id: int) -> dict | None: + """Get a global image gen config by negative ID.""" + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None + + +def create_generate_image_tool( + search_space_id: int, + db_session: AsyncSession, +): + """ + Factory function to create the generate_image tool. + + Args: + search_space_id: The search space ID (for config resolution) + db_session: Async database session + """ + + @tool + async def generate_image( + prompt: str, + n: int = 1, + ) -> dict[str, Any]: + """ + Generate an image from a text description using AI image models. + + Use this tool when the user asks you to create, generate, draw, or make an image. + The generated image will be displayed directly in the chat. + + Args: + prompt: A detailed text description of the image to generate. + Be specific about subject, style, colors, composition, and mood. + n: Number of images to generate (1-4). Default: 1 + + Returns: + A dictionary containing the generated image(s) for display in the chat. + """ + try: + # Resolve the image generation config from the search space preference + result = await db_session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + return {"error": "Search space not found"} + + config_id = ( + search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + ) + + # Build generation kwargs + # NOTE: size, quality, and style are intentionally NOT passed. + # Different models support different values for these params + # (e.g. DALL-E 3 wants "hd"/"standard" for quality while + # gpt-image-1 wants "high"/"medium"/"low"; size options also + # differ). Letting the model use its own defaults avoids errors. + gen_kwargs: dict[str, Any] = {} + if n is not None and n > 1: + gen_kwargs["n"] = n + + # Call litellm based on config type + if is_image_gen_auto_mode(config_id): + if not ImageGenRouterService.is_initialized(): + return { + "error": "No image generation models configured. " + "Please add an image model in Settings > Image Models." + } + response = await ImageGenRouterService.aimage_generation( + prompt=prompt, model="auto", **gen_kwargs + ) + elif config_id < 0: + cfg = _get_global_image_gen_config(config_id) + if not cfg: + return {"error": f"Image generation config {config_id} not found"} + + model_string = _build_model_string( + cfg.get("provider", ""), + cfg["model_name"], + cfg.get("custom_provider"), + ) + gen_kwargs["api_key"] = cfg.get("api_key") + if cfg.get("api_base"): + gen_kwargs["api_base"] = cfg["api_base"] + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) + + response = await aimage_generation( + prompt=prompt, model=model_string, **gen_kwargs + ) + else: + # Positive ID = user-created ImageGenerationConfig + cfg_result = await db_session.execute( + select(ImageGenerationConfig).filter( + ImageGenerationConfig.id == config_id + ) + ) + db_cfg = cfg_result.scalars().first() + if not db_cfg: + return {"error": f"Image generation config {config_id} not found"} + + model_string = _build_model_string( + db_cfg.provider.value, + db_cfg.model_name, + db_cfg.custom_provider, + ) + gen_kwargs["api_key"] = db_cfg.api_key + if db_cfg.api_base: + gen_kwargs["api_base"] = db_cfg.api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) + + response = await aimage_generation( + prompt=prompt, model=model_string, **gen_kwargs + ) + + # Parse the response and store in DB + response_dict = ( + response.model_dump() + if hasattr(response, "model_dump") + else dict(response) + ) + + # Generate a random access token for this image + access_token = generate_image_token() + + # Save to image_generations table for history + db_image_gen = ImageGeneration( + prompt=prompt, + model=getattr(response, "_hidden_params", {}).get("model"), + n=n, + image_generation_config_id=config_id, + response_data=response_dict, + search_space_id=search_space_id, + access_token=access_token, + ) + db_session.add(db_image_gen) + await db_session.commit() + await db_session.refresh(db_image_gen) + + # Extract image URLs from response + images = response_dict.get("data", []) + if not images: + return {"error": "No images were generated"} + + first_image = images[0] + revised_prompt = first_image.get("revised_prompt", prompt) + + # Resolve image URL: + # - If the API returned a URL, use it directly. + # - If the API returned b64_json (e.g. gpt-image-1), serve the + # image through our backend endpoint to avoid bloating the + # LLM context with megabytes of base64 data. + if first_image.get("url"): + image_url = first_image["url"] + elif first_image.get("b64_json"): + backend_url = config.BACKEND_URL or "http://localhost:8000" + image_url = ( + f"{backend_url}/api/v1/image-generations/" + f"{db_image_gen.id}/image?token={access_token}" + ) + else: + return {"error": "No displayable image data in the response"} + + return { + "src": image_url, + "alt": revised_prompt or prompt, + "title": "Generated Image", + "description": revised_prompt if revised_prompt != prompt else None, + "generated": True, + "prompt": prompt, + "image_count": len(images), + } + + except Exception as e: + logger.exception("Image generation failed in tool") + return { + "error": f"Image generation failed: {e!s}", + "prompt": prompt, + } + + return generate_image diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index c65445419..2cf43c973 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -44,6 +44,7 @@ from typing import Any from langchain_core.tools import BaseTool from .display_image import create_display_image_tool +from .generate_image import create_generate_image_tool from .knowledge_base import create_search_knowledge_base_tool from .link_preview import create_link_preview_tool from .mcp_tool import load_mcp_tools @@ -125,6 +126,16 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ factory=lambda deps: create_display_image_tool(), requires=[], ), + # Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.) + ToolDefinition( + name="generate_image", + description="Generate images from text descriptions using AI image models", + factory=lambda deps: create_generate_image_tool( + search_space_id=deps["search_space_id"], + db_session=deps["db_session"], + ), + requires=["search_space_id", "db_session"], + ), # Web scraping tool - extracts content from webpages ToolDefinition( name="scrape_webpage", diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 01dd0da3d..0f619097e 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -9,9 +9,10 @@ from app.agents.new_chat.checkpointer import ( close_checkpointer, setup_checkpointer_tables, ) -from app.config import config, initialize_llm_router +from app.config import config, initialize_image_gen_router, initialize_llm_router from app.db import User, create_db_and_tables, get_async_session from app.routes import router as crud_router +from app.routes.auth_routes import router as auth_router from app.schemas import UserCreate, UserRead, UserUpdate from app.tasks.surfsense_docs_indexer import seed_surfsense_docs from app.users import SECRET, auth_backend, current_active_user, fastapi_users @@ -25,6 +26,8 @@ async def lifespan(app: FastAPI): await setup_checkpointer_tables() # Initialize LLM Router for Auto mode load balancing initialize_llm_router() + # Initialize Image Generation Router for Auto mode load balancing + initialize_image_gen_router() # Seed Surfsense documentation await seed_surfsense_docs() yield @@ -111,6 +114,9 @@ app.include_router( tags=["users"], ) +# Include custom auth routes (refresh token, logout) +app.include_router(auth_router) + if config.AUTH_TYPE == "GOOGLE": from fastapi.responses import RedirectResponse diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 74b21fbf0..af406eab7 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -13,14 +13,15 @@ load_dotenv() @worker_process_init.connect def init_worker(**kwargs): - """Initialize the LLM Router when a Celery worker process starts. + """Initialize the LLM Router and Image Gen Router when a Celery worker process starts. This ensures the Auto mode (LiteLLM Router) is available for background tasks - like document summarization. + like document summarization and image generation. """ - from app.config import initialize_llm_router + from app.config import initialize_image_gen_router, initialize_llm_router initialize_llm_router() + initialize_image_gen_router() # Get Celery configuration from environment diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 149fedd39..bb299e583 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -81,6 +81,56 @@ def load_router_settings(): return default_settings +def load_global_image_gen_configs(): + """ + Load global image generation configurations from YAML file. + + Returns: + list: List of global image generation config dictionaries, or empty list + """ + global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml" + + if not global_config_file.exists(): + return [] + + try: + with open(global_config_file, encoding="utf-8") as f: + data = yaml.safe_load(f) + return data.get("global_image_generation_configs", []) + except Exception as e: + print(f"Warning: Failed to load global image generation configs: {e}") + return [] + + +def load_image_gen_router_settings(): + """ + Load router settings for image generation Auto mode from YAML file. + + Returns: + dict: Router settings dictionary + """ + default_settings = { + "routing_strategy": "usage-based-routing", + "num_retries": 3, + "allowed_fails": 3, + "cooldown_time": 60, + } + + global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml" + + if not global_config_file.exists(): + return default_settings + + try: + with open(global_config_file, encoding="utf-8") as f: + data = yaml.safe_load(f) + settings = data.get("image_generation_router_settings", {}) + return {**default_settings, **settings} + except Exception as e: + print(f"Warning: Failed to load image generation router settings: {e}") + return default_settings + + def initialize_llm_router(): """ Initialize the LLM Router service for Auto mode. @@ -105,6 +155,33 @@ def initialize_llm_router(): print(f"Warning: Failed to initialize LLM Router: {e}") +def initialize_image_gen_router(): + """ + Initialize the Image Generation Router service for Auto mode. + This should be called during application startup. + """ + image_gen_configs = load_global_image_gen_configs() + router_settings = load_image_gen_router_settings() + + if not image_gen_configs: + print( + "Info: No global image generation configs found, " + "Image Generation Auto mode will not be available" + ) + return + + try: + from app.services.image_gen_router_service import ImageGenRouterService + + ImageGenRouterService.initialize(image_gen_configs, router_settings) + print( + f"Info: Image Generation Router initialized with {len(image_gen_configs)} models " + f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})" + ) + except Exception as e: + print(f"Warning: Failed to initialize Image Generation Router: {e}") + + class Config: # Check if ffmpeg is installed if not is_ffmpeg_installed(): @@ -216,6 +293,12 @@ class Config: # Router settings for Auto mode (LiteLLM Router load balancing) ROUTER_SETTINGS = load_router_settings() + # Global Image Generation Configurations (optional) + GLOBAL_IMAGE_GEN_CONFIGS = load_global_image_gen_configs() + + # Router settings for Image Generation Auto mode + IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings() + # Chonkie Configuration | Edit this to your needs EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") # Azure OpenAI credentials from environment variables @@ -255,6 +338,14 @@ class Config: # OAuth JWT SECRET_KEY = os.getenv("SECRET_KEY") + # JWT Token Lifetimes + ACCESS_TOKEN_LIFETIME_SECONDS = int( + os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(24 * 60 * 60)) # 1 day + ) + REFRESH_TOKEN_LIFETIME_SECONDS = int( + os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks + ) + # ETL Service ETL_SERVICE = os.getenv("ETL_SERVICE") diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 9b213aafe..0bb00c398 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -183,6 +183,69 @@ global_llm_configs: use_default_system_instructions: true citations_enabled: true +# ============================================================================= +# Image Generation Configuration +# ============================================================================= +# These configurations power the image generation feature using litellm.aimage_generation(). +# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock, +# Recraft, OpenRouter, Xinference, Nscale +# +# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs. + +# Router Settings for Image Generation Auto Mode +image_generation_router_settings: + routing_strategy: "usage-based-routing" + num_retries: 3 + allowed_fails: 3 + cooldown_time: 60 + +global_image_generation_configs: + # Example: OpenAI DALL-E 3 + - id: -1 + name: "Global DALL-E 3" + description: "OpenAI's DALL-E 3 for high-quality image generation" + provider: "OPENAI" + model_name: "dall-e-3" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens) + litellm_params: {} + + # Example: OpenAI GPT Image 1 + - id: -2 + name: "Global GPT Image 1" + description: "OpenAI's GPT Image 1 model" + provider: "OPENAI" + model_name: "gpt-image-1" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 50 + litellm_params: {} + + # Example: Azure OpenAI DALL-E 3 + - id: -3 + name: "Global Azure DALL-E 3" + description: "Azure-hosted DALL-E 3 deployment" + provider: "AZURE_OPENAI" + model_name: "azure/dall-e-3-deployment" + api_key: "your-azure-api-key-here" + api_base: "https://your-resource.openai.azure.com" + api_version: "2024-02-15-preview" + rpm: 50 + litellm_params: + base_model: "dall-e-3" + + # Example: OpenRouter Gemini Image Generation + # - id: -4 + # name: "Global Gemini Image Gen" + # description: "Google Gemini image generation via OpenRouter" + # provider: "OPENROUTER" + # model_name: "google/gemini-2.5-flash-image" + # api_key: "your-openrouter-api-key-here" + # api_base: "" + # rpm: 30 + # litellm_params: {} + # Notes: # - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing # - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB) @@ -195,10 +258,11 @@ global_llm_configs: # - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute) # These help the router distribute load evenly and avoid rate limit errors # -# AZURE-SPECIFIC NOTES: -# - Always add 'base_model' in litellm_params for Azure deployments -# - This fixes "Could not identify azure model 'X'" warnings -# - base_model should match the underlying OpenAI model (e.g., gpt-4o, gpt-4-turbo, gpt-3.5-turbo) -# - model_name format: "azure/" -# - api_version: Use a recent Azure API version (e.g., "2024-02-15-preview") -# - See: https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models +# +# IMAGE GENERATION NOTES: +# - Image generation configs use the same ID scheme as LLM configs (negative for global) +# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure), +# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter) +# - The router uses litellm.aimage_generation() for async image generation +# - Only RPM (requests per minute) is relevant for image generation rate limiting. +# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token. diff --git a/surfsense_backend/app/connectors/airtable_history.py b/surfsense_backend/app/connectors/airtable_history.py index 64f6465fe..49c2fcbdd 100644 --- a/surfsense_backend/app/connectors/airtable_history.py +++ b/surfsense_backend/app/connectors/airtable_history.py @@ -71,6 +71,14 @@ class AirtableHistoryConnector: config_data = connector.config.copy() + # Check if access_token exists before processing + raw_access_token = config_data.get("access_token") + if not raw_access_token: + raise ValueError( + "Airtable access token not found. " + "Please reconnect your Airtable account." + ) + # Decrypt credentials if they are encrypted token_encrypted = config_data.get("_token_encrypted", False) if token_encrypted and config.SECRET_KEY: @@ -98,6 +106,16 @@ class AirtableHistoryConnector: f"Failed to decrypt Airtable credentials: {e!s}" ) from e + # Final validation after decryption + final_token = config_data.get("access_token") + if not final_token or ( + isinstance(final_token, str) and not final_token.strip() + ): + raise ValueError( + "Airtable access token is invalid or empty. " + "Please reconnect your Airtable account." + ) + try: self._credentials = AirtableAuthCredentialsBase.from_dict(config_data) except Exception as e: diff --git a/surfsense_backend/app/connectors/confluence_history.py b/surfsense_backend/app/connectors/confluence_history.py index 9e10ffcf1..5d19edc54 100644 --- a/surfsense_backend/app/connectors/confluence_history.py +++ b/surfsense_backend/app/connectors/confluence_history.py @@ -87,6 +87,14 @@ class ConfluenceHistoryConnector: if is_oauth: # OAuth 2.0 authentication + # Check if access_token exists before processing + raw_access_token = config_data.get("access_token") + if not raw_access_token: + raise ValueError( + "Confluence access token not found. " + "Please reconnect your Confluence account." + ) + # Decrypt credentials if they are encrypted token_encrypted = config_data.get("_token_encrypted", False) if token_encrypted and config.SECRET_KEY: @@ -118,6 +126,16 @@ class ConfluenceHistoryConnector: f"Failed to decrypt Confluence credentials: {e!s}" ) from e + # Final validation after decryption + final_token = config_data.get("access_token") + if not final_token or ( + isinstance(final_token, str) and not final_token.strip() + ): + raise ValueError( + "Confluence access token is invalid or empty. " + "Please reconnect your Confluence account." + ) + try: self._credentials = AtlassianAuthCredentialsBase.from_dict( config_data diff --git a/surfsense_backend/app/connectors/jira_history.py b/surfsense_backend/app/connectors/jira_history.py index 6e04ec2a4..e9f28a2c4 100644 --- a/surfsense_backend/app/connectors/jira_history.py +++ b/surfsense_backend/app/connectors/jira_history.py @@ -86,6 +86,14 @@ class JiraHistoryConnector: if is_oauth: # OAuth 2.0 authentication + # Check if access_token exists before processing + raw_access_token = config_data.get("access_token") + if not raw_access_token: + raise ValueError( + "Jira access token not found. " + "Please reconnect your Jira account." + ) + if not config.SECRET_KEY: raise ValueError( "SECRET_KEY not configured but tokens are marked as encrypted" @@ -119,6 +127,16 @@ class JiraHistoryConnector: f"Failed to decrypt Jira credentials: {e!s}" ) from e + # Final validation after decryption + final_token = config_data.get("access_token") + if not final_token or ( + isinstance(final_token, str) and not final_token.strip() + ): + raise ValueError( + "Jira access token is invalid or empty. " + "Please reconnect your Jira account." + ) + try: self._credentials = AtlassianAuthCredentialsBase.from_dict( config_data diff --git a/surfsense_backend/app/connectors/linear_connector.py b/surfsense_backend/app/connectors/linear_connector.py index b8206a40d..534d70b89 100644 --- a/surfsense_backend/app/connectors/linear_connector.py +++ b/surfsense_backend/app/connectors/linear_connector.py @@ -116,6 +116,14 @@ class LinearConnector: config_data = connector.config.copy() + # Check if access_token exists before processing + raw_access_token = config_data.get("access_token") + if not raw_access_token: + raise ValueError( + "Linear access token not found. " + "Please reconnect your Linear account." + ) + # Decrypt credentials if they are encrypted token_encrypted = config_data.get("_token_encrypted", False) if token_encrypted and config.SECRET_KEY: @@ -143,6 +151,16 @@ class LinearConnector: f"Failed to decrypt Linear credentials: {e!s}" ) from e + # Final validation after decryption + final_token = config_data.get("access_token") + if not final_token or ( + isinstance(final_token, str) and not final_token.strip() + ): + raise ValueError( + "Linear access token is invalid or empty. " + "Please reconnect your Linear account." + ) + try: self._credentials = LinearAuthCredentialsBase.from_dict(config_data) except Exception as e: diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 5cdb712db..a82c18470 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -137,6 +137,24 @@ class LiteLLMProvider(str, Enum): CUSTOM = "CUSTOM" +class ImageGenProvider(str, Enum): + """ + Enum for image generation providers supported by LiteLLM. + This is a subset of LLM providers — only those that support image generation. + See: https://docs.litellm.ai/docs/image_generation#supported-providers + """ + + OPENAI = "OPENAI" + AZURE_OPENAI = "AZURE_OPENAI" + GOOGLE = "GOOGLE" # Google AI Studio + VERTEX_AI = "VERTEX_AI" + BEDROCK = "BEDROCK" # AWS Bedrock + RECRAFT = "RECRAFT" + OPENROUTER = "OPENROUTER" + XINFERENCE = "XINFERENCE" + NSCALE = "NSCALE" + + class LogLevel(str, Enum): DEBUG = "DEBUG" INFO = "INFO" @@ -237,6 +255,11 @@ class Permission(str, Enum): PODCASTS_UPDATE = "podcasts:update" PODCASTS_DELETE = "podcasts:delete" + # Image Generations + IMAGE_GENERATIONS_CREATE = "image_generations:create" + IMAGE_GENERATIONS_READ = "image_generations:read" + IMAGE_GENERATIONS_DELETE = "image_generations:delete" + # Connectors CONNECTORS_CREATE = "connectors:create" CONNECTORS_READ = "connectors:read" @@ -298,6 +321,9 @@ DEFAULT_ROLE_PERMISSIONS = { Permission.PODCASTS_CREATE.value, Permission.PODCASTS_READ.value, Permission.PODCASTS_UPDATE.value, + # Image Generations (create and read, no delete) + Permission.IMAGE_GENERATIONS_CREATE.value, + Permission.IMAGE_GENERATIONS_READ.value, # Connectors (no delete) Permission.CONNECTORS_CREATE.value, Permission.CONNECTORS_READ.value, @@ -327,6 +353,8 @@ DEFAULT_ROLE_PERMISSIONS = { Permission.LLM_CONFIGS_READ.value, # Podcasts (read only) Permission.PODCASTS_READ.value, + # Image Generations (read only) + Permission.IMAGE_GENERATIONS_READ.value, # Connectors (read only) Permission.CONNECTORS_READ.value, # Logs (read only) @@ -881,6 +909,103 @@ class Podcast(BaseModel, TimestampMixin): thread = relationship("NewChatThread") +class ImageGenerationConfig(BaseModel, TimestampMixin): + """ + Dedicated configuration table for image generation models. + + Separate from NewLLMConfig because image generation models don't need + system_instructions, citations_enabled, or use_default_system_instructions. + They only need provider credentials and model parameters. + """ + + __tablename__ = "image_generation_configs" + + name = Column(String(100), nullable=False, index=True) + description = Column(String(500), nullable=True) + + # Provider & model (uses ImageGenProvider, NOT LiteLLMProvider) + provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False) + custom_provider = Column(String(100), nullable=True) + model_name = Column(String(100), nullable=False) + + # Credentials + api_key = Column(String, nullable=False) + api_base = Column(String(500), nullable=True) + api_version = Column(String(50), nullable=True) # Azure-specific + + # Additional litellm parameters + litellm_params = Column(JSON, nullable=True, default={}) + + # Relationships + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship( + "SearchSpace", back_populates="image_generation_configs" + ) + + +class ImageGeneration(BaseModel, TimestampMixin): + """ + Stores image generation requests and results using litellm.aimage_generation(). + + Since aimage_generation is a single async call (not a background job), + there is no status enum. A row with response_data means success; + a row with error_message means failure. + + Response data is stored as JSONB matching the litellm output format: + { + "created": int, + "data": [{"b64_json": str|None, "revised_prompt": str|None, "url": str|None}], + "usage": {"prompt_tokens": int, "completion_tokens": int, "total_tokens": int} + } + """ + + __tablename__ = "image_generations" + + # Request parameters (matching litellm.aimage_generation() params) + prompt = Column(Text, nullable=False) + model = Column(String(200), nullable=True) # e.g., "dall-e-3", "gpt-image-1" + n = Column(Integer, nullable=True, default=1) + quality = Column( + String(50), nullable=True + ) # "auto", "high", "medium", "low", "hd", "standard" + size = Column( + String(50), nullable=True + ) # "1024x1024", "1536x1024", "1024x1536", etc. + style = Column(String(50), nullable=True) # Model-specific style parameter + response_format = Column(String(50), nullable=True) # "url" or "b64_json" + + # Image generation config reference + # 0 = Auto mode (router), negative IDs = global configs from YAML, + # positive IDs = ImageGenerationConfig records in DB + image_generation_config_id = Column(Integer, nullable=True) + + # Response data (full litellm response as JSONB) — present on success + response_data = Column(JSONB, nullable=True) + # Error message — present on failure + error_message = Column(Text, nullable=True) + + # Signed access token for serving images via tags. + # Stored in DB so it survives SECRET_KEY rotation. + access_token = Column(String(64), nullable=True, index=True) + + # Foreign keys + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + created_by_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + # Relationships + search_space = relationship("SearchSpace", back_populates="image_generations") + created_by = relationship("User", back_populates="image_generations") + + class SearchSpace(BaseModel, TimestampMixin): __tablename__ = "searchspaces" @@ -905,6 +1030,9 @@ class SearchSpace(BaseModel, TimestampMixin): document_summary_llm_id = Column( Integer, nullable=True, default=0 ) # For document summarization, defaults to Auto mode + image_generation_config_id = Column( + Integer, nullable=True, default=0 + ) # For image generation, defaults to Auto mode user_id = Column( UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False @@ -929,6 +1057,12 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="Podcast.id.desc()", cascade="all, delete-orphan", ) + image_generations = relationship( + "ImageGeneration", + back_populates="search_space", + order_by="ImageGeneration.id.desc()", + cascade="all, delete-orphan", + ) logs = relationship( "Log", back_populates="search_space", @@ -953,6 +1087,12 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="NewLLMConfig.id", cascade="all, delete-orphan", ) + image_generation_configs = relationship( + "ImageGenerationConfig", + back_populates="search_space", + order_by="ImageGenerationConfig.id", + cascade="all, delete-orphan", + ) # RBAC relationships roles = relationship( @@ -1333,6 +1473,13 @@ if config.AUTH_TYPE == "GOOGLE": passive_deletes=True, ) + # Image generations created by this user + image_generations = relationship( + "ImageGeneration", + back_populates="created_by", + passive_deletes=True, + ) + # User memories for personalized AI responses memories = relationship( "UserMemory", @@ -1361,6 +1508,13 @@ if config.AUTH_TYPE == "GOOGLE": display_name = Column(String, nullable=True) avatar_url = Column(String, nullable=True) + # Refresh tokens for this user + refresh_tokens = relationship( + "RefreshToken", + back_populates="user", + cascade="all, delete-orphan", + ) + else: class User(SQLAlchemyBaseUserTableUUID, Base): @@ -1398,6 +1552,13 @@ else: passive_deletes=True, ) + # Image generations created by this user + image_generations = relationship( + "ImageGeneration", + back_populates="created_by", + passive_deletes=True, + ) + # User memories for personalized AI responses memories = relationship( "UserMemory", @@ -1426,6 +1587,43 @@ else: display_name = Column(String, nullable=True) avatar_url = Column(String, nullable=True) + # Refresh tokens for this user + refresh_tokens = relationship( + "RefreshToken", + back_populates="user", + cascade="all, delete-orphan", + ) + + +class RefreshToken(Base, TimestampMixin): + """ + Stores refresh tokens for user session management. + Each row represents one device/session. + """ + + __tablename__ = "refresh_tokens" + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + user = relationship("User", back_populates="refresh_tokens") + token_hash = Column(String(256), unique=True, nullable=False, index=True) + expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True) + is_revoked = Column(Boolean, default=False, nullable=False) + family_id = Column(UUID(as_uuid=True), nullable=False, index=True) + + @property + def is_expired(self) -> bool: + return datetime.now(UTC) >= self.expires_at + + @property + def is_valid(self) -> bool: + return not self.is_expired and not self.is_revoked + engine = create_async_engine(DATABASE_URL) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) diff --git a/surfsense_backend/app/prompts/__init__.py b/surfsense_backend/app/prompts/__init__.py index 3b21cb9e1..efa31d612 100644 --- a/surfsense_backend/app/prompts/__init__.py +++ b/surfsense_backend/app/prompts/__init__.py @@ -104,3 +104,33 @@ SUMMARY_PROMPT = ( SUMMARY_PROMPT_TEMPLATE = PromptTemplate( input_variables=["document"], template=SUMMARY_PROMPT ) + +# ============================================================================= +# Chat Title Generation Prompt +# ============================================================================= + +TITLE_GENERATION_PROMPT = """Generate a concise, descriptive title for the following conversation. + + +- The title MUST be between 1 and 6 words +- The title MUST be on a single line +- Capture the main topic or intent of the conversation +- Do NOT use quotes, punctuation, or formatting +- Do NOT include words like "Chat about" or "Discussion of" +- Return ONLY the title, nothing else + + + +{user_query} + + + +{assistant_response} + + +Title:""" + +TITLE_GENERATION_PROMPT_TEMPLATE = PromptTemplate( + input_variables=["user_query", "assistant_response"], + template=TITLE_GENERATION_PROMPT, +) diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 746c18c6d..d9353284c 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -20,6 +20,7 @@ from .google_drive_add_connector_route import ( from .google_gmail_add_connector_route import ( router as google_gmail_add_connector_router, ) +from .image_generation_routes import router as image_generation_router from .incentive_tasks_routes import router as incentive_tasks_router from .jira_add_connector_route import router as jira_add_connector_router from .linear_add_connector_route import router as linear_add_connector_router @@ -49,6 +50,7 @@ router.include_router(notes_router) router.include_router(new_chat_router) # Chat with assistant-ui persistence router.include_router(chat_comments_router) router.include_router(podcasts_router) # Podcast task status and audio +router.include_router(image_generation_router) # Image generation via litellm router.include_router(search_source_connectors_router) router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) diff --git a/surfsense_backend/app/routes/auth_routes.py b/surfsense_backend/app/routes/auth_routes.py new file mode 100644 index 000000000..b1cbaf2a5 --- /dev/null +++ b/surfsense_backend/app/routes/auth_routes.py @@ -0,0 +1,93 @@ +"""Authentication routes for refresh token management.""" + +import logging + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import select + +from app.db import User, async_session_maker +from app.schemas.auth import ( + LogoutAllResponse, + LogoutRequest, + LogoutResponse, + RefreshTokenRequest, + RefreshTokenResponse, +) +from app.users import current_active_user, get_jwt_strategy +from app.utils.refresh_tokens import ( + revoke_all_user_tokens, + revoke_refresh_token, + rotate_refresh_token, + validate_refresh_token, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/auth/jwt", tags=["auth"]) + + +@router.post("/refresh", response_model=RefreshTokenResponse) +async def refresh_access_token(request: RefreshTokenRequest): + """ + Exchange a valid refresh token for a new access token and refresh token. + Implements token rotation for security. + """ + token_record = await validate_refresh_token(request.refresh_token) + + if not token_record: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired refresh token", + ) + + # Get user from token record + async with async_session_maker() as session: + result = await session.execute( + select(User).where(User.id == token_record.user_id) + ) + user = result.scalars().first() + + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found", + ) + + # Generate new access token + strategy = get_jwt_strategy() + access_token = await strategy.write_token(user) + + # Rotate refresh token + new_refresh_token = await rotate_refresh_token(token_record) + + logger.info(f"Refreshed token for user {user.id}") + + return RefreshTokenResponse( + access_token=access_token, + refresh_token=new_refresh_token, + ) + + +@router.post("/revoke", response_model=LogoutResponse) +async def revoke_token(request: LogoutRequest): + """ + Logout current device by revoking the provided refresh token. + Does not require authentication - just the refresh token. + """ + revoked = await revoke_refresh_token(request.refresh_token) + if revoked: + logger.info("User logged out from current device - token revoked") + else: + logger.warning("Logout called but no matching token found to revoke") + return LogoutResponse() + + +@router.post("/logout-all", response_model=LogoutAllResponse) +async def logout_all_devices(user: User = Depends(current_active_user)): + """ + Logout from all devices by revoking all refresh tokens for the user. + Requires valid access token. + """ + await revoke_all_user_tokens(user.id) + logger.info(f"User {user.id} logged out from all devices") + return LogoutAllResponse() diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py new file mode 100644 index 000000000..9406867c6 --- /dev/null +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -0,0 +1,710 @@ +""" +Image Generation routes: +- CRUD for ImageGenerationConfig (user-created image model configs) +- Global image gen configs endpoint (from YAML) +- Image generation execution (calls litellm.aimage_generation()) +- CRUD for ImageGeneration records (results) +- Image serving endpoint (serves b64_json images from DB, protected by signed tokens) +""" + +import base64 +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import Response +from litellm import aimage_generation +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import ( + ImageGeneration, + ImageGenerationConfig, + Permission, + SearchSpace, + SearchSpaceMembership, + User, + get_async_session, +) +from app.schemas import ( + GlobalImageGenConfigRead, + ImageGenerationConfigCreate, + ImageGenerationConfigRead, + ImageGenerationConfigUpdate, + ImageGenerationCreate, + ImageGenerationListRead, + ImageGenerationRead, +) +from app.services.image_gen_router_service import ( + IMAGE_GEN_AUTO_MODE_ID, + ImageGenRouterService, + is_image_gen_auto_mode, +) +from app.users import current_active_user +from app.utils.rbac import check_permission +from app.utils.signed_image_urls import verify_image_token + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Provider mapping for building litellm model strings. +# Only includes providers that support image generation. +# See: https://docs.litellm.ai/docs/image_generation#supported-providers +_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", # Google AI Studio + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", # AWS Bedrock + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} + + +def _get_global_image_gen_config(config_id: int) -> dict | None: + """Get a global image generation configuration by ID (negative IDs).""" + if config_id == IMAGE_GEN_AUTO_MODE_ID: + return { + "id": IMAGE_GEN_AUTO_MODE_ID, + "name": "Auto (Load Balanced)", + "provider": "AUTO", + "model_name": "auto", + "is_auto_mode": True, + } + if config_id > 0: + return None + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None + + +def _build_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: + """Build a litellm model string from provider + model_name.""" + if custom_provider: + return f"{custom_provider}/{model_name}" + prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) + return f"{prefix}/{model_name}" + + +async def _execute_image_generation( + session: AsyncSession, + image_gen: ImageGeneration, + search_space: SearchSpace, +) -> None: + """ + Call litellm.aimage_generation() with the appropriate config. + + Resolution order: + 1. Explicit image_generation_config_id on the request + 2. Search space's image_generation_config_id preference + 3. Falls back to Auto mode if available + """ + config_id = image_gen.image_generation_config_id + if config_id is None: + config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + image_gen.image_generation_config_id = config_id + + # Build kwargs + gen_kwargs = {} + if image_gen.n is not None: + gen_kwargs["n"] = image_gen.n + if image_gen.quality is not None: + gen_kwargs["quality"] = image_gen.quality + if image_gen.size is not None: + gen_kwargs["size"] = image_gen.size + if image_gen.style is not None: + gen_kwargs["style"] = image_gen.style + if image_gen.response_format is not None: + gen_kwargs["response_format"] = image_gen.response_format + + if is_image_gen_auto_mode(config_id): + if not ImageGenRouterService.is_initialized(): + raise ValueError( + "Auto mode requested but Image Generation Router not initialized. " + "Ensure global_llm_config.yaml has global_image_generation_configs." + ) + response = await ImageGenRouterService.aimage_generation( + prompt=image_gen.prompt, model="auto", **gen_kwargs + ) + elif config_id < 0: + # Global config from YAML + cfg = _get_global_image_gen_config(config_id) + if not cfg: + raise ValueError(f"Global image generation config {config_id} not found") + + model_string = _build_model_string( + cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider") + ) + gen_kwargs["api_key"] = cfg.get("api_key") + if cfg.get("api_base"): + gen_kwargs["api_base"] = cfg["api_base"] + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) + + # User model override + if image_gen.model: + model_string = image_gen.model + + response = await aimage_generation( + prompt=image_gen.prompt, model=model_string, **gen_kwargs + ) + else: + # Positive ID = DB ImageGenerationConfig + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_cfg = result.scalars().first() + if not db_cfg: + raise ValueError(f"Image generation config {config_id} not found") + + model_string = _build_model_string( + db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider + ) + gen_kwargs["api_key"] = db_cfg.api_key + if db_cfg.api_base: + gen_kwargs["api_base"] = db_cfg.api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) + + # User model override + if image_gen.model: + model_string = image_gen.model + + response = await aimage_generation( + prompt=image_gen.prompt, model=model_string, **gen_kwargs + ) + + # Store response + image_gen.response_data = ( + response.model_dump() if hasattr(response, "model_dump") else dict(response) + ) + if not image_gen.model and hasattr(response, "_hidden_params"): + hidden = response._hidden_params + if isinstance(hidden, dict) and hidden.get("model"): + image_gen.model = hidden["model"] + + +# ============================================================================= +# Global Image Generation Configs (from YAML) +# ============================================================================= + + +@router.get( + "/global-image-generation-configs", + response_model=list[GlobalImageGenConfigRead], +) +async def get_global_image_gen_configs( + user: User = Depends(current_active_user), +): + """Get all global image generation configs. API keys are hidden.""" + try: + global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS + safe_configs = [] + + if global_configs and len(global_configs) > 0: + safe_configs.append( + { + "id": 0, + "name": "Auto (Load Balanced)", + "description": "Automatically routes across available image generation providers.", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "api_version": None, + "litellm_params": {}, + "is_global": True, + "is_auto_mode": True, + } + ) + + for cfg in global_configs: + safe_configs.append( + { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + } + ) + + return safe_configs + except Exception as e: + logger.exception("Failed to fetch global image generation configs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +# ============================================================================= +# ImageGenerationConfig CRUD +# ============================================================================= + + +@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead) +async def create_image_gen_config( + config_data: ImageGenerationConfigCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Create a new image generation config for a search space.""" + try: + await check_permission( + session, + user, + config_data.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to create image generation configs in this search space", + ) + + db_config = ImageGenerationConfig(**config_data.model_dump()) + session.add(db_config) + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to create ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to create config: {e!s}" + ) from e + + +@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead]) +async def list_image_gen_configs( + search_space_id: int, + skip: int = 0, + limit: int = 100, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """List image generation configs for a search space.""" + try: + await check_permission( + session, + user, + search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to view image generation configs in this search space", + ) + + result = await session.execute( + select(ImageGenerationConfig) + .filter(ImageGenerationConfig.search_space_id == search_space_id) + .order_by(ImageGenerationConfig.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to list ImageGenerationConfigs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +@router.get( + "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead +) +async def get_image_gen_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Get a specific image generation config by ID.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to view image generation configs in this search space", + ) + return db_config + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to fetch config: {e!s}" + ) from e + + +@router.put( + "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead +) +async def update_image_gen_config( + config_id: int, + update_data: ImageGenerationConfigUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Update an existing image generation config.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to update image generation configs in this search space", + ) + + for key, value in update_data.model_dump(exclude_unset=True).items(): + setattr(db_config, key, value) + + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to update config: {e!s}" + ) from e + + +@router.delete("/image-generation-configs/{config_id}", response_model=dict) +async def delete_image_gen_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Delete an image generation config.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.IMAGE_GENERATIONS_DELETE.value, + "You don't have permission to delete image generation configs in this search space", + ) + + await session.delete(db_config) + await session.commit() + return { + "message": "Image generation config deleted successfully", + "id": config_id, + } + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to delete ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to delete config: {e!s}" + ) from e + + +# ============================================================================= +# Image Generation Execution + Results CRUD +# ============================================================================= + + +@router.post("/image-generations", response_model=ImageGenerationRead) +async def create_image_generation( + data: ImageGenerationCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Create and execute an image generation request.""" + try: + await check_permission( + session, + user, + data.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to create image generations in this search space", + ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == data.search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + db_image_gen = ImageGeneration( + prompt=data.prompt, + model=data.model, + n=data.n, + quality=data.quality, + size=data.size, + style=data.style, + response_format=data.response_format, + image_generation_config_id=data.image_generation_config_id, + search_space_id=data.search_space_id, + created_by_id=user.id, + ) + session.add(db_image_gen) + await session.flush() + + try: + await _execute_image_generation(session, db_image_gen, search_space) + except Exception as e: + logger.exception("Image generation call failed") + db_image_gen.error_message = str(e) + + await session.commit() + await session.refresh(db_image_gen) + return db_image_gen + + except HTTPException: + raise + except SQLAlchemyError: + await session.rollback() + raise HTTPException( + status_code=500, detail="Database error during image generation" + ) from None + except Exception as e: + await session.rollback() + logger.exception("Failed to create image generation") + raise HTTPException( + status_code=500, detail=f"Image generation failed: {e!s}" + ) from e + + +@router.get("/image-generations", response_model=list[ImageGenerationListRead]) +async def list_image_generations( + search_space_id: int | None = None, + skip: int = 0, + limit: int = 50, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """List image generations.""" + if skip < 0 or limit < 1: + raise HTTPException(status_code=400, detail="Invalid pagination parameters") + if limit > 100: + limit = 100 + + try: + if search_space_id is not None: + await check_permission( + session, + user, + search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to read image generations in this search space", + ) + result = await session.execute( + select(ImageGeneration) + .filter(ImageGeneration.search_space_id == search_space_id) + .order_by(ImageGeneration.created_at.desc()) + .offset(skip) + .limit(limit) + ) + else: + result = await session.execute( + select(ImageGeneration) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + .order_by(ImageGeneration.created_at.desc()) + .offset(skip) + .limit(limit) + ) + + return [ + ImageGenerationListRead.from_orm_with_count(img) + for img in result.scalars().all() + ] + + except HTTPException: + raise + except SQLAlchemyError: + raise HTTPException( + status_code=500, detail="Database error fetching image generations" + ) from None + + +@router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead) +async def get_image_generation( + image_gen_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Get a specific image generation by ID.""" + try: + result = await session.execute( + select(ImageGeneration).filter(ImageGeneration.id == image_gen_id) + ) + image_gen = result.scalars().first() + if not image_gen: + raise HTTPException(status_code=404, detail="Image generation not found") + + await check_permission( + session, + user, + image_gen.search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to read image generations in this search space", + ) + return image_gen + + except HTTPException: + raise + except SQLAlchemyError: + raise HTTPException( + status_code=500, detail="Database error fetching image generation" + ) from None + + +@router.delete("/image-generations/{image_gen_id}", response_model=dict) +async def delete_image_generation( + image_gen_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Delete an image generation record.""" + try: + result = await session.execute( + select(ImageGeneration).filter(ImageGeneration.id == image_gen_id) + ) + db_image_gen = result.scalars().first() + if not db_image_gen: + raise HTTPException(status_code=404, detail="Image generation not found") + + await check_permission( + session, + user, + db_image_gen.search_space_id, + Permission.IMAGE_GENERATIONS_DELETE.value, + "You don't have permission to delete image generations in this search space", + ) + + await session.delete(db_image_gen) + await session.commit() + return {"message": "Image generation deleted successfully"} + + except HTTPException: + raise + except SQLAlchemyError: + await session.rollback() + raise HTTPException( + status_code=500, detail="Database error deleting image generation" + ) from None + + +# ============================================================================= +# Image Serving (serves generated images from DB, protected by signed tokens) +# ============================================================================= + + +@router.get("/image-generations/{image_gen_id}/image") +async def serve_generated_image( + image_gen_id: int, + token: str = Query(..., description="Signed access token"), + index: int = 0, + session: AsyncSession = Depends(get_async_session), +): + """ + Serve a generated image by ID, protected by a signed token. + + The token is generated when the image URL is created by the generate_image + tool and encodes the image_gen_id, search_space_id, and an expiry timestamp. + This ensures only users with access to the search space can view images, + without requiring auth headers (which tags cannot pass). + + Args: + image_gen_id: The image generation record ID + token: HMAC-signed access token (included as query parameter) + index: Which image to serve if multiple were generated (default: 0) + """ + try: + result = await session.execute( + select(ImageGeneration).filter(ImageGeneration.id == image_gen_id) + ) + image_gen = result.scalars().first() + if not image_gen: + raise HTTPException(status_code=404, detail="Image generation not found") + + # Verify the access token against the one stored on the record + if not verify_image_token(image_gen.access_token, token): + raise HTTPException(status_code=403, detail="Invalid image access token") + + if not image_gen.response_data: + raise HTTPException(status_code=404, detail="No image data available") + + images = image_gen.response_data.get("data", []) + if not images or index >= len(images): + raise HTTPException( + status_code=404, detail="Image not found at the specified index" + ) + + image_entry = images[index] + + # If there's a URL, redirect to it + if image_entry.get("url"): + from fastapi.responses import RedirectResponse + + return RedirectResponse(url=image_entry["url"]) + + # If there's b64_json data, decode and serve it + if image_entry.get("b64_json"): + image_bytes = base64.b64decode(image_entry["b64_json"]) + return Response( + content=image_bytes, + media_type="image/png", + headers={ + "Cache-Control": "public, max-age=86400", + "Content-Disposition": f'inline; filename="generated-{image_gen_id}-{index}.png"', + }, + ) + + raise HTTPException(status_code=404, detail="No displayable image data") + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to serve generated image") + raise HTTPException( + status_code=500, detail=f"Failed to serve image: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 42b8a821b..06e929997 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -886,30 +886,8 @@ async def append_message( # Update thread's updated_at timestamp thread.updated_at = datetime.now(UTC) - # Auto-generate title from first user message if title is still default - if thread.title == "New Chat" and role_str == "user": - # Extract text content for title - content = message.content - if isinstance(content, str): - title_text = content - elif isinstance(content, list): - # Find first text content - title_text = "" - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - title_text = part.get("text", "") - break - elif isinstance(part, str): - title_text = part - break - else: - title_text = str(content) - - # Truncate title - if title_text: - thread.title = title_text[:100] + ( - "..." if len(title_text) > 100 else "" - ) + # Note: Title generation now happens in stream_new_chat.py after the first response + # using LLM to generate a descriptive title (with truncation as fallback) await session.commit() await session.refresh(db_message) diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index a8916f2ea..fd84c0f45 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -7,6 +7,7 @@ from sqlalchemy.future import select from app.config import config from app.db import ( + ImageGenerationConfig, NewLLMConfig, Permission, SearchSpace, @@ -387,6 +388,69 @@ async def _get_llm_config_by_id( return None +async def _get_image_gen_config_by_id( + session: AsyncSession, config_id: int | None +) -> dict | None: + """ + Get an image generation config by ID as a dictionary. + Returns Auto mode for ID 0, global config for negative IDs, + DB ImageGenerationConfig for positive IDs, or None. + """ + if config_id is None: + return None + + if config_id == 0: + return { + "id": 0, + "name": "Auto (Load Balanced)", + "description": "Automatically routes requests across available image generation providers", + "provider": "AUTO", + "model_name": "auto", + "is_global": True, + "is_auto_mode": True, + } + + if config_id < 0: + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + } + return None + + # Positive ID: query ImageGenerationConfig table + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if db_config: + return { + "id": db_config.id, + "name": db_config.name, + "description": db_config.description, + "provider": db_config.provider.value if db_config.provider else None, + "custom_provider": db_config.custom_provider, + "model_name": db_config.model_name, + "api_base": db_config.api_base, + "api_version": db_config.api_version, + "litellm_params": db_config.litellm_params or {}, + "created_at": db_config.created_at.isoformat() + if db_config.created_at + else None, + "search_space_id": db_config.search_space_id, + } + return None + + @router.get( "/search-spaces/{search_space_id}/llm-preferences", response_model=LLMPreferencesRead, @@ -423,12 +487,17 @@ async def get_llm_preferences( document_summary_llm = await _get_llm_config_by_id( session, search_space.document_summary_llm_id ) + image_generation_config = await _get_image_gen_config_by_id( + session, search_space.image_generation_config_id + ) return LLMPreferencesRead( agent_llm_id=search_space.agent_llm_id, document_summary_llm_id=search_space.document_summary_llm_id, + image_generation_config_id=search_space.image_generation_config_id, agent_llm=agent_llm, document_summary_llm=document_summary_llm, + image_generation_config=image_generation_config, ) except HTTPException: @@ -485,12 +554,17 @@ async def update_llm_preferences( document_summary_llm = await _get_llm_config_by_id( session, search_space.document_summary_llm_id ) + image_generation_config = await _get_image_gen_config_by_id( + session, search_space.image_generation_config_id + ) return LLMPreferencesRead( agent_llm_id=search_space.agent_llm_id, document_summary_llm_id=search_space.document_summary_llm_id, + image_generation_config_id=search_space.image_generation_config_id, agent_llm=agent_llm, document_summary_llm=document_summary_llm, + image_generation_config=image_generation_config, ) except HTTPException: diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 6c9577c46..ad5abf777 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -1,3 +1,10 @@ +from .auth import ( + LogoutAllResponse, + LogoutRequest, + LogoutResponse, + RefreshTokenRequest, + RefreshTokenResponse, +) from .base import IDModel, TimestampModel from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate from .documents import ( @@ -13,6 +20,16 @@ from .documents import ( PaginatedResponse, ) from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest +from .image_generation import ( + GlobalImageGenConfigRead, + ImageGenerationConfigCreate, + ImageGenerationConfigPublic, + ImageGenerationConfigRead, + ImageGenerationConfigUpdate, + ImageGenerationCreate, + ImageGenerationListRead, + ImageGenerationRead, +) from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate from .new_chat import ( ChatMessage, @@ -96,11 +113,21 @@ __all__ = [ "DriveItem", "ExtensionDocumentContent", "ExtensionDocumentMetadata", + "GlobalImageGenConfigRead", "GlobalNewLLMConfigRead", "GoogleDriveIndexRequest", "GoogleDriveIndexingOptions", # Base schemas "IDModel", + # Image Generation Config schemas + "ImageGenerationConfigCreate", + "ImageGenerationConfigPublic", + "ImageGenerationConfigRead", + "ImageGenerationConfigUpdate", + # Image Generation schemas + "ImageGenerationCreate", + "ImageGenerationListRead", + "ImageGenerationRead", # RBAC schemas "InviteAcceptRequest", "InviteAcceptResponse", @@ -117,6 +144,10 @@ __all__ = [ "LogFilter", "LogRead", "LogUpdate", + # Auth schemas + "LogoutAllResponse", + "LogoutRequest", + "LogoutResponse", # Search source connector schemas "MCPConnectorCreate", "MCPConnectorRead", @@ -146,6 +177,8 @@ __all__ = [ "PodcastCreate", "PodcastRead", "PodcastUpdate", + "RefreshTokenRequest", + "RefreshTokenResponse", "RoleCreate", "RoleRead", "RoleUpdate", diff --git a/surfsense_backend/app/schemas/auth.py b/surfsense_backend/app/schemas/auth.py new file mode 100644 index 000000000..0d958a6d2 --- /dev/null +++ b/surfsense_backend/app/schemas/auth.py @@ -0,0 +1,35 @@ +"""Authentication schemas for refresh token endpoints.""" + +from pydantic import BaseModel + + +class RefreshTokenRequest(BaseModel): + """Request body for token refresh endpoint.""" + + refresh_token: str + + +class RefreshTokenResponse(BaseModel): + """Response from token refresh endpoint.""" + + access_token: str + refresh_token: str + token_type: str = "bearer" + + +class LogoutRequest(BaseModel): + """Request body for logout endpoint (current device).""" + + refresh_token: str + + +class LogoutResponse(BaseModel): + """Response from logout endpoint (current device).""" + + detail: str = "Successfully logged out" + + +class LogoutAllResponse(BaseModel): + """Response from logout all devices endpoint.""" + + detail: str = "Successfully logged out from all devices" diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py new file mode 100644 index 000000000..6ef4feff8 --- /dev/null +++ b/surfsense_backend/app/schemas/image_generation.py @@ -0,0 +1,230 @@ +""" +Pydantic schemas for Image Generation configs and generation requests. + +ImageGenerationConfig: CRUD schemas for user-created image gen model configs. +ImageGeneration: Schemas for the actual image generation requests/results. +GlobalImageGenConfigRead: Schema for admin-configured YAML configs. +""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import ImageGenProvider + +# ============================================================================= +# ImageGenerationConfig CRUD Schemas +# ============================================================================= + + +class ImageGenerationConfigBase(BaseModel): + """Base schema with fields for ImageGenerationConfig.""" + + name: str = Field( + ..., max_length=100, description="User-friendly name for the config" + ) + description: str | None = Field( + None, max_length=500, description="Optional description" + ) + provider: ImageGenProvider = Field( + ..., + description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)", + ) + custom_provider: str | None = Field( + None, max_length=100, description="Custom provider name" + ) + model_name: str = Field( + ..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)" + ) + api_key: str = Field(..., description="API key for the provider") + api_base: str | None = Field( + None, max_length=500, description="Optional API base URL" + ) + api_version: str | None = Field( + None, + max_length=50, + description="Azure-specific API version (e.g., '2024-02-15-preview')", + ) + litellm_params: dict[str, Any] | None = Field( + default=None, description="Additional LiteLLM parameters" + ) + + +class ImageGenerationConfigCreate(ImageGenerationConfigBase): + """Schema for creating a new ImageGenerationConfig.""" + + search_space_id: int = Field( + ..., description="Search space ID to associate the config with" + ) + + +class ImageGenerationConfigUpdate(BaseModel): + """Schema for updating an existing ImageGenerationConfig. All fields optional.""" + + name: str | None = Field(None, max_length=100) + description: str | None = Field(None, max_length=500) + provider: ImageGenProvider | None = None + custom_provider: str | None = Field(None, max_length=100) + model_name: str | None = Field(None, max_length=100) + api_key: str | None = None + api_base: str | None = Field(None, max_length=500) + api_version: str | None = Field(None, max_length=50) + litellm_params: dict[str, Any] | None = None + + +class ImageGenerationConfigRead(ImageGenerationConfigBase): + """Schema for reading an ImageGenerationConfig (includes id and timestamps).""" + + id: int + created_at: datetime + search_space_id: int + + model_config = ConfigDict(from_attributes=True) + + +class ImageGenerationConfigPublic(BaseModel): + """Public schema that hides the API key (for list views).""" + + id: int + name: str + description: str | None = None + provider: ImageGenProvider + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + created_at: datetime + search_space_id: int + + model_config = ConfigDict(from_attributes=True) + + +# ============================================================================= +# ImageGeneration (request/result) Schemas +# ============================================================================= + + +class ImageGenerationCreate(BaseModel): + """Schema for creating an image generation request.""" + + prompt: str = Field( + ..., + min_length=1, + max_length=4000, + description="A text description of the desired image(s)", + ) + model: str | None = Field( + None, + max_length=200, + description="The model to use (e.g., 'dall-e-3', 'gpt-image-1'). Overrides the config model.", + ) + n: int | None = Field( + None, + ge=1, + le=10, + description="Number of images to generate (1-10).", + ) + quality: str | None = Field(None, max_length=50) + size: str | None = Field(None, max_length=50) + style: str | None = Field(None, max_length=50) + response_format: str | None = Field(None, max_length=50) + search_space_id: int = Field( + ..., description="Search space ID to associate the generation with" + ) + image_generation_config_id: int | None = Field( + None, + description=( + "Image generation config ID. " + "0 = Auto mode (router), negative = global YAML config, positive = DB config. " + "If not provided, uses the search space's image_generation_config_id preference." + ), + ) + + +class ImageGenerationRead(BaseModel): + """Schema for reading an image generation record.""" + + id: int + prompt: str + model: str | None = None + n: int | None = None + quality: str | None = None + size: str | None = None + style: str | None = None + response_format: str | None = None + image_generation_config_id: int | None = None + response_data: dict[str, Any] | None = None + error_message: str | None = None + search_space_id: int + created_at: datetime + + model_config = ConfigDict(from_attributes=True) + + +class ImageGenerationListRead(BaseModel): + """Lightweight schema for listing image generations (without full response_data).""" + + id: int + prompt: str + model: str | None = None + n: int | None = None + quality: str | None = None + size: str | None = None + search_space_id: int + created_at: datetime + is_success: bool + image_count: int | None = None + + model_config = ConfigDict(from_attributes=True) + + @classmethod + def from_orm_with_count(cls, obj: Any) -> "ImageGenerationListRead": + """Create ImageGenerationListRead with computed fields.""" + image_count = None + if obj.response_data and isinstance(obj.response_data, dict): + data = obj.response_data.get("data") + if isinstance(data, list): + image_count = len(data) + + return cls( + id=obj.id, + prompt=obj.prompt, + model=obj.model, + n=obj.n, + quality=obj.quality, + size=obj.size, + search_space_id=obj.search_space_id, + created_at=obj.created_at, + is_success=obj.response_data is not None, + image_count=image_count, + ) + + +# ============================================================================= +# Global Image Gen Config (from YAML) +# ============================================================================= + + +class GlobalImageGenConfigRead(BaseModel): + """ + Schema for reading global image generation configs from YAML. + Global configs have negative IDs. API key is hidden. + ID 0 is reserved for Auto mode (LiteLLM Router load balancing). + """ + + id: int = Field( + ..., + description="Config ID: 0 for Auto mode, negative for global configs", + ) + name: str + description: str | None = None + provider: str + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + is_global: bool = True + is_auto_mode: bool = False diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py index 286c07843..a6294fba2 100644 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -176,12 +176,18 @@ class LLMPreferencesRead(BaseModel): document_summary_llm_id: int | None = Field( None, description="ID of the LLM config to use for document summarization" ) + image_generation_config_id: int | None = Field( + None, description="ID of the image generation config to use" + ) agent_llm: dict[str, Any] | None = Field( None, description="Full config for agent LLM" ) document_summary_llm: dict[str, Any] | None = Field( None, description="Full config for document summary LLM" ) + image_generation_config: dict[str, Any] | None = Field( + None, description="Full config for image generation" + ) model_config = ConfigDict(from_attributes=True) @@ -195,3 +201,6 @@ class LLMPreferencesUpdate(BaseModel): document_summary_llm_id: int | None = Field( None, description="ID of the LLM config to use for document summarization" ) + image_generation_config_id: int | None = Field( + None, description="ID of the image generation config to use" + ) diff --git a/surfsense_backend/app/services/chat_comments_service.py b/surfsense_backend/app/services/chat_comments_service.py index dc3b51238..c9ca920f6 100644 --- a/surfsense_backend/app/services/chat_comments_service.py +++ b/surfsense_backend/app/services/chat_comments_service.py @@ -5,7 +5,7 @@ Service layer for chat comments and mentions. from uuid import UUID from fastapi import HTTPException -from sqlalchemy import delete, select +from sqlalchemy import delete, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -103,6 +103,37 @@ async def process_mentions( return mentions_map +async def get_comment_thread_participants( + session: AsyncSession, + parent_comment_id: int, + exclude_user_ids: set[UUID], +) -> list[UUID]: + """ + Get all unique authors in a comment thread (parent + replies), excluding specified users. + + Args: + session: Database session + parent_comment_id: ID of the parent comment + exclude_user_ids: Set of user IDs to exclude (e.g., replier, mentioned users) + + Returns: + List of user UUIDs who have participated in the thread + """ + query = select(ChatComment.author_id).where( + or_( + ChatComment.id == parent_comment_id, + ChatComment.parent_id == parent_comment_id, + ), + ChatComment.author_id.isnot(None), + ) + + if exclude_user_ids: + query = query.where(ChatComment.author_id.notin_(list(exclude_user_ids))) + + result = await session.execute(query.distinct()) + return [row[0] for row in result.fetchall()] + + async def get_comments_for_message( session: AsyncSession, message_id: int, @@ -436,6 +467,31 @@ async def create_reply( search_space_id=search_space_id, ) + # Notify thread participants (excluding replier and mentioned users) + mentioned_user_ids = set(mentions_map.keys()) + exclude_ids = {user.id} | mentioned_user_ids + participants = await get_comment_thread_participants( + session, comment_id, exclude_ids + ) + for participant_id in participants: + if participant_id in mentioned_user_ids: + continue + await NotificationService.comment_reply.notify_comment_reply( + session=session, + user_id=participant_id, + reply_id=reply.id, + parent_comment_id=comment_id, + message_id=parent_comment.message_id, + thread_id=thread.id, + thread_title=thread.title or "Untitled thread", + author_id=str(user.id), + author_name=author_name, + author_avatar_url=user.avatar_url, + author_email=user.email, + content_preview=content_preview[:200], + search_space_id=search_space_id, + ) + author = AuthorResponse( id=user.id, display_name=user.display_name, diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py new file mode 100644 index 000000000..f45a6ab63 --- /dev/null +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -0,0 +1,278 @@ +""" +Image Generation Router Service for Load Balancing + +This module provides a singleton LiteLLM Router for automatic load balancing +across multiple image generation deployments. It uses litellm.Router which +natively supports aimage_generation() for async image generation. + +The router handles: +- Rate limit management with automatic cooldowns +- Automatic failover and retries +- Usage-based routing to distribute load evenly + +Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, +AWS Bedrock, Recraft, OpenRouter, Xinference, Nscale. +""" + +import logging +from typing import Any + +from litellm import Router +from litellm.utils import ImageResponse + +logger = logging.getLogger(__name__) + +# Special ID for Auto mode - uses router for load balancing +IMAGE_GEN_AUTO_MODE_ID = 0 + +# Provider mapping for LiteLLM model string construction. +# Only includes providers that support image generation. +# See: https://docs.litellm.ai/docs/image_generation#supported-providers +IMAGE_GEN_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", # Google AI Studio + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", # AWS Bedrock + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} + + +class ImageGenRouterService: + """ + Singleton service for managing LiteLLM Router for image generation. + + The router provides automatic load balancing, failover, and rate limit + handling across multiple image generation deployments. + Uses Router.aimage_generation() for async image generation calls. + """ + + _instance = None + _router: Router | None = None + _model_list: list[dict] = [] + _router_settings: dict = {} + _initialized: bool = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def get_instance(cls) -> "ImageGenRouterService": + """Get the singleton instance of the router service.""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def initialize( + cls, + global_configs: list[dict], + router_settings: dict | None = None, + ) -> None: + """ + Initialize the router with global image generation configurations. + + Args: + global_configs: List of global image gen config dictionaries from YAML + router_settings: Optional router settings (routing_strategy, num_retries, etc.) + """ + instance = cls.get_instance() + + if instance._initialized: + logger.debug("Image Generation Router already initialized, skipping") + return + + # Build model list from global configs + model_list = [] + for config in global_configs: + deployment = cls._config_to_deployment(config) + if deployment: + model_list.append(deployment) + + if not model_list: + logger.warning( + "No valid image generation configs found for router initialization" + ) + return + + instance._model_list = model_list + instance._router_settings = router_settings or {} + + # Default router settings optimized for rate limit handling + default_settings = { + "routing_strategy": "usage-based-routing", + "num_retries": 3, + "allowed_fails": 3, + "cooldown_time": 60, + "retry_after": 5, + } + + # Merge with provided settings + final_settings = {**default_settings, **instance._router_settings} + + try: + instance._router = Router( + model_list=model_list, + routing_strategy=final_settings.get( + "routing_strategy", "usage-based-routing" + ), + num_retries=final_settings.get("num_retries", 3), + allowed_fails=final_settings.get("allowed_fails", 3), + cooldown_time=final_settings.get("cooldown_time", 60), + set_verbose=False, + ) + instance._initialized = True + logger.info( + f"Image Generation Router initialized with {len(model_list)} deployments, " + f"strategy: {final_settings.get('routing_strategy')}" + ) + except Exception as e: + logger.error(f"Failed to initialize Image Generation Router: {e}") + instance._router = None + + @classmethod + def _config_to_deployment(cls, config: dict) -> dict | None: + """ + Convert a global image gen config to a router deployment entry. + + Args: + config: Global image gen config dictionary + + Returns: + Router deployment dictionary or None if invalid + """ + try: + # Skip if essential fields are missing + if not config.get("model_name") or not config.get("api_key"): + return None + + # Build model string + if config.get("custom_provider"): + model_string = f"{config['custom_provider']}/{config['model_name']}" + else: + provider = config.get("provider", "").upper() + provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{config['model_name']}" + + # Build litellm params + litellm_params: dict[str, Any] = { + "model": model_string, + "api_key": config.get("api_key"), + } + + # Add optional api_base + if config.get("api_base"): + litellm_params["api_base"] = config["api_base"] + + # Add api_version (required for Azure) + if config.get("api_version"): + litellm_params["api_version"] = config["api_version"] + + # Add any additional litellm parameters + if config.get("litellm_params"): + litellm_params.update(config["litellm_params"]) + + # All configs use same alias "auto" for unified routing + deployment: dict[str, Any] = { + "model_name": "auto", + "litellm_params": litellm_params, + } + + # Add RPM rate limit from config if available + # Note: TPM (tokens per minute) is not applicable for image generation + # since image APIs are rate-limited by requests, not tokens. + if config.get("rpm"): + deployment["rpm"] = config["rpm"] + + return deployment + + except Exception as e: + logger.warning(f"Failed to convert image gen config to deployment: {e}") + return None + + @classmethod + def get_router(cls) -> Router | None: + """Get the initialized router instance.""" + instance = cls.get_instance() + return instance._router + + @classmethod + def is_initialized(cls) -> bool: + """Check if the router has been initialized.""" + instance = cls.get_instance() + return instance._initialized and instance._router is not None + + @classmethod + def get_model_count(cls) -> int: + """Get the number of models in the router.""" + instance = cls.get_instance() + return len(instance._model_list) + + @classmethod + async def aimage_generation( + cls, + prompt: str, + model: str = "auto", + n: int | None = None, + timeout: int = 600, + **kwargs, + ) -> ImageResponse: + """ + Generate images using the router for load balancing. + + Uses Router.aimage_generation() which distributes requests + across configured image generation deployments. + + Parameters like size, quality, style, and response_format are intentionally + omitted to keep the interface model-agnostic. Providers use their own + sensible defaults. If needed, pass them via **kwargs. + + Args: + prompt: Text description of the desired image(s) + model: Model alias (default "auto" for router routing) + n: Number of images to generate + timeout: Request timeout in seconds + **kwargs: Additional provider-specific params (size, quality, etc.) + + Returns: + ImageResponse from litellm + + Raises: + ValueError: If router is not initialized + """ + instance = cls.get_instance() + if not instance._router: + raise ValueError( + "Image Generation Router not initialized. " + "Ensure global_llm_config.yaml has global_image_generation_configs." + ) + + # Build kwargs for aimage_generation + gen_kwargs: dict[str, Any] = { + "model": model, + "prompt": prompt, + "timeout": timeout, + } + if n is not None: + gen_kwargs["n"] = n + gen_kwargs.update(kwargs) + + return await instance._router.aimage_generation(**gen_kwargs) + + +def is_image_gen_auto_mode(config_id: int | None) -> bool: + """ + Check if the given config ID represents Image Generation Auto mode. + + Args: + config_id: The config ID to check + + Returns: + True if this is Auto mode, False otherwise + """ + return config_id == IMAGE_GEN_AUTO_MODE_ID diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 05dd2d4dd..57fbc9663 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -479,6 +479,31 @@ class VercelStreamingService: }, ) + def format_thread_title_update(self, thread_id: int, title: str) -> str: + """ + Format a thread title update notification (SurfSense specific). + + This is sent after the first response in a thread to update the + auto-generated title based on the conversation content. + + Args: + thread_id: The ID of the thread being updated + title: The new title for the thread + + Returns: + str: SSE formatted thread title update data part + + Example output: + data: {"type":"data-thread-title-update","data":{"threadId":123,"title":"New Title"}} + """ + return self.format_data( + "thread-title-update", + { + "threadId": thread_id, + "title": title, + }, + ) + # ========================================================================= # Error Part # ========================================================================= diff --git a/surfsense_backend/app/services/notification_service.py b/surfsense_backend/app/services/notification_service.py index 1788d05e1..a759f3536 100644 --- a/surfsense_backend/app/services/notification_service.py +++ b/surfsense_backend/app/services/notification_service.py @@ -861,6 +861,98 @@ class MentionNotificationHandler(BaseNotificationHandler): raise +class CommentReplyNotificationHandler(BaseNotificationHandler): + """Handler for comment reply notifications.""" + + def __init__(self): + super().__init__("comment_reply") + + async def find_notification_by_reply( + self, + session: AsyncSession, + reply_id: int, + user_id: UUID, + ) -> Notification | None: + query = select(Notification).where( + Notification.type == self.notification_type, + Notification.user_id == user_id, + Notification.notification_metadata["reply_id"].astext == str(reply_id), + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def notify_comment_reply( + self, + session: AsyncSession, + user_id: UUID, + reply_id: int, + parent_comment_id: int, + message_id: int, + thread_id: int, + thread_title: str, + author_id: str, + author_name: str, + author_avatar_url: str | None, + author_email: str, + content_preview: str, + search_space_id: int, + ) -> Notification: + existing = await self.find_notification_by_reply(session, reply_id, user_id) + if existing: + logger.info( + f"Notification already exists for reply {reply_id} to user {user_id}" + ) + return existing + + title = f"{author_name} replied in a thread" + message = content_preview[:100] + ("..." if len(content_preview) > 100 else "") + + metadata = { + "reply_id": reply_id, + "parent_comment_id": parent_comment_id, + "message_id": message_id, + "thread_id": thread_id, + "thread_title": thread_title, + "author_id": author_id, + "author_name": author_name, + "author_avatar_url": author_avatar_url, + "author_email": author_email, + "content_preview": content_preview[:200], + } + + try: + notification = Notification( + user_id=user_id, + search_space_id=search_space_id, + type=self.notification_type, + title=title, + message=message, + notification_metadata=metadata, + ) + session.add(notification) + await session.commit() + await session.refresh(notification) + logger.info( + f"Created comment_reply notification {notification.id} for user {user_id}" + ) + return notification + except Exception as e: + await session.rollback() + if ( + "duplicate key" in str(e).lower() + or "unique constraint" in str(e).lower() + ): + logger.warning( + f"Duplicate notification for reply {reply_id} to user {user_id}" + ) + existing = await self.find_notification_by_reply( + session, reply_id, user_id + ) + if existing: + return existing + raise + + class PageLimitNotificationHandler(BaseNotificationHandler): """Handler for page limit exceeded notifications.""" @@ -959,6 +1051,7 @@ class NotificationService: connector_indexing = ConnectorIndexingNotificationHandler() document_processing = DocumentProcessingNotificationHandler() mention = MentionNotificationHandler() + comment_reply = CommentReplyNotificationHandler() page_limit = PageLimitNotificationHandler() @staticmethod diff --git a/surfsense_backend/app/services/public_chat_service.py b/surfsense_backend/app/services/public_chat_service.py index 2125dd8ce..4da316240 100644 --- a/surfsense_backend/app/services/public_chat_service.py +++ b/surfsense_backend/app/services/public_chat_service.py @@ -366,11 +366,14 @@ async def list_snapshots_for_thread( if not thread: raise HTTPException(status_code=404, detail="Thread not found") - if thread.created_by_id != user.id: - raise HTTPException( - status_code=403, - detail="Only the creator can view snapshots", - ) + # Check permission to view public share links + await check_permission( + session, + user, + thread.search_space_id, + Permission.PUBLIC_SHARING_VIEW.value, + "You don't have permission to view public share links", + ) result = await session.execute( select(PublicChatSnapshot) diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index d0710d246..760651589 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -1,6 +1,7 @@ """Celery tasks for connector indexing.""" import logging +import traceback from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool @@ -11,6 +12,36 @@ from app.config import config logger = logging.getLogger(__name__) +def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> None: + """ + Handle greenlet_spawn errors with detailed logging for debugging. + + The 'greenlet_spawn has not been called' error occurs when: + 1. SQLAlchemy lazy-loads a relationship outside of an async context + 2. A sync operation is called from an async context (or vice versa) + 3. Session objects are accessed after the session is closed + + This helper logs detailed context to help identify the root cause. + """ + error_str = str(e) + if "greenlet_spawn has not been called" in error_str: + logger.error( + f"GREENLET ERROR in {task_name} for connector {connector_id}: {error_str}\n" + f"This error typically occurs when SQLAlchemy tries to lazy-load a relationship " + f"outside of an async context. Check for:\n" + f"1. Accessing relationship attributes (e.g., document.chunks, connector.search_space) " + f"without using selectinload() or joinedload()\n" + f"2. Accessing model attributes after the session is closed\n" + f"3. Passing ORM objects between different async contexts\n" + f"Stack trace:\n{traceback.format_exc()}" + ) + else: + logger.error( + f"Error in {task_name} for connector {connector_id}: {error_str}\n" + f"Stack trace:\n{traceback.format_exc()}" + ) + + def get_celery_session_maker(): """ Create a new async session maker for Celery tasks. @@ -46,6 +77,9 @@ def index_slack_messages_task( connector_id, search_space_id, user_id, start_date, end_date ) ) + except Exception as e: + _handle_greenlet_error(e, "index_slack_messages", connector_id) + raise finally: loop.close() @@ -89,6 +123,9 @@ def index_notion_pages_task( connector_id, search_space_id, user_id, start_date, end_date ) ) + except Exception as e: + _handle_greenlet_error(e, "index_notion_pages", connector_id) + raise finally: loop.close() @@ -347,6 +384,9 @@ def index_google_calendar_events_task( connector_id, search_space_id, user_id, start_date, end_date ) ) + except Exception as e: + _handle_greenlet_error(e, "index_google_calendar_events", connector_id) + raise finally: loop.close() @@ -696,6 +736,9 @@ def index_crawled_urls_task( connector_id, search_space_id, user_id, start_date, end_date ) ) + except Exception as e: + _handle_greenlet_error(e, "index_crawled_urls", connector_id) + raise finally: loop.close() diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 688777203..685f77e39 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -27,6 +27,7 @@ from app.agents.new_chat.llm_config import ( load_llm_config_from_yaml, ) from app.db import Document, SurfsenseDocsDocument +from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.schemas.new_chat import ChatAttachment from app.services.chat_session_state_service import ( clear_ai_responding, @@ -1208,6 +1209,62 @@ async def stream_new_chat( if completion_event: yield completion_event + # Generate LLM title for new chats after first response + # Check if this is the first assistant response by counting existing assistant messages + from sqlalchemy import func + + from app.db import NewChatMessage, NewChatThread + + assistant_count_result = await session.execute( + select(func.count(NewChatMessage.id)).filter( + NewChatMessage.thread_id == chat_id, + NewChatMessage.role == "assistant", + ) + ) + assistant_message_count = assistant_count_result.scalar() or 0 + + # Only generate title on the first response (no prior assistant messages) + if assistant_message_count == 0: + generated_title = None + try: + # Generate title using the same LLM + title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm + # Truncate inputs to avoid context length issues + truncated_query = user_query[:500] + truncated_response = accumulated_text[:1000] + title_result = await title_chain.ainvoke( + { + "user_query": truncated_query, + "assistant_response": truncated_response, + } + ) + + # Extract and clean the title + if title_result and hasattr(title_result, "content"): + raw_title = title_result.content.strip() + # Validate the title (reasonable length) + if raw_title and len(raw_title) <= 100: + # Remove any quotes or extra formatting + generated_title = raw_title.strip("\"'") + except Exception: + generated_title = None + + # Only update if LLM succeeded (keep truncated prompt title as fallback) + if generated_title: + # Fetch thread and update title + thread_result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == chat_id) + ) + thread = thread_result.scalars().first() + if thread: + thread.title = generated_title + await session.commit() + + # Notify frontend of the title update + yield streaming_service.format_thread_title_update( + chat_id, generated_title + ) + # Finish the step and message yield streaming_service.format_finish_step() yield streaming_service.format_finish() diff --git a/surfsense_backend/app/tasks/connector_indexers/base.py b/surfsense_backend/app/tasks/connector_indexers/base.py index b390937f0..b801b67d6 100644 --- a/surfsense_backend/app/tasks/connector_indexers/base.py +++ b/surfsense_backend/app/tasks/connector_indexers/base.py @@ -28,6 +28,34 @@ def get_current_timestamp() -> datetime: return datetime.now(UTC) +def parse_date_flexible(date_str: str) -> datetime: + """ + Parse date from multiple common formats. + + Args: + date_str: Date string to parse + + Returns: + Parsed datetime object + + Raises: + ValueError: If unable to parse the date string + """ + formats = ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"] + + for fmt in formats: + try: + return datetime.strptime(date_str.rstrip("Z"), fmt) + except ValueError: + continue + + # Try ISO format as fallback + try: + return datetime.fromisoformat(date_str.replace("Z", "+00:00")) + except ValueError: + raise ValueError(f"Unable to parse date: {date_str}") + + async def check_duplicate_document_by_hash( session: AsyncSession, content_hash: str ) -> Document | None: @@ -159,6 +187,26 @@ def calculate_date_range( ) end_date_str = end_date if end_date else calculated_end_date.strftime("%Y-%m-%d") + # FIX: Ensure end_date is at least 1 day after start_date to avoid + # "start_date must be strictly before end_date" errors when dates are the same + # (e.g., when last_indexed_at is today) + if start_date_str == end_date_str: + logger.info( + f"Start date ({start_date_str}) equals end date ({end_date_str}), " + "adjusting end date to next day to ensure valid date range" + ) + # Parse end_date and add 1 day + try: + end_dt = parse_date_flexible(end_date_str) + except ValueError: + logger.warning( + f"Could not parse end_date '{end_date_str}', using current date" + ) + end_dt = datetime.now() + end_dt = end_dt + timedelta(days=1) + end_date_str = end_dt.strftime("%Y-%m-%d") + logger.info(f"Adjusted end date to {end_date_str}") + return start_date_str, end_date_str diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index 386c9de43..0b773025f 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -27,6 +27,7 @@ from .base import ( get_connector_by_id, get_current_timestamp, logger, + parse_date_flexible, update_connector_last_indexed, ) @@ -217,6 +218,26 @@ async def index_google_calendar_events( start_date_str = start_date end_date_str = end_date + # FIX: Ensure end_date is at least 1 day after start_date to avoid + # "start_date must be strictly before end_date" errors when dates are the same + # (e.g., when last_indexed_at is today) + if start_date_str == end_date_str: + logger.info( + f"Start date ({start_date_str}) equals end date ({end_date_str}), " + "adjusting end date to next day to ensure valid date range" + ) + # Parse end_date and add 1 day + try: + end_dt = parse_date_flexible(end_date_str) + except ValueError: + logger.warning( + f"Could not parse end_date '{end_date_str}', using current date" + ) + end_dt = datetime.now() + end_dt = end_dt + timedelta(days=1) + end_date_str = end_dt.strftime("%Y-%m-%d") + logger.info(f"Adjusted end date to {end_date_str}") + await task_logger.log_task_progress( log_entry, f"Fetching Google Calendar events from {start_date_str} to {end_date_str}", diff --git a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py index b1adeb035..ed300898c 100644 --- a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py @@ -196,13 +196,44 @@ async def index_notion_pages( "Recommend reconnecting with OAuth." ) except Exception as e: - await task_logger.log_task_failure( - log_entry, - f"Failed to get Notion pages for connector {connector_id}", - str(e), - {"error_type": "PageFetchError"}, + error_str = str(e) + # Check if this is an unsupported block type error (transcription, ai_block, etc.) + # These are known Notion API limitations and should be logged as warnings, not errors + unsupported_block_errors = [ + "transcription is not supported", + "ai_block is not supported", + "is not supported via the API", + ] + is_unsupported_block_error = any( + err in error_str.lower() for err in unsupported_block_errors ) - logger.error(f"Error fetching Notion pages: {e!s}", exc_info=True) + + if is_unsupported_block_error: + # Log as warning since this is a known Notion API limitation + logger.warning( + f"Notion API limitation for connector {connector_id}: {error_str}. " + "This is a known issue with Notion AI blocks (transcription, ai_block) " + "that are not accessible via the Notion API." + ) + await task_logger.log_task_failure( + log_entry, + f"Failed to get Notion pages: Notion API limitation", + f"{error_str} - This page contains Notion AI content (transcription/ai_block) that cannot be accessed via the API.", + {"error_type": "UnsupportedBlockType", "is_known_limitation": True}, + ) + else: + # Log as error for other failures + logger.error( + f"Error fetching Notion pages for connector {connector_id}: {error_str}", + exc_info=True, + ) + await task_logger.log_task_failure( + log_entry, + f"Failed to get Notion pages for connector {connector_id}", + str(e), + {"error_type": "PageFetchError"}, + ) + await notion_client.close() return 0, f"Failed to get Notion pages: {e!s}" diff --git a/surfsense_backend/app/tasks/connector_indexers/webcrawler_indexer.py b/surfsense_backend/app/tasks/connector_indexers/webcrawler_indexer.py index cb11a6ec2..a2f0898ba 100644 --- a/surfsense_backend/app/tasks/connector_indexers/webcrawler_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/webcrawler_indexer.py @@ -108,10 +108,15 @@ async def index_crawled_urls( api_key = connector.config.get("FIRECRAWL_API_KEY") # Get URLs from connector config - urls = parse_webcrawler_urls(connector.config.get("INITIAL_URLS")) + raw_initial_urls = connector.config.get("INITIAL_URLS") + urls = parse_webcrawler_urls(raw_initial_urls) + # DEBUG: Log connector config details for troubleshooting empty URL issues logger.info( - f"Starting crawled web page indexing for connector {connector_id} with {len(urls)} URLs" + f"Starting crawled web page indexing for connector {connector_id} with {len(urls)} URLs. " + f"Connector name: {connector.name}, " + f"INITIAL_URLS type: {type(raw_initial_urls).__name__}, " + f"INITIAL_URLS value: {repr(raw_initial_urls)[:200] if raw_initial_urls else 'None'}" ) # Initialize webcrawler client @@ -128,11 +133,18 @@ async def index_crawled_urls( # Validate URLs if not urls: + # DEBUG: Log detailed connector config for troubleshooting + logger.error( + f"No URLs provided for indexing. Connector ID: {connector_id}, " + f"Connector name: {connector.name}, " + f"Config keys: {list(connector.config.keys()) if connector.config else 'None'}, " + f"INITIAL_URLS raw value: {repr(raw_initial_urls)}" + ) await task_logger.log_task_failure( log_entry, "No URLs provided for indexing", - "Empty URL list", - {"error_type": "ValidationError"}, + f"Empty URL list. INITIAL_URLS value: {repr(raw_initial_urls)[:100]}", + {"error_type": "ValidationError", "connector_name": connector.name}, ) return 0, "No URLs provided for indexing" diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index 4be2fe525..ee07ba88f 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -23,17 +23,20 @@ from app.db import ( get_default_roles_config, get_user_db, ) +from app.utils.refresh_tokens import create_refresh_token logger = logging.getLogger(__name__) class BearerResponse(BaseModel): access_token: str + refresh_token: str token_type: str SECRET = config.SECRET_KEY + if config.AUTH_TYPE == "GOOGLE": from httpx_oauth.clients.google import GoogleOAuth2 @@ -183,7 +186,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: - return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24) + return JWTStrategy( + secret=SECRET, + lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS, + ) # # COOKIE AUTH | Uncomment if you want to use cookie auth. @@ -209,9 +215,32 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: # BEARER AUTH CODE. class CustomBearerTransport(BearerTransport): async def get_login_response(self, token: str) -> Response: - bearer_response = BearerResponse(access_token=token, token_type="bearer") - redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}" + import jwt + + # Decode JWT to get user_id for refresh token creation + try: + payload = jwt.decode( + token, SECRET, algorithms=["HS256"], options={"verify_aud": False} + ) + user_id = uuid.UUID(payload.get("sub")) + refresh_token = await create_refresh_token(user_id) + except Exception as e: + logger.error(f"Failed to create refresh token: {e}") + # Fall back to response without refresh token + refresh_token = "" + + bearer_response = BearerResponse( + access_token=token, + refresh_token=refresh_token, + token_type="bearer", + ) + if config.AUTH_TYPE == "GOOGLE": + redirect_url = ( + f"{config.NEXT_FRONTEND_URL}/auth/callback" + f"?token={bearer_response.access_token}" + f"&refresh_token={bearer_response.refresh_token}" + ) return RedirectResponse(redirect_url, status_code=302) else: return JSONResponse(bearer_response.model_dump()) diff --git a/surfsense_backend/app/utils/refresh_tokens.py b/surfsense_backend/app/utils/refresh_tokens.py new file mode 100644 index 000000000..8c0312ba8 --- /dev/null +++ b/surfsense_backend/app/utils/refresh_tokens.py @@ -0,0 +1,153 @@ +"""Utilities for managing refresh tokens.""" + +import hashlib +import logging +import secrets +import uuid +from datetime import UTC, datetime, timedelta + +from sqlalchemy import select, update + +from app.config import config +from app.db import RefreshToken, async_session_maker + +logger = logging.getLogger(__name__) + + +def generate_refresh_token() -> str: + """Generate a cryptographically secure refresh token.""" + return secrets.token_urlsafe(32) + + +def hash_token(token: str) -> str: + """Hash a token for secure storage.""" + return hashlib.sha256(token.encode()).hexdigest() + + +async def create_refresh_token( + user_id: uuid.UUID, + family_id: uuid.UUID | None = None, +) -> str: + """ + Create and store a new refresh token for a user. + + Args: + user_id: The user's ID + family_id: Optional family ID for token rotation + + Returns: + The plaintext refresh token + """ + token = generate_refresh_token() + token_hash = hash_token(token) + expires_at = datetime.now(UTC) + timedelta( + seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS + ) + + if family_id is None: + family_id = uuid.uuid4() + + async with async_session_maker() as session: + refresh_token = RefreshToken( + user_id=user_id, + token_hash=token_hash, + expires_at=expires_at, + family_id=family_id, + ) + session.add(refresh_token) + await session.commit() + + return token + + +async def validate_refresh_token(token: str) -> RefreshToken | None: + """ + Validate a refresh token. Handles reuse detection. + + Args: + token: The plaintext refresh token + + Returns: + RefreshToken if valid, None otherwise + """ + token_hash = hash_token(token) + + async with async_session_maker() as session: + result = await session.execute( + select(RefreshToken).where(RefreshToken.token_hash == token_hash) + ) + refresh_token = result.scalars().first() + + if not refresh_token: + return None + + # Reuse detection: revoked token used while family has active tokens + if refresh_token.is_revoked: + active = await session.execute( + select(RefreshToken).where( + RefreshToken.family_id == refresh_token.family_id, + RefreshToken.is_revoked == False, # noqa: E712 + RefreshToken.expires_at > datetime.now(UTC), + ) + ) + if active.scalars().first(): + # Revoke entire family + await session.execute( + update(RefreshToken) + .where(RefreshToken.family_id == refresh_token.family_id) + .values(is_revoked=True) + ) + await session.commit() + logger.warning(f"Token reuse detected for user {refresh_token.user_id}") + return None + + if refresh_token.is_expired: + return None + + return refresh_token + + +async def rotate_refresh_token(old_token: RefreshToken) -> str: + """Revoke old token and create new one in same family.""" + async with async_session_maker() as session: + await session.execute( + update(RefreshToken) + .where(RefreshToken.id == old_token.id) + .values(is_revoked=True) + ) + await session.commit() + + return await create_refresh_token(old_token.user_id, old_token.family_id) + + +async def revoke_refresh_token(token: str) -> bool: + """ + Revoke a single refresh token by its plaintext value. + + Args: + token: The plaintext refresh token + + Returns: + True if token was found and revoked, False otherwise + """ + token_hash = hash_token(token) + + async with async_session_maker() as session: + result = await session.execute( + update(RefreshToken) + .where(RefreshToken.token_hash == token_hash) + .values(is_revoked=True) + ) + await session.commit() + return result.rowcount > 0 + + +async def revoke_all_user_tokens(user_id: uuid.UUID) -> None: + """Revoke all refresh tokens for a user (logout all devices).""" + async with async_session_maker() as session: + await session.execute( + update(RefreshToken) + .where(RefreshToken.user_id == user_id) + .values(is_revoked=True) + ) + await session.commit() diff --git a/surfsense_backend/app/utils/signed_image_urls.py b/surfsense_backend/app/utils/signed_image_urls.py new file mode 100644 index 000000000..d8d0bb57e --- /dev/null +++ b/surfsense_backend/app/utils/signed_image_urls.py @@ -0,0 +1,44 @@ +""" +Access token utilities for generated images. + +Provides token generation and verification so that generated images can be +served via tags (which cannot pass auth headers) while still +restricting access to authorised users. + +Each image generation record stores its own random access token. The token +is verified by comparing the incoming query-parameter value against the +stored value in the database. This approach: + +* Survives SECRET_KEY rotation — tokens are random, not derived from a key. +* Allows explicit revocation — just clear the column. +* Is immune to timing attacks — uses ``hmac.compare_digest``. +""" + +import hmac +import secrets + + +def generate_image_token() -> str: + """ + Generate a cryptographically random access token for an image. + + Returns: + A 64-character URL-safe hex string. + """ + return secrets.token_hex(32) + + +def verify_image_token(stored_token: str | None, provided_token: str) -> bool: + """ + Constant-time comparison of a stored token against a user-provided one. + + Args: + stored_token: The token persisted on the ImageGeneration record. + provided_token: The token from the URL query parameter. + + Returns: + True if the tokens match, False otherwise. + """ + if not stored_token or not provided_token: + return False + return hmac.compare_digest(stored_token, provided_token) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 9b462fcbc..1a535539d 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -437,7 +437,10 @@ export default function NewChatPage() { let isNewThread = false; if (!currentThreadId) { try { - const newThread = await createThread(searchSpaceId, "New Chat"); + // Create thread with truncated prompt as initial title + const initialTitle = + userQuery.trim().slice(0, 100) + (userQuery.trim().length > 100 ? "..." : ""); + const newThread = await createThread(searchSpaceId, initialTitle); currentThreadId = newThread.id; setThreadId(currentThreadId); // Set currentThread so ChatHeader can show share button immediately @@ -827,6 +830,26 @@ export default function NewChatPage() { break; } + case "data-thread-title-update": { + // Handle thread title update from LLM-generated title + const titleData = parsed.data as { threadId: number; title: string }; + if (titleData?.title && titleData?.threadId === currentThreadId) { + // Update current thread state with new title + setCurrentThread((prev) => + prev ? { ...prev, title: titleData.title } : prev + ); + // Invalidate thread list to refresh sidebar + queryClient.invalidateQueries({ + queryKey: ["threads", String(searchSpaceId)], + }); + // Invalidate thread detail for breadcrumb update + queryClient.invalidateQueries({ + queryKey: ["threads", String(searchSpaceId), "detail", String(titleData.threadId)], + }); + } + break; + } + case "error": throw new Error(parsed.errorText || "Server error"); } diff --git a/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx index 1a727f1b6..e6c973ac6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx @@ -7,6 +7,7 @@ import { ChevronRight, FileText, Globe, + ImageIcon, type LucideIcon, Menu, MessageSquare, @@ -19,6 +20,7 @@ import { useTranslations } from "next-intl"; import { useCallback, useEffect, useState } from "react"; import { PublicChatSnapshotsManager } from "@/components/public-chat-snapshots/public-chat-snapshots-manager"; import { GeneralSettingsManager } from "@/components/settings/general-settings-manager"; +import { ImageModelManager } from "@/components/settings/image-model-manager"; import { LLMRoleManager } from "@/components/settings/llm-role-manager"; import { ModelConfigManager } from "@/components/settings/model-config-manager"; import { PromptConfigManager } from "@/components/settings/prompt-config-manager"; @@ -52,6 +54,12 @@ const settingsNavItems: SettingsNavItem[] = [ descriptionKey: "nav_role_assignments_desc", icon: Brain, }, + { + id: "image-models", + labelKey: "nav_image_models", + descriptionKey: "nav_image_models_desc", + icon: ImageIcon, + }, { id: "prompts", labelKey: "nav_system_instructions", @@ -282,8 +290,11 @@ function SettingsContent({ )} {activeSection === "models" && } - {activeSection === "roles" && } - {activeSection === "prompts" && } + {activeSection === "roles" && } + {activeSection === "image-models" && ( + + )} + {activeSection === "prompts" && } {activeSection === "public-links" && ( )} diff --git a/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx b/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx index b7040b4e3..3424113a9 100644 --- a/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx +++ b/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx @@ -5,6 +5,7 @@ import { ArrowLeft, ChevronRight, X } from "lucide-react"; import { AnimatePresence, motion } from "motion/react"; import { useTranslations } from "next-intl"; import { Button } from "@/components/ui/button"; +import { APP_VERSION } from "@/lib/env-config"; import { cn } from "@/lib/utils"; export interface SettingsNavItem { @@ -148,6 +149,11 @@ export function UserSettingsSidebar({ ); })} + + {/* Version display */} +
+

v{APP_VERSION}

+
); diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts new file mode 100644 index 000000000..dbaf441d0 --- /dev/null +++ b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts @@ -0,0 +1,91 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + CreateImageGenConfigRequest, + GetImageGenConfigsResponse, + UpdateImageGenConfigRequest, + UpdateImageGenConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Mutation atom for creating a new ImageGenerationConfig + */ +export const createImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "create"], + enabled: !!searchSpaceId, + mutationFn: async (request: CreateImageGenConfigRequest) => { + return imageGenConfigApiService.createConfig(request); + }, + onSuccess: () => { + toast.success("Image model configuration created"); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to create image model configuration"); + }, + }; +}); + +/** + * Mutation atom for updating an existing ImageGenerationConfig + */ +export const updateImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "update"], + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateImageGenConfigRequest) => { + return imageGenConfigApiService.updateConfig(request); + }, + onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => { + toast.success("Image model configuration updated"); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.byId(request.id), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update image model configuration"); + }, + }; +}); + +/** + * Mutation atom for deleting an ImageGenerationConfig + */ +export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "delete"], + enabled: !!searchSpaceId, + mutationFn: async (id: number) => { + return imageGenConfigApiService.deleteConfig(id); + }, + onSuccess: (_, id: number) => { + toast.success("Image model configuration deleted"); + queryClient.setQueryData( + cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + (oldData: GetImageGenConfigsResponse | undefined) => { + if (!oldData) return oldData; + return oldData.filter((config) => config.id !== id); + } + ); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to delete image model configuration"); + }, + }; +}); diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts new file mode 100644 index 000000000..a45e69a03 --- /dev/null +++ b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts @@ -0,0 +1,33 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Query atom for fetching user-created image gen configs for the active search space + */ +export const imageGenConfigsAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, // 5 minutes + queryFn: async () => { + return imageGenConfigApiService.getConfigs(Number(searchSpaceId)); + }, + }; +}); + +/** + * Query atom for fetching global image gen configs (from YAML, negative IDs) + */ +export const globalImageGenConfigsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.imageGenConfigs.global(), + staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change + queryFn: async () => { + return imageGenConfigApiService.getGlobalConfigs(); + }, + }; +}); diff --git a/surfsense_web/components/Logo.tsx b/surfsense_web/components/Logo.tsx index 58f8d1c9f..9f5915777 100644 --- a/surfsense_web/components/Logo.tsx +++ b/surfsense_web/components/Logo.tsx @@ -4,16 +4,20 @@ import Image from "next/image"; import Link from "next/link"; import { cn } from "@/lib/utils"; -export const Logo = ({ className }: { className?: string }) => { - return ( - - logo - +export const Logo = ({ className, disableLink = false }: { className?: string; disableLink?: boolean }) => { + const image = ( + logo ); + + if (disableLink) { + return image; + } + + return {image}; }; diff --git a/surfsense_web/components/TokenHandler.tsx b/surfsense_web/components/TokenHandler.tsx index e3295df7c..230cda81a 100644 --- a/surfsense_web/components/TokenHandler.tsx +++ b/surfsense_web/components/TokenHandler.tsx @@ -3,7 +3,7 @@ import { useSearchParams } from "next/navigation"; import { useEffect } from "react"; import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; -import { getAndClearRedirectPath, setBearerToken } from "@/lib/auth-utils"; +import { getAndClearRedirectPath, setBearerToken, setRefreshToken } from "@/lib/auth-utils"; import { trackLoginSuccess } from "@/lib/posthog/events"; interface TokenHandlerProps { @@ -35,8 +35,9 @@ const TokenHandler = ({ // Only run on client-side if (typeof window === "undefined") return; - // Get token from URL parameters + // Get tokens from URL parameters const token = searchParams.get(tokenParamName); + const refreshToken = searchParams.get("refresh_token"); if (token) { try { @@ -50,10 +51,15 @@ const TokenHandler = ({ // Clear the flag for future logins sessionStorage.removeItem("login_success_tracked"); - // Store token in localStorage using both methods for compatibility + // Store access token in localStorage using both methods for compatibility localStorage.setItem(storageKey, token); setBearerToken(token); + // Store refresh token if provided + if (refreshToken) { + setRefreshToken(refreshToken); + } + // Check if there's a saved redirect path from before the auth flow const savedRedirectPath = getAndClearRedirectPath(); diff --git a/surfsense_web/components/UserDropdown.tsx b/surfsense_web/components/UserDropdown.tsx index 3dac745cf..233a41a1f 100644 --- a/surfsense_web/components/UserDropdown.tsx +++ b/surfsense_web/components/UserDropdown.tsx @@ -1,7 +1,8 @@ "use client"; -import { BadgeCheck, LogOut } from "lucide-react"; +import { BadgeCheck, Loader2, LogOut } from "lucide-react"; import { useRouter } from "next/navigation"; +import { useState } from "react"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; import { Button } from "@/components/ui/button"; import { @@ -13,6 +14,7 @@ import { DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; +import { logout } from "@/lib/auth-utils"; import { cleanupElectric } from "@/lib/electric/client"; import { resetUser, trackLogout } from "@/lib/posthog/events"; @@ -26,8 +28,11 @@ export function UserDropdown({ }; }) { const router = useRouter(); + const [isLoggingOut, setIsLoggingOut] = useState(false); const handleLogout = async () => { + if (isLoggingOut) return; + setIsLoggingOut(true); try { // Track logout event and reset PostHog identity trackLogout(); @@ -41,15 +46,17 @@ export function UserDropdown({ console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err); } + // Revoke refresh token on server and clear all tokens from localStorage + await logout(); + if (typeof window !== "undefined") { - localStorage.removeItem("surfsense_bearer_token"); window.location.href = "/"; } } catch (error) { console.error("Error during logout:", error); - // Optionally, provide user feedback + // Even if there's an error, try to clear tokens and redirect + await logout(); if (typeof window !== "undefined") { - localStorage.removeItem("surfsense_bearer_token"); window.location.href = "/"; } } @@ -85,9 +92,17 @@ export function UserDropdown({ - - - Log out + + {isLoggingOut ? ( + + ) : ( + + )} + {isLoggingOut ? "Logging out..." : "Log out"} diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 4fd2446c3..5cdd287de 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -4,20 +4,19 @@ import { ErrorPrimitive, MessagePrimitive, useAssistantState, + useMessage, } from "@assistant-ui/react"; -import { useAtom, useAtomValue, useSetAtom } from "jotai"; +import { useAtom, useAtomValue } from "jotai"; import { CheckIcon, CopyIcon, DownloadIcon, MessageSquare, RefreshCwIcon } from "lucide-react"; import type { FC } from "react"; import { useContext, useEffect, useMemo, useRef, useState } from "react"; import { addingCommentToMessageIdAtom, - clearTargetCommentIdAtom, commentsCollapsedAtom, commentsEnabledAtom, targetCommentIdAtom, } from "@/atoms/chat/current-thread.atom"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; -import { BranchPicker } from "@/components/assistant-ui/branch-picker"; import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { ThinkingStepsContext, @@ -84,7 +83,6 @@ const AssistantMessageInner: FC = () => {
-
@@ -126,7 +124,6 @@ export const AssistantMessage: FC = () => { // Target comment navigation - read target from global atom const targetCommentId = useAtomValue(targetCommentIdAtom); - const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom); // Check if target comment belongs to this message (including replies) const hasTargetComment = useMemo(() => { @@ -263,6 +260,8 @@ export const AssistantMessage: FC = () => { }; const AssistantActionBar: FC = () => { + const { isLast } = useMessage(); + return ( { - - - - - + {/* Only allow regenerating the last assistant message */} + {isLast && ( + + + + + + )} ); }; diff --git a/surfsense_web/components/assistant-ui/branch-picker.tsx b/surfsense_web/components/assistant-ui/branch-picker.tsx deleted file mode 100644 index ee4addd2a..000000000 --- a/surfsense_web/components/assistant-ui/branch-picker.tsx +++ /dev/null @@ -1,32 +0,0 @@ -import { BranchPickerPrimitive } from "@assistant-ui/react"; -import { ChevronLeftIcon, ChevronRightIcon } from "lucide-react"; -import type { FC } from "react"; -import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; -import { cn } from "@/lib/utils"; - -export const BranchPicker: FC = ({ className, ...rest }) => { - return ( - - - - - - - - / - - - - - - - - ); -}; diff --git a/surfsense_web/components/assistant-ui/user-message.tsx b/surfsense_web/components/assistant-ui/user-message.tsx index 896b8c748..1ae8aef3c 100644 --- a/surfsense_web/components/assistant-ui/user-message.tsx +++ b/surfsense_web/components/assistant-ui/user-message.tsx @@ -4,7 +4,6 @@ import { FileText, PencilIcon } from "lucide-react"; import { type FC, useState } from "react"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { UserMessageAttachments } from "@/components/assistant-ui/attachment"; -import { BranchPicker } from "@/components/assistant-ui/branch-picker"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; interface AuthorMetadata { @@ -95,24 +94,47 @@ export const UserMessage: FC = () => { )} - - ); }; const UserActionBar: FC = () => { + const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning); + + // Get current message ID + const currentMessageId = useAssistantState(({ message }) => message?.id); + + // Find the last user message ID in the thread (computed once, memoized by selector) + const lastUserMessageId = useAssistantState(({ thread }) => { + const messages = thread.messages; + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === "user") { + return messages[i].id; + } + } + return null; + }); + + // Simple comparison - no iteration needed per message + const isLastUserMessage = currentMessageId === lastUserMessageId; + + // Show edit button only on the last user message and when thread is not running + const canEdit = isLastUserMessage && !isThreadRunning; + return ( - - - - - + {/* Only allow editing the last user message */} + {canEdit && ( + + + + + + )} ); }; diff --git a/surfsense_web/components/dashboard-breadcrumb.tsx b/surfsense_web/components/dashboard-breadcrumb.tsx index 96bd0ef30..5c6399ce0 100644 --- a/surfsense_web/components/dashboard-breadcrumb.tsx +++ b/surfsense_web/components/dashboard-breadcrumb.tsx @@ -14,6 +14,7 @@ import { } from "@/components/ui/breadcrumb"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch, getBearerToken } from "@/lib/auth-utils"; +import { getThreadFull } from "@/lib/chat/thread-persistence"; import { cacheKeys } from "@/lib/query-client/cache-keys"; interface BreadcrumbItemInterface { @@ -34,6 +35,16 @@ export function DashboardBreadcrumb() { enabled: !!searchSpaceId, }); + // Extract chat thread ID from pathname for chat pages + const chatThreadId = segments[2] === "new-chat" && segments[3] ? segments[3] : null; + + // Fetch thread details when on a chat page with a thread ID + const { data: threadData } = useQuery({ + queryKey: ["threads", searchSpaceId, "detail", chatThreadId], + queryFn: () => getThreadFull(Number(chatThreadId)), + enabled: !!chatThreadId && !!searchSpaceId, + }); + // State to store document title for editor breadcrumb const [documentTitle, setDocumentTitle] = useState(null); @@ -144,10 +155,11 @@ export function DashboardBreadcrumb() { } // Handle new-chat sub-sections (thread IDs) - // Don't show thread ID in breadcrumb - users identify chats by content, not by ID + // Show the chat title if available, otherwise fall back to "Chat" if (section === "new-chat") { + const chatLabel = threadData?.title || t("chat") || "Chat"; breadcrumbs.push({ - label: t("chat") || "Chat", + label: chatLabel, }); return breadcrumbs; } diff --git a/surfsense_web/components/homepage/navbar.tsx b/surfsense_web/components/homepage/navbar.tsx index 4abd4031b..670e3c810 100644 --- a/surfsense_web/components/homepage/navbar.tsx +++ b/surfsense_web/components/homepage/navbar.tsx @@ -64,7 +64,7 @@ const DesktopNav = ({ navItems, isScrolled }: any) => { href="/" className="flex flex-1 flex-row items-center gap-0.5 hover:opacity-80 transition-opacity" > - + SurfSense
@@ -145,7 +145,7 @@ const MobileNav = ({ navItems, isScrolled }: any) => { href="/" className="flex flex-row items-center gap-2 hover:opacity-80 transition-opacity" > - + SurfSense + + + + + {/* Delete Search Space Dialog */} diff --git a/surfsense_web/components/layout/ui/shell/LayoutShell.tsx b/surfsense_web/components/layout/ui/shell/LayoutShell.tsx index 3624c90a3..8eae99b03 100644 --- a/surfsense_web/components/layout/ui/shell/LayoutShell.tsx +++ b/surfsense_web/components/layout/ui/shell/LayoutShell.tsx @@ -54,6 +54,7 @@ interface LayoutShellProps { activeChatId?: number | null; onNewChat: () => void; onChatSelect: (chat: ChatItem) => void; + onChatRename?: (chat: ChatItem) => void; onChatDelete?: (chat: ChatItem) => void; onChatArchive?: (chat: ChatItem) => void; onViewAllSharedChats?: () => void; @@ -90,6 +91,7 @@ export function LayoutShell({ activeChatId, onNewChat, onChatSelect, + onChatRename, onChatDelete, onChatArchive, onViewAllSharedChats, @@ -147,6 +149,7 @@ export function LayoutShell({ activeChatId={activeChatId} onNewChat={onNewChat} onChatSelect={onChatSelect} + onChatRename={onChatRename} onChatDelete={onChatDelete} onChatArchive={onChatArchive} onViewAllSharedChats={onViewAllSharedChats} @@ -215,6 +218,7 @@ export function LayoutShell({ activeChatId={activeChatId} onNewChat={onNewChat} onChatSelect={onChatSelect} + onChatRename={onChatRename} onChatDelete={onChatDelete} onChatArchive={onChatArchive} onViewAllSharedChats={onViewAllSharedChats} diff --git a/surfsense_web/components/layout/ui/sidebar/ChatListItem.tsx b/surfsense_web/components/layout/ui/sidebar/ChatListItem.tsx index 6db6782d0..ba2989145 100644 --- a/surfsense_web/components/layout/ui/sidebar/ChatListItem.tsx +++ b/surfsense_web/components/layout/ui/sidebar/ChatListItem.tsx @@ -1,6 +1,6 @@ "use client"; -import { ArchiveIcon, MessageSquare, MoreHorizontal, RotateCcwIcon, Trash2 } from "lucide-react"; +import { ArchiveIcon, MessageSquare, MoreHorizontal, PencilIcon, RotateCcwIcon, Trash2 } from "lucide-react"; import { useTranslations } from "next-intl"; import { Button } from "@/components/ui/button"; import { @@ -17,6 +17,7 @@ interface ChatListItemProps { isActive?: boolean; archived?: boolean; onClick?: () => void; + onRename?: () => void; onArchive?: () => void; onDelete?: () => void; } @@ -26,6 +27,7 @@ export function ChatListItem({ isActive, archived, onClick, + onRename, onArchive, onDelete, }: ChatListItemProps) { @@ -57,15 +59,26 @@ export function ChatListItem({ {t("more_options")} - - {onArchive && ( - { - e.stopPropagation(); - onArchive(); - }} - > - {archived ? ( + + {onRename && ( + { + e.stopPropagation(); + onRename(); + }} + > + + {t("rename") || "Rename"} + + )} + {onArchive && ( + { + e.stopPropagation(); + onArchive(); + }} + > + {archived ? ( <> {t("unarchive") || "Restore"} diff --git a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx index 9ef49c0d8..f313dd6f9 100644 --- a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx @@ -4,7 +4,6 @@ import { useAtom } from "jotai"; import { AlertCircle, AlertTriangle, - AtSign, BellDot, Check, CheckCheck, @@ -15,6 +14,7 @@ import { Inbox, LayoutGrid, ListFilter, + MessageSquare, Search, X, } from "lucide-react"; @@ -46,6 +46,7 @@ import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import { + isCommentReplyMetadata, isConnectorIndexingMetadata, isNewMentionMetadata, isPageLimitExceededMetadata, @@ -133,7 +134,7 @@ function getConnectorTypeDisplayName(connectorType: string): string { ); } -type InboxTab = "mentions" | "status"; +type InboxTab = "comments" | "status"; type InboxFilter = "all" | "unread"; // Tab-specific data source with independent pagination @@ -186,7 +187,7 @@ export function InboxSidebar({ const [, setTargetCommentId] = useAtom(setTargetCommentIdAtom); const [searchQuery, setSearchQuery] = useState(""); - const [activeTab, setActiveTab] = useState("mentions"); + const [activeTab, setActiveTab] = useState("comments"); const [activeFilter, setActiveFilter] = useState("all"); const [selectedConnector, setSelectedConnector] = useState(null); const [mounted, setMounted] = useState(false); @@ -233,12 +234,17 @@ export function InboxSidebar({ } }, [activeTab]); - // Get current tab's data source - each tab has independent data and pagination - const currentDataSource = activeTab === "mentions" ? mentions : status; - const { loading, loadingMore = false, hasMore = false, loadMore } = currentDataSource; + // Both tabs now derive items from status (all types), so use status for pagination + const { loading, loadingMore = false, hasMore = false, loadMore } = status; - // Status tab includes: connector indexing, document processing, page limit exceeded, connector deletion - // Filter to only show status notification types + // Comments tab: mentions and comment replies + const commentsItems = useMemo( + () => + status.items.filter((item) => item.type === "new_mention" || item.type === "comment_reply"), + [status.items] + ); + + // Status tab: connector indexing, document processing, page limit exceeded, connector deletion const statusItems = useMemo( () => status.items.filter( @@ -270,8 +276,8 @@ export function InboxSidebar({ })); }, [statusItems]); - // Get items for current tab - mentions use their source directly, status uses filtered items - const displayItems = activeTab === "mentions" ? mentions.items : statusItems; + // Get items for current tab + const displayItems = activeTab === "comments" ? commentsItems : statusItems; // Filter items based on filter type, connector filter, and search query const filteredItems = useMemo(() => { @@ -334,9 +340,15 @@ export function InboxSidebar({ return () => observer.disconnect(); }, [loadMore, hasMore, loadingMore, open, searchQuery]); - // Use unread counts from data sources (more accurate than client-side counting) - const unreadMentionsCount = mentions.unreadCount; - const unreadStatusCount = status.unreadCount; + // Unread counts derived from filtered items + const unreadCommentsCount = useMemo( + () => commentsItems.filter((item) => !item.read).length, + [commentsItems] + ); + const unreadStatusCount = useMemo( + () => statusItems.filter((item) => !item.read).length, + [statusItems] + ); const handleItemClick = useCallback( async (item: InboxItem) => { @@ -347,19 +359,15 @@ export function InboxSidebar({ } if (item.type === "new_mention") { - // Use type guard for safe metadata access if (isNewMentionMetadata(item.metadata)) { const searchSpaceId = item.search_space_id; const threadId = item.metadata.thread_id; const commentId = item.metadata.comment_id; if (searchSpaceId && threadId) { - // Pre-set target comment ID before navigation - // This also ensures comments panel is not collapsed if (commentId) { setTargetCommentId(commentId); } - const url = commentId ? `/dashboard/${searchSpaceId}/new-chat/${threadId}?commentId=${commentId}` : `/dashboard/${searchSpaceId}/new-chat/${threadId}`; @@ -368,6 +376,24 @@ export function InboxSidebar({ router.push(url); } } + } else if (item.type === "comment_reply") { + if (isCommentReplyMetadata(item.metadata)) { + const searchSpaceId = item.search_space_id; + const threadId = item.metadata.thread_id; + const replyId = item.metadata.reply_id; + + if (searchSpaceId && threadId) { + if (replyId) { + setTargetCommentId(replyId); + } + const url = replyId + ? `/dashboard/${searchSpaceId}/new-chat/${threadId}?commentId=${replyId}` + : `/dashboard/${searchSpaceId}/new-chat/${threadId}`; + onOpenChange(false); + onCloseMobileSidebar?.(); + router.push(url); + } + } } else if (item.type === "page_limit_exceeded") { // Navigate to the upgrade/more-pages page if (isPageLimitExceededMetadata(item.metadata)) { @@ -411,24 +437,29 @@ export function InboxSidebar({ }; const getStatusIcon = (item: InboxItem) => { - // For mentions, show the author's avatar with initials fallback - if (item.type === "new_mention") { - // Use type guard for safe metadata access - if (isNewMentionMetadata(item.metadata)) { - const authorName = item.metadata.author_name; - const avatarUrl = item.metadata.author_avatar_url; - const authorEmail = item.metadata.author_email; + // For mentions and comment replies, show the author's avatar + if (item.type === "new_mention" || item.type === "comment_reply") { + const metadata = + item.type === "new_mention" + ? isNewMentionMetadata(item.metadata) + ? item.metadata + : null + : isCommentReplyMetadata(item.metadata) + ? item.metadata + : null; + if (metadata) { return ( - {avatarUrl && } + {metadata.author_avatar_url && ( + + )} - {getInitials(authorName, authorEmail)} + {getInitials(metadata.author_name, metadata.author_email)} ); } - // Fallback for invalid metadata return ( @@ -481,10 +512,10 @@ export function InboxSidebar({ }; const getEmptyStateMessage = () => { - if (activeTab === "mentions") { + if (activeTab === "comments") { return { - title: t("no_mentions") || "No mentions", - hint: t("no_mentions_hint") || "You'll see mentions from others here", + title: t("no_comments") || "No comments", + hint: t("no_comments_hint") || "You'll see mentions and replies here", }; } return { @@ -823,14 +854,14 @@ export function InboxSidebar({ > - - {t("mentions") || "Mentions"} + + {t("comments") || "Comments"} - {formatInboxCount(unreadMentionsCount)} + {formatInboxCount(unreadCommentsCount)} @@ -932,8 +963,8 @@ export function InboxSidebar({
) : (
- {activeTab === "mentions" ? ( - + {activeTab === "comments" ? ( + ) : ( )} diff --git a/surfsense_web/components/layout/ui/sidebar/MobileSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/MobileSidebar.tsx index 85f907611..3ed2f9cca 100644 --- a/surfsense_web/components/layout/ui/sidebar/MobileSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/MobileSidebar.tsx @@ -24,6 +24,7 @@ interface MobileSidebarProps { activeChatId?: number | null; onNewChat: () => void; onChatSelect: (chat: ChatItem) => void; + onChatRename?: (chat: ChatItem) => void; onChatDelete?: (chat: ChatItem) => void; onChatArchive?: (chat: ChatItem) => void; onViewAllSharedChats?: () => void; @@ -64,6 +65,7 @@ export function MobileSidebar({ activeChatId, onNewChat, onChatSelect, + onChatRename, onChatDelete, onChatArchive, onViewAllSharedChats, @@ -142,6 +144,7 @@ export function MobileSidebar({ onOpenChange(false); }} onChatSelect={handleChatSelect} + onChatRename={onChatRename} onChatDelete={onChatDelete} onChatArchive={onChatArchive} onViewAllSharedChats={onViewAllSharedChats} diff --git a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx index db04bf6dc..8763056ed 100644 --- a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx @@ -25,6 +25,7 @@ interface SidebarProps { activeChatId?: number | null; onNewChat: () => void; onChatSelect: (chat: ChatItem) => void; + onChatRename?: (chat: ChatItem) => void; onChatDelete?: (chat: ChatItem) => void; onChatArchive?: (chat: ChatItem) => void; onViewAllSharedChats?: () => void; @@ -51,6 +52,7 @@ export function Sidebar({ activeChatId, onNewChat, onChatSelect, + onChatRename, onChatDelete, onChatArchive, onViewAllSharedChats, @@ -163,6 +165,7 @@ export function Sidebar({ isActive={chat.id === activeChatId} archived={chat.archived} onClick={() => onChatSelect(chat)} + onRename={() => onChatRename?.(chat)} onArchive={() => onChatArchive?.(chat)} onDelete={() => onChatDelete?.(chat)} /> @@ -215,6 +218,7 @@ export function Sidebar({ isActive={chat.id === activeChatId} archived={chat.archived} onClick={() => onChatSelect(chat)} + onRename={() => onChatRename?.(chat)} onArchive={() => onChatArchive?.(chat)} onDelete={() => onChatDelete?.(chat)} /> diff --git a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx index 7c96b1dcb..38b3028d2 100644 --- a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx +++ b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx @@ -1,7 +1,8 @@ "use client"; -import { Check, ChevronUp, Languages, Laptop, LogOut, Moon, Settings, Sun } from "lucide-react"; +import { Check, ChevronUp, Languages, Laptop, Loader2, LogOut, Moon, Settings, Sun } from "lucide-react"; import { useTranslations } from "next-intl"; +import { useState } from "react"; import { DropdownMenu, DropdownMenuContent, @@ -124,6 +125,7 @@ export function SidebarUserProfile({ }: SidebarUserProfileProps) { const t = useTranslations("sidebar"); const { locale, setLocale } = useLocaleContext(); + const [isLoggingOut, setIsLoggingOut] = useState(false); const bgColor = stringToColor(user.email); const initials = getInitials(user.email); const displayName = user.name || user.email.split("@")[0]; @@ -136,6 +138,16 @@ export function SidebarUserProfile({ setTheme?.(newTheme); }; + const handleLogout = async () => { + if (isLoggingOut || !onLogout) return; + setIsLoggingOut(true); + try { + await onLogout(); + } finally { + setIsLoggingOut(false); + } + }; + // Collapsed view - just show avatar with dropdown if (isCollapsed) { return ( @@ -242,9 +254,13 @@ export function SidebarUserProfile({ - - - {t("logout")} + + {isLoggingOut ? ( + + ) : ( + + )} + {isLoggingOut ? t("loggingOut") : t("logout")} @@ -360,9 +376,13 @@ export function SidebarUserProfile({ - - - {t("logout")} + + {isLoggingOut ? ( + + ) : ( + + )} + {isLoggingOut ? t("loggingOut") : t("logout")} diff --git a/surfsense_web/components/new-chat/chat-header.tsx b/surfsense_web/components/new-chat/chat-header.tsx index a6cf8df3a..8a8fa11a0 100644 --- a/surfsense_web/components/new-chat/chat-header.tsx +++ b/surfsense_web/components/new-chat/chat-header.tsx @@ -2,9 +2,13 @@ import { useCallback, useState } from "react"; import type { + GlobalImageGenConfig, GlobalNewLLMConfig, + ImageGenerationConfig, NewLLMConfigPublic, } from "@/contracts/types/new-llm-config.types"; +import { ImageConfigSidebar } from "./image-config-sidebar"; +import { ImageModelSelector } from "./image-model-selector"; import { ModelConfigSidebar } from "./model-config-sidebar"; import { ModelSelector } from "./model-selector"; @@ -13,6 +17,7 @@ interface ChatHeaderProps { } export function ChatHeader({ searchSpaceId }: ChatHeaderProps) { + // LLM config sidebar state const [sidebarOpen, setSidebarOpen] = useState(false); const [selectedConfig, setSelectedConfig] = useState< NewLLMConfigPublic | GlobalNewLLMConfig | null @@ -20,6 +25,15 @@ export function ChatHeader({ searchSpaceId }: ChatHeaderProps) { const [isGlobal, setIsGlobal] = useState(false); const [sidebarMode, setSidebarMode] = useState<"create" | "edit" | "view">("view"); + // Image config sidebar state + const [imageSidebarOpen, setImageSidebarOpen] = useState(false); + const [selectedImageConfig, setSelectedImageConfig] = useState< + ImageGenerationConfig | GlobalImageGenConfig | null + >(null); + const [isImageGlobal, setIsImageGlobal] = useState(false); + const [imageSidebarMode, setImageSidebarMode] = useState<"create" | "edit" | "view">("view"); + + // LLM handlers const handleEditConfig = useCallback( (config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => { setSelectedConfig(config); @@ -39,15 +53,36 @@ export function ChatHeader({ searchSpaceId }: ChatHeaderProps) { const handleSidebarClose = useCallback((open: boolean) => { setSidebarOpen(open); - if (!open) { - // Reset state when closing - setSelectedConfig(null); - } + if (!open) setSelectedConfig(null); + }, []); + + // Image model handlers + const handleAddImageModel = useCallback(() => { + setSelectedImageConfig(null); + setIsImageGlobal(false); + setImageSidebarMode("create"); + setImageSidebarOpen(true); + }, []); + + const handleEditImageConfig = useCallback( + (config: ImageGenerationConfig | GlobalImageGenConfig, global: boolean) => { + setSelectedImageConfig(config); + setIsImageGlobal(global); + setImageSidebarMode(global ? "view" : "edit"); + setImageSidebarOpen(true); + }, + [] + ); + + const handleImageSidebarClose = useCallback((open: boolean) => { + setImageSidebarOpen(open); + if (!open) setSelectedImageConfig(null); }, []); return (
+ +
); } diff --git a/surfsense_web/components/new-chat/chat-share-button.tsx b/surfsense_web/components/new-chat/chat-share-button.tsx index fa05f44c1..2e04fa3ba 100644 --- a/surfsense_web/components/new-chat/chat-share-button.tsx +++ b/surfsense_web/components/new-chat/chat-share-button.tsx @@ -1,8 +1,9 @@ "use client"; -import { useQueryClient } from "@tanstack/react-query"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; import { useAtomValue, useSetAtom } from "jotai"; import { Globe, User, Users } from "lucide-react"; +import { useParams, useRouter } from "next/navigation"; import { useCallback, useMemo, useState } from "react"; import { toast } from "sonner"; import { currentThreadAtom, setThreadVisibilityAtom } from "@/atoms/chat/current-thread.atom"; @@ -11,6 +12,7 @@ import { createPublicChatSnapshotMutationAtom } from "@/atoms/public-chat-snapsh import { Button } from "@/components/ui/button"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { chatThreadsApiService } from "@/lib/apis/chat-threads-api.service"; import { type ChatVisibility, type ThreadRecord, @@ -46,6 +48,8 @@ const visibilityOptions: { export function ChatShareButton({ thread, onVisibilityChange, className }: ChatShareButtonProps) { const queryClient = useQueryClient(); + const router = useRouter(); + const params = useParams(); const [open, setOpen] = useState(false); // Use Jotai atom for visibility (single source of truth) @@ -65,6 +69,16 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS return access.permissions?.includes("public_sharing:create") ?? false; }, [access]); + // Query to check if thread has public snapshots + const { data: snapshotsData } = useQuery({ + queryKey: ["thread-snapshots", thread?.id], + queryFn: () => chatThreadsApiService.listPublicChatSnapshots({ thread_id: thread!.id }), + enabled: !!thread?.id, + staleTime: 30000, // Cache for 30 seconds + }); + const hasPublicSnapshots = (snapshotsData?.snapshots?.length ?? 0) > 0; + const snapshotCount = snapshotsData?.snapshots?.length ?? 0; + // Use Jotai visibility if available (synced from chat page), otherwise fall back to thread prop const currentVisibility = currentThreadState.visibility ?? thread?.visibility ?? "PRIVATE"; @@ -106,11 +120,13 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS try { await createSnapshot({ thread_id: thread.id }); + // Refetch snapshots to show the globe indicator + await queryClient.invalidateQueries({ queryKey: ["thread-snapshots", thread.id] }); setOpen(false); } catch (error) { console.error("Failed to create public link:", error); } - }, [thread, createSnapshot]); + }, [thread, createSnapshot, queryClient]); // Don't show if no thread (new chat that hasn't been created yet) if (!thread) { @@ -121,112 +137,131 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS const buttonLabel = currentVisibility === "PRIVATE" ? "Private" : "Shared"; return ( - - - - - - - - Share settings - - - e.preventDefault()} - > -
- {/* Visibility Options */} - {visibilityOptions.map((option) => { - const isSelected = currentVisibility === option.value; - const Icon = option.icon; - - return ( - + + + Share settings + + + e.preventDefault()} + > +
+ {/* Visibility Options */} + {visibilityOptions.map((option) => { + const isSelected = currentVisibility === option.value; + const Icon = option.icon; + + return ( +
-
-
- - {option.label} - + > +
-

- {option.description} -

-
- - ); - })} - - {canCreatePublicLink && ( - <> - {/* Divider */} -
- - {/* Public Link Option */} - - - )} -
-
- + + ); + })} + + {canCreatePublicLink && ( + <> + {/* Divider */} +
+ + {/* Public Link Option */} + + + )} +
+ + + + {/* Globe indicator when public snapshots exist - clicks to settings */} + {hasPublicSnapshots && ( + + + + + + {snapshotCount === 1 + ? "This chat has a public link" + : `This chat has ${snapshotCount} public links`} + + + )} +
); } diff --git a/surfsense_web/components/new-chat/image-config-sidebar.tsx b/surfsense_web/components/new-chat/image-config-sidebar.tsx new file mode 100644 index 000000000..18f98acb7 --- /dev/null +++ b/surfsense_web/components/new-chat/image-config-sidebar.tsx @@ -0,0 +1,522 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { + AlertCircle, + Check, + ChevronsUpDown, + Globe, + ImageIcon, + Key, + Shuffle, + X, + Zap, +} from "lucide-react"; +import { AnimatePresence, motion } from "motion/react"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { createPortal } from "react-dom"; +import { toast } from "sonner"; +import { + createImageGenConfigMutationAtom, + updateImageGenConfigMutationAtom, +} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Separator } from "@/components/ui/separator"; +import { Spinner } from "@/components/ui/spinner"; +import { IMAGE_GEN_MODELS, IMAGE_GEN_PROVIDERS } from "@/contracts/enums/image-gen-providers"; +import type { + GlobalImageGenConfig, + ImageGenerationConfig, +} from "@/contracts/types/new-llm-config.types"; +import { cn } from "@/lib/utils"; + +interface ImageConfigSidebarProps { + open: boolean; + onOpenChange: (open: boolean) => void; + config: ImageGenerationConfig | GlobalImageGenConfig | null; + isGlobal: boolean; + searchSpaceId: number; + mode: "create" | "edit" | "view"; +} + +const INITIAL_FORM = { + name: "", + description: "", + provider: "", + model_name: "", + api_key: "", + api_base: "", + api_version: "", +}; + +export function ImageConfigSidebar({ + open, + onOpenChange, + config, + isGlobal, + searchSpaceId, + mode, +}: ImageConfigSidebarProps) { + const [isSubmitting, setIsSubmitting] = useState(false); + const [mounted, setMounted] = useState(false); + const [formData, setFormData] = useState(INITIAL_FORM); + const [modelComboboxOpen, setModelComboboxOpen] = useState(false); + + useEffect(() => { + setMounted(true); + }, []); + + // Reset form when opening + useEffect(() => { + if (open) { + if (mode === "edit" && config && !isGlobal) { + setFormData({ + name: config.name || "", + description: config.description || "", + provider: config.provider || "", + model_name: config.model_name || "", + api_key: (config as ImageGenerationConfig).api_key || "", + api_base: config.api_base || "", + api_version: config.api_version || "", + }); + } else if (mode === "create") { + setFormData(INITIAL_FORM); + } + } + }, [open, mode, config, isGlobal]); + + // Mutations + const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom); + const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + // Escape key + useEffect(() => { + const handleEscape = (e: KeyboardEvent) => { + if (e.key === "Escape" && open) onOpenChange(false); + }; + window.addEventListener("keydown", handleEscape); + return () => window.removeEventListener("keydown", handleEscape); + }, [open, onOpenChange]); + + const isAutoMode = config && "is_auto_mode" in config && config.is_auto_mode; + + const suggestedModels = useMemo(() => { + if (!formData.provider) return []; + return IMAGE_GEN_MODELS.filter((m) => m.provider === formData.provider); + }, [formData.provider]); + + const getTitle = () => { + if (mode === "create") return "Add Image Model"; + if (isAutoMode) return "Auto Mode (Load Balanced)"; + if (isGlobal) return "View Global Image Model"; + return "Edit Image Model"; + }; + + const handleSubmit = useCallback(async () => { + setIsSubmitting(true); + try { + if (mode === "create") { + const result = await createConfig({ + name: formData.name, + provider: formData.provider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + description: formData.description || undefined, + search_space_id: searchSpaceId, + }); + // Set as active image model + if (result?.id) { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { image_generation_config_id: result.id }, + }); + } + toast.success("Image model created and assigned!"); + onOpenChange(false); + } else if (!isGlobal && config) { + await updateConfig({ + id: config.id, + data: { + name: formData.name, + description: formData.description || undefined, + provider: formData.provider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + }, + }); + toast.success("Image model updated!"); + onOpenChange(false); + } + } catch (error) { + console.error("Failed to save image config:", error); + toast.error("Failed to save image model"); + } finally { + setIsSubmitting(false); + } + }, [mode, isGlobal, config, formData, searchSpaceId, createConfig, updateConfig, updatePreferences, onOpenChange]); + + const handleUseGlobalConfig = useCallback(async () => { + if (!config || !isGlobal) return; + setIsSubmitting(true); + try { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { image_generation_config_id: config.id }, + }); + toast.success(`Now using ${config.name}`); + onOpenChange(false); + } catch (error) { + console.error("Failed to set image model:", error); + toast.error("Failed to set image model"); + } finally { + setIsSubmitting(false); + } + }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); + + const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key; + const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider); + + if (!mounted) return null; + + const sidebarContent = ( + + {open && ( + <> + {/* Backdrop */} + onOpenChange(false)} + /> + + {/* Sidebar */} + + {/* Header */} +
+
+
+ {isAutoMode ? ( + + ) : ( + + )} +
+
+

{getTitle()}

+
+ {isAutoMode ? ( + + + Recommended + + ) : isGlobal ? ( + + + Global + + ) : null} + {config && !isAutoMode && ( + {config.model_name} + )} +
+
+
+ +
+ + {/* Content */} +
+
+ {/* Auto mode */} + {isAutoMode && ( + <> + + + + Auto mode distributes image generation requests across all configured providers for optimal performance and rate limit protection. + + +
+ + +
+ + )} + + {/* Global config (read-only) */} + {isGlobal && !isAutoMode && config && ( + <> + + + + Global configurations are read-only. To customize, create a new model. + + +
+
+
+
Name
+

{config.name}

+
+ {config.description && ( +
+
Description
+

{config.description}

+
+ )} +
+ +
+
+
Provider
+

{config.provider}

+
+
+
Model
+

{config.model_name}

+
+
+
+
+ + +
+ + )} + + {/* Create / Edit form */} + {(mode === "create" || (mode === "edit" && !isGlobal)) && ( +
+ {/* Name */} +
+ + setFormData((p) => ({ ...p, name: e.target.value }))} + /> +
+ + {/* Description */} +
+ + setFormData((p) => ({ ...p, description: e.target.value }))} + /> +
+ + + + {/* Provider */} +
+ + +
+ + {/* Model Name */} +
+ + {suggestedModels.length > 0 ? ( + + + + + + + setFormData((p) => ({ ...p, model_name: val }))} + /> + + + Type a custom model name + + + {suggestedModels.map((m) => ( + { + setFormData((p) => ({ ...p, model_name: m.value })); + setModelComboboxOpen(false); + }} + > + + {m.value} + {m.label} + + ))} + + + + + + ) : ( + setFormData((p) => ({ ...p, model_name: e.target.value }))} + /> + )} +
+ + {/* API Key */} +
+ + setFormData((p) => ({ ...p, api_key: e.target.value }))} + /> +
+ + {/* API Base */} +
+ + setFormData((p) => ({ ...p, api_base: e.target.value }))} + /> +
+ + {/* Azure API Version */} + {formData.provider === "AZURE_OPENAI" && ( +
+ + setFormData((p) => ({ ...p, api_version: e.target.value }))} + /> +
+ )} + + {/* Actions */} +
+ + +
+
+ )} +
+
+
+ + )} +
+ ); + + return typeof document !== "undefined" ? createPortal(sidebarContent, document.body) : null; +} diff --git a/surfsense_web/components/new-chat/image-model-selector.tsx b/surfsense_web/components/new-chat/image-model-selector.tsx new file mode 100644 index 000000000..b3422b264 --- /dev/null +++ b/surfsense_web/components/new-chat/image-model-selector.tsx @@ -0,0 +1,364 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { + Check, + ChevronDown, + ChevronRight, + Edit3, + Globe, + ImageIcon, + Plus, + Shuffle, + User, +} from "lucide-react"; +import { useCallback, useMemo, useState } from "react"; +import { toast } from "sonner"; +import { + createImageGenConfigMutationAtom, + updateImageGenConfigMutationAtom, +} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; +import { + globalImageGenConfigsAtom, + imageGenConfigsAtom, +} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { llmPreferencesAtom } from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, + CommandSeparator, +} from "@/components/ui/command"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { Spinner } from "@/components/ui/spinner"; +import type { + GlobalImageGenConfig, + ImageGenerationConfig, +} from "@/contracts/types/new-llm-config.types"; +import { cn } from "@/lib/utils"; + +interface ImageModelSelectorProps { + className?: string; + onAddNew?: () => void; + onEdit?: (config: ImageGenerationConfig | GlobalImageGenConfig, isGlobal: boolean) => void; +} + +export function ImageModelSelector({ className, onAddNew, onEdit }: ImageModelSelectorProps) { + const [open, setOpen] = useState(false); + const [searchQuery, setSearchQuery] = useState(""); + + const { data: globalConfigs, isLoading: globalLoading } = + useAtomValue(globalImageGenConfigsAtom); + const { data: userConfigs, isLoading: userLoading } = useAtomValue(imageGenConfigsAtom); + const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom); + const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const isLoading = globalLoading || userLoading || prefsLoading; + + const currentConfig = useMemo(() => { + if (!preferences) return null; + const id = preferences.image_generation_config_id; + if (id === null || id === undefined) return null; + const globalMatch = globalConfigs?.find((c) => c.id === id); + if (globalMatch) return globalMatch; + return userConfigs?.find((c) => c.id === id) ?? null; + }, [preferences, globalConfigs, userConfigs]); + + const isCurrentAutoMode = useMemo(() => { + return currentConfig && "is_auto_mode" in currentConfig && currentConfig.is_auto_mode; + }, [currentConfig]); + + const filteredGlobal = useMemo(() => { + if (!globalConfigs) return []; + if (!searchQuery) return globalConfigs; + const q = searchQuery.toLowerCase(); + return globalConfigs.filter( + (c) => + c.name.toLowerCase().includes(q) || + c.model_name.toLowerCase().includes(q) || + c.provider.toLowerCase().includes(q) + ); + }, [globalConfigs, searchQuery]); + + const filteredUser = useMemo(() => { + if (!userConfigs) return []; + if (!searchQuery) return userConfigs; + const q = searchQuery.toLowerCase(); + return userConfigs.filter( + (c) => + c.name.toLowerCase().includes(q) || + c.model_name.toLowerCase().includes(q) || + c.provider.toLowerCase().includes(q) + ); + }, [userConfigs, searchQuery]); + + const totalModels = (globalConfigs?.length ?? 0) + (userConfigs?.length ?? 0); + + const handleSelect = useCallback( + async (configId: number) => { + if (currentConfig?.id === configId) { + setOpen(false); + return; + } + if (!searchSpaceId) { + toast.error("No search space selected"); + return; + } + try { + await updatePreferences({ + search_space_id: Number(searchSpaceId), + data: { image_generation_config_id: configId }, + }); + toast.success("Image model updated"); + setOpen(false); + } catch { + toast.error("Failed to switch image model"); + } + }, + [currentConfig, searchSpaceId, updatePreferences] + ); + + // Don't render if no configs at all + if (!isLoading && totalModels === 0) { + return ( + + ); + } + + return ( + + + + + + + + {totalModels > 3 && ( +
+ +
+ )} + + +
+ +

No image models found

+
+
+ + {/* Global Image Gen Configs */} + {filteredGlobal.length > 0 && ( + +
+ + Global Image Models +
+ {filteredGlobal.map((config) => { + const isSelected = currentConfig?.id === config.id; + const isAuto = "is_auto_mode" in config && config.is_auto_mode; + return ( + handleSelect(config.id)} + className={cn( + "mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50", + isSelected && "bg-accent/80", + isAuto && "border border-violet-200 dark:border-violet-800/50" + )} + > +
+
+ {isAuto ? ( + + ) : ( + + )} +
+
+
+ {config.name} + {isAuto && ( + + Recommended + + )} + {isSelected && } +
+ + {isAuto ? "Auto load balancing" : config.model_name} + +
+ {onEdit && ( + { + e.stopPropagation(); + setOpen(false); + onEdit(config, true); + }} + /> + )} +
+
+ ); + })} +
+ )} + + {/* User Image Gen Configs */} + {filteredUser.length > 0 && ( + <> + {filteredGlobal.length > 0 && } + +
+ + Your Image Models +
+ {filteredUser.map((config) => { + const isSelected = currentConfig?.id === config.id; + return ( + handleSelect(config.id)} + className={cn( + "mx-2 rounded-lg mb-1 cursor-pointer group transition-all hover:bg-accent/50", + isSelected && "bg-accent/80" + )} + > +
+
+ +
+
+
+ {config.name} + {isSelected && ( + + )} +
+ + {config.model_name} + +
+ {onEdit && ( + + )} +
+
+ ); + })} +
+ + )} + + {/* Add New */} + {onAddNew && ( +
+ +
+ )} +
+
+
+
+ ); +} diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index ec1143e04..148028df2 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -392,8 +392,8 @@ export function ModelSelector({ onEdit, onAddNew, className }: ModelSelectorProp )} - {/* Add New Config Button */} -
+ {/* Add New Config Button */} +

- {isNaN(Number(plan.price)) ? "" : isMonthly ? "billed monthly" : "billed annually"} + {plan.billingText ?? (isNaN(Number(plan.price)) ? "" : isMonthly ? "billed monthly" : "billed annually")}

    diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index fdad0796a..117be15ec 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -4,44 +4,47 @@ import { Pricing } from "@/components/pricing"; const demoPlans = [ { - name: "COMMUNITY", + name: "FREE", price: "0", yearlyPrice: "0", - period: "forever", + period: "", + billingText: "Includes 30 day PRO trial", features: [ - "Community support", - "Supports 100+ LLMs", - "Supports OpenAI spec and LiteLLM", - "Supports local vLLM or Ollama setups", - "6000+ embedding models", + "Open source on GitHub", + "Upload and chat with 300+ pages of content", + "Connects with 8 popular sources, like Drive and Notion.", + "Includes limited access to ChatGPT, Claude, and DeepSeek models", + "Supports 100+ more LLMs, including Gemini, Llama and many more.", "50+ File extensions supported.", - "Podcasts support with local TTS providers.", - "Connects with 15+ external sources, like Drive and Notion.", + "Generate podcasts in seconds.", "Cross-Browser Extension for dynamic webpages including authenticated content", - "Role-based access control (RBAC)", - "Collaboration and team features", + "Community support on Discord", ], - description: "Open source version with powerful features", - buttonText: "Dive In", - href: "/docs", + description: "Powerful features with some limitations", + buttonText: "Get Started", + href: "/", isPopular: false, }, { - name: "CLOUD", - price: "0", - yearlyPrice: "0", - period: "in beta", + name: "PRO", + price: "10", + yearlyPrice: "10", + period: "user / month", + billingText: "billed annually", features: [ - "Everything in Community", - "Email support", - "Get started in seconds", - "Instant access to new features", - "Easy access from anywhere", - "Remote team management and collaboration", + "Everything in Free", + "Upload and chat with 5,000+ pages of content", + "Connects with 15+ external sources, like Slack and Airtable.", + "Includes extended access to ChatGPT, Claude, and DeepSeek models", + "Collaboration and commenting features", + "Shared BYOK (Bring Your Own Key)", + "Team and role management", + "Planned: Centralized billing", + "Priority support", ], - description: "Instant access for individuals and teams", - buttonText: "Get Started", - href: "/", + description: "The AIknowledge base for individuals and teams", + buttonText: "Upgrade", + href: "/contact", isPopular: true, }, { @@ -49,18 +52,21 @@ const demoPlans = [ price: "Contact Us", yearlyPrice: "Contact Us", period: "", + billingText: "", features: [ - "Everything in Community", - "Priority support", + "Everything in Pro", + "Connect and chat with virtually unlimited pages of content", + "Limit models and/or providers", + "On-prem or VPC deployment", + "Planned: Audit logs and compliance", + "Planned: SSO, OIDC & SAML", + "Planned: Role-based access control (RBAC)", "White-glove setup and deployment", "Monthly managed updates and maintenance", - "On-prem or VPC deployment", - "Audit logs and compliance", - "SSO, OIDC & SAML", - "SLA guarantee", - "Uptime guarantee on VPC", + "SLA commitments", + "Dedicated support", ], - description: "Professional, customized setup for large organizations", + description: "Customized setup for large organizations", buttonText: "Contact Sales", href: "/contact", isPopular: false, diff --git a/surfsense_web/components/public-chat-snapshots/public-chat-snapshot-row.tsx b/surfsense_web/components/public-chat-snapshots/public-chat-snapshot-row.tsx index 696d32466..5f0048100 100644 --- a/surfsense_web/components/public-chat-snapshots/public-chat-snapshot-row.tsx +++ b/surfsense_web/components/public-chat-snapshots/public-chat-snapshot-row.tsx @@ -38,6 +38,13 @@ export function PublicChatSnapshotRow({ {snapshot.message_count}
+ (e.target as HTMLInputElement).select()} + />
+
+ + {/* Errors */} + + {errors.map((err) => ( + + + + {err?.message} + + + ))} + + + {/* Global info */} + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( + + + + + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length} global image model(s) + {" "} + available from your administrator. + + + )} + + {/* Active Preference Card */} + {!isLoading && allConfigs.length > 0 && ( + + + +
+
+ +
+
+ Active Image Model + + Select which model to use for image generation + +
+
+
+ + + {hasPrefChanges && ( +
+ + +
+ )} +
+
+
+ )} + + {/* Loading */} + {isLoading && ( + + + + + + )} + + {/* User Configs */} + {!isLoading && ( +
+
+

Your Image Models

+ +
+ + {(userConfigs?.length ?? 0) === 0 ? ( + + +
+ +
+

No Image Models Yet

+

+ Add your own image generation model (DALL-E 3, GPT Image 1, etc.) +

+ +
+
+ ) : ( + + + {userConfigs?.map((config) => ( + + + +
+
+
+
+
+
+ +
+
+
+

{config.name}

+ + {config.provider} + +
+ + {config.model_name} + + {config.description && ( +

{config.description}

+ )} +
+ + {new Date(config.created_at).toLocaleDateString()} +
+
+
+
+ + + + + + Edit + + + + + + + + Delete + + +
+
+
+
+ + + + ))} + + + )} +
+ )} + + {/* Create/Edit Dialog */} + { if (!open) { setIsDialogOpen(false); setEditingConfig(null); resetForm(); } }}> + + + + {editingConfig ? : } + {editingConfig ? "Edit Image Model" : "Add Image Model"} + + + {editingConfig ? "Update your image generation model" : "Configure a new image generation model (DALL-E 3, GPT Image 1, etc.)"} + + + +
+ {/* Name */} +
+ + setFormData((p) => ({ ...p, name: e.target.value }))} + /> +
+ + {/* Description */} +
+ + setFormData((p) => ({ ...p, description: e.target.value }))} + /> +
+ + + + {/* Provider */} +
+ + +
+ + {/* Model Name */} +
+ + {suggestedModels.length > 0 ? ( + + + + + + + setFormData((p) => ({ ...p, model_name: val }))} + /> + + + Type a custom model name + + + {suggestedModels.map((m) => ( + { + setFormData((p) => ({ ...p, model_name: m.value })); + setModelComboboxOpen(false); + }} + > + + {m.value} + {m.label} + + ))} + + + + + + ) : ( + setFormData((p) => ({ ...p, model_name: e.target.value }))} + /> + )} +
+ + {/* API Key */} +
+ + setFormData((p) => ({ ...p, api_key: e.target.value }))} + /> +
+ + {/* API Base (optional) */} +
+ + setFormData((p) => ({ ...p, api_base: e.target.value }))} + /> +
+ + {/* API Version (Azure) */} + {formData.provider === "AZURE_OPENAI" && ( +
+ + setFormData((p) => ({ ...p, api_version: e.target.value }))} + /> +
+ )} + + {/* Actions */} +
+ + +
+
+
+
+ + {/* Delete Confirmation */} + !open && setConfigToDelete(null)}> + + + + + Delete Image Model + + + Are you sure you want to delete {configToDelete?.name}? + + + + Cancel + + {isDeleting ? <>Deleting : <>Delete} + + + + +
+ ); +} diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index dac68a358..22e3d8e08 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -255,15 +255,15 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { )} - {/* Role Assignment Cards */} - {availableConfigs.length > 0 && ( -
- {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { - const IconComponent = role.icon; - const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments]; - const assignedConfig = availableConfigs.find( - (config) => config.id === currentAssignment - ); + {/* Role Assignment Cards */} + {availableConfigs.length > 0 && ( +
+ {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { + const IconComponent = role.icon; + const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments]; + const assignedConfig = availableConfigs.find( + (config) => config.id === currentAssignment + ); return ( -
- - handleRoleAssignment(`${key}_llm_id`, value)} + > + + + + + + Unassigned + - {/* Global Configurations */} - {globalConfigs.length > 0 && ( - <> -
- Global Configurations -
- {globalConfigs.map((config) => { - const isAutoMode = - "is_auto_mode" in config && config.is_auto_mode; - return ( - -
- {isAutoMode ? ( - - - AUTO - - ) : ( - - {config.provider} - - )} - {config.name} - {!isAutoMode && ( - - ({config.model_name}) - - )} - {isAutoMode ? ( - - Recommended - - ) : ( - - 🌐 Global - - )} -
-
- ); - })} - - )} + {/* Global Configurations */} + {globalConfigs.length > 0 && ( + <> +
+ Global Configurations +
+ {globalConfigs.map((config) => { + const isAutoMode = + "is_auto_mode" in config && config.is_auto_mode; + return ( + +
+ {isAutoMode ? ( + + + AUTO + + ) : ( + + {config.provider} + + )} + {config.name} + {!isAutoMode && ( + + ({config.model_name}) + + )} + {isAutoMode ? ( + + Recommended + + ) : ( + + 🌐 Global + + )} +
+
+ ); + })} + + )} - {/* Custom Configurations */} - {newLLMConfigs.length > 0 && ( - <> -
- Your Configurations -
- {newLLMConfigs - .filter( - (config) => config.id && config.id.toString().trim() !== "" - ) - .map((config) => ( - -
- - {config.provider} - - {config.name} - - ({config.model_name}) - -
-
- ))} - - )} -
- -
+ {/* Custom Configurations */} + {newLLMConfigs.length > 0 && ( + <> +
+ Your Configurations +
+ {newLLMConfigs + .filter( + (config) => config.id && config.id.toString().trim() !== "" + ) + .map((config) => ( + +
+ + {config.provider} + + {config.name} + + ({config.model_name}) + +
+
+ ))} + + )} + + +
{assignedConfig && (
{ - switch (position) { - case "top-left": - return { cx: "0", cy: "0" }; - case "top-right": - return { cx: "40", cy: "0" }; - case "bottom-left": - return { cx: "0", cy: "40" }; - case "bottom-right": - return { cx: "40", cy: "40" }; - case "top-center": - return { cx: "20", cy: "0" }; - case "bottom-center": - return { cx: "20", cy: "40" }; - case "bottom-up": - case "top-down": - case "left-right": - case "right-left": - return { cx: "20", cy: "20" }; - } + switch (position) { + case "top-left": + return { cx: "0", cy: "0" }; + case "top-right": + return { cx: "40", cy: "0" }; + case "bottom-left": + return { cx: "0", cy: "40" }; + case "bottom-right": + return { cx: "40", cy: "40" }; + case "top-center": + return { cx: "20", cy: "0" }; + case "bottom-center": + return { cx: "20", cy: "40" }; + case "bottom-up": + case "top-down": + case "left-right": + case "right-left": + return { cx: "20", cy: "20" }; + } }; const generateSVG = (variant: AnimationVariant, start: AnimationStart) => { - if (variant === "circle-blur") { - if (start === "center") { - return `data:image/svg+xml,`; - } - const positionCoords = getPositionCoords(start); - if (!positionCoords) { - throw new Error(`Invalid start position: ${start}`); - } - const { cx, cy } = positionCoords; - return `data:image/svg+xml,`; - } + if (variant === "circle-blur") { + if (start === "center") { + return `data:image/svg+xml,`; + } + const positionCoords = getPositionCoords(start); + if (!positionCoords) { + throw new Error(`Invalid start position: ${start}`); + } + const { cx, cy } = positionCoords; + return `data:image/svg+xml,`; + } - if (start === "center") return; + if (start === "center") return; - if (variant === "rectangle") return ""; + if (variant === "rectangle") return ""; - const positionCoords = getPositionCoords(start); - if (!positionCoords) { - throw new Error(`Invalid start position: ${start}`); - } - const { cx, cy } = positionCoords; + const positionCoords = getPositionCoords(start); + if (!positionCoords) { + throw new Error(`Invalid start position: ${start}`); + } + const { cx, cy } = positionCoords; - if (variant === "circle") { - return `data:image/svg+xml,`; - } + if (variant === "circle") { + return `data:image/svg+xml,`; + } - return ""; + return ""; }; const getTransformOrigin = (start: AnimationStart) => { - switch (start) { - case "top-left": - return "top left"; - case "top-right": - return "top right"; - case "bottom-left": - return "bottom left"; - case "bottom-right": - return "bottom right"; - case "top-center": - return "top center"; - case "bottom-center": - return "bottom center"; - case "bottom-up": - case "top-down": - case "left-right": - case "right-left": - return "center"; - } + switch (start) { + case "top-left": + return "top left"; + case "top-right": + return "top right"; + case "bottom-left": + return "bottom left"; + case "bottom-right": + return "bottom right"; + case "top-center": + return "top center"; + case "bottom-center": + return "bottom center"; + case "bottom-up": + case "top-down": + case "left-right": + case "right-left": + return "center"; + } }; export const createAnimation = ( - variant: AnimationVariant, - start: AnimationStart = "center", - blur = false, - url?: string, + variant: AnimationVariant, + start: AnimationStart = "center", + blur = false, + url?: string ): Animation => { - const svg = generateSVG(variant, start); - const transformOrigin = getTransformOrigin(start); + const svg = generateSVG(variant, start); + const transformOrigin = getTransformOrigin(start); - if (variant === "rectangle") { - const getClipPath = (direction: AnimationStart) => { - switch (direction) { - case "bottom-up": - return { - from: "polygon(0% 100%, 100% 100%, 100% 100%, 0% 100%)", - to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", - }; - case "top-down": - return { - from: "polygon(0% 0%, 100% 0%, 100% 0%, 0% 0%)", - to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", - }; - case "left-right": - return { - from: "polygon(0% 0%, 0% 0%, 0% 100%, 0% 100%)", - to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", - }; - case "right-left": - return { - from: "polygon(100% 0%, 100% 0%, 100% 100%, 100% 100%)", - to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", - }; - case "top-left": - return { - from: "polygon(0% 0%, 0% 0%, 0% 0%, 0% 0%)", - to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", - }; - case "top-right": - return { - from: "polygon(100% 0%, 100% 0%, 100% 0%, 100% 0%)", - to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", - }; - case "bottom-left": - return { - from: "polygon(0% 100%, 0% 100%, 0% 100%, 0% 100%)", - to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", - }; - case "bottom-right": - return { - from: "polygon(100% 100%, 100% 100%, 100% 100%, 100% 100%)", - to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", - }; - default: - return { - from: "polygon(0% 100%, 100% 100%, 100% 100%, 0% 100%)", - to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", - }; - } - }; + if (variant === "rectangle") { + const getClipPath = (direction: AnimationStart) => { + switch (direction) { + case "bottom-up": + return { + from: "polygon(0% 100%, 100% 100%, 100% 100%, 0% 100%)", + to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", + }; + case "top-down": + return { + from: "polygon(0% 0%, 100% 0%, 100% 0%, 0% 0%)", + to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", + }; + case "left-right": + return { + from: "polygon(0% 0%, 0% 0%, 0% 100%, 0% 100%)", + to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", + }; + case "right-left": + return { + from: "polygon(100% 0%, 100% 0%, 100% 100%, 100% 100%)", + to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", + }; + case "top-left": + return { + from: "polygon(0% 0%, 0% 0%, 0% 0%, 0% 0%)", + to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", + }; + case "top-right": + return { + from: "polygon(100% 0%, 100% 0%, 100% 0%, 100% 0%)", + to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", + }; + case "bottom-left": + return { + from: "polygon(0% 100%, 0% 100%, 0% 100%, 0% 100%)", + to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", + }; + case "bottom-right": + return { + from: "polygon(100% 100%, 100% 100%, 100% 100%, 100% 100%)", + to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", + }; + default: + return { + from: "polygon(0% 100%, 100% 100%, 100% 100%, 0% 100%)", + to: "polygon(0% 0%, 100% 0%, 100% 100%, 0% 100%)", + }; + } + }; - const clipPath = getClipPath(start); + const clipPath = getClipPath(start); - return { - name: `${variant}-${start}${blur ? "-blur" : ""}`, - css: ` + return { + name: `${variant}-${start}${blur ? "-blur" : ""}`, + css: ` ::view-transition-group(root) { animation-duration: 0.7s; animation-timing-function: var(--expo-out); @@ -218,12 +213,12 @@ export const createAnimation = ( } } `, - }; - } - if (variant === "circle" && start == "center") { - return { - name: `${variant}-${start}${blur ? "-blur" : ""}`, - css: ` + }; + } + if (variant === "circle" && start == "center") { + return { + name: `${variant}-${start}${blur ? "-blur" : ""}`, + css: ` ::view-transition-group(root) { animation-duration: 0.7s; animation-timing-function: var(--expo-out); @@ -268,12 +263,12 @@ export const createAnimation = ( } } `, - }; - } - if (variant === "gif") { - return { - name: `${variant}-${start}`, - css: ` + }; + } + if (variant === "gif") { + return { + name: `${variant}-${start}`, + css: ` ::view-transition-group(root) { animation-timing-function: var(--expo-in); } @@ -302,14 +297,14 @@ export const createAnimation = ( mask-size: 2000vmax; } }`, - }; - } + }; + } - if (variant === "circle-blur") { - if (start === "center") { - return { - name: `${variant}-${start}`, - css: ` + if (variant === "circle-blur") { + if (start === "center") { + return { + name: `${variant}-${start}`, + css: ` ::view-transition-group(root) { animation-timing-function: var(--expo-out); } @@ -334,12 +329,12 @@ export const createAnimation = ( } } `, - }; - } + }; + } - return { - name: `${variant}-${start}`, - css: ` + return { + name: `${variant}-${start}`, + css: ` ::view-transition-group(root) { animation-timing-function: var(--expo-out); } @@ -364,41 +359,41 @@ export const createAnimation = ( } } `, - }; - } + }; + } - if (variant === "polygon") { - const getPolygonClipPaths = (position: AnimationStart) => { - switch (position) { - case "top-left": - return { - darkFrom: "polygon(50% -71%, -50% 71%, -50% 71%, 50% -71%)", - darkTo: "polygon(50% -71%, -50% 71%, 50% 171%, 171% 50%)", - lightFrom: "polygon(171% 50%, 50% 171%, 50% 171%, 171% 50%)", - lightTo: "polygon(171% 50%, 50% 171%, -50% 71%, 50% -71%)", - }; - case "top-right": - return { - darkFrom: "polygon(150% -71%, 250% 71%, 250% 71%, 150% -71%)", - darkTo: "polygon(150% -71%, 250% 71%, 50% 171%, -71% 50%)", - lightFrom: "polygon(-71% 50%, 50% 171%, 50% 171%, -71% 50%)", - lightTo: "polygon(-71% 50%, 50% 171%, 250% 71%, 150% -71%)", - }; - default: - return { - darkFrom: "polygon(50% -71%, -50% 71%, -50% 71%, 50% -71%)", - darkTo: "polygon(50% -71%, -50% 71%, 50% 171%, 171% 50%)", - lightFrom: "polygon(171% 50%, 50% 171%, 50% 171%, 171% 50%)", - lightTo: "polygon(171% 50%, 50% 171%, -50% 71%, 50% -71%)", - }; - } - }; + if (variant === "polygon") { + const getPolygonClipPaths = (position: AnimationStart) => { + switch (position) { + case "top-left": + return { + darkFrom: "polygon(50% -71%, -50% 71%, -50% 71%, 50% -71%)", + darkTo: "polygon(50% -71%, -50% 71%, 50% 171%, 171% 50%)", + lightFrom: "polygon(171% 50%, 50% 171%, 50% 171%, 171% 50%)", + lightTo: "polygon(171% 50%, 50% 171%, -50% 71%, 50% -71%)", + }; + case "top-right": + return { + darkFrom: "polygon(150% -71%, 250% 71%, 250% 71%, 150% -71%)", + darkTo: "polygon(150% -71%, 250% 71%, 50% 171%, -71% 50%)", + lightFrom: "polygon(-71% 50%, 50% 171%, 50% 171%, -71% 50%)", + lightTo: "polygon(-71% 50%, 50% 171%, 250% 71%, 150% -71%)", + }; + default: + return { + darkFrom: "polygon(50% -71%, -50% 71%, -50% 71%, 50% -71%)", + darkTo: "polygon(50% -71%, -50% 71%, 50% 171%, 171% 50%)", + lightFrom: "polygon(171% 50%, 50% 171%, 50% 171%, 171% 50%)", + lightTo: "polygon(171% 50%, 50% 171%, -50% 71%, 50% -71%)", + }; + } + }; - const clipPaths = getPolygonClipPaths(start); + const clipPaths = getPolygonClipPaths(start); - return { - name: `${variant}-${start}${blur ? "-blur" : ""}`, - css: ` + return { + name: `${variant}-${start}${blur ? "-blur" : ""}`, + css: ` ::view-transition-group(root) { animation-duration: 0.7s; animation-timing-function: var(--expo-out); @@ -443,35 +438,35 @@ export const createAnimation = ( } } `, - }; - } + }; + } - // Handle circle variants with start positions using clip-path - if (variant === "circle" && start !== "center") { - const getClipPathPosition = (position: AnimationStart) => { - switch (position) { - case "top-left": - return "0% 0%"; - case "top-right": - return "100% 0%"; - case "bottom-left": - return "0% 100%"; - case "bottom-right": - return "100% 100%"; - case "top-center": - return "50% 0%"; - case "bottom-center": - return "50% 100%"; - default: - return "50% 50%"; - } - }; + // Handle circle variants with start positions using clip-path + if (variant === "circle" && start !== "center") { + const getClipPathPosition = (position: AnimationStart) => { + switch (position) { + case "top-left": + return "0% 0%"; + case "top-right": + return "100% 0%"; + case "bottom-left": + return "0% 100%"; + case "bottom-right": + return "100% 100%"; + case "top-center": + return "50% 0%"; + case "bottom-center": + return "50% 100%"; + default: + return "50% 50%"; + } + }; - const clipPosition = getClipPathPosition(start); + const clipPosition = getClipPathPosition(start); - return { - name: `${variant}-${start}${blur ? "-blur" : ""}`, - css: ` + return { + name: `${variant}-${start}${blur ? "-blur" : ""}`, + css: ` ::view-transition-group(root) { animation-duration: 1s; animation-timing-function: var(--expo-out); @@ -516,12 +511,12 @@ export const createAnimation = ( } } `, - }; - } + }; + } - return { - name: `${variant}-${start}${blur ? "-blur" : ""}`, - css: ` + return { + name: `${variant}-${start}${blur ? "-blur" : ""}`, + css: ` ::view-transition-group(root) { animation-timing-function: var(--expo-in); } @@ -549,237 +544,229 @@ export const createAnimation = ( } } `, - }; + }; }; // /////////////////////////////////////////////////////////////////////////// // Custom hook for theme toggle functionality export const useThemeToggle = ({ - variant = "circle", - start = "center", - blur = false, - gifUrl = "", + variant = "circle", + start = "center", + blur = false, + gifUrl = "", }: { - variant?: AnimationVariant; - start?: AnimationStart; - blur?: boolean; - gifUrl?: string; + variant?: AnimationVariant; + start?: AnimationStart; + blur?: boolean; + gifUrl?: string; } = {}) => { - const { theme, setTheme, resolvedTheme } = useTheme(); + const { theme, setTheme, resolvedTheme } = useTheme(); - const [isDark, setIsDark] = useState(false); + const [isDark, setIsDark] = useState(false); - // Sync isDark state with resolved theme after hydration - useEffect(() => { - setIsDark(resolvedTheme === "dark"); - }, [resolvedTheme]); + // Sync isDark state with resolved theme after hydration + useEffect(() => { + setIsDark(resolvedTheme === "dark"); + }, [resolvedTheme]); - const styleId = "theme-transition-styles"; + const styleId = "theme-transition-styles"; - const updateStyles = useCallback((css: string) => { - if (typeof window === "undefined") return; + const updateStyles = useCallback((css: string) => { + if (typeof window === "undefined") return; - let styleElement = document.getElementById(styleId) as HTMLStyleElement; + let styleElement = document.getElementById(styleId) as HTMLStyleElement; - if (!styleElement) { - styleElement = document.createElement("style"); - styleElement.id = styleId; - document.head.appendChild(styleElement); - } + if (!styleElement) { + styleElement = document.createElement("style"); + styleElement.id = styleId; + document.head.appendChild(styleElement); + } - styleElement.textContent = css; - }, []); + styleElement.textContent = css; + }, []); - const toggleTheme = useCallback(() => { - setIsDark(!isDark); + const toggleTheme = useCallback(() => { + setIsDark(!isDark); - const animation = createAnimation(variant, start, blur, gifUrl); + const animation = createAnimation(variant, start, blur, gifUrl); - updateStyles(animation.css); + updateStyles(animation.css); - if (typeof window === "undefined") return; + if (typeof window === "undefined") return; - const switchTheme = () => { - setTheme(theme === "light" ? "dark" : "light"); - }; + const switchTheme = () => { + setTheme(theme === "light" ? "dark" : "light"); + }; - if (!document.startViewTransition) { - switchTheme(); - return; - } + if (!document.startViewTransition) { + switchTheme(); + return; + } - document.startViewTransition(switchTheme); - }, [theme, setTheme, variant, start, blur, gifUrl, updateStyles, isDark]); + document.startViewTransition(switchTheme); + }, [theme, setTheme, variant, start, blur, gifUrl, updateStyles, isDark]); - const setCrazyLightTheme = useCallback(() => { - setIsDark(false); + const setCrazyLightTheme = useCallback(() => { + setIsDark(false); - const animation = createAnimation(variant, start, blur, gifUrl); + const animation = createAnimation(variant, start, blur, gifUrl); - updateStyles(animation.css); + updateStyles(animation.css); - if (typeof window === "undefined") return; + if (typeof window === "undefined") return; - const switchTheme = () => { - setTheme("light"); - }; + const switchTheme = () => { + setTheme("light"); + }; - if (!document.startViewTransition) { - switchTheme(); - return; - } + if (!document.startViewTransition) { + switchTheme(); + return; + } - document.startViewTransition(switchTheme); - }, [setTheme, variant, start, blur, gifUrl, updateStyles]); + document.startViewTransition(switchTheme); + }, [setTheme, variant, start, blur, gifUrl, updateStyles]); - const setCrazyDarkTheme = useCallback(() => { - setIsDark(true); + const setCrazyDarkTheme = useCallback(() => { + setIsDark(true); - const animation = createAnimation(variant, start, blur, gifUrl); + const animation = createAnimation(variant, start, blur, gifUrl); - updateStyles(animation.css); + updateStyles(animation.css); - if (typeof window === "undefined") return; + if (typeof window === "undefined") return; - const switchTheme = () => { - setTheme("dark"); - }; + const switchTheme = () => { + setTheme("dark"); + }; - if (!document.startViewTransition) { - switchTheme(); - return; - } + if (!document.startViewTransition) { + switchTheme(); + return; + } - document.startViewTransition(switchTheme); - }, [setTheme, variant, start, blur, gifUrl, updateStyles]); + document.startViewTransition(switchTheme); + }, [setTheme, variant, start, blur, gifUrl, updateStyles]); - const setCrazySystemTheme = useCallback(() => { - if (typeof window === "undefined") return; + const setCrazySystemTheme = useCallback(() => { + if (typeof window === "undefined") return; - const prefersDark = window.matchMedia( - "(prefers-color-scheme: dark)", - ).matches; - setIsDark(prefersDark); + const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches; + setIsDark(prefersDark); - const animation = createAnimation(variant, start, blur, gifUrl); + const animation = createAnimation(variant, start, blur, gifUrl); - updateStyles(animation.css); + updateStyles(animation.css); - const switchTheme = () => { - setTheme("system"); - }; + const switchTheme = () => { + setTheme("system"); + }; - if (!document.startViewTransition) { - switchTheme(); - return; - } + if (!document.startViewTransition) { + switchTheme(); + return; + } - document.startViewTransition(switchTheme); - }, [setTheme, variant, start, blur, gifUrl, updateStyles]); + document.startViewTransition(switchTheme); + }, [setTheme, variant, start, blur, gifUrl, updateStyles]); - return { - isDark, - setIsDark, - toggleTheme, - setCrazyLightTheme, - setCrazyDarkTheme, - setCrazySystemTheme, - }; + return { + isDark, + setIsDark, + toggleTheme, + setCrazyLightTheme, + setCrazyDarkTheme, + setCrazySystemTheme, + }; }; // /////////////////////////////////////////////////////////////////////////// // Theme Toggle Button Component (Sun/Moon Style) export const ThemeToggleButton = ({ - className = "", - variant = "circle", - start = "center", - blur = false, - gifUrl = "", + className = "", + variant = "circle", + start = "center", + blur = false, + gifUrl = "", }: { - className?: string; - variant?: AnimationVariant; - start?: AnimationStart; - blur?: boolean; - gifUrl?: string; + className?: string; + variant?: AnimationVariant; + start?: AnimationStart; + blur?: boolean; + gifUrl?: string; }) => { - const { isDark, toggleTheme } = useThemeToggle({ - variant, - start, - blur, - gifUrl, - }); - const clipId = useId(); - const clipPathId = `theme-toggle-clip-${clipId}`; + const { isDark, toggleTheme } = useThemeToggle({ + variant, + start, + blur, + gifUrl, + }); + const clipId = useId(); + const clipPathId = `theme-toggle-clip-${clipId}`; - return ( - - ); + return ( + + ); }; // /////////////////////////////////////////////////////////////////////////// // Backwards compatible export (alias for ThemeToggleButton with default settings) export function ThemeTogglerComponent() { - return ( - - ); + return ; } /** diff --git a/surfsense_web/components/tool-ui/display-image.tsx b/surfsense_web/components/tool-ui/display-image.tsx index 660e95bca..b5fccbc78 100644 --- a/surfsense_web/components/tool-ui/display-image.tsx +++ b/surfsense_web/components/tool-ui/display-image.tsx @@ -88,7 +88,7 @@ function ImageCancelledState({ src }: { src: string }) { function ParsedImage({ result }: { result: unknown }) { const image = parseSerializableImage(result); - return ; + return ; } /** diff --git a/surfsense_web/components/tool-ui/image/index.tsx b/surfsense_web/components/tool-ui/image/index.tsx index 42725d258..9ecfe4cfa 100644 --- a/surfsense_web/components/tool-ui/image/index.tsx +++ b/surfsense_web/components/tool-ui/image/index.tsx @@ -1,6 +1,6 @@ "use client"; -import { ExternalLinkIcon, ImageIcon } from "lucide-react"; +import { ExternalLinkIcon, ImageIcon, SparklesIcon } from "lucide-react"; import NextImage from "next/image"; import { Component, type ReactNode, useState } from "react"; import { z } from "zod"; @@ -25,7 +25,7 @@ const SerializableImageSchema = z.object({ id: z.string(), assetId: z.string(), src: z.string(), - alt: z.string().nullish(), // Made optional - will use fallback if missing + alt: z.string().nullish(), title: z.string().nullish(), description: z.string().nullish(), href: z.string().nullish(), @@ -49,7 +49,7 @@ export interface ImageProps { id: string; assetId: string; src: string; - alt?: string; // Optional with default fallback + alt?: string; title?: string; description?: string; href?: string; @@ -71,10 +71,8 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a if (!parsed.success) { console.warn("Invalid image data:", parsed.error.issues); - // Try to extract basic info and return a fallback object const obj = (result && typeof result === "object" ? result : {}) as Record; - // If we have at least id, assetId, and src, we can still render the image if ( typeof obj.id === "string" && typeof obj.assetId === "string" && @@ -89,7 +87,7 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a description: typeof obj.description === "string" ? obj.description : undefined, href: typeof obj.href === "string" ? obj.href : undefined, domain: typeof obj.domain === "string" ? obj.domain : undefined, - ratio: undefined, // Use default ratio + ratio: undefined, source: undefined, }; } @@ -97,7 +95,6 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a throw new Error(`Invalid image: ${parsed.error.issues.map((i) => i.message).join(", ")}`); } - // Provide fallback for alt if it's null/undefined return { ...parsed.data, alt: parsed.data.alt ?? "Image", @@ -105,7 +102,7 @@ export function parseSerializableImage(result: unknown): SerializableImage & { a } /** - * Get aspect ratio class based on ratio prop + * Get aspect ratio class based on ratio prop (used for fixed-ratio images only) */ function getAspectRatioClass(ratio?: AspectRatio): string { switch (ratio) { @@ -119,7 +116,6 @@ function getAspectRatioClass(ratio?: AspectRatio): string { return "aspect-[9/16]"; case "21:9": return "aspect-[21/9]"; - case "auto": default: return "aspect-[4/3]"; } @@ -150,7 +146,7 @@ export class ImageErrorBoundary extends Component< if (this.state.hasError) { return ( -
+

Failed to load image

@@ -167,10 +163,10 @@ export class ImageErrorBoundary extends Component< /** * Loading skeleton for Image */ -export function ImageSkeleton({ maxWidth = "420px" }: { maxWidth?: string }) { +export function ImageSkeleton({ maxWidth = "512px" }: { maxWidth?: string }) { return ( -
+
@@ -183,7 +179,7 @@ export function ImageSkeleton({ maxWidth = "420px" }: { maxWidth?: string }) { export function ImageLoading({ title = "Loading image..." }: { title?: string }) { return ( -
+

{title}

@@ -197,7 +193,9 @@ export function ImageLoading({ title = "Loading image..." }: { title?: string }) * Image Component * * Display images with metadata and attribution. - * Features hover overlay with title and source attribution. + * - For "auto" ratio: renders the image at natural dimensions (no cropping) + * - For fixed ratios: uses a fixed aspect container with object-cover + * - Features hover overlay with title, description, and source attribution. */ export function Image({ id, @@ -207,16 +205,18 @@ export function Image({ description, href, domain, - ratio = "4:3", + ratio = "auto", fit = "cover", source, - maxWidth = "420px", + maxWidth = "512px", className, }: ImageProps) { const [isHovered, setIsHovered] = useState(false); const [imageError, setImageError] = useState(false); - const aspectRatioClass = getAspectRatioClass(ratio); + const [imageLoaded, setImageLoaded] = useState(false); const displayDomain = domain || source?.label; + const isGenerated = domain === "ai-generated"; + const isAutoRatio = !ratio || ratio === "auto"; const handleClick = () => { const targetUrl = href || source?.url || src; @@ -228,7 +228,7 @@ export function Image({ if (imageError) { return ( -
+

Image not available

@@ -243,6 +243,7 @@ export function Image({ id={id} className={cn( "group w-full overflow-hidden cursor-pointer transition-shadow duration-200 hover:shadow-lg", + isGenerated && "ring-1 ring-primary/10", className )} style={{ maxWidth }} @@ -258,71 +259,98 @@ export function Image({ role="button" tabIndex={0} > -
- {/* Image */} - setImageError(true)} - /> +
+ {isAutoRatio ? ( + /* Auto ratio: image renders at natural dimensions, no cropping */ + <> + {!imageLoaded && ( +
+ +
+ )} + {/* eslint-disable-next-line @next/next/no-img-element */} + {alt} setImageLoaded(true)} + onError={() => setImageError(true)} + /> + + ) : ( + /* Fixed ratio: constrained aspect container with fill */ +
+ setImageError(true)} + /> +
+ )} - {/* Hover overlay - appears on hover */} + {/* Hover overlay */}
- {/* Content at bottom */} -
- {/* Title */} +
{title && ( -

+

{title}

)} - - {/* Description */} {description && ( -

{description}

+

{description}

)} - - {/* Source attribution */} {displayDomain && (
- {source?.iconUrl ? ( + {isGenerated ? ( + + ) : source?.iconUrl ? ( ) : ( - + )} - {displayDomain} + {displayDomain}
)}
- {/* Always visible domain badge (bottom right, shown when NOT hovered) */} + {/* Badge when not hovered */} {displayDomain && !isHovered && (
+ {isGenerated && } {displayDomain}
diff --git a/surfsense_web/contracts/enums/image-gen-providers.ts b/surfsense_web/contracts/enums/image-gen-providers.ts new file mode 100644 index 000000000..8410aeb4b --- /dev/null +++ b/surfsense_web/contracts/enums/image-gen-providers.ts @@ -0,0 +1,105 @@ +export interface ImageGenProvider { + value: string; + label: string; + example: string; + description: string; + apiBase?: string; +} + +/** + * Image generation providers supported by LiteLLM. + * See: https://docs.litellm.ai/docs/image_generation#supported-providers + */ +export const IMAGE_GEN_PROVIDERS: ImageGenProvider[] = [ + { + value: "OPENAI", + label: "OpenAI", + example: "dall-e-3, gpt-image-1, dall-e-2", + description: "DALL-E and GPT Image models", + }, + { + value: "AZURE_OPENAI", + label: "Azure OpenAI", + example: "azure/dall-e-3, azure/gpt-image-1", + description: "OpenAI image models on Azure", + }, + { + value: "GOOGLE", + label: "Google AI Studio", + example: "gemini/imagen-3.0-generate-002", + description: "Google AI Studio image generation", + }, + { + value: "VERTEX_AI", + label: "Google Vertex AI", + example: "vertex_ai/imagegeneration@006", + description: "Vertex AI image generation models", + }, + { + value: "BEDROCK", + label: "AWS Bedrock", + example: "bedrock/stability.stable-diffusion-xl-v0", + description: "Stable Diffusion on AWS Bedrock", + }, + { + value: "RECRAFT", + label: "Recraft", + example: "recraft/recraftv3", + description: "AI-powered design and image generation", + }, + { + value: "OPENROUTER", + label: "OpenRouter", + example: "openrouter/google/gemini-2.5-flash-image", + description: "Image generation via OpenRouter", + }, + { + value: "XINFERENCE", + label: "Xinference", + example: "xinference/stable-diffusion-xl", + description: "Self-hosted Stable Diffusion models", + }, + { + value: "NSCALE", + label: "Nscale", + example: "nscale/flux.1-schnell", + description: "Nscale image generation", + }, +]; + +/** + * Image generation models organized by provider. + */ +export interface ImageGenModel { + value: string; + label: string; + provider: string; +} + +export const IMAGE_GEN_MODELS: ImageGenModel[] = [ + // OpenAI + { value: "gpt-image-1", label: "GPT Image 1", provider: "OPENAI" }, + { value: "dall-e-3", label: "DALL-E 3", provider: "OPENAI" }, + { value: "dall-e-2", label: "DALL-E 2", provider: "OPENAI" }, + // Azure OpenAI + { value: "azure/dall-e-3", label: "DALL-E 3 (Azure)", provider: "AZURE_OPENAI" }, + { value: "azure/gpt-image-1", label: "GPT Image 1 (Azure)", provider: "AZURE_OPENAI" }, + // Recraft + { value: "recraft/recraftv3", label: "Recraft V3", provider: "RECRAFT" }, + // Bedrock + { + value: "bedrock/stability.stable-diffusion-xl-v0", + label: "Stable Diffusion XL", + provider: "BEDROCK", + }, + // Vertex AI + { + value: "vertex_ai/imagegeneration@006", + label: "Imagen 3", + provider: "VERTEX_AI", + }, +]; + +export function getImageGenModelsByProvider(provider: string): ImageGenModel[] { + return IMAGE_GEN_MODELS.filter((m) => m.provider === provider); +} diff --git a/surfsense_web/contracts/enums/llm-models.ts b/surfsense_web/contracts/enums/llm-models.ts index c62b2a9d6..5ff15c3df 100644 --- a/surfsense_web/contracts/enums/llm-models.ts +++ b/surfsense_web/contracts/enums/llm-models.ts @@ -178,6 +178,18 @@ export const LLM_MODELS: LLMModel[] = [ }, // Google (Gemini) + { + value: "gemini-3-flash-preview", + label: "Gemini 3 Flash", + provider: "GOOGLE", + contextWindow: "1M", + }, + { + value: "gemini-3-pro-preview", + label: "Gemini 3 Pro", + provider: "GOOGLE", + contextWindow: "1M", + }, { value: "gemini-2.5-flash", label: "Gemini 2.5 Flash", diff --git a/surfsense_web/contracts/types/inbox.types.ts b/surfsense_web/contracts/types/inbox.types.ts index 8e4b9ae86..ebf1889a1 100644 --- a/surfsense_web/contracts/types/inbox.types.ts +++ b/surfsense_web/contracts/types/inbox.types.ts @@ -10,6 +10,7 @@ export const inboxItemTypeEnum = z.enum([ "connector_deletion", "document_processing", "new_mention", + "comment_reply", "page_limit_exceeded", ]); @@ -101,6 +102,19 @@ export const newMentionMetadata = z.object({ content_preview: z.string(), }); +export const commentReplyMetadata = z.object({ + reply_id: z.number(), + parent_comment_id: z.number(), + message_id: z.number(), + thread_id: z.number(), + thread_title: z.string(), + author_id: z.string(), + author_name: z.string(), + author_avatar_url: z.string().nullable().optional(), + author_email: z.string().optional(), + content_preview: z.string(), +}); + /** * Page limit exceeded metadata schema */ @@ -125,6 +139,7 @@ export const inboxItemMetadata = z.union([ connectorDeletionMetadata, documentProcessingMetadata, newMentionMetadata, + commentReplyMetadata, pageLimitExceededMetadata, baseInboxItemMetadata, ]); @@ -168,6 +183,11 @@ export const newMentionInboxItem = inboxItem.extend({ metadata: newMentionMetadata, }); +export const commentReplyInboxItem = inboxItem.extend({ + type: z.literal("comment_reply"), + metadata: commentReplyMetadata, +}); + export const pageLimitExceededInboxItem = inboxItem.extend({ type: z.literal("page_limit_exceeded"), metadata: pageLimitExceededMetadata, @@ -278,6 +298,10 @@ export function isNewMentionMetadata(metadata: unknown): metadata is NewMentionM return newMentionMetadata.safeParse(metadata).success; } +export function isCommentReplyMetadata(metadata: unknown): metadata is CommentReplyMetadata { + return commentReplyMetadata.safeParse(metadata).success; +} + /** * Type guard for PageLimitExceededMetadata */ @@ -298,6 +322,7 @@ export function parseInboxItemMetadata( | ConnectorDeletionMetadata | DocumentProcessingMetadata | NewMentionMetadata + | CommentReplyMetadata | PageLimitExceededMetadata | null { switch (type) { @@ -317,6 +342,10 @@ export function parseInboxItemMetadata( const result = newMentionMetadata.safeParse(metadata); return result.success ? result.data : null; } + case "comment_reply": { + const result = commentReplyMetadata.safeParse(metadata); + return result.success ? result.data : null; + } case "page_limit_exceeded": { const result = pageLimitExceededMetadata.safeParse(metadata); return result.success ? result.data : null; @@ -338,6 +367,7 @@ export type ConnectorIndexingMetadata = z.infer; export type DocumentProcessingMetadata = z.infer; export type NewMentionMetadata = z.infer; +export type CommentReplyMetadata = z.infer; export type PageLimitExceededMetadata = z.infer; export type InboxItemMetadata = z.infer; export type InboxItem = z.infer; @@ -345,6 +375,7 @@ export type ConnectorIndexingInboxItem = z.infer; export type DocumentProcessingInboxItem = z.infer; export type NewMentionInboxItem = z.infer; +export type CommentReplyInboxItem = z.infer; export type PageLimitExceededInboxItem = z.infer; // API Request/Response types diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts index f397d4f08..3f0d39e5a 100644 --- a/surfsense_web/contracts/types/new-llm-config.types.ts +++ b/surfsense_web/contracts/types/new-llm-config.types.ts @@ -161,19 +161,105 @@ export const globalNewLLMConfig = z.object({ export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig); +// ============================================================================= +// Image Generation Config (separate table from NewLLMConfig) +// ============================================================================= + +/** + * ImageGenProvider enum - only providers that support image generation + * See: https://docs.litellm.ai/docs/image_generation#supported-providers + */ +export const imageGenProviderEnum = z.enum([ + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", +]); + +export type ImageGenProvider = z.infer; + +/** + * ImageGenerationConfig - user-created image gen model configs + * Separate from NewLLMConfig: no system_instructions, no citations_enabled. + */ +export const imageGenerationConfig = z.object({ + id: z.number(), + name: z.string().max(100), + description: z.string().max(500).nullable().optional(), + provider: imageGenProviderEnum, + custom_provider: z.string().max(100).nullable().optional(), + model_name: z.string().max(100), + api_key: z.string(), + api_base: z.string().max(500).nullable().optional(), + api_version: z.string().max(50).nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + created_at: z.string(), + search_space_id: z.number(), +}); + +export const createImageGenConfigRequest = imageGenerationConfig.omit({ + id: true, + created_at: true, +}); + +export const createImageGenConfigResponse = imageGenerationConfig; + +export const getImageGenConfigsResponse = z.array(imageGenerationConfig); + +export const updateImageGenConfigRequest = z.object({ + id: z.number(), + data: imageGenerationConfig + .omit({ id: true, created_at: true, search_space_id: true }) + .partial(), +}); + +export const updateImageGenConfigResponse = imageGenerationConfig; + +export const deleteImageGenConfigResponse = z.object({ + message: z.string(), + id: z.number(), +}); + +/** + * Global Image Generation Config - from YAML, has negative IDs + * ID 0 is reserved for "Auto" mode (LiteLLM Router load balancing) + */ +export const globalImageGenConfig = z.object({ + id: z.number(), + name: z.string(), + description: z.string().nullable().optional(), + provider: z.string(), + custom_provider: z.string().nullable().optional(), + model_name: z.string(), + api_base: z.string().nullable().optional(), + api_version: z.string().nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + is_global: z.literal(true), + is_auto_mode: z.boolean().optional().default(false), +}); + +export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig); + // ============================================================================= // LLM Preferences (Role Assignments) // ============================================================================= /** * LLM Preferences schemas - for role assignments - * The agent_llm and document_summary_llm fields contain the full NewLLMConfig objects + * image_generation uses image_generation_config_id (not llm_id) */ export const llmPreferences = z.object({ agent_llm_id: z.union([z.number(), z.null()]).optional(), document_summary_llm_id: z.union([z.number(), z.null()]).optional(), + image_generation_config_id: z.union([z.number(), z.null()]).optional(), agent_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(), document_summary_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(), + image_generation_config: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(), }); /** @@ -193,6 +279,7 @@ export const updateLLMPreferencesRequest = z.object({ data: llmPreferences.pick({ agent_llm_id: true, document_summary_llm_id: true, + image_generation_config_id: true, }), }); @@ -219,6 +306,15 @@ export type GetDefaultSystemInstructionsResponse = z.infer< >; export type GlobalNewLLMConfig = z.infer; export type GetGlobalNewLLMConfigsResponse = z.infer; +export type ImageGenerationConfig = z.infer; +export type CreateImageGenConfigRequest = z.infer; +export type CreateImageGenConfigResponse = z.infer; +export type GetImageGenConfigsResponse = z.infer; +export type UpdateImageGenConfigRequest = z.infer; +export type UpdateImageGenConfigResponse = z.infer; +export type DeleteImageGenConfigResponse = z.infer; +export type GlobalImageGenConfig = z.infer; +export type GetGlobalImageGenConfigsResponse = z.infer; export type LLMPreferences = z.infer; export type GetLLMPreferencesRequest = z.infer; export type GetLLMPreferencesResponse = z.infer; diff --git a/surfsense_web/hooks/use-api-key.ts b/surfsense_web/hooks/use-api-key.ts index a5f24d4c6..0c595b420 100644 --- a/surfsense_web/hooks/use-api-key.ts +++ b/surfsense_web/hooks/use-api-key.ts @@ -1,6 +1,7 @@ import { useCallback, useEffect, useState } from "react"; import { toast } from "sonner"; import { getBearerToken } from "@/lib/auth-utils"; +import { copyToClipboard as copyToClipboardUtil } from "@/lib/utils"; interface UseApiKeyReturn { apiKey: string | null; @@ -33,60 +34,17 @@ export function useApiKey(): UseApiKeyReturn { return () => clearTimeout(timer); }, []); - const fallbackCopyTextToClipboard = (text: string) => { - const textArea = document.createElement("textarea"); - textArea.value = text; - - // Avoid scrolling to bottom - textArea.style.top = "0"; - textArea.style.left = "0"; - textArea.style.position = "fixed"; - textArea.style.opacity = "0"; - - document.body.appendChild(textArea); - textArea.focus(); - textArea.select(); - - try { - const successful = document.execCommand("copy"); - document.body.removeChild(textArea); - - if (successful) { - setCopied(true); - toast.success("API key copied to clipboard"); - - setTimeout(() => { - setCopied(false); - }, 2000); - } else { - toast.error("Failed to copy API key"); - } - } catch (err) { - console.error("Fallback: Oops, unable to copy", err); - document.body.removeChild(textArea); - toast.error("Failed to copy API key"); - } - }; - const copyToClipboard = useCallback(async () => { if (!apiKey) return; - try { - if (navigator.clipboard && window.isSecureContext) { - // Use Clipboard API if available and in secure context - await navigator.clipboard.writeText(apiKey); - setCopied(true); - toast.success("API key copied to clipboard"); - - setTimeout(() => { - setCopied(false); - }, 2000); - } else { - // Fallback for non-secure contexts or browsers without clipboard API - fallbackCopyTextToClipboard(apiKey); - } - } catch (err) { - console.error("Failed to copy:", err); + const success = await copyToClipboardUtil(apiKey); + if (success) { + setCopied(true); + toast.success("API key copied to clipboard"); + setTimeout(() => { + setCopied(false); + }, 2000); + } else { toast.error("Failed to copy API key"); } }, [apiKey]); diff --git a/surfsense_web/lib/apis/base-api.service.ts b/surfsense_web/lib/apis/base-api.service.ts index a87d4deaf..933e54656 100644 --- a/surfsense_web/lib/apis/base-api.service.ts +++ b/surfsense_web/lib/apis/base-api.service.ts @@ -1,5 +1,5 @@ import type { ZodType } from "zod"; -import { getBearerToken, handleUnauthorized } from "../auth-utils"; +import { getBearerToken, handleUnauthorized, refreshAccessToken } from "../auth-utils"; import { AppError, AuthenticationError, AuthorizationError, NotFoundError } from "../error"; enum ResponseType { @@ -17,6 +17,7 @@ export type RequestOptions = { signal?: AbortSignal; body?: any; responseType?: ResponseType; + _isRetry?: boolean; // Internal flag to prevent infinite retry loops // Add more options as needed }; @@ -135,8 +136,23 @@ class BaseApiService { throw new AppError("Failed to parse response", response.status, response.statusText); } - // Handle 401 first before other error handling - ensures token is cleared and user redirected + // Handle 401 - try to refresh token first (only once) if (response.status === 401) { + if (!options?._isRetry) { + const newToken = await refreshAccessToken(); + if (newToken) { + // Retry the request with the new token + return this.request(url, responseSchema, { + ...mergedOptions, + headers: { + ...mergedOptions.headers, + Authorization: `Bearer ${newToken}`, + }, + _isRetry: true, + } as RequestOptions & { responseType?: R }); + } + } + // Refresh failed or retry failed, redirect to login handleUnauthorized(); throw new AuthenticationError( typeof data === "object" && "detail" in data diff --git a/surfsense_web/lib/apis/image-gen-config-api.service.ts b/surfsense_web/lib/apis/image-gen-config-api.service.ts new file mode 100644 index 000000000..84aeed3d8 --- /dev/null +++ b/surfsense_web/lib/apis/image-gen-config-api.service.ts @@ -0,0 +1,83 @@ +import { + type CreateImageGenConfigRequest, + createImageGenConfigRequest, + createImageGenConfigResponse, + type UpdateImageGenConfigRequest, + updateImageGenConfigRequest, + updateImageGenConfigResponse, + deleteImageGenConfigResponse, + getImageGenConfigsResponse, + getGlobalImageGenConfigsResponse, +} from "@/contracts/types/new-llm-config.types"; +import { ValidationError } from "../error"; +import { baseApiService } from "./base-api.service"; + +class ImageGenConfigApiService { + /** + * Get all global image generation configs (from YAML, negative IDs) + */ + getGlobalConfigs = async () => { + return baseApiService.get( + `/api/v1/global-image-generation-configs`, + getGlobalImageGenConfigsResponse + ); + }; + + /** + * Create a new image generation config for a search space + */ + createConfig = async (request: CreateImageGenConfigRequest) => { + const parsed = createImageGenConfigRequest.safeParse(request); + if (!parsed.success) { + const msg = parsed.error.issues.map((i) => i.message).join(", "); + throw new ValidationError(`Invalid request: ${msg}`); + } + return baseApiService.post( + `/api/v1/image-generation-configs`, + createImageGenConfigResponse, + { body: parsed.data } + ); + }; + + /** + * Get image generation configs for a search space + */ + getConfigs = async (searchSpaceId: number) => { + const params = new URLSearchParams({ + search_space_id: String(searchSpaceId), + }).toString(); + return baseApiService.get( + `/api/v1/image-generation-configs?${params}`, + getImageGenConfigsResponse + ); + }; + + /** + * Update an existing image generation config + */ + updateConfig = async (request: UpdateImageGenConfigRequest) => { + const parsed = updateImageGenConfigRequest.safeParse(request); + if (!parsed.success) { + const msg = parsed.error.issues.map((i) => i.message).join(", "); + throw new ValidationError(`Invalid request: ${msg}`); + } + const { id, data } = parsed.data; + return baseApiService.put( + `/api/v1/image-generation-configs/${id}`, + updateImageGenConfigResponse, + { body: data } + ); + }; + + /** + * Delete an image generation config + */ + deleteConfig = async (id: number) => { + return baseApiService.delete( + `/api/v1/image-generation-configs/${id}`, + deleteImageGenConfigResponse + ); + }; +} + +export const imageGenConfigApiService = new ImageGenConfigApiService(); diff --git a/surfsense_web/lib/auth-utils.ts b/surfsense_web/lib/auth-utils.ts index 604843292..8c067a4b7 100644 --- a/surfsense_web/lib/auth-utils.ts +++ b/surfsense_web/lib/auth-utils.ts @@ -4,6 +4,11 @@ const REDIRECT_PATH_KEY = "surfsense_redirect_path"; const BEARER_TOKEN_KEY = "surfsense_bearer_token"; +const REFRESH_TOKEN_KEY = "surfsense_refresh_token"; + +// Flag to prevent multiple simultaneous refresh attempts +let isRefreshing = false; +let refreshPromise: Promise | null = null; /** * Saves the current path and redirects to login page @@ -21,8 +26,9 @@ export function handleUnauthorized(): void { localStorage.setItem(REDIRECT_PATH_KEY, currentPath); } - // Clear the token + // Clear both tokens localStorage.removeItem(BEARER_TOKEN_KEY); + localStorage.removeItem(REFRESH_TOKEN_KEY); // Redirect to home page (which has login options) window.location.href = "/login"; @@ -66,6 +72,71 @@ export function clearBearerToken(): void { localStorage.removeItem(BEARER_TOKEN_KEY); } +/** + * Gets the refresh token from localStorage + */ +export function getRefreshToken(): string | null { + if (typeof window === "undefined") return null; + return localStorage.getItem(REFRESH_TOKEN_KEY); +} + +/** + * Sets the refresh token in localStorage + */ +export function setRefreshToken(token: string): void { + if (typeof window === "undefined") return; + localStorage.setItem(REFRESH_TOKEN_KEY, token); +} + +/** + * Clears the refresh token from localStorage + */ +export function clearRefreshToken(): void { + if (typeof window === "undefined") return; + localStorage.removeItem(REFRESH_TOKEN_KEY); +} + +/** + * Clears all auth tokens from localStorage + */ +export function clearAllTokens(): void { + clearBearerToken(); + clearRefreshToken(); +} + +/** + * Logout the current user by revoking the refresh token and clearing localStorage. + * Returns true if logout was successful (or tokens were cleared), false otherwise. + */ +export async function logout(): Promise { + const refreshToken = getRefreshToken(); + + // Call backend to revoke the refresh token + if (refreshToken) { + try { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + const response = await fetch(`${backendUrl}/auth/jwt/revoke`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ refresh_token: refreshToken }), + }); + + if (!response.ok) { + console.warn("Failed to revoke refresh token:", response.status, await response.text()); + } + } catch (error) { + console.warn("Failed to revoke refresh token on server:", error); + // Continue to clear local tokens even if server call fails + } + } + + // Clear all tokens from localStorage + clearAllTokens(); + return true; +} + /** * Checks if the user is authenticated (has a token) */ @@ -106,14 +177,67 @@ export function getAuthHeaders(additionalHeaders?: Record): Reco } /** - * Authenticated fetch wrapper that handles 401 responses uniformly - * Automatically redirects to login on 401 and saves the current path + * Attempts to refresh the access token using the stored refresh token. + * Returns the new access token if successful, null otherwise. + * Exported for use by API services. + */ +export async function refreshAccessToken(): Promise { + // If already refreshing, wait for that request to complete + if (isRefreshing && refreshPromise) { + return refreshPromise; + } + + const currentRefreshToken = getRefreshToken(); + if (!currentRefreshToken) { + return null; + } + + isRefreshing = true; + refreshPromise = (async () => { + try { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + const response = await fetch(`${backendUrl}/auth/jwt/refresh`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ refresh_token: currentRefreshToken }), + }); + + if (!response.ok) { + // Refresh failed, clear tokens + clearAllTokens(); + return null; + } + + const data = await response.json(); + if (data.access_token && data.refresh_token) { + setBearerToken(data.access_token); + setRefreshToken(data.refresh_token); + return data.access_token; + } + return null; + } catch { + return null; + } finally { + isRefreshing = false; + refreshPromise = null; + } + })(); + + return refreshPromise; +} + +/** + * Authenticated fetch wrapper that handles 401 responses uniformly. + * On 401, attempts to refresh the token and retry the request. + * If refresh fails, redirects to login and saves the current path. */ export async function authenticatedFetch( url: string, - options?: RequestInit & { skipAuthRedirect?: boolean } + options?: RequestInit & { skipAuthRedirect?: boolean; skipRefresh?: boolean } ): Promise { - const { skipAuthRedirect = false, ...fetchOptions } = options || {}; + const { skipAuthRedirect = false, skipRefresh = false, ...fetchOptions } = options || {}; const headers = getAuthHeaders(fetchOptions.headers as Record); @@ -124,6 +248,23 @@ export async function authenticatedFetch( // Handle 401 Unauthorized if (response.status === 401 && !skipAuthRedirect) { + // Try to refresh the token (unless skipRefresh is set to prevent infinite loops) + if (!skipRefresh) { + const newToken = await refreshAccessToken(); + if (newToken) { + // Retry the original request with the new token + const retryHeaders = { + ...(fetchOptions.headers as Record), + Authorization: `Bearer ${newToken}`, + }; + return fetch(url, { + ...fetchOptions, + headers: retryHeaders, + }); + } + } + + // Refresh failed or was skipped, redirect to login handleUnauthorized(); throw new Error("Unauthorized: Redirecting to login page"); } diff --git a/surfsense_web/lib/env-config.ts b/surfsense_web/lib/env-config.ts index 2f9e92357..e36aff10a 100644 --- a/surfsense_web/lib/env-config.ts +++ b/surfsense_web/lib/env-config.ts @@ -9,6 +9,8 @@ * as it may prevent the sed replacement from working correctly. */ +import packageJson from "../package.json"; + // Auth type: "LOCAL" for email/password, "GOOGLE" for OAuth // Placeholder: __NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__ export const AUTH_TYPE = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE || "GOOGLE"; @@ -28,6 +30,10 @@ export const ETL_SERVICE = process.env.NEXT_PUBLIC_ETL_SERVICE || "DOCLING"; // Placeholder: __NEXT_PUBLIC_DEPLOYMENT_MODE__ export const DEPLOYMENT_MODE = process.env.NEXT_PUBLIC_DEPLOYMENT_MODE || "self-hosted"; +// App version - defaults to package.json version +// Can be overridden at build time with NEXT_PUBLIC_APP_VERSION for full git tag version +export const APP_VERSION = process.env.NEXT_PUBLIC_APP_VERSION || packageJson.version; + // Helper to check if local auth is enabled export const isLocalAuth = () => AUTH_TYPE === "LOCAL"; diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 4d220a62a..c6981b28a 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -34,6 +34,11 @@ export const cacheKeys = { defaultInstructions: () => ["new-llm-configs", "default-instructions"] as const, global: () => ["new-llm-configs", "global"] as const, }, + imageGenConfigs: { + all: (searchSpaceId: number) => ["image-gen-configs", searchSpaceId] as const, + byId: (configId: number) => ["image-gen-configs", "detail", configId] as const, + global: () => ["image-gen-configs", "global"] as const, + }, auth: { user: ["auth", "user"] as const, }, diff --git a/surfsense_web/lib/utils.ts b/surfsense_web/lib/utils.ts index 212ff1259..e7bf8bdbe 100644 --- a/surfsense_web/lib/utils.ts +++ b/surfsense_web/lib/utils.ts @@ -12,3 +12,44 @@ export const formatDate = (date: Date): string => { day: "numeric", }); }; + +/** + * Copy text to clipboard with fallback for older browsers and non-secure contexts. + * Returns true if successful, false otherwise. + */ +export async function copyToClipboard(text: string): Promise { + // Use modern Clipboard API if available and in secure context + if (navigator.clipboard && window.isSecureContext) { + try { + await navigator.clipboard.writeText(text); + return true; + } catch (err) { + console.error("Clipboard API failed:", err); + return false; + } + } + + // Fallback for non-secure contexts or browsers without Clipboard API + const textArea = document.createElement("textarea"); + textArea.value = text; + + // Avoid scrolling to bottom + textArea.style.top = "0"; + textArea.style.left = "0"; + textArea.style.position = "fixed"; + textArea.style.opacity = "0"; + + document.body.appendChild(textArea); + textArea.focus(); + textArea.select(); + + try { + const successful = document.execCommand("copy"); + document.body.removeChild(textArea); + return successful; + } catch (err) { + console.error("Fallback copy failed:", err); + document.body.removeChild(textArea); + return false; + } +} diff --git a/surfsense_web/messages/en.json b/surfsense_web/messages/en.json index 75b186420..5a18f80c3 100644 --- a/surfsense_web/messages/en.json +++ b/surfsense_web/messages/en.json @@ -676,6 +676,13 @@ "unarchive": "Restore", "chat_archived": "Chat archived", "chat_unarchived": "Chat restored", + "chat_renamed": "Chat renamed", + "error_renaming_chat": "Failed to rename chat", + "rename": "Rename", + "rename_chat": "Rename Chat", + "rename_chat_description": "Enter a new name for this conversation.", + "chat_title_placeholder": "Chat title", + "renaming": "Renaming...", "no_archived_chats": "No archived chats", "error_archiving_chat": "Failed to archive chat", "new_chat": "New chat", @@ -693,15 +700,19 @@ "dark": "Dark", "system": "System", "logout": "Logout", + "loggingOut": "Logging out...", "inbox": "Inbox", "search_inbox": "Search inbox", "mark_all_read": "Mark all as read", "mark_as_read": "Mark as read", "mentions": "Mentions", + "comments": "Comments", "status": "Status", "no_results_found": "No results found", "no_mentions": "No mentions", "no_mentions_hint": "You'll see mentions from others here", + "no_comments": "No comments", + "no_comments_hint": "You'll see mentions and replies here", "no_status_updates": "No status updates", "no_status_updates_hint": "Document and connector updates will appear here", "filter": "Filter", @@ -729,6 +740,8 @@ "nav_agent_configs_desc": "LLM models with prompts & citations", "nav_role_assignments": "Role Assignments", "nav_role_assignments_desc": "Assign configs to agent roles", + "nav_image_models": "Image Models", + "nav_image_models_desc": "Configure image generation models", "nav_system_instructions": "System Instructions", "nav_system_instructions_desc": "SearchSpace-wide AI instructions", "nav_public_links": "Public Chat Links", diff --git a/surfsense_web/messages/zh.json b/surfsense_web/messages/zh.json index 81121ef3e..1046b7296 100644 --- a/surfsense_web/messages/zh.json +++ b/surfsense_web/messages/zh.json @@ -661,6 +661,13 @@ "unarchive": "恢复", "chat_archived": "对话已归档", "chat_unarchived": "对话已恢复", + "chat_renamed": "对话已重命名", + "error_renaming_chat": "重命名对话失败", + "rename": "重命名", + "rename_chat": "重命名对话", + "rename_chat_description": "为此对话输入新名称。", + "chat_title_placeholder": "对话标题", + "renaming": "重命名中...", "no_archived_chats": "暂无已归档对话", "error_archiving_chat": "归档对话失败", "new_chat": "新对话", @@ -678,15 +685,19 @@ "dark": "深色", "system": "系统", "logout": "退出登录", + "loggingOut": "正在退出...", "inbox": "收件箱", "search_inbox": "搜索收件箱", "mark_all_read": "全部标记为已读", "mark_as_read": "标记为已读", "mentions": "提及", + "comments": "评论", "status": "状态", "no_results_found": "未找到结果", "no_mentions": "没有提及", "no_mentions_hint": "您会在这里看到他人的提及", + "no_comments": "没有评论", + "no_comments_hint": "您会在这里看到提及和回复", "no_status_updates": "没有状态更新", "no_status_updates_hint": "文档和连接器更新将显示在这里", "filter": "筛选", @@ -714,6 +725,8 @@ "nav_agent_configs_desc": "LLM 模型配置提示词和引用", "nav_role_assignments": "角色分配", "nav_role_assignments_desc": "为代理角色分配配置", + "nav_image_models": "图像模型", + "nav_image_models_desc": "配置图像生成模型", "nav_system_instructions": "系统指令", "nav_system_instructions_desc": "搜索空间级别的 AI 指令", "nav_public_links": "公开聊天链接",