SurfSense/surfsense_backend/app/routes/image_generation_routes.py

810 lines
30 KiB
Python

"""
Image Generation routes:
- CRUD for ImageGenerationConfig (user-created image model configs)
- Global image gen configs endpoint (from YAML)
- Image generation execution (calls litellm.aimage_generation())
- CRUD for ImageGeneration records (results)
- Image serving endpoint (serves b64_json images from DB, protected by signed tokens)
"""
import base64
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from litellm import aimage_generation
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import config
from app.db import (
ImageGeneration,
ImageGenerationConfig,
Model,
Permission,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.schemas import (
GlobalImageGenConfigRead,
ImageGenerationConfigCreate,
ImageGenerationConfigRead,
ImageGenerationConfigUpdate,
ImageGenerationCreate,
ImageGenerationListRead,
ImageGenerationRead,
)
from app.services.billable_calls import (
DEFAULT_IMAGE_RESERVE_MICROS,
QuotaInsufficientError,
billable_call,
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
is_image_gen_auto_mode,
)
from app.services.auto_model_pin_service import (
auto_model_candidates,
choose_auto_model_candidate,
)
from app.services.model_resolver import to_litellm
from app.services.model_capabilities import has_capability
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
router = APIRouter()
logger = logging.getLogger(__name__)
def _get_global_model(model_id: int) -> dict | None:
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
def _get_global_connection(connection_id: int) -> dict | None:
return next(
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
None,
)
async def _resolve_billing_for_image_gen(
session: AsyncSession,
config_id: int | None,
search_space: SearchSpace,
) -> tuple[str, str, int]:
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
The resolution mirrors ``_execute_image_generation``'s lookup tree but
only extracts the fields needed for billing — we do this *before*
``billable_call`` so the reservation is correctly sized for the
config that will actually run, and so we don't open an
``ImageGeneration`` row for a request that's about to 402.
User-owned (positive ID) BYOK models are always free — they cost
the user nothing on our side. Auto mode resolves to one concrete
global or BYOK model before billing is calculated.
"""
resolved_id = config_id
if resolved_id is None:
resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
if is_image_gen_auto_mode(resolved_id):
candidates = await auto_model_candidates(
session,
search_space_id=search_space.id,
user_id=search_space.user_id,
capability="image_gen",
)
if not candidates:
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
selected = choose_auto_model_candidate(candidates, search_space.id)
resolved_id = int(selected["id"])
if resolved_id < 0:
global_model = _get_global_model(resolved_id) or {}
global_connection = _get_global_connection(global_model.get("connection_id", 0))
billing_tier = str(global_model.get("billing_tier", "free")).lower()
if global_connection and global_model.get("model_id"):
base_model, _ = to_litellm(global_connection, global_model["model_id"])
else:
base_model = "global_image_model"
catalog = global_model.get("catalog") or {}
reserve_micros = int(
catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
)
return (billing_tier, base_model, reserve_micros)
# Positive ID = user-owned BYOK image-gen model — always free.
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
async def _execute_image_generation(
session: AsyncSession,
image_gen: ImageGeneration,
search_space: SearchSpace,
) -> None:
"""
Call litellm.aimage_generation() with the appropriate config.
Resolution order:
1. Explicit image_generation_config_id on the request
2. Search space's image_generation_config_id preference
3. Falls back to Auto mode if available
"""
config_id = image_gen.image_generation_config_id
if config_id is None:
config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
image_gen.image_generation_config_id = config_id
# Build kwargs
gen_kwargs = {}
if image_gen.n is not None:
gen_kwargs["n"] = image_gen.n
if image_gen.quality is not None:
gen_kwargs["quality"] = image_gen.quality
if image_gen.size is not None:
gen_kwargs["size"] = image_gen.size
if image_gen.style is not None:
gen_kwargs["style"] = image_gen.style
if image_gen.response_format is not None:
gen_kwargs["response_format"] = image_gen.response_format
if is_image_gen_auto_mode(config_id):
candidates = await auto_model_candidates(
session,
search_space_id=search_space.id,
user_id=search_space.user_id,
capability="image_gen",
)
if not candidates:
raise ValueError("No image-generation models are available for Auto mode")
config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"])
image_gen.image_generation_config_id = config_id
if config_id < 0:
global_model = _get_global_model(config_id)
if not global_model or not has_capability(global_model, "image_gen"):
raise ValueError(f"Global image generation model {config_id} not found")
global_connection = _get_global_connection(global_model["connection_id"])
if not global_connection:
raise ValueError(f"Global connection for image model {config_id} not found")
model_string, resolved_kwargs = to_litellm(
global_connection,
global_model["model_id"],
)
gen_kwargs.update(resolved_kwargs)
# User model override
if image_gen.model:
model_string = image_gen.model
response = await aimage_generation(
prompt=image_gen.prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = Model + Connection
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.filter(Model.id == config_id, Model.enabled.is_(True))
)
db_model = result.scalars().first()
if not db_model or not db_model.connection or not db_model.connection.enabled:
raise ValueError(f"Image generation model {config_id} not found")
conn = db_model.connection
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
raise ValueError(f"Image generation model {config_id} not found")
if conn.user_id is not None and conn.user_id != search_space.user_id:
raise ValueError(f"Image generation model {config_id} not found")
if not has_capability(db_model, "image_gen"):
raise ValueError(f"Model {config_id} is not image-generation capable")
model_string, resolved_kwargs = to_litellm(
db_model.connection,
db_model.model_id,
)
gen_kwargs.update(resolved_kwargs)
# User model override
if image_gen.model:
model_string = image_gen.model
response = await aimage_generation(
prompt=image_gen.prompt, model=model_string, **gen_kwargs
)
# Store response
image_gen.response_data = (
response.model_dump() if hasattr(response, "model_dump") else dict(response)
)
if not image_gen.model and hasattr(response, "_hidden_params"):
hidden = response._hidden_params
if isinstance(hidden, dict) and hidden.get("model"):
image_gen.model = hidden["model"]
# =============================================================================
# Global Image Generation Configs (from YAML)
# =============================================================================
@router.get(
"/global-image-generation-configs",
response_model=list[GlobalImageGenConfigRead],
)
async def get_global_image_gen_configs(
user: User = Depends(current_active_user),
):
"""Get all global image generation configs. API keys are hidden."""
try:
global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS
safe_configs = []
if global_configs and len(global_configs) > 0:
safe_configs.append(
{
"id": 0,
"name": "Auto (Fastest)",
"description": "Automatically routes across available image generation providers.",
"provider": "AUTO",
"custom_provider": None,
"model_name": "auto",
"api_base": None,
"api_version": None,
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
# Auto mode currently treated as free until per-deployment
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
"billing_tier": "free",
"is_premium": False,
}
)
for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)
return safe_configs
except Exception as e:
logger.exception("Failed to fetch global image generation configs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configs: {e!s}"
) from e
# =============================================================================
# ImageGenerationConfig CRUD
# =============================================================================
@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead)
async def create_image_gen_config(
config_data: ImageGenerationConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create a new image generation config for a search space."""
try:
await check_permission(
session,
user,
config_data.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to create image generation configs in this search space",
)
db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id)
session.add(db_config)
await session.commit()
await session.refresh(db_config)
return db_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to create ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to create config: {e!s}"
) from e
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
async def list_image_gen_configs(
search_space_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""List image generation configs for a search space."""
try:
await check_permission(
session,
user,
search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to view image generation configs in this search space",
)
result = await session.execute(
select(ImageGenerationConfig)
.filter(ImageGenerationConfig.search_space_id == search_space_id)
.order_by(ImageGenerationConfig.created_at.desc())
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to list ImageGenerationConfigs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configs: {e!s}"
) from e
@router.get(
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
)
async def get_image_gen_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get a specific image generation config by ID."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to view image generation configs in this search space",
)
return db_config
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to get ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to fetch config: {e!s}"
) from e
@router.put(
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
)
async def update_image_gen_config(
config_id: int,
update_data: ImageGenerationConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Update an existing image generation config."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to update image generation configs in this search space",
)
for key, value in update_data.model_dump(exclude_unset=True).items():
setattr(db_config, key, value)
await session.commit()
await session.refresh(db_config)
return db_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to update ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to update config: {e!s}"
) from e
@router.delete("/image-generation-configs/{config_id}", response_model=dict)
async def delete_image_gen_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Delete an image generation config."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_DELETE.value,
"You don't have permission to delete image generation configs in this search space",
)
await session.delete(db_config)
await session.commit()
return {
"message": "Image generation config deleted successfully",
"id": config_id,
}
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to delete ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to delete config: {e!s}"
) from e
# =============================================================================
# Image Generation Execution + Results CRUD
# =============================================================================
@router.post("/image-generations", response_model=ImageGenerationRead)
async def create_image_generation(
data: ImageGenerationCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create and execute an image generation request.
Premium configs are gated by the user's shared premium credit pool.
The flow is:
1. Permission check + load the search space (cheap, no provider call).
2. Resolve which config will run so we know its billing tier and the
worst-case reservation size *before* opening any DB rows.
3. Wrap the entire ImageGeneration row insert + provider call in
``billable_call``. If quota is denied, ``billable_call`` raises
``QuotaInsufficientError`` *before* we flush a row, which we
translate to HTTP 402 (no orphaned rows on the user's account,
no inserted error rows for "you ran out of credit").
4. On success, the actual ``response_cost`` flows through the
LiteLLM callback into the accumulator, and ``billable_call``
finalizes the debit at exit. Inner ``try/except`` still catches
provider errors and stores them on ``error_message`` (HTTP 200
with ``error_message`` set is preserved for failed-but-not-quota
scenarios — clients already know how to surface those).
"""
try:
await check_permission(
session,
user,
data.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to create image generations in this search space",
)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == data.search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
session, data.image_generation_config_id, search_space
)
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
# propagates to the outer ``except QuotaInsufficientError`` handler
# below as HTTP 402 — it is intentionally NOT swallowed into
# ``error_message`` because that would (1) imply a successful row
# exists when none does, and (2) return HTTP 200 to a client
# whose request was actively *denied* (issue K).
async with billable_call(
user_id=search_space.user_id,
search_space_id=data.search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=reserve_micros,
usage_type="image_generation",
call_details={"model": base_model, "prompt": data.prompt[:100]},
):
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
)
session.add(db_image_gen)
await session.flush()
try:
await _execute_image_generation(session, db_image_gen, search_space)
except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
except HTTPException:
raise
except QuotaInsufficientError as exc:
# The user's premium credit pool is empty. No DB row is created
# because ``billable_call`` denies before yielding (issue K).
await session.rollback()
raise HTTPException(
status_code=402,
detail={
"error_code": "premium_quota_exhausted",
"usage_type": exc.usage_type,
"balance_micros": exc.balance_micros,
"remaining_micros": exc.remaining_micros,
"message": (
"Out of credits for image generation. "
"Purchase additional credits or switch to a free model."
),
},
) from exc
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error during image generation"
) from None
except Exception as e:
await session.rollback()
logger.exception("Failed to create image generation")
raise HTTPException(
status_code=500, detail=f"Image generation failed: {e!s}"
) from e
@router.get("/image-generations", response_model=list[ImageGenerationListRead])
async def list_image_generations(
search_space_id: int | None = None,
skip: int = 0,
limit: int = 50,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""List image generations."""
if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
if limit > 100:
limit = 100
try:
if search_space_id is not None:
await check_permission(
session,
user,
search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space",
)
result = await session.execute(
select(ImageGeneration)
.filter(ImageGeneration.search_space_id == search_space_id)
.order_by(ImageGeneration.created_at.desc())
.offset(skip)
.limit(limit)
)
else:
result = await session.execute(
select(ImageGeneration)
.join(SearchSpace)
.join(SearchSpaceMembership)
.filter(SearchSpaceMembership.user_id == user.id)
.order_by(ImageGeneration.created_at.desc())
.offset(skip)
.limit(limit)
)
return [
ImageGenerationListRead.from_orm_with_count(img)
for img in result.scalars().all()
]
except HTTPException:
raise
except SQLAlchemyError:
raise HTTPException(
status_code=500, detail="Database error fetching image generations"
) from None
@router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead)
async def get_image_generation(
image_gen_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get a specific image generation by ID."""
try:
result = await session.execute(
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
)
image_gen = result.scalars().first()
if not image_gen:
raise HTTPException(status_code=404, detail="Image generation not found")
await check_permission(
session,
user,
image_gen.search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space",
)
return image_gen
except HTTPException:
raise
except SQLAlchemyError:
raise HTTPException(
status_code=500, detail="Database error fetching image generation"
) from None
@router.delete("/image-generations/{image_gen_id}", response_model=dict)
async def delete_image_generation(
image_gen_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Delete an image generation record."""
try:
result = await session.execute(
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
)
db_image_gen = result.scalars().first()
if not db_image_gen:
raise HTTPException(status_code=404, detail="Image generation not found")
await check_permission(
session,
user,
db_image_gen.search_space_id,
Permission.IMAGE_GENERATIONS_DELETE.value,
"You don't have permission to delete image generations in this search space",
)
await session.delete(db_image_gen)
await session.commit()
return {"message": "Image generation deleted successfully"}
except HTTPException:
raise
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error deleting image generation"
) from None
# =============================================================================
# Image Serving (serves generated images from DB, protected by signed tokens)
# =============================================================================
@router.get("/image-generations/{image_gen_id}/image")
async def serve_generated_image(
image_gen_id: int,
token: str = Query(..., description="Signed access token"),
index: int = 0,
session: AsyncSession = Depends(get_async_session),
):
"""
Serve a generated image by ID, protected by a signed token.
The token is generated when the image URL is created by the generate_image
tool and encodes the image_gen_id, search_space_id, and an expiry timestamp.
This ensures only users with access to the search space can view images,
without requiring auth headers (which <img> tags cannot pass).
Args:
image_gen_id: The image generation record ID
token: HMAC-signed access token (included as query parameter)
index: Which image to serve if multiple were generated (default: 0)
"""
try:
result = await session.execute(
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
)
image_gen = result.scalars().first()
if not image_gen:
raise HTTPException(status_code=404, detail="Image generation not found")
# Verify the access token against the one stored on the record
if not verify_image_token(image_gen.access_token, token):
raise HTTPException(status_code=403, detail="Invalid image access token")
if not image_gen.response_data:
raise HTTPException(status_code=404, detail="No image data available")
images = image_gen.response_data.get("data", [])
if not images or index >= len(images):
raise HTTPException(
status_code=404, detail="Image not found at the specified index"
)
image_entry = images[index]
# If there's a URL, redirect to it
if image_entry.get("url"):
from fastapi.responses import RedirectResponse
return RedirectResponse(url=image_entry["url"])
# If there's b64_json data, decode and serve it
if image_entry.get("b64_json"):
image_bytes = base64.b64decode(image_entry["b64_json"])
return Response(
content=image_bytes,
media_type="image/png",
headers={
"Cache-Control": "public, max-age=86400",
"Content-Disposition": f'inline; filename="generated-{image_gen_id}-{index}.png"',
},
)
raise HTTPException(status_code=404, detail="No displayable image data")
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to serve generated image")
raise HTTPException(
status_code=500, detail=f"Failed to serve image: {e!s}"
) from e