mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/documents
This commit is contained in:
commit
0fdd194d92
60 changed files with 5086 additions and 248 deletions
|
|
@ -143,6 +143,15 @@ STT_SERVICE=local/base
|
|||
PAGES_LIMIT=500
|
||||
|
||||
|
||||
# Residential Proxy Configuration (anonymous-proxies.net)
|
||||
# Used for web crawling, link previews, and YouTube transcript fetching to avoid IP bans.
|
||||
# Leave commented out to disable proxying.
|
||||
# RESIDENTIAL_PROXY_USERNAME=your_proxy_username
|
||||
# RESIDENTIAL_PROXY_PASSWORD=your_proxy_password
|
||||
# RESIDENTIAL_PROXY_HOSTNAME=rotating.dnsproxifier.com:31230
|
||||
# RESIDENTIAL_PROXY_LOCATION=
|
||||
# RESIDENTIAL_PROXY_TYPE=1
|
||||
|
||||
FIRECRAWL_API_KEY=fcr-01J0000000000000000000000
|
||||
|
||||
# File Parser Service
|
||||
|
|
|
|||
|
|
@ -0,0 +1,299 @@
|
|||
"""Add image generation tables and search space preference
|
||||
|
||||
Revision ID: 93
|
||||
Revises: 92
|
||||
|
||||
Changes:
|
||||
1. Create image_generation_configs table (user-created image model configs)
|
||||
2. Create image_generations table (stores generation requests/results)
|
||||
3. Add image_generation_config_id column to searchspaces table
|
||||
4. Add image generation permissions to existing system roles
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import ENUM as PG_ENUM, JSONB, UUID
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "93"
|
||||
down_revision: str | None = "92"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
connection = op.get_bind()
|
||||
|
||||
# 1. Create imagegenprovider enum type if it doesn't exist
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'imagegenprovider') THEN
|
||||
CREATE TYPE imagegenprovider AS ENUM (
|
||||
'OPENAI', 'AZURE_OPENAI', 'GOOGLE', 'VERTEX_AI', 'BEDROCK',
|
||||
'RECRAFT', 'OPENROUTER', 'XINFERENCE', 'NSCALE'
|
||||
);
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create image_generation_configs table (uses imagegenprovider enum)
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generation_configs')"
|
||||
)
|
||||
)
|
||||
if not result.scalar():
|
||||
op.create_table(
|
||||
"image_generation_configs",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("name", sa.String(100), nullable=False),
|
||||
sa.Column("description", sa.String(500), nullable=True),
|
||||
sa.Column(
|
||||
"provider",
|
||||
PG_ENUM(
|
||||
"OPENAI",
|
||||
"AZURE_OPENAI",
|
||||
"GOOGLE",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"RECRAFT",
|
||||
"OPENROUTER",
|
||||
"XINFERENCE",
|
||||
"NSCALE",
|
||||
name="imagegenprovider",
|
||||
create_type=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("custom_provider", sa.String(100), nullable=True),
|
||||
sa.Column("model_name", sa.String(100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(500), nullable=True),
|
||||
sa.Column("api_version", sa.String(50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generation_configs_name "
|
||||
"ON image_generation_configs (name)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generation_configs_search_space_id "
|
||||
"ON image_generation_configs (search_space_id)"
|
||||
)
|
||||
|
||||
# 3. Create image_generations table
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'image_generations')"
|
||||
)
|
||||
)
|
||||
if not result.scalar():
|
||||
op.create_table(
|
||||
"image_generations",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.Text(), nullable=False),
|
||||
sa.Column("model", sa.String(200), nullable=True),
|
||||
sa.Column("n", sa.Integer(), nullable=True),
|
||||
sa.Column("quality", sa.String(50), nullable=True),
|
||||
sa.Column("size", sa.String(50), nullable=True),
|
||||
sa.Column("style", sa.String(50), nullable=True),
|
||||
sa.Column("response_format", sa.String(50), nullable=True),
|
||||
sa.Column("image_generation_config_id", sa.Integer(), nullable=True),
|
||||
sa.Column("response_data", JSONB(), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_by_id", UUID(as_uuid=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by_id"], ["user.id"], ondelete="SET NULL"
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generations_search_space_id "
|
||||
"ON image_generations (search_space_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generations_created_by_id "
|
||||
"ON image_generations (created_by_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_image_generations_created_at "
|
||||
"ON image_generations (created_at)"
|
||||
)
|
||||
|
||||
# 4. Add image_generation_config_id column to searchspaces
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces'
|
||||
AND column_name = 'image_generation_config_id'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.scalar():
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column(
|
||||
"image_generation_config_id",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
# Drop old column name if it exists (from earlier version of this migration)
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces'
|
||||
AND column_name = 'image_generation_llm_id'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.scalar():
|
||||
op.drop_column("searchspaces", "image_generation_llm_id")
|
||||
|
||||
# Drop old column name on image_generations if it exists
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'image_generations'
|
||||
AND column_name = 'llm_config_id'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.scalar():
|
||||
op.drop_column("image_generations", "llm_config_id")
|
||||
|
||||
# Drop old api_version column on image_generations if it exists
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'image_generations'
|
||||
AND column_name = 'api_version'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.scalar():
|
||||
op.drop_column("image_generations", "api_version")
|
||||
|
||||
# 5. Add image generation permissions to existing system roles
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_cat(
|
||||
permissions,
|
||||
ARRAY['image_generations:create', 'image_generations:read']
|
||||
)
|
||||
WHERE is_system_role = true
|
||||
AND name = 'Editor'
|
||||
AND NOT ('image_generations:create' = ANY(permissions))
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_cat(
|
||||
permissions,
|
||||
ARRAY['image_generations:read']
|
||||
)
|
||||
WHERE is_system_role = true
|
||||
AND name = 'Viewer'
|
||||
AND NOT ('image_generations:read' = ANY(permissions))
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
connection = op.get_bind()
|
||||
|
||||
# Remove permissions
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_remove(
|
||||
array_remove(
|
||||
array_remove(permissions, 'image_generations:create'),
|
||||
'image_generations:read'
|
||||
),
|
||||
'image_generations:delete'
|
||||
)
|
||||
WHERE is_system_role = true
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Remove image_generation_config_id from searchspaces
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces'
|
||||
AND column_name = 'image_generation_config_id'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.scalar():
|
||||
op.drop_column("searchspaces", "image_generation_config_id")
|
||||
|
||||
# Drop indexes and tables
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generations_created_at")
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generations_created_by_id")
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generations_search_space_id")
|
||||
op.execute("DROP TABLE IF EXISTS image_generations")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_search_space_id")
|
||||
op.execute("DROP INDEX IF EXISTS ix_image_generation_configs_name")
|
||||
op.execute("DROP TABLE IF EXISTS image_generation_configs")
|
||||
|
||||
# Drop the imagegenprovider enum type
|
||||
op.execute("DROP TYPE IF EXISTS imagegenprovider")
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Add access_token column to image_generations
|
||||
|
||||
Revision ID: 94
|
||||
Revises: 93
|
||||
|
||||
Adds an indexed access_token column to the image_generations table.
|
||||
This token is stored per-record so that image serving URLs survive
|
||||
SECRET_KEY rotation.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "94"
|
||||
down_revision: str | None = "93"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add access_token column (nullable so existing rows are unaffected)
|
||||
op.add_column(
|
||||
"image_generations",
|
||||
sa.Column("access_token", sa.String(64), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_image_generations_access_token",
|
||||
"image_generations",
|
||||
["access_token"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_image_generations_access_token", table_name="image_generations")
|
||||
op.drop_column("image_generations", "access_token")
|
||||
|
|
@ -133,6 +133,7 @@ async def create_surfsense_deep_agent(
|
|||
The agent comes with built-in tools that can be configured:
|
||||
- search_knowledge_base: Search the user's personal knowledge base
|
||||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- link_preview: Fetch rich previews for URLs
|
||||
- display_image: Display images in chat
|
||||
- scrape_webpage: Extract content from webpages
|
||||
|
|
|
|||
|
|
@ -3,15 +3,25 @@ PostgreSQL-based checkpointer for LangGraph agents.
|
|||
|
||||
This module provides a persistent checkpointer using AsyncPostgresSaver
|
||||
that stores conversation state in the PostgreSQL database.
|
||||
|
||||
Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle
|
||||
connection lifecycle, health checks, and automatic reconnection,
|
||||
preventing 'the connection is closed' errors in long-running deployments.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg.rows import dict_row
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpointer instance (initialized lazily)
|
||||
_checkpointer: AsyncPostgresSaver | None = None
|
||||
_checkpointer_context = None # Store the context manager for cleanup
|
||||
_connection_pool: AsyncConnectionPool | None = None
|
||||
_checkpointer_initialized: bool = False
|
||||
|
||||
|
||||
|
|
@ -38,26 +48,65 @@ def get_postgres_connection_string() -> str:
|
|||
return db_url
|
||||
|
||||
|
||||
async def _create_checkpointer() -> AsyncPostgresSaver:
|
||||
"""
|
||||
Create a new AsyncPostgresSaver backed by a connection pool.
|
||||
|
||||
The connection pool automatically handles:
|
||||
- Connection health checks before use
|
||||
- Reconnection when connections die (idle timeout, DB restart, etc.)
|
||||
- Connection lifecycle management (max_lifetime, max_idle)
|
||||
"""
|
||||
global _connection_pool
|
||||
|
||||
conn_string = get_postgres_connection_string()
|
||||
|
||||
_connection_pool = AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
# Connections are recycled after 30 minutes to avoid stale connections
|
||||
max_lifetime=1800,
|
||||
# Idle connections are closed after 5 minutes
|
||||
max_idle=300,
|
||||
open=False,
|
||||
# Connection kwargs required by AsyncPostgresSaver:
|
||||
# - autocommit: required for .setup() to commit checkpoint tables
|
||||
# - prepare_threshold: disable prepared statements for compatibility
|
||||
# - row_factory: checkpointer accesses rows as dicts (row["column"])
|
||||
kwargs={
|
||||
"autocommit": True,
|
||||
"prepare_threshold": 0,
|
||||
"row_factory": dict_row,
|
||||
},
|
||||
)
|
||||
await _connection_pool.open(wait=True)
|
||||
|
||||
checkpointer = AsyncPostgresSaver(conn=_connection_pool)
|
||||
logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool")
|
||||
return checkpointer
|
||||
|
||||
|
||||
async def get_checkpointer() -> AsyncPostgresSaver:
|
||||
"""
|
||||
Get or create the global AsyncPostgresSaver instance.
|
||||
|
||||
This function:
|
||||
1. Creates the checkpointer if it doesn't exist
|
||||
1. Creates the checkpointer with a connection pool if it doesn't exist
|
||||
2. Sets up the required database tables on first call
|
||||
3. Returns the cached instance on subsequent calls
|
||||
|
||||
The underlying connection pool handles reconnection automatically,
|
||||
so a stale/closed connection will not cause OperationalError.
|
||||
|
||||
Returns:
|
||||
AsyncPostgresSaver: The configured checkpointer instance
|
||||
"""
|
||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
||||
global _checkpointer, _checkpointer_initialized
|
||||
|
||||
if _checkpointer is None:
|
||||
conn_string = get_postgres_connection_string()
|
||||
# from_conn_string returns an async context manager
|
||||
# We need to enter the context to get the actual checkpointer
|
||||
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
|
||||
_checkpointer = await _checkpointer_context.__aenter__()
|
||||
_checkpointer = await _create_checkpointer()
|
||||
_checkpointer_initialized = False
|
||||
|
||||
# Setup tables on first call (idempotent)
|
||||
if not _checkpointer_initialized:
|
||||
|
|
@ -75,20 +124,21 @@ async def setup_checkpointer_tables() -> None:
|
|||
tables exist before any agent calls.
|
||||
"""
|
||||
await get_checkpointer()
|
||||
print("[Checkpointer] PostgreSQL checkpoint tables ready")
|
||||
logger.info("[Checkpointer] PostgreSQL checkpoint tables ready")
|
||||
|
||||
|
||||
async def close_checkpointer() -> None:
|
||||
"""
|
||||
Close the checkpointer connection.
|
||||
Close the checkpointer connection pool.
|
||||
|
||||
This should be called during application shutdown.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
||||
global _checkpointer, _connection_pool, _checkpointer_initialized
|
||||
|
||||
if _checkpointer_context is not None:
|
||||
await _checkpointer_context.__aexit__(None, None, None)
|
||||
_checkpointer = None
|
||||
_checkpointer_context = None
|
||||
_checkpointer_initialized = False
|
||||
print("[Checkpointer] PostgreSQL connection closed")
|
||||
if _connection_pool is not None:
|
||||
await _connection_pool.close()
|
||||
logger.info("[Checkpointer] PostgreSQL connection pool closed")
|
||||
|
||||
_checkpointer = None
|
||||
_connection_pool = None
|
||||
_checkpointer_initialized = False
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ You have access to the following tools:
|
|||
* Showing an image from a URL the user explicitly mentioned in their message
|
||||
* Displaying images found in scraped webpage content (from scrape_webpage tool)
|
||||
* Showing a publicly accessible diagram or chart from a known URL
|
||||
* Displaying an AI-generated image after calling the generate_image tool (ALWAYS required)
|
||||
|
||||
CRITICAL - NEVER USE THIS TOOL FOR USER-UPLOADED ATTACHMENTS:
|
||||
When a user uploads/attaches an image file to their message:
|
||||
|
|
@ -100,7 +101,21 @@ You have access to the following tools:
|
|||
- Returns: An image card with the image, title, and description
|
||||
- The image will automatically be displayed in the chat.
|
||||
|
||||
5. scrape_webpage: Scrape and extract the main content from a webpage.
|
||||
5. generate_image: Generate images from text descriptions using AI image models.
|
||||
- Use this when the user asks you to create, generate, draw, design, or make an image.
|
||||
- Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork"
|
||||
- Args:
|
||||
- prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood.
|
||||
- n: Number of images to generate (1-4, default: 1)
|
||||
- Returns: A dictionary with the generated image URL in the "src" field, along with metadata.
|
||||
- CRITICAL: After calling generate_image, you MUST call `display_image` with the returned "src" URL
|
||||
to actually show the image in the chat. The generate_image tool only generates the image and returns
|
||||
the URL — it does NOT display anything. You must always follow up with display_image.
|
||||
- IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim -
|
||||
expand and improve the prompt with specific details about style, lighting, composition, and mood.
|
||||
- If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details.
|
||||
|
||||
6. scrape_webpage: Scrape and extract the main content from a webpage.
|
||||
- Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage.
|
||||
- IMPORTANT: This is different from link_preview:
|
||||
* link_preview: Only fetches metadata (title, description, thumbnail) for display
|
||||
|
|
@ -123,7 +138,7 @@ You have access to the following tools:
|
|||
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
||||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||
|
||||
6. save_memory: Save facts, preferences, or context about the user for personalized responses.
|
||||
7. save_memory: Save facts, preferences, or context about the user for personalized responses.
|
||||
- Use this when the user explicitly or implicitly shares information worth remembering.
|
||||
- Trigger scenarios:
|
||||
* User says "remember this", "keep this in mind", "note that", or similar
|
||||
|
|
@ -146,7 +161,7 @@ You have access to the following tools:
|
|||
- IMPORTANT: Only save information that would be genuinely useful for future conversations.
|
||||
Don't save trivial or temporary information.
|
||||
|
||||
7. recall_memory: Retrieve relevant memories about the user for personalized responses.
|
||||
8. recall_memory: Retrieve relevant memories about the user for personalized responses.
|
||||
- Use this to access stored information about the user.
|
||||
- Trigger scenarios:
|
||||
* You need user context to give a better, more personalized answer
|
||||
|
|
@ -281,6 +296,22 @@ You have access to the following tools:
|
|||
- Then, if the content contains useful diagrams/images like ``:
|
||||
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
|
||||
- Then provide your explanation, referencing the displayed image
|
||||
|
||||
- User: "Generate an image of a cat"
|
||||
- Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")`
|
||||
- Step 2: Use the returned "src" URL to display it: `display_image(src="<returned_url>", alt="A fluffy orange tabby cat on a windowsill", title="Generated Image")`
|
||||
|
||||
- User: "Create a landscape painting of mountains"
|
||||
- Step 1: `generate_image(prompt="Majestic snow-capped mountain range at sunset, dramatic orange and purple sky, alpine meadow with wildflowers in the foreground, oil painting style with visible brushstrokes, inspired by the Hudson River School art movement")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="Mountain landscape painting", title="Generated Image")`
|
||||
|
||||
- User: "Draw me a logo for a coffee shop called Bean Dream"
|
||||
- Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="Bean Dream coffee shop logo", title="Generated Image")`
|
||||
|
||||
- User: "Make a wide banner image for my blog about AI"
|
||||
- Step 1: `generate_image(prompt="Wide banner illustration for an AI technology blog, featuring abstract neural network patterns, glowing blue and purple connections, modern futuristic aesthetic, digital art style, clean and professional")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="AI blog banner", title="Generated Image")`
|
||||
</tool_call_examples>
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ Available tools:
|
|||
- search_knowledge_base: Search the user's personal knowledge base
|
||||
- search_surfsense_docs: Search Surfsense documentation for usage help
|
||||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- link_preview: Fetch rich previews for URLs
|
||||
- display_image: Display images in chat
|
||||
- scrape_webpage: Extract content from webpages
|
||||
|
|
@ -18,6 +19,7 @@ Available tools:
|
|||
# Registry exports
|
||||
# Tool factory exports (for direct use)
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .knowledge_base import (
|
||||
CONNECTOR_DESCRIPTIONS,
|
||||
create_search_knowledge_base_tool,
|
||||
|
|
@ -47,6 +49,7 @@ __all__ = [
|
|||
"build_tools",
|
||||
# Tool factories
|
||||
"create_display_image_tool",
|
||||
"create_generate_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_link_preview_tool",
|
||||
"create_recall_memory_tool",
|
||||
|
|
|
|||
|
|
@ -82,14 +82,20 @@ def create_display_image_tool():
|
|||
|
||||
domain = extract_domain(src)
|
||||
|
||||
# Determine aspect ratio based on common image sources
|
||||
ratio = "16:9" # Default
|
||||
if "unsplash.com" in src or "pexels.com" in src:
|
||||
# Determine aspect ratio based on image source
|
||||
# AI-generated images should use "auto" to preserve their native ratio
|
||||
is_generated = "/image-generations/" in src
|
||||
if is_generated:
|
||||
ratio = "auto"
|
||||
domain = "ai-generated"
|
||||
elif "unsplash.com" in src or "pexels.com" in src:
|
||||
ratio = "16:9"
|
||||
elif (
|
||||
"imgur.com" in src or "github.com" in src or "githubusercontent.com" in src
|
||||
):
|
||||
ratio = "auto"
|
||||
else:
|
||||
ratio = "auto"
|
||||
|
||||
return {
|
||||
"id": image_id,
|
||||
|
|
|
|||
242
surfsense_backend/app/agents/new_chat/tools/generate_image.py
Normal file
242
surfsense_backend/app/agents/new_chat/tools/generate_image.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""
|
||||
Image generation tool for the SurfSense agent.
|
||||
|
||||
This module provides a tool that generates images using litellm.aimage_generation()
|
||||
and returns the result via the existing display_image tool format so the frontend
|
||||
renders the generated image inline in the chat.
|
||||
|
||||
Config resolution:
|
||||
1. Uses the search space's image_generation_config_id preference
|
||||
2. Falls back to Auto mode (router load balancing) if available
|
||||
3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping (same as routes)
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def create_generate_image_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_image tool.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID (for config resolution)
|
||||
db_session: Async database session
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
n: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate an image from a text description using AI image models.
|
||||
|
||||
Use this tool when the user asks you to create, generate, draw, or make an image.
|
||||
The generated image will be displayed directly in the chat.
|
||||
|
||||
Args:
|
||||
prompt: A detailed text description of the image to generate.
|
||||
Be specific about subject, style, colors, composition, and mood.
|
||||
n: Number of images to generate (1-4). Default: 1
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generated image(s) for display in the chat.
|
||||
"""
|
||||
try:
|
||||
# Resolve the image generation config from the search space preference
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return {"error": "Search space not found"}
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# Build generation kwargs
|
||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
||||
# Different models support different values for these params
|
||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
||||
# differ). Letting the model use its own defaults avoids errors.
|
||||
gen_kwargs: dict[str, Any] = {}
|
||||
if n is not None and n > 1:
|
||||
gen_kwargs["n"] = n
|
||||
|
||||
# Call litellm based on config type
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
return {
|
||||
"error": "No image generation models configured. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
}
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
)
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
return {"error": f"Image generation config {config_id} not found"}
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg["model_name"],
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
cfg_result = await db_session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
return {"error": f"Image generation config {config_id} not found"}
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value,
|
||||
db_cfg.model_name,
|
||||
db_cfg.custom_provider,
|
||||
)
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
|
||||
# Parse the response and store in DB
|
||||
response_dict = (
|
||||
response.model_dump()
|
||||
if hasattr(response, "model_dump")
|
||||
else dict(response)
|
||||
)
|
||||
|
||||
# Generate a random access token for this image
|
||||
access_token = generate_image_token()
|
||||
|
||||
# Save to image_generations table for history
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_generation_config_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
db_session.add(db_image_gen)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(db_image_gen)
|
||||
|
||||
# Extract image URLs from response
|
||||
images = response_dict.get("data", [])
|
||||
if not images:
|
||||
return {"error": "No images were generated"}
|
||||
|
||||
first_image = images[0]
|
||||
revised_prompt = first_image.get("revised_prompt", prompt)
|
||||
|
||||
# Resolve image URL:
|
||||
# - If the API returned a URL, use it directly.
|
||||
# - If the API returned b64_json (e.g. gpt-image-1), serve the
|
||||
# image through our backend endpoint to avoid bloating the
|
||||
# LLM context with megabytes of base64 data.
|
||||
if first_image.get("url"):
|
||||
image_url = first_image["url"]
|
||||
elif first_image.get("b64_json"):
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
image_url = (
|
||||
f"{backend_url}/api/v1/image-generations/"
|
||||
f"{db_image_gen.id}/image?token={access_token}"
|
||||
)
|
||||
else:
|
||||
return {"error": "No displayable image data in the response"}
|
||||
|
||||
return {
|
||||
"src": image_url,
|
||||
"alt": revised_prompt or prompt,
|
||||
"title": "Generated Image",
|
||||
"description": revised_prompt if revised_prompt != prompt else None,
|
||||
"generated": True,
|
||||
"prompt": prompt,
|
||||
"image_count": len(images),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Image generation failed in tool")
|
||||
return {
|
||||
"error": f"Image generation failed: {e!s}",
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
return generate_image
|
||||
|
|
@ -17,6 +17,8 @@ from fake_useragent import UserAgent
|
|||
from langchain_core.tools import tool
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from app.utils.proxy_config import get_playwright_proxy, get_residential_proxy_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -186,9 +188,15 @@ async def fetch_with_chromium(url: str) -> dict[str, Any] | None:
|
|||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
|
||||
# Use residential proxy if configured
|
||||
playwright_proxy = get_playwright_proxy()
|
||||
|
||||
# Use Playwright to fetch the page
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch(headless=True)
|
||||
launch_kwargs: dict = {"headless": True}
|
||||
if playwright_proxy:
|
||||
launch_kwargs["proxy"] = playwright_proxy
|
||||
browser = await p.chromium.launch(**launch_kwargs)
|
||||
context = await browser.new_context(user_agent=user_agent)
|
||||
page = await context.new_page()
|
||||
|
||||
|
|
@ -283,12 +291,16 @@ def create_link_preview_tool():
|
|||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
|
||||
# Use residential proxy if configured
|
||||
proxy_url = get_residential_proxy_url()
|
||||
|
||||
# Use a browser-like User-Agent to fetch Open Graph metadata.
|
||||
# We're only fetching publicly available metadata (title, description, thumbnail)
|
||||
# that websites intentionally expose via OG tags for link preview purposes.
|
||||
async with httpx.AsyncClient(
|
||||
timeout=10.0,
|
||||
follow_redirects=True,
|
||||
proxy=proxy_url,
|
||||
headers={
|
||||
"User-Agent": user_agent,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8",
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from typing import Any
|
|||
from langchain_core.tools import BaseTool
|
||||
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .link_preview import create_link_preview_tool
|
||||
from .mcp_tool import load_mcp_tools
|
||||
|
|
@ -125,6 +126,16 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
factory=lambda deps: create_display_image_tool(),
|
||||
requires=[],
|
||||
),
|
||||
# Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.)
|
||||
ToolDefinition(
|
||||
name="generate_image",
|
||||
description="Generate images from text descriptions using AI image models",
|
||||
factory=lambda deps: create_generate_image_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
),
|
||||
requires=["search_space_id", "db_session"],
|
||||
),
|
||||
# Web scraping tool - extracts content from webpages
|
||||
ToolDefinition(
|
||||
name="scrape_webpage",
|
||||
|
|
|
|||
|
|
@ -2,17 +2,26 @@
|
|||
Web scraping tool for the SurfSense agent.
|
||||
|
||||
This module provides a tool for scraping and extracting content from webpages
|
||||
using the existing WebCrawlerConnector. The scraped content can be used by
|
||||
the agent to answer questions about web pages.
|
||||
using the existing WebCrawlerConnector. For YouTube URLs, it fetches the
|
||||
transcript directly via the YouTubeTranscriptApi instead of crawling the page.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from fake_useragent import UserAgent
|
||||
from langchain_core.tools import tool
|
||||
from requests import Session
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
from app.connectors.webcrawler_connector import WebCrawlerConnector
|
||||
from app.tasks.document_processors.youtube_processor import get_youtube_video_id
|
||||
from app.utils.proxy_config import get_requests_proxies
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_domain(url: str) -> str:
|
||||
|
|
@ -57,6 +66,89 @@ def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
|
|||
return truncated + "\n\n[Content truncated...]", True
|
||||
|
||||
|
||||
async def _scrape_youtube_video(
|
||||
url: str, video_id: str, max_length: int
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch YouTube video metadata and transcript via the YouTubeTranscriptApi.
|
||||
|
||||
Returns a result dict in the same shape as the regular scrape_webpage output.
|
||||
"""
|
||||
scrape_id = generate_scrape_id(url)
|
||||
domain = "youtube.com"
|
||||
|
||||
# --- Video metadata via oEmbed ---
|
||||
residential_proxies = get_requests_proxies()
|
||||
|
||||
params = {
|
||||
"format": "json",
|
||||
"url": f"https://www.youtube.com/watch?v={video_id}",
|
||||
}
|
||||
oembed_url = "https://www.youtube.com/oembed"
|
||||
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as http_session,
|
||||
http_session.get(
|
||||
oembed_url,
|
||||
params=params,
|
||||
proxy=residential_proxies["http"] if residential_proxies else None,
|
||||
) as response,
|
||||
):
|
||||
video_data = await response.json()
|
||||
except Exception:
|
||||
video_data = {}
|
||||
|
||||
title = video_data.get("title", "YouTube Video")
|
||||
author = video_data.get("author_name", "Unknown")
|
||||
|
||||
# --- Transcript via YouTubeTranscriptApi ---
|
||||
try:
|
||||
ua = UserAgent()
|
||||
http_client = Session()
|
||||
http_client.headers.update({"User-Agent": ua.random})
|
||||
if residential_proxies:
|
||||
http_client.proxies.update(residential_proxies)
|
||||
ytt_api = YouTubeTranscriptApi(http_client=http_client)
|
||||
captions = ytt_api.fetch(video_id)
|
||||
|
||||
transcript_segments = []
|
||||
for line in captions:
|
||||
start_time = line.start
|
||||
duration = line.duration
|
||||
text = line.text
|
||||
timestamp = f"[{start_time:.2f}s-{start_time + duration:.2f}s]"
|
||||
transcript_segments.append(f"{timestamp} {text}")
|
||||
transcript_text = "\n".join(transcript_segments)
|
||||
except Exception as e:
|
||||
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
|
||||
transcript_text = f"No captions available for this video. Error: {e!s}"
|
||||
|
||||
# Build combined content
|
||||
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
|
||||
|
||||
# Truncate if needed
|
||||
content, was_truncated = truncate_content(content, max_length)
|
||||
word_count = len(content.split())
|
||||
|
||||
description = f"YouTube video by {author}"
|
||||
|
||||
return {
|
||||
"id": scrape_id,
|
||||
"assetId": url,
|
||||
"kind": "article",
|
||||
"href": url,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"content": content,
|
||||
"domain": domain,
|
||||
"word_count": word_count,
|
||||
"was_truncated": was_truncated,
|
||||
"crawler_type": "youtube_transcript",
|
||||
"author": author,
|
||||
}
|
||||
|
||||
|
||||
def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
||||
"""
|
||||
Factory function to create the scrape_webpage tool.
|
||||
|
|
@ -79,7 +171,8 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
|||
|
||||
Use this tool when the user wants you to read, summarize, or answer
|
||||
questions about a specific webpage's content. This tool actually
|
||||
fetches and reads the full page content.
|
||||
fetches and reads the full page content. For YouTube video URLs it
|
||||
fetches the transcript directly instead of crawling the page.
|
||||
|
||||
Common triggers:
|
||||
- "Read this article and summarize it"
|
||||
|
|
@ -114,6 +207,11 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
|||
url = f"https://{url}"
|
||||
|
||||
try:
|
||||
# Check if this is a YouTube URL and use transcript API instead
|
||||
video_id = get_youtube_video_id(url)
|
||||
if video_id:
|
||||
return await _scrape_youtube_video(url, video_id, max_length)
|
||||
|
||||
# Create webcrawler connector
|
||||
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
|
||||
|
||||
|
|
@ -184,7 +282,7 @@ def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
|||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
print(f"[scrape_webpage] Error scraping {url}: {error_message}")
|
||||
logger.error(f"[scrape_webpage] Error scraping {url}: {error_message}")
|
||||
return {
|
||||
"id": scrape_id,
|
||||
"assetId": url,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from app.agents.new_chat.checkpointer import (
|
|||
close_checkpointer,
|
||||
setup_checkpointer_tables,
|
||||
)
|
||||
from app.config import config, initialize_llm_router
|
||||
from app.config import config, initialize_image_gen_router, initialize_llm_router
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.routes import router as crud_router
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
|
|
@ -26,6 +26,8 @@ async def lifespan(app: FastAPI):
|
|||
await setup_checkpointer_tables()
|
||||
# Initialize LLM Router for Auto mode load balancing
|
||||
initialize_llm_router()
|
||||
# Initialize Image Generation Router for Auto mode load balancing
|
||||
initialize_image_gen_router()
|
||||
# Seed Surfsense documentation
|
||||
await seed_surfsense_docs()
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -13,14 +13,15 @@ load_dotenv()
|
|||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs):
|
||||
"""Initialize the LLM Router when a Celery worker process starts.
|
||||
"""Initialize the LLM Router and Image Gen Router when a Celery worker process starts.
|
||||
|
||||
This ensures the Auto mode (LiteLLM Router) is available for background tasks
|
||||
like document summarization.
|
||||
like document summarization and image generation.
|
||||
"""
|
||||
from app.config import initialize_llm_router
|
||||
from app.config import initialize_image_gen_router, initialize_llm_router
|
||||
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
|
||||
|
||||
# Get Celery configuration from environment
|
||||
|
|
|
|||
|
|
@ -81,6 +81,56 @@ def load_router_settings():
|
|||
return default_settings
|
||||
|
||||
|
||||
def load_global_image_gen_configs():
|
||||
"""
|
||||
Load global image generation configurations from YAML file.
|
||||
|
||||
Returns:
|
||||
list: List of global image generation config dictionaries, or empty list
|
||||
"""
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_image_generation_configs", [])
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global image generation configs: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def load_image_gen_router_settings():
|
||||
"""
|
||||
Load router settings for image generation Auto mode from YAML file.
|
||||
|
||||
Returns:
|
||||
dict: Router settings dictionary
|
||||
"""
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("image_generation_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load image generation router settings: {e}")
|
||||
return default_settings
|
||||
|
||||
|
||||
def initialize_llm_router():
|
||||
"""
|
||||
Initialize the LLM Router service for Auto mode.
|
||||
|
|
@ -105,6 +155,33 @@ def initialize_llm_router():
|
|||
print(f"Warning: Failed to initialize LLM Router: {e}")
|
||||
|
||||
|
||||
def initialize_image_gen_router():
|
||||
"""
|
||||
Initialize the Image Generation Router service for Auto mode.
|
||||
This should be called during application startup.
|
||||
"""
|
||||
image_gen_configs = load_global_image_gen_configs()
|
||||
router_settings = load_image_gen_router_settings()
|
||||
|
||||
if not image_gen_configs:
|
||||
print(
|
||||
"Info: No global image generation configs found, "
|
||||
"Image Generation Auto mode will not be available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.image_gen_router_service import ImageGenRouterService
|
||||
|
||||
ImageGenRouterService.initialize(image_gen_configs, router_settings)
|
||||
print(
|
||||
f"Info: Image Generation Router initialized with {len(image_gen_configs)} models "
|
||||
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize Image Generation Router: {e}")
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
|
|
@ -216,6 +293,12 @@ class Config:
|
|||
# Router settings for Auto mode (LiteLLM Router load balancing)
|
||||
ROUTER_SETTINGS = load_router_settings()
|
||||
|
||||
# Global Image Generation Configurations (optional)
|
||||
GLOBAL_IMAGE_GEN_CONFIGS = load_global_image_gen_configs()
|
||||
|
||||
# Router settings for Image Generation Auto mode
|
||||
IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings()
|
||||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
# Azure OpenAI credentials from environment variables
|
||||
|
|
@ -277,6 +360,14 @@ class Config:
|
|||
# LlamaCloud API Key
|
||||
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
|
||||
# Residential Proxy Configuration (anonymous-proxies.net)
|
||||
# Used for web crawling and YouTube transcript fetching to avoid IP bans.
|
||||
RESIDENTIAL_PROXY_USERNAME = os.getenv("RESIDENTIAL_PROXY_USERNAME")
|
||||
RESIDENTIAL_PROXY_PASSWORD = os.getenv("RESIDENTIAL_PROXY_PASSWORD")
|
||||
RESIDENTIAL_PROXY_HOSTNAME = os.getenv("RESIDENTIAL_PROXY_HOSTNAME")
|
||||
RESIDENTIAL_PROXY_LOCATION = os.getenv("RESIDENTIAL_PROXY_LOCATION", "")
|
||||
RESIDENTIAL_PROXY_TYPE = int(os.getenv("RESIDENTIAL_PROXY_TYPE", "1"))
|
||||
|
||||
# Litellm TTS Configuration
|
||||
TTS_SERVICE = os.getenv("TTS_SERVICE")
|
||||
TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE")
|
||||
|
|
|
|||
|
|
@ -183,6 +183,69 @@ global_llm_configs:
|
|||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Configuration
|
||||
# =============================================================================
|
||||
# These configurations power the image generation feature using litellm.aimage_generation().
|
||||
# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock,
|
||||
# Recraft, OpenRouter, Xinference, Nscale
|
||||
#
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs.
|
||||
|
||||
# Router Settings for Image Generation Auto Mode
|
||||
image_generation_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
global_image_generation_configs:
|
||||
# Example: OpenAI DALL-E 3
|
||||
- id: -1
|
||||
name: "Global DALL-E 3"
|
||||
description: "OpenAI's DALL-E 3 for high-quality image generation"
|
||||
provider: "OPENAI"
|
||||
model_name: "dall-e-3"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
|
||||
litellm_params: {}
|
||||
|
||||
# Example: OpenAI GPT Image 1
|
||||
- id: -2
|
||||
name: "Global GPT Image 1"
|
||||
description: "OpenAI's GPT Image 1 model"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-image-1"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50
|
||||
litellm_params: {}
|
||||
|
||||
# Example: Azure OpenAI DALL-E 3
|
||||
- id: -3
|
||||
name: "Global Azure DALL-E 3"
|
||||
description: "Azure-hosted DALL-E 3 deployment"
|
||||
provider: "AZURE_OPENAI"
|
||||
model_name: "azure/dall-e-3-deployment"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview"
|
||||
rpm: 50
|
||||
litellm_params:
|
||||
base_model: "dall-e-3"
|
||||
|
||||
# Example: OpenRouter Gemini Image Generation
|
||||
# - id: -4
|
||||
# name: "Global Gemini Image Gen"
|
||||
# description: "Google Gemini image generation via OpenRouter"
|
||||
# provider: "OPENROUTER"
|
||||
# model_name: "google/gemini-2.5-flash-image"
|
||||
# api_key: "your-openrouter-api-key-here"
|
||||
# api_base: ""
|
||||
# rpm: 30
|
||||
# litellm_params: {}
|
||||
|
||||
# Notes:
|
||||
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
|
||||
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
|
||||
|
|
@ -195,10 +258,11 @@ global_llm_configs:
|
|||
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
|
||||
# These help the router distribute load evenly and avoid rate limit errors
|
||||
#
|
||||
# AZURE-SPECIFIC NOTES:
|
||||
# - Always add 'base_model' in litellm_params for Azure deployments
|
||||
# - This fixes "Could not identify azure model 'X'" warnings
|
||||
# - base_model should match the underlying OpenAI model (e.g., gpt-4o, gpt-4-turbo, gpt-3.5-turbo)
|
||||
# - model_name format: "azure/<your-deployment-name>"
|
||||
# - api_version: Use a recent Azure API version (e.g., "2024-02-15-preview")
|
||||
# - See: https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
|
||||
#
|
||||
# IMAGE GENERATION NOTES:
|
||||
# - Image generation configs use the same ID scheme as LLM configs (negative for global)
|
||||
# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure),
|
||||
# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter)
|
||||
# - The router uses litellm.aimage_generation() for async image generation
|
||||
# - Only RPM (requests per minute) is relevant for image generation rate limiting.
|
||||
# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token.
|
||||
|
|
|
|||
|
|
@ -108,7 +108,9 @@ class AirtableHistoryConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Airtable access token is invalid or empty. "
|
||||
"Please reconnect your Airtable account."
|
||||
|
|
|
|||
|
|
@ -128,7 +128,9 @@ class ConfluenceHistoryConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Confluence access token is invalid or empty. "
|
||||
"Please reconnect your Confluence account."
|
||||
|
|
|
|||
|
|
@ -129,7 +129,9 @@ class JiraHistoryConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Jira access token is invalid or empty. "
|
||||
"Please reconnect your Jira account."
|
||||
|
|
|
|||
|
|
@ -153,7 +153,9 @@ class LinearConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Linear access token is invalid or empty. "
|
||||
"Please reconnect your Linear account."
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from fake_useragent import UserAgent
|
|||
from firecrawl import AsyncFirecrawlApp
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from app.utils.proxy_config import get_playwright_proxy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -165,9 +167,15 @@ class WebCrawlerConnector:
|
|||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
|
||||
# Use residential proxy if configured
|
||||
playwright_proxy = get_playwright_proxy()
|
||||
|
||||
# Use Playwright to fetch the page
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch(headless=True)
|
||||
launch_kwargs: dict = {"headless": True}
|
||||
if playwright_proxy:
|
||||
launch_kwargs["proxy"] = playwright_proxy
|
||||
browser = await p.chromium.launch(**launch_kwargs)
|
||||
context = await browser.new_context(user_agent=user_agent)
|
||||
page = await context.new_page()
|
||||
|
||||
|
|
|
|||
|
|
@ -214,6 +214,24 @@ class LiteLLMProvider(str, Enum):
|
|||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class ImageGenProvider(str, Enum):
|
||||
"""
|
||||
Enum for image generation providers supported by LiteLLM.
|
||||
This is a subset of LLM providers — only those that support image generation.
|
||||
See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
GOOGLE = "GOOGLE" # Google AI Studio
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK" # AWS Bedrock
|
||||
RECRAFT = "RECRAFT"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
XINFERENCE = "XINFERENCE"
|
||||
NSCALE = "NSCALE"
|
||||
|
||||
|
||||
class LogLevel(str, Enum):
|
||||
DEBUG = "DEBUG"
|
||||
INFO = "INFO"
|
||||
|
|
@ -314,6 +332,11 @@ class Permission(str, Enum):
|
|||
PODCASTS_UPDATE = "podcasts:update"
|
||||
PODCASTS_DELETE = "podcasts:delete"
|
||||
|
||||
# Image Generations
|
||||
IMAGE_GENERATIONS_CREATE = "image_generations:create"
|
||||
IMAGE_GENERATIONS_READ = "image_generations:read"
|
||||
IMAGE_GENERATIONS_DELETE = "image_generations:delete"
|
||||
|
||||
# Connectors
|
||||
CONNECTORS_CREATE = "connectors:create"
|
||||
CONNECTORS_READ = "connectors:read"
|
||||
|
|
@ -375,6 +398,9 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.PODCASTS_CREATE.value,
|
||||
Permission.PODCASTS_READ.value,
|
||||
Permission.PODCASTS_UPDATE.value,
|
||||
# Image Generations (create and read, no delete)
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
# Connectors (no delete)
|
||||
Permission.CONNECTORS_CREATE.value,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
|
|
@ -404,6 +430,8 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.LLM_CONFIGS_READ.value,
|
||||
# Podcasts (read only)
|
||||
Permission.PODCASTS_READ.value,
|
||||
# Image Generations (read only)
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
# Connectors (read only)
|
||||
Permission.CONNECTORS_READ.value,
|
||||
# Logs (read only)
|
||||
|
|
@ -969,6 +997,103 @@ class Podcast(BaseModel, TimestampMixin):
|
|||
thread = relationship("NewChatThread")
|
||||
|
||||
|
||||
class ImageGenerationConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Dedicated configuration table for image generation models.
|
||||
|
||||
Separate from NewLLMConfig because image generation models don't need
|
||||
system_instructions, citations_enabled, or use_default_system_instructions.
|
||||
They only need provider credentials and model parameters.
|
||||
"""
|
||||
|
||||
__tablename__ = "image_generation_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# Provider & model (uses ImageGenProvider, NOT LiteLLMProvider)
|
||||
provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
# Credentials
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True) # Azure-specific
|
||||
|
||||
# Additional litellm parameters
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# Relationships
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship(
|
||||
"SearchSpace", back_populates="image_generation_configs"
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneration(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Stores image generation requests and results using litellm.aimage_generation().
|
||||
|
||||
Since aimage_generation is a single async call (not a background job),
|
||||
there is no status enum. A row with response_data means success;
|
||||
a row with error_message means failure.
|
||||
|
||||
Response data is stored as JSONB matching the litellm output format:
|
||||
{
|
||||
"created": int,
|
||||
"data": [{"b64_json": str|None, "revised_prompt": str|None, "url": str|None}],
|
||||
"usage": {"prompt_tokens": int, "completion_tokens": int, "total_tokens": int}
|
||||
}
|
||||
"""
|
||||
|
||||
__tablename__ = "image_generations"
|
||||
|
||||
# Request parameters (matching litellm.aimage_generation() params)
|
||||
prompt = Column(Text, nullable=False)
|
||||
model = Column(String(200), nullable=True) # e.g., "dall-e-3", "gpt-image-1"
|
||||
n = Column(Integer, nullable=True, default=1)
|
||||
quality = Column(
|
||||
String(50), nullable=True
|
||||
) # "auto", "high", "medium", "low", "hd", "standard"
|
||||
size = Column(
|
||||
String(50), nullable=True
|
||||
) # "1024x1024", "1536x1024", "1024x1536", etc.
|
||||
style = Column(String(50), nullable=True) # Model-specific style parameter
|
||||
response_format = Column(String(50), nullable=True) # "url" or "b64_json"
|
||||
|
||||
# Image generation config reference
|
||||
# 0 = Auto mode (router), negative IDs = global configs from YAML,
|
||||
# positive IDs = ImageGenerationConfig records in DB
|
||||
image_generation_config_id = Column(Integer, nullable=True)
|
||||
|
||||
# Response data (full litellm response as JSONB) — present on success
|
||||
response_data = Column(JSONB, nullable=True)
|
||||
# Error message — present on failure
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Signed access token for serving images via <img> tags.
|
||||
# Stored in DB so it survives SECRET_KEY rotation.
|
||||
access_token = Column(String(64), nullable=True, index=True)
|
||||
|
||||
# Foreign keys
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
created_by_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
search_space = relationship("SearchSpace", back_populates="image_generations")
|
||||
created_by = relationship("User", back_populates="image_generations")
|
||||
|
||||
|
||||
class SearchSpace(BaseModel, TimestampMixin):
|
||||
__tablename__ = "searchspaces"
|
||||
|
||||
|
|
@ -993,6 +1118,9 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
document_summary_llm_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For document summarization, defaults to Auto mode
|
||||
image_generation_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For image generation, defaults to Auto mode
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
|
|
@ -1017,6 +1145,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="Podcast.id.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
image_generations = relationship(
|
||||
"ImageGeneration",
|
||||
back_populates="search_space",
|
||||
order_by="ImageGeneration.id.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
logs = relationship(
|
||||
"Log",
|
||||
back_populates="search_space",
|
||||
|
|
@ -1041,6 +1175,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="NewLLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="search_space",
|
||||
order_by="ImageGenerationConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# RBAC relationships
|
||||
roles = relationship(
|
||||
|
|
@ -1421,6 +1561,13 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generations created by this user
|
||||
image_generations = relationship(
|
||||
"ImageGeneration",
|
||||
back_populates="created_by",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
|
|
@ -1493,6 +1640,13 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generations created by this user
|
||||
image_generations = relationship(
|
||||
"ImageGeneration",
|
||||
back_populates="created_by",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .google_drive_add_connector_route import (
|
|||
from .google_gmail_add_connector_route import (
|
||||
router as google_gmail_add_connector_router,
|
||||
)
|
||||
from .image_generation_routes import router as image_generation_router
|
||||
from .incentive_tasks_routes import router as incentive_tasks_router
|
||||
from .jira_add_connector_route import router as jira_add_connector_router
|
||||
from .linear_add_connector_route import router as linear_add_connector_router
|
||||
|
|
@ -49,6 +50,7 @@ router.include_router(notes_router)
|
|||
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||
router.include_router(chat_comments_router)
|
||||
router.include_router(podcasts_router) # Podcast task status and audio
|
||||
router.include_router(image_generation_router) # Image generation via litellm
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(google_calendar_add_connector_router)
|
||||
router.include_router(google_gmail_add_connector_router)
|
||||
|
|
|
|||
710
surfsense_backend/app/routes/image_generation_routes.py
Normal file
710
surfsense_backend/app/routes/image_generation_routes.py
Normal file
|
|
@ -0,0 +1,710 @@
|
|||
"""
|
||||
Image Generation routes:
|
||||
- CRUD for ImageGenerationConfig (user-created image model configs)
|
||||
- Global image gen configs endpoint (from YAML)
|
||||
- Image generation execution (calls litellm.aimage_generation())
|
||||
- CRUD for ImageGeneration records (results)
|
||||
- Image serving endpoint (serves b64_json images from DB, protected by signed tokens)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
GlobalImageGenConfigRead,
|
||||
ImageGenerationConfigCreate,
|
||||
ImageGenerationConfigRead,
|
||||
ImageGenerationConfigUpdate,
|
||||
ImageGenerationCreate,
|
||||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping for building litellm model strings.
|
||||
# Only includes providers that support image generation.
|
||||
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image generation configuration by ID (negative IDs)."""
|
||||
if config_id == IMAGE_GEN_AUTO_MODE_ID:
|
||||
return {
|
||||
"id": IMAGE_GEN_AUTO_MODE_ID,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
if config_id > 0:
|
||||
return None
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
"""Build a litellm model string from provider + model_name."""
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
async def _execute_image_generation(
|
||||
session: AsyncSession,
|
||||
image_gen: ImageGeneration,
|
||||
search_space: SearchSpace,
|
||||
) -> None:
|
||||
"""
|
||||
Call litellm.aimage_generation() with the appropriate config.
|
||||
|
||||
Resolution order:
|
||||
1. Explicit image_generation_config_id on the request
|
||||
2. Search space's image_generation_config_id preference
|
||||
3. Falls back to Auto mode if available
|
||||
"""
|
||||
config_id = image_gen.image_generation_config_id
|
||||
if config_id is None:
|
||||
config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
image_gen.image_generation_config_id = config_id
|
||||
|
||||
# Build kwargs
|
||||
gen_kwargs = {}
|
||||
if image_gen.n is not None:
|
||||
gen_kwargs["n"] = image_gen.n
|
||||
if image_gen.quality is not None:
|
||||
gen_kwargs["quality"] = image_gen.quality
|
||||
if image_gen.size is not None:
|
||||
gen_kwargs["size"] = image_gen.size
|
||||
if image_gen.style is not None:
|
||||
gen_kwargs["style"] = image_gen.style
|
||||
if image_gen.response_format is not None:
|
||||
gen_kwargs["response_format"] = image_gen.response_format
|
||||
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
raise ValueError(
|
||||
"Auto mode requested but Image Generation Router not initialized. "
|
||||
"Ensure global_llm_config.yaml has global_image_generation_configs."
|
||||
)
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=image_gen.prompt, model="auto", **gen_kwargs
|
||||
)
|
||||
elif config_id < 0:
|
||||
# Global config from YAML
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
raise ValueError(f"Global image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
|
||||
)
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
# User model override
|
||||
if image_gen.model:
|
||||
model_string = image_gen.model
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=image_gen.prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = DB ImageGenerationConfig
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_cfg = result.scalars().first()
|
||||
if not db_cfg:
|
||||
raise ValueError(f"Image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
|
||||
)
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
|
||||
# User model override
|
||||
if image_gen.model:
|
||||
model_string = image_gen.model
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=image_gen.prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
|
||||
# Store response
|
||||
image_gen.response_data = (
|
||||
response.model_dump() if hasattr(response, "model_dump") else dict(response)
|
||||
)
|
||||
if not image_gen.model and hasattr(response, "_hidden_params"):
|
||||
hidden = response._hidden_params
|
||||
if isinstance(hidden, dict) and hidden.get("model"):
|
||||
image_gen.model = hidden["model"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Image Generation Configs (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/global-image-generation-configs",
|
||||
response_model=list[GlobalImageGenConfigRead],
|
||||
)
|
||||
async def get_global_image_gen_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get all global image generation configs. API keys are hidden."""
|
||||
try:
|
||||
global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
safe_configs = []
|
||||
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"description": "Automatically routes across available image generation providers.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
}
|
||||
)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global image generation configs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ImageGenerationConfig CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead)
|
||||
async def create_image_gen_config(
|
||||
config_data: ImageGenerationConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a new image generation config for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generation configs in this search space",
|
||||
)
|
||||
|
||||
db_config = ImageGenerationConfig(**config_data.model_dump())
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
|
||||
async def list_image_gen_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List image generation configs for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig)
|
||||
.filter(ImageGenerationConfig.search_space_id == search_space_id)
|
||||
.order_by(ImageGenerationConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list ImageGenerationConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def get_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific image generation config by ID."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def update_image_gen_config(
|
||||
config_id: int,
|
||||
update_data: ImageGenerationConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update an existing image generation config."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to update image generation configs in this search space",
|
||||
)
|
||||
|
||||
for key, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(db_config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/image-generation-configs/{config_id}", response_model=dict)
|
||||
async def delete_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete an image generation config."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generation configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {
|
||||
"message": "Image generation config deleted successfully",
|
||||
"id": config_id,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Execution + Results CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/image-generations", response_model=ImageGenerationRead)
|
||||
async def create_image_generation(
|
||||
data: ImageGenerationCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create and execute an image generation request."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generations in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == data.search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error during image generation"
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create image generation")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Image generation failed: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generations", response_model=list[ImageGenerationListRead])
|
||||
async def list_image_generations(
|
||||
search_space_id: int | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List image generations."""
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
|
||||
try:
|
||||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to read image generations in this search space",
|
||||
)
|
||||
result = await session.execute(
|
||||
select(ImageGeneration)
|
||||
.filter(ImageGeneration.search_space_id == search_space_id)
|
||||
.order_by(ImageGeneration.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(ImageGeneration)
|
||||
.join(SearchSpace)
|
||||
.join(SearchSpaceMembership)
|
||||
.filter(SearchSpaceMembership.user_id == user.id)
|
||||
.order_by(ImageGeneration.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
return [
|
||||
ImageGenerationListRead.from_orm_with_count(img)
|
||||
for img in result.scalars().all()
|
||||
]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error fetching image generations"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead)
|
||||
async def get_image_generation(
|
||||
image_gen_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific image generation by ID."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
|
||||
)
|
||||
image_gen = result.scalars().first()
|
||||
if not image_gen:
|
||||
raise HTTPException(status_code=404, detail="Image generation not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
image_gen.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to read image generations in this search space",
|
||||
)
|
||||
return image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error fetching image generation"
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/image-generations/{image_gen_id}", response_model=dict)
|
||||
async def delete_image_generation(
|
||||
image_gen_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete an image generation record."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
|
||||
)
|
||||
db_image_gen = result.scalars().first()
|
||||
if not db_image_gen:
|
||||
raise HTTPException(status_code=404, detail="Image generation not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_image_gen.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generations in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_image_gen)
|
||||
await session.commit()
|
||||
return {"message": "Image generation deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error deleting image generation"
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image Serving (serves generated images from DB, protected by signed tokens)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/image-generations/{image_gen_id}/image")
|
||||
async def serve_generated_image(
|
||||
image_gen_id: int,
|
||||
token: str = Query(..., description="Signed access token"),
|
||||
index: int = 0,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Serve a generated image by ID, protected by a signed token.
|
||||
|
||||
The token is generated when the image URL is created by the generate_image
|
||||
tool and encodes the image_gen_id, search_space_id, and an expiry timestamp.
|
||||
This ensures only users with access to the search space can view images,
|
||||
without requiring auth headers (which <img> tags cannot pass).
|
||||
|
||||
Args:
|
||||
image_gen_id: The image generation record ID
|
||||
token: HMAC-signed access token (included as query parameter)
|
||||
index: Which image to serve if multiple were generated (default: 0)
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
|
||||
)
|
||||
image_gen = result.scalars().first()
|
||||
if not image_gen:
|
||||
raise HTTPException(status_code=404, detail="Image generation not found")
|
||||
|
||||
# Verify the access token against the one stored on the record
|
||||
if not verify_image_token(image_gen.access_token, token):
|
||||
raise HTTPException(status_code=403, detail="Invalid image access token")
|
||||
|
||||
if not image_gen.response_data:
|
||||
raise HTTPException(status_code=404, detail="No image data available")
|
||||
|
||||
images = image_gen.response_data.get("data", [])
|
||||
if not images or index >= len(images):
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Image not found at the specified index"
|
||||
)
|
||||
|
||||
image_entry = images[index]
|
||||
|
||||
# If there's a URL, redirect to it
|
||||
if image_entry.get("url"):
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
return RedirectResponse(url=image_entry["url"])
|
||||
|
||||
# If there's b64_json data, decode and serve it
|
||||
if image_entry.get("b64_json"):
|
||||
image_bytes = base64.b64decode(image_entry["b64_json"])
|
||||
return Response(
|
||||
content=image_bytes,
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Cache-Control": "public, max-age=86400",
|
||||
"Content-Disposition": f'inline; filename="generated-{image_gen_id}-{index}.png"',
|
||||
},
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=404, detail="No displayable image data")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to serve generated image")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to serve image: {e!s}"
|
||||
) from e
|
||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGenerationConfig,
|
||||
NewLLMConfig,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
|
|
@ -387,6 +388,69 @@ async def _get_llm_config_by_id(
|
|||
return None
|
||||
|
||||
|
||||
async def _get_image_gen_config_by_id(
|
||||
session: AsyncSession, config_id: int | None
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get an image generation config by ID as a dictionary.
|
||||
Returns Auto mode for ID 0, global config for negative IDs,
|
||||
DB ImageGenerationConfig for positive IDs, or None.
|
||||
"""
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"description": "Automatically routes requests across available image generation providers",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
}
|
||||
return None
|
||||
|
||||
# Positive ID: query ImageGenerationConfig table
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if db_config:
|
||||
return {
|
||||
"id": db_config.id,
|
||||
"name": db_config.name,
|
||||
"description": db_config.description,
|
||||
"provider": db_config.provider.value if db_config.provider else None,
|
||||
"custom_provider": db_config.custom_provider,
|
||||
"model_name": db_config.model_name,
|
||||
"api_base": db_config.api_base,
|
||||
"api_version": db_config.api_version,
|
||||
"litellm_params": db_config.litellm_params or {},
|
||||
"created_at": db_config.created_at.isoformat()
|
||||
if db_config.created_at
|
||||
else None,
|
||||
"search_space_id": db_config.search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/llm-preferences",
|
||||
response_model=LLMPreferencesRead,
|
||||
|
|
@ -423,12 +487,17 @@ async def get_llm_preferences(
|
|||
document_summary_llm = await _get_llm_config_by_id(
|
||||
session, search_space.document_summary_llm_id
|
||||
)
|
||||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -485,12 +554,17 @@ async def update_llm_preferences(
|
|||
document_summary_llm = await _get_llm_config_by_id(
|
||||
session, search_space.document_summary_llm_id
|
||||
)
|
||||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -21,6 +21,16 @@ from .documents import (
|
|||
PaginatedResponse,
|
||||
)
|
||||
from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest
|
||||
from .image_generation import (
|
||||
GlobalImageGenConfigRead,
|
||||
ImageGenerationConfigCreate,
|
||||
ImageGenerationConfigPublic,
|
||||
ImageGenerationConfigRead,
|
||||
ImageGenerationConfigUpdate,
|
||||
ImageGenerationCreate,
|
||||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||
from .new_chat import (
|
||||
ChatMessage,
|
||||
|
|
@ -105,11 +115,21 @@ __all__ = [
|
|||
"DriveItem",
|
||||
"ExtensionDocumentContent",
|
||||
"ExtensionDocumentMetadata",
|
||||
"GlobalImageGenConfigRead",
|
||||
"GlobalNewLLMConfigRead",
|
||||
"GoogleDriveIndexRequest",
|
||||
"GoogleDriveIndexingOptions",
|
||||
# Base schemas
|
||||
"IDModel",
|
||||
# Image Generation Config schemas
|
||||
"ImageGenerationConfigCreate",
|
||||
"ImageGenerationConfigPublic",
|
||||
"ImageGenerationConfigRead",
|
||||
"ImageGenerationConfigUpdate",
|
||||
# Image Generation schemas
|
||||
"ImageGenerationCreate",
|
||||
"ImageGenerationListRead",
|
||||
"ImageGenerationRead",
|
||||
# RBAC schemas
|
||||
"InviteAcceptRequest",
|
||||
"InviteAcceptResponse",
|
||||
|
|
|
|||
230
surfsense_backend/app/schemas/image_generation.py
Normal file
230
surfsense_backend/app/schemas/image_generation.py
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
"""
|
||||
Pydantic schemas for Image Generation configs and generation requests.
|
||||
|
||||
ImageGenerationConfig: CRUD schemas for user-created image gen model configs.
|
||||
ImageGeneration: Schemas for the actual image generation requests/results.
|
||||
GlobalImageGenConfigRead: Schema for admin-configured YAML configs.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import ImageGenProvider
|
||||
|
||||
# =============================================================================
|
||||
# ImageGenerationConfig CRUD Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ImageGenerationConfigBase(BaseModel):
|
||||
"""Base schema with fields for ImageGenerationConfig."""
|
||||
|
||||
name: str = Field(
|
||||
..., max_length=100, description="User-friendly name for the config"
|
||||
)
|
||||
description: str | None = Field(
|
||||
None, max_length=500, description="Optional description"
|
||||
)
|
||||
provider: ImageGenProvider = Field(
|
||||
...,
|
||||
description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)",
|
||||
)
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name"
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)"
|
||||
)
|
||||
api_key: str = Field(..., description="API key for the provider")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
None,
|
||||
max_length=50,
|
||||
description="Azure-specific API version (e.g., '2024-02-15-preview')",
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfigCreate(ImageGenerationConfigBase):
|
||||
"""Schema for creating a new ImageGenerationConfig."""
|
||||
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the config with"
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfigUpdate(BaseModel):
|
||||
"""Schema for updating an existing ImageGenerationConfig. All fields optional."""
|
||||
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: ImageGenProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
api_key: str | None = None
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ImageGenerationConfigRead(ImageGenerationConfigBase):
|
||||
"""Schema for reading an ImageGenerationConfig (includes id and timestamps)."""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ImageGenerationConfigPublic(BaseModel):
|
||||
"""Public schema that hides the API key (for list views)."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: ImageGenProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ImageGeneration (request/result) Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ImageGenerationCreate(BaseModel):
|
||||
"""Schema for creating an image generation request."""
|
||||
|
||||
prompt: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=4000,
|
||||
description="A text description of the desired image(s)",
|
||||
)
|
||||
model: str | None = Field(
|
||||
None,
|
||||
max_length=200,
|
||||
description="The model to use (e.g., 'dall-e-3', 'gpt-image-1'). Overrides the config model.",
|
||||
)
|
||||
n: int | None = Field(
|
||||
None,
|
||||
ge=1,
|
||||
le=10,
|
||||
description="Number of images to generate (1-10).",
|
||||
)
|
||||
quality: str | None = Field(None, max_length=50)
|
||||
size: str | None = Field(None, max_length=50)
|
||||
style: str | None = Field(None, max_length=50)
|
||||
response_format: str | None = Field(None, max_length=50)
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the generation with"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Image generation config ID. "
|
||||
"0 = Auto mode (router), negative = global YAML config, positive = DB config. "
|
||||
"If not provided, uses the search space's image_generation_config_id preference."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationRead(BaseModel):
|
||||
"""Schema for reading an image generation record."""
|
||||
|
||||
id: int
|
||||
prompt: str
|
||||
model: str | None = None
|
||||
n: int | None = None
|
||||
quality: str | None = None
|
||||
size: str | None = None
|
||||
style: str | None = None
|
||||
response_format: str | None = None
|
||||
image_generation_config_id: int | None = None
|
||||
response_data: dict[str, Any] | None = None
|
||||
error_message: str | None = None
|
||||
search_space_id: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ImageGenerationListRead(BaseModel):
|
||||
"""Lightweight schema for listing image generations (without full response_data)."""
|
||||
|
||||
id: int
|
||||
prompt: str
|
||||
model: str | None = None
|
||||
n: int | None = None
|
||||
quality: str | None = None
|
||||
size: str | None = None
|
||||
search_space_id: int
|
||||
created_at: datetime
|
||||
is_success: bool
|
||||
image_count: int | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@classmethod
|
||||
def from_orm_with_count(cls, obj: Any) -> "ImageGenerationListRead":
|
||||
"""Create ImageGenerationListRead with computed fields."""
|
||||
image_count = None
|
||||
if obj.response_data and isinstance(obj.response_data, dict):
|
||||
data = obj.response_data.get("data")
|
||||
if isinstance(data, list):
|
||||
image_count = len(data)
|
||||
|
||||
return cls(
|
||||
id=obj.id,
|
||||
prompt=obj.prompt,
|
||||
model=obj.model,
|
||||
n=obj.n,
|
||||
quality=obj.quality,
|
||||
size=obj.size,
|
||||
search_space_id=obj.search_space_id,
|
||||
created_at=obj.created_at,
|
||||
is_success=obj.response_data is not None,
|
||||
image_count=image_count,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Image Gen Config (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GlobalImageGenConfigRead(BaseModel):
|
||||
"""
|
||||
Schema for reading global image generation configs from YAML.
|
||||
Global configs have negative IDs. API key is hidden.
|
||||
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
|
||||
"""
|
||||
|
||||
id: int = Field(
|
||||
...,
|
||||
description="Config ID: 0 for Auto mode, negative for global configs",
|
||||
)
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
|
|
@ -176,12 +176,18 @@ class LLMPreferencesRead(BaseModel):
|
|||
document_summary_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for document summarization"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
agent_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for agent LLM"
|
||||
)
|
||||
document_summary_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for document summary LLM"
|
||||
)
|
||||
image_generation_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for image generation"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -195,3 +201,6 @@ class LLMPreferencesUpdate(BaseModel):
|
|||
document_summary_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for document summarization"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
|
|
|
|||
278
surfsense_backend/app/services/image_gen_router_service.py
Normal file
278
surfsense_backend/app/services/image_gen_router_service.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""
|
||||
Image Generation Router Service for Load Balancing
|
||||
|
||||
This module provides a singleton LiteLLM Router for automatic load balancing
|
||||
across multiple image generation deployments. It uses litellm.Router which
|
||||
natively supports aimage_generation() for async image generation.
|
||||
|
||||
The router handles:
|
||||
- Rate limit management with automatic cooldowns
|
||||
- Automatic failover and retries
|
||||
- Usage-based routing to distribute load evenly
|
||||
|
||||
Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI,
|
||||
AWS Bedrock, Recraft, OpenRouter, Xinference, Nscale.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from litellm import Router
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
IMAGE_GEN_AUTO_MODE_ID = 0
|
||||
|
||||
# Provider mapping for LiteLLM model string construction.
|
||||
# Only includes providers that support image generation.
|
||||
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
IMAGE_GEN_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
class ImageGenRouterService:
|
||||
"""
|
||||
Singleton service for managing LiteLLM Router for image generation.
|
||||
|
||||
The router provides automatic load balancing, failover, and rate limit
|
||||
handling across multiple image generation deployments.
|
||||
Uses Router.aimage_generation() for async image generation calls.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_router: Router | None = None
|
||||
_model_list: list[dict] = []
|
||||
_router_settings: dict = {}
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ImageGenRouterService":
|
||||
"""Get the singleton instance of the router service."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(
|
||||
cls,
|
||||
global_configs: list[dict],
|
||||
router_settings: dict | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the router with global image generation configurations.
|
||||
|
||||
Args:
|
||||
global_configs: List of global image gen config dictionaries from YAML
|
||||
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
|
||||
if instance._initialized:
|
||||
logger.debug("Image Generation Router already initialized, skipping")
|
||||
return
|
||||
|
||||
# Build model list from global configs
|
||||
model_list = []
|
||||
for config in global_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
|
||||
if not model_list:
|
||||
logger.warning(
|
||||
"No valid image generation configs found for router initialization"
|
||||
)
|
||||
return
|
||||
|
||||
instance._model_list = model_list
|
||||
instance._router_settings = router_settings or {}
|
||||
|
||||
# Default router settings optimized for rate limit handling
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
"retry_after": 5,
|
||||
}
|
||||
|
||||
# Merge with provided settings
|
||||
final_settings = {**default_settings, **instance._router_settings}
|
||||
|
||||
try:
|
||||
instance._router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=final_settings.get(
|
||||
"routing_strategy", "usage-based-routing"
|
||||
),
|
||||
num_retries=final_settings.get("num_retries", 3),
|
||||
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||
set_verbose=False,
|
||||
)
|
||||
instance._initialized = True
|
||||
logger.info(
|
||||
f"Image Generation Router initialized with {len(model_list)} deployments, "
|
||||
f"strategy: {final_settings.get('routing_strategy')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Image Generation Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||
"""
|
||||
Convert a global image gen config to a router deployment entry.
|
||||
|
||||
Args:
|
||||
config: Global image gen config dictionary
|
||||
|
||||
Returns:
|
||||
Router deployment dictionary or None if invalid
|
||||
"""
|
||||
try:
|
||||
# Skip if essential fields are missing
|
||||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
# Build model string
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Add optional api_base
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
|
||||
# Add api_version (required for Azure)
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
|
||||
# All configs use same alias "auto" for unified routing
|
||||
deployment: dict[str, Any] = {
|
||||
"model_name": "auto",
|
||||
"litellm_params": litellm_params,
|
||||
}
|
||||
|
||||
# Add RPM rate limit from config if available
|
||||
# Note: TPM (tokens per minute) is not applicable for image generation
|
||||
# since image APIs are rate-limited by requests, not tokens.
|
||||
if config.get("rpm"):
|
||||
deployment["rpm"] = config["rpm"]
|
||||
|
||||
return deployment
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert image gen config to deployment: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_router(cls) -> Router | None:
|
||||
"""Get the initialized router instance."""
|
||||
instance = cls.get_instance()
|
||||
return instance._router
|
||||
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
"""Check if the router has been initialized."""
|
||||
instance = cls.get_instance()
|
||||
return instance._initialized and instance._router is not None
|
||||
|
||||
@classmethod
|
||||
def get_model_count(cls) -> int:
|
||||
"""Get the number of models in the router."""
|
||||
instance = cls.get_instance()
|
||||
return len(instance._model_list)
|
||||
|
||||
@classmethod
|
||||
async def aimage_generation(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: str = "auto",
|
||||
n: int | None = None,
|
||||
timeout: int = 600,
|
||||
**kwargs,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Generate images using the router for load balancing.
|
||||
|
||||
Uses Router.aimage_generation() which distributes requests
|
||||
across configured image generation deployments.
|
||||
|
||||
Parameters like size, quality, style, and response_format are intentionally
|
||||
omitted to keep the interface model-agnostic. Providers use their own
|
||||
sensible defaults. If needed, pass them via **kwargs.
|
||||
|
||||
Args:
|
||||
prompt: Text description of the desired image(s)
|
||||
model: Model alias (default "auto" for router routing)
|
||||
n: Number of images to generate
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional provider-specific params (size, quality, etc.)
|
||||
|
||||
Returns:
|
||||
ImageResponse from litellm
|
||||
|
||||
Raises:
|
||||
ValueError: If router is not initialized
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
if not instance._router:
|
||||
raise ValueError(
|
||||
"Image Generation Router not initialized. "
|
||||
"Ensure global_llm_config.yaml has global_image_generation_configs."
|
||||
)
|
||||
|
||||
# Build kwargs for aimage_generation
|
||||
gen_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if n is not None:
|
||||
gen_kwargs["n"] = n
|
||||
gen_kwargs.update(kwargs)
|
||||
|
||||
return await instance._router.aimage_generation(**gen_kwargs)
|
||||
|
||||
|
||||
def is_image_gen_auto_mode(config_id: int | None) -> bool:
|
||||
"""
|
||||
Check if the given config ID represents Image Generation Auto mode.
|
||||
|
||||
Args:
|
||||
config_id: The config ID to check
|
||||
|
||||
Returns:
|
||||
True if this is Auto mode, False otherwise
|
||||
"""
|
||||
return config_id == IMAGE_GEN_AUTO_MODE_ID
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
"""Celery tasks for connector indexing."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
|
@ -11,6 +12,36 @@ from app.config import config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> None:
|
||||
"""
|
||||
Handle greenlet_spawn errors with detailed logging for debugging.
|
||||
|
||||
The 'greenlet_spawn has not been called' error occurs when:
|
||||
1. SQLAlchemy lazy-loads a relationship outside of an async context
|
||||
2. A sync operation is called from an async context (or vice versa)
|
||||
3. Session objects are accessed after the session is closed
|
||||
|
||||
This helper logs detailed context to help identify the root cause.
|
||||
"""
|
||||
error_str = str(e)
|
||||
if "greenlet_spawn has not been called" in error_str:
|
||||
logger.error(
|
||||
f"GREENLET ERROR in {task_name} for connector {connector_id}: {error_str}\n"
|
||||
f"This error typically occurs when SQLAlchemy tries to lazy-load a relationship "
|
||||
f"outside of an async context. Check for:\n"
|
||||
f"1. Accessing relationship attributes (e.g., document.chunks, connector.search_space) "
|
||||
f"without using selectinload() or joinedload()\n"
|
||||
f"2. Accessing model attributes after the session is closed\n"
|
||||
f"3. Passing ORM objects between different async contexts\n"
|
||||
f"Stack trace:\n{traceback.format_exc()}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Error in {task_name} for connector {connector_id}: {error_str}\n"
|
||||
f"Stack trace:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""
|
||||
Create a new async session maker for Celery tasks.
|
||||
|
|
@ -46,6 +77,9 @@ def index_slack_messages_task(
|
|||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_slack_messages", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
|
@ -89,6 +123,9 @@ def index_notion_pages_task(
|
|||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_notion_pages", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
|
@ -347,6 +384,9 @@ def index_google_calendar_events_task(
|
|||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
|
@ -696,6 +736,9 @@ def index_crawled_urls_task(
|
|||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -27,12 +27,12 @@ from app.agents.new_chat.llm_config import (
|
|||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.db import Document, SurfsenseDocsDocument
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.schemas.new_chat import ChatAttachment
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
|
|
@ -1211,9 +1211,10 @@ async def stream_new_chat(
|
|||
|
||||
# Generate LLM title for new chats after first response
|
||||
# Check if this is the first assistant response by counting existing assistant messages
|
||||
from app.db import NewChatMessage, NewChatThread
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db import NewChatMessage, NewChatThread
|
||||
|
||||
assistant_count_result = await session.execute(
|
||||
select(func.count(NewChatMessage.id)).filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
|
|
@ -1231,10 +1232,12 @@ async def stream_new_chat(
|
|||
# Truncate inputs to avoid context length issues
|
||||
truncated_query = user_query[:500]
|
||||
truncated_response = accumulated_text[:1000]
|
||||
title_result = await title_chain.ainvoke({
|
||||
"user_query": truncated_query,
|
||||
"assistant_response": truncated_response,
|
||||
})
|
||||
title_result = await title_chain.ainvoke(
|
||||
{
|
||||
"user_query": truncated_query,
|
||||
"assistant_response": truncated_response,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract and clean the title
|
||||
if title_result and hasattr(title_result, "content"):
|
||||
|
|
@ -1242,7 +1245,7 @@ async def stream_new_chat(
|
|||
# Validate the title (reasonable length)
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
# Remove any quotes or extra formatting
|
||||
generated_title = raw_title.strip('"\'')
|
||||
generated_title = raw_title.strip("\"'")
|
||||
except Exception:
|
||||
generated_title = None
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,34 @@ def safe_set_chunks(document: Document, chunks: list) -> None:
|
|||
set_committed_value(document, "chunks", chunks)
|
||||
|
||||
|
||||
def parse_date_flexible(date_str: str) -> datetime:
|
||||
"""
|
||||
Parse date from multiple common formats.
|
||||
|
||||
Args:
|
||||
date_str: Date string to parse
|
||||
|
||||
Returns:
|
||||
Parsed datetime object
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to parse the date string
|
||||
"""
|
||||
formats = ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
return datetime.strptime(date_str.rstrip("Z"), fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Try ISO format as fallback
|
||||
try:
|
||||
return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||
except ValueError as err:
|
||||
raise ValueError(f"Unable to parse date: {date_str}") from err
|
||||
|
||||
|
||||
async def check_duplicate_document_by_hash(
|
||||
session: AsyncSession, content_hash: str
|
||||
) -> Document | None:
|
||||
|
|
@ -188,6 +216,26 @@ def calculate_date_range(
|
|||
)
|
||||
end_date_str = end_date if end_date else calculated_end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# FIX: Ensure end_date is at least 1 day after start_date to avoid
|
||||
# "start_date must be strictly before end_date" errors when dates are the same
|
||||
# (e.g., when last_indexed_at is today)
|
||||
if start_date_str == end_date_str:
|
||||
logger.info(
|
||||
f"Start date ({start_date_str}) equals end date ({end_date_str}), "
|
||||
"adjusting end date to next day to ensure valid date range"
|
||||
)
|
||||
# Parse end_date and add 1 day
|
||||
try:
|
||||
end_dt = parse_date_flexible(end_date_str)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Could not parse end_date '{end_date_str}', using current date"
|
||||
)
|
||||
end_dt = datetime.now()
|
||||
end_dt = end_dt + timedelta(days=1)
|
||||
end_date_str = end_dt.strftime("%Y-%m-%d")
|
||||
logger.info(f"Adjusted end date to {end_date_str}")
|
||||
|
||||
return start_date_str, end_date_str
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from .base import (
|
|||
get_connector_by_id,
|
||||
get_current_timestamp,
|
||||
logger,
|
||||
parse_date_flexible,
|
||||
safe_set_chunks,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
|
@ -222,6 +223,26 @@ async def index_google_calendar_events(
|
|||
start_date_str = start_date
|
||||
end_date_str = end_date
|
||||
|
||||
# FIX: Ensure end_date is at least 1 day after start_date to avoid
|
||||
# "start_date must be strictly before end_date" errors when dates are the same
|
||||
# (e.g., when last_indexed_at is today)
|
||||
if start_date_str == end_date_str:
|
||||
logger.info(
|
||||
f"Start date ({start_date_str}) equals end date ({end_date_str}), "
|
||||
"adjusting end date to next day to ensure valid date range"
|
||||
)
|
||||
# Parse end_date and add 1 day
|
||||
try:
|
||||
end_dt = parse_date_flexible(end_date_str)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Could not parse end_date '{end_date_str}', using current date"
|
||||
)
|
||||
end_dt = datetime.now()
|
||||
end_dt = end_dt + timedelta(days=1)
|
||||
end_date_str = end_dt.strftime("%Y-%m-%d")
|
||||
logger.info(f"Adjusted end date to {end_date_str}")
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching Google Calendar events from {start_date_str} to {end_date_str}",
|
||||
|
|
|
|||
|
|
@ -202,13 +202,44 @@ async def index_notion_pages(
|
|||
"Recommend reconnecting with OAuth."
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to get Notion pages for connector {connector_id}",
|
||||
str(e),
|
||||
{"error_type": "PageFetchError"},
|
||||
error_str = str(e)
|
||||
# Check if this is an unsupported block type error (transcription, ai_block, etc.)
|
||||
# These are known Notion API limitations and should be logged as warnings, not errors
|
||||
unsupported_block_errors = [
|
||||
"transcription is not supported",
|
||||
"ai_block is not supported",
|
||||
"is not supported via the API",
|
||||
]
|
||||
is_unsupported_block_error = any(
|
||||
err in error_str.lower() for err in unsupported_block_errors
|
||||
)
|
||||
logger.error(f"Error fetching Notion pages: {e!s}", exc_info=True)
|
||||
|
||||
if is_unsupported_block_error:
|
||||
# Log as warning since this is a known Notion API limitation
|
||||
logger.warning(
|
||||
f"Notion API limitation for connector {connector_id}: {error_str}. "
|
||||
"This is a known issue with Notion AI blocks (transcription, ai_block) "
|
||||
"that are not accessible via the Notion API."
|
||||
)
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
"Failed to get Notion pages: Notion API limitation",
|
||||
f"{error_str} - This page contains Notion AI content (transcription/ai_block) that cannot be accessed via the API.",
|
||||
{"error_type": "UnsupportedBlockType", "is_known_limitation": True},
|
||||
)
|
||||
else:
|
||||
# Log as error for other failures
|
||||
logger.error(
|
||||
f"Error fetching Notion pages for connector {connector_id}: {error_str}",
|
||||
exc_info=True,
|
||||
)
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to get Notion pages for connector {connector_id}",
|
||||
str(e),
|
||||
{"error_type": "PageFetchError"},
|
||||
)
|
||||
|
||||
await notion_client.close()
|
||||
return 0, f"Failed to get Notion pages: {e!s}"
|
||||
|
||||
|
|
|
|||
|
|
@ -117,10 +117,15 @@ async def index_crawled_urls(
|
|||
api_key = connector.config.get("FIRECRAWL_API_KEY")
|
||||
|
||||
# Get URLs from connector config
|
||||
urls = parse_webcrawler_urls(connector.config.get("INITIAL_URLS"))
|
||||
raw_initial_urls = connector.config.get("INITIAL_URLS")
|
||||
urls = parse_webcrawler_urls(raw_initial_urls)
|
||||
|
||||
# DEBUG: Log connector config details for troubleshooting empty URL issues
|
||||
logger.info(
|
||||
f"Starting crawled web page indexing for connector {connector_id} with {len(urls)} URLs"
|
||||
f"Starting crawled web page indexing for connector {connector_id} with {len(urls)} URLs. "
|
||||
f"Connector name: {connector.name}, "
|
||||
f"INITIAL_URLS type: {type(raw_initial_urls).__name__}, "
|
||||
f"INITIAL_URLS value: {repr(raw_initial_urls)[:200] if raw_initial_urls else 'None'}"
|
||||
)
|
||||
|
||||
# Initialize webcrawler client
|
||||
|
|
@ -137,11 +142,18 @@ async def index_crawled_urls(
|
|||
|
||||
# Validate URLs
|
||||
if not urls:
|
||||
# DEBUG: Log detailed connector config for troubleshooting
|
||||
logger.error(
|
||||
f"No URLs provided for indexing. Connector ID: {connector_id}, "
|
||||
f"Connector name: {connector.name}, "
|
||||
f"Config keys: {list(connector.config.keys()) if connector.config else 'None'}, "
|
||||
f"INITIAL_URLS raw value: {raw_initial_urls!r}"
|
||||
)
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
"No URLs provided for indexing",
|
||||
"Empty URL list",
|
||||
{"error_type": "ValidationError"},
|
||||
f"Empty URL list. INITIAL_URLS value: {repr(raw_initial_urls)[:100]}",
|
||||
{"error_type": "ValidationError", "connector_name": connector.name},
|
||||
)
|
||||
return 0, "No URLs provided for indexing"
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import logging
|
|||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import aiohttp
|
||||
from fake_useragent import UserAgent
|
||||
from requests import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
|
@ -23,6 +25,7 @@ from app.utils.document_converters import (
|
|||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.proxy_config import get_requests_proxies
|
||||
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -200,9 +203,16 @@ async def add_youtube_video_document(
|
|||
}
|
||||
oembed_url = "https://www.youtube.com/oembed"
|
||||
|
||||
# Build residential proxy URL (if configured)
|
||||
residential_proxies = get_requests_proxies()
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as http_session,
|
||||
http_session.get(oembed_url, params=params) as response,
|
||||
http_session.get(
|
||||
oembed_url,
|
||||
params=params,
|
||||
proxy=residential_proxies["http"] if residential_proxies else None,
|
||||
) as response,
|
||||
):
|
||||
video_data = await response.json()
|
||||
|
||||
|
|
@ -228,7 +238,12 @@ async def add_youtube_video_document(
|
|||
)
|
||||
|
||||
try:
|
||||
ytt_api = YouTubeTranscriptApi()
|
||||
ua = UserAgent()
|
||||
http_client = Session()
|
||||
http_client.headers.update({"User-Agent": ua.random})
|
||||
if residential_proxies:
|
||||
http_client.proxies.update(residential_proxies)
|
||||
ytt_api = YouTubeTranscriptApi(http_client=http_client)
|
||||
captions = ytt_api.fetch(video_id)
|
||||
# Include complete caption information with timestamps
|
||||
transcript_segments = []
|
||||
|
|
|
|||
|
|
@ -219,7 +219,9 @@ class CustomBearerTransport(BearerTransport):
|
|||
|
||||
# Decode JWT to get user_id for refresh token creation
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False})
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
)
|
||||
user_id = uuid.UUID(payload.get("sub"))
|
||||
refresh_token = await create_refresh_token(user_id)
|
||||
except Exception as e:
|
||||
|
|
|
|||
86
surfsense_backend/app/utils/proxy_config.py
Normal file
86
surfsense_backend/app/utils/proxy_config.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
Residential proxy configuration utility.
|
||||
|
||||
Reads proxy credentials from the application Config and provides helper
|
||||
functions that return proxy configs in the format expected by different
|
||||
HTTP libraries (requests, httpx, aiohttp, Playwright).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from app.config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_password_b64() -> str | None:
|
||||
"""
|
||||
Build the base64-encoded password dict required by anonymous-proxies.net.
|
||||
|
||||
Returns ``None`` when the required config values are not set.
|
||||
"""
|
||||
password = Config.RESIDENTIAL_PROXY_PASSWORD
|
||||
if not password:
|
||||
return None
|
||||
|
||||
password_dict = {
|
||||
"p": password,
|
||||
"l": Config.RESIDENTIAL_PROXY_LOCATION,
|
||||
"t": Config.RESIDENTIAL_PROXY_TYPE,
|
||||
}
|
||||
return base64.b64encode(json.dumps(password_dict).encode("utf-8")).decode("utf-8")
|
||||
|
||||
|
||||
def get_residential_proxy_url() -> str | None:
|
||||
"""
|
||||
Return the fully-formed residential proxy URL, or ``None`` when not
|
||||
configured.
|
||||
|
||||
The URL format is::
|
||||
|
||||
http://<username>:<base64_password>@<hostname>/
|
||||
"""
|
||||
username = Config.RESIDENTIAL_PROXY_USERNAME
|
||||
hostname = Config.RESIDENTIAL_PROXY_HOSTNAME
|
||||
password_b64 = _build_password_b64()
|
||||
|
||||
if not all([username, hostname, password_b64]):
|
||||
return None
|
||||
|
||||
return f"http://{username}:{password_b64}@{hostname}/"
|
||||
|
||||
|
||||
def get_requests_proxies() -> dict[str, str] | None:
|
||||
"""
|
||||
Return a ``{"http": …, "https": …}`` dict suitable for
|
||||
``requests.Session.proxies`` and ``aiohttp`` ``proxy=`` kwarg,
|
||||
or ``None`` when not configured.
|
||||
"""
|
||||
proxy_url = get_residential_proxy_url()
|
||||
if proxy_url is None:
|
||||
return None
|
||||
return {"http": proxy_url, "https": proxy_url}
|
||||
|
||||
|
||||
def get_playwright_proxy() -> dict[str, str] | None:
|
||||
"""
|
||||
Return a Playwright-compatible proxy dict::
|
||||
|
||||
{"server": "http://host:port", "username": "…", "password": "…"}
|
||||
|
||||
or ``None`` when not configured.
|
||||
"""
|
||||
username = Config.RESIDENTIAL_PROXY_USERNAME
|
||||
hostname = Config.RESIDENTIAL_PROXY_HOSTNAME
|
||||
password_b64 = _build_password_b64()
|
||||
|
||||
if not all([username, hostname, password_b64]):
|
||||
return None
|
||||
|
||||
return {
|
||||
"server": f"http://{hostname}",
|
||||
"username": username,
|
||||
"password": password_b64,
|
||||
}
|
||||
44
surfsense_backend/app/utils/signed_image_urls.py
Normal file
44
surfsense_backend/app/utils/signed_image_urls.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Access token utilities for generated images.
|
||||
|
||||
Provides token generation and verification so that generated images can be
|
||||
served via <img> tags (which cannot pass auth headers) while still
|
||||
restricting access to authorised users.
|
||||
|
||||
Each image generation record stores its own random access token. The token
|
||||
is verified by comparing the incoming query-parameter value against the
|
||||
stored value in the database. This approach:
|
||||
|
||||
* Survives SECRET_KEY rotation — tokens are random, not derived from a key.
|
||||
* Allows explicit revocation — just clear the column.
|
||||
* Is immune to timing attacks — uses ``hmac.compare_digest``.
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import secrets
|
||||
|
||||
|
||||
def generate_image_token() -> str:
|
||||
"""
|
||||
Generate a cryptographically random access token for an image.
|
||||
|
||||
Returns:
|
||||
A 64-character URL-safe hex string.
|
||||
"""
|
||||
return secrets.token_hex(32)
|
||||
|
||||
|
||||
def verify_image_token(stored_token: str | None, provided_token: str) -> bool:
|
||||
"""
|
||||
Constant-time comparison of a stored token against a user-provided one.
|
||||
|
||||
Args:
|
||||
stored_token: The token persisted on the ImageGeneration record.
|
||||
provided_token: The token from the URL query parameter.
|
||||
|
||||
Returns:
|
||||
True if the tokens match, False otherwise.
|
||||
"""
|
||||
if not stored_token or not provided_token:
|
||||
return False
|
||||
return hmac.compare_digest(stored_token, provided_token)
|
||||
Loading…
Add table
Add a link
Reference in a new issue