SurfSense/surfsense_backend/app/routes/image_generation_routes.py

545 lines
20 KiB
Python
Raw Normal View History

2026-02-05 16:43:48 -08:00
"""
Image Generation routes:
- Image generation execution (calls litellm.aimage_generation())
- CRUD for ImageGeneration records (results)
- Image serving endpoint (serves b64_json images from DB, protected by signed tokens)
"""
import base64
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from litellm import aimage_generation
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
2026-02-05 16:43:48 -08:00
from app.config import config
from app.db import (
ImageGeneration,
Model,
2026-02-05 16:43:48 -08:00
Permission,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.schemas import (
ImageGenerationCreate,
ImageGenerationListRead,
ImageGenerationRead,
)
from app.services.auto_model_pin_service import (
auto_model_candidates,
choose_auto_model_candidate,
)
from app.services.billable_calls import (
DEFAULT_IMAGE_RESERVE_MICROS,
QuotaInsufficientError,
billable_call,
)
2026-02-05 16:43:48 -08:00
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
is_image_gen_auto_mode,
)
from app.services.model_capabilities import has_capability
from app.services.model_resolver import to_litellm
2026-02-05 16:43:48 -08:00
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
router = APIRouter()
logger = logging.getLogger(__name__)
2026-06-13 21:59:35 +05:30
def _get_global_model(model_id: int) -> dict | None:
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
def _get_global_connection(connection_id: int) -> dict | None:
return next(
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
None,
)
2026-02-05 16:43:48 -08:00
async def _resolve_billing_for_image_gen(
session: AsyncSession,
config_id: int | None,
search_space: SearchSpace,
) -> tuple[str, str, int]:
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
The resolution mirrors ``_execute_image_generation``'s lookup tree but
only extracts the fields needed for billing we do this *before*
``billable_call`` so the reservation is correctly sized for the
config that will actually run, and so we don't open an
``ImageGeneration`` row for a request that's about to 402.
User-owned (positive ID) BYOK models are always free they cost
the user nothing on our side. Auto mode resolves to one concrete
global or BYOK model before billing is calculated.
"""
resolved_id = config_id
if resolved_id is None:
resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
if is_image_gen_auto_mode(resolved_id):
candidates = await auto_model_candidates(
session,
search_space_id=search_space.id,
user_id=search_space.user_id,
capability="image_gen",
)
if not candidates:
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
selected = choose_auto_model_candidate(candidates, search_space.id)
resolved_id = int(selected["id"])
if resolved_id < 0:
global_model = _get_global_model(resolved_id) or {}
global_connection = _get_global_connection(global_model.get("connection_id", 0))
billing_tier = str(global_model.get("billing_tier", "free")).lower()
if global_connection and global_model.get("model_id"):
base_model, _ = to_litellm(global_connection, global_model["model_id"])
else:
base_model = "global_image_model"
catalog = global_model.get("catalog") or {}
reserve_micros = int(
catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
)
return (billing_tier, base_model, reserve_micros)
# Positive ID = user-owned BYOK image-gen model — always free.
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
2026-02-05 16:43:48 -08:00
async def _execute_image_generation(
session: AsyncSession,
image_gen: ImageGeneration,
search_space: SearchSpace,
) -> None:
"""
Call litellm.aimage_generation() with the appropriate config.
Resolution order:
1. Explicit image_gen_model_id on the request
2. Search space's image_gen_model_id preference
2026-02-05 16:43:48 -08:00
3. Falls back to Auto mode if available
"""
config_id = image_gen.image_gen_model_id
2026-02-05 16:43:48 -08:00
if config_id is None:
config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
image_gen.image_gen_model_id = config_id
2026-02-05 16:43:48 -08:00
# Build kwargs
gen_kwargs = {}
if image_gen.n is not None:
gen_kwargs["n"] = image_gen.n
if image_gen.quality is not None:
gen_kwargs["quality"] = image_gen.quality
if image_gen.size is not None:
gen_kwargs["size"] = image_gen.size
if image_gen.style is not None:
gen_kwargs["style"] = image_gen.style
if image_gen.response_format is not None:
gen_kwargs["response_format"] = image_gen.response_format
if is_image_gen_auto_mode(config_id):
candidates = await auto_model_candidates(
session,
search_space_id=search_space.id,
user_id=search_space.user_id,
capability="image_gen",
2026-02-05 16:43:48 -08:00
)
if not candidates:
raise ValueError("No image-generation models are available for Auto mode")
config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"])
image_gen.image_gen_model_id = config_id
if config_id < 0:
global_model = _get_global_model(config_id)
if not global_model or not has_capability(global_model, "image_gen"):
raise ValueError(f"Global image generation model {config_id} not found")
global_connection = _get_global_connection(global_model["connection_id"])
if not global_connection:
raise ValueError(f"Global connection for image model {config_id} not found")
2026-02-05 16:43:48 -08:00
model_string, resolved_kwargs = to_litellm(
global_connection,
global_model["model_id"],
)
gen_kwargs.update(resolved_kwargs)
2026-02-05 16:43:48 -08:00
# User model override
if image_gen.model:
model_string = image_gen.model
response = await aimage_generation(
prompt=image_gen.prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = Model + Connection
2026-02-05 16:43:48 -08:00
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.filter(Model.id == config_id, Model.enabled.is_(True))
2026-02-05 16:43:48 -08:00
)
db_model = result.scalars().first()
if not db_model or not db_model.connection or not db_model.connection.enabled:
raise ValueError(f"Image generation model {config_id} not found")
conn = db_model.connection
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
raise ValueError(f"Image generation model {config_id} not found")
if conn.user_id is not None and conn.user_id != search_space.user_id:
raise ValueError(f"Image generation model {config_id} not found")
if not has_capability(db_model, "image_gen"):
raise ValueError(f"Model {config_id} is not image-generation capable")
model_string, resolved_kwargs = to_litellm(
db_model.connection,
db_model.model_id,
)
gen_kwargs.update(resolved_kwargs)
2026-02-05 16:43:48 -08:00
# User model override
if image_gen.model:
model_string = image_gen.model
response = await aimage_generation(
prompt=image_gen.prompt, model=model_string, **gen_kwargs
)
# Store response
image_gen.response_data = (
response.model_dump() if hasattr(response, "model_dump") else dict(response)
)
if not image_gen.model and hasattr(response, "_hidden_params"):
hidden = response._hidden_params
if isinstance(hidden, dict) and hidden.get("model"):
image_gen.model = hidden["model"]
# =============================================================================
# Image Generation Execution + Results CRUD
# =============================================================================
@router.post("/image-generations", response_model=ImageGenerationRead)
async def create_image_generation(
data: ImageGenerationCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create and execute an image generation request.
Premium configs are gated by the user's shared premium credit pool.
The flow is:
1. Permission check + load the search space (cheap, no provider call).
2. Resolve which config will run so we know its billing tier and the
worst-case reservation size *before* opening any DB rows.
3. Wrap the entire ImageGeneration row insert + provider call in
``billable_call``. If quota is denied, ``billable_call`` raises
``QuotaInsufficientError`` *before* we flush a row, which we
translate to HTTP 402 (no orphaned rows on the user's account,
no inserted error rows for "you ran out of credit").
4. On success, the actual ``response_cost`` flows through the
LiteLLM callback into the accumulator, and ``billable_call``
finalizes the debit at exit. Inner ``try/except`` still catches
provider errors and stores them on ``error_message`` (HTTP 200
with ``error_message`` set is preserved for failed-but-not-quota
scenarios clients already know how to surface those).
"""
2026-02-05 16:43:48 -08:00
try:
await check_permission(
session,
user,
data.search_space_id,
2026-02-05 16:43:48 -08:00
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to create image generations in this search space",
)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == data.search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
session, data.image_gen_model_id, search_space
2026-02-05 16:43:48 -08:00
)
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
# propagates to the outer ``except QuotaInsufficientError`` handler
# below as HTTP 402 — it is intentionally NOT swallowed into
# ``error_message`` because that would (1) imply a successful row
# exists when none does, and (2) return HTTP 200 to a client
# whose request was actively *denied* (issue K).
async with billable_call(
user_id=search_space.user_id,
search_space_id=data.search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=reserve_micros,
usage_type="image_generation",
call_details={"model": base_model, "prompt": data.prompt[:100]},
):
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_gen_model_id=data.image_gen_model_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
)
session.add(db_image_gen)
await session.flush()
try:
await _execute_image_generation(session, db_image_gen, search_space)
except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
2026-02-05 16:43:48 -08:00
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
2026-02-05 16:43:48 -08:00
except HTTPException:
raise
except QuotaInsufficientError as exc:
# The user's premium credit pool is empty. No DB row is created
# because ``billable_call`` denies before yielding (issue K).
await session.rollback()
raise HTTPException(
status_code=402,
detail={
"error_code": "premium_quota_exhausted",
"usage_type": exc.usage_type,
"balance_micros": exc.balance_micros,
"remaining_micros": exc.remaining_micros,
"message": (
"Out of credits for image generation. "
"Purchase additional credits or switch to a free model."
),
},
) from exc
2026-02-05 16:43:48 -08:00
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error during image generation"
) from None
2026-02-05 16:43:48 -08:00
except Exception as e:
await session.rollback()
logger.exception("Failed to create image generation")
raise HTTPException(
status_code=500, detail=f"Image generation failed: {e!s}"
) from e
2026-02-05 16:43:48 -08:00
@router.get("/image-generations", response_model=list[ImageGenerationListRead])
async def list_image_generations(
search_space_id: int | None = None,
skip: int = 0,
limit: int = 50,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""List image generations."""
if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
if limit > 100:
limit = 100
try:
if search_space_id is not None:
await check_permission(
session,
user,
search_space_id,
2026-02-05 16:43:48 -08:00
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space",
)
result = await session.execute(
select(ImageGeneration)
.filter(ImageGeneration.search_space_id == search_space_id)
.order_by(ImageGeneration.created_at.desc())
.offset(skip)
.limit(limit)
2026-02-05 16:43:48 -08:00
)
else:
result = await session.execute(
select(ImageGeneration)
.join(SearchSpace)
.join(SearchSpaceMembership)
.filter(SearchSpaceMembership.user_id == user.id)
.order_by(ImageGeneration.created_at.desc())
.offset(skip)
.limit(limit)
2026-02-05 16:43:48 -08:00
)
return [
ImageGenerationListRead.from_orm_with_count(img)
for img in result.scalars().all()
]
2026-02-05 16:43:48 -08:00
except HTTPException:
raise
except SQLAlchemyError:
raise HTTPException(
status_code=500, detail="Database error fetching image generations"
) from None
2026-02-05 16:43:48 -08:00
@router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead)
async def get_image_generation(
image_gen_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get a specific image generation by ID."""
try:
result = await session.execute(
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
)
image_gen = result.scalars().first()
if not image_gen:
raise HTTPException(status_code=404, detail="Image generation not found")
await check_permission(
session,
user,
image_gen.search_space_id,
2026-02-05 16:43:48 -08:00
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space",
)
return image_gen
except HTTPException:
raise
except SQLAlchemyError:
raise HTTPException(
status_code=500, detail="Database error fetching image generation"
) from None
2026-02-05 16:43:48 -08:00
@router.delete("/image-generations/{image_gen_id}", response_model=dict)
async def delete_image_generation(
image_gen_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Delete an image generation record."""
try:
result = await session.execute(
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
)
db_image_gen = result.scalars().first()
if not db_image_gen:
raise HTTPException(status_code=404, detail="Image generation not found")
await check_permission(
session,
user,
db_image_gen.search_space_id,
2026-02-05 16:43:48 -08:00
Permission.IMAGE_GENERATIONS_DELETE.value,
"You don't have permission to delete image generations in this search space",
)
await session.delete(db_image_gen)
await session.commit()
return {"message": "Image generation deleted successfully"}
except HTTPException:
raise
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error deleting image generation"
) from None
2026-02-05 16:43:48 -08:00
# =============================================================================
# Image Serving (serves generated images from DB, protected by signed tokens)
# =============================================================================
2026-02-05 16:43:48 -08:00
@router.get("/image-generations/{image_gen_id}/image")
async def serve_generated_image(
image_gen_id: int,
token: str = Query(..., description="Signed access token"),
index: int = 0,
session: AsyncSession = Depends(get_async_session),
):
"""
Serve a generated image by ID, protected by a signed token.
The token is generated when the image URL is created by the generate_image
tool and encodes the image_gen_id, search_space_id, and an expiry timestamp.
This ensures only users with access to the search space can view images,
without requiring auth headers (which <img> tags cannot pass).
Args:
image_gen_id: The image generation record ID
token: HMAC-signed access token (included as query parameter)
index: Which image to serve if multiple were generated (default: 0)
"""
try:
result = await session.execute(
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
)
image_gen = result.scalars().first()
if not image_gen:
raise HTTPException(status_code=404, detail="Image generation not found")
# Verify the access token against the one stored on the record
if not verify_image_token(image_gen.access_token, token):
raise HTTPException(status_code=403, detail="Invalid image access token")
if not image_gen.response_data:
raise HTTPException(status_code=404, detail="No image data available")
images = image_gen.response_data.get("data", [])
if not images or index >= len(images):
raise HTTPException(
status_code=404, detail="Image not found at the specified index"
)
2026-02-05 16:43:48 -08:00
image_entry = images[index]
# If there's a URL, redirect to it
if image_entry.get("url"):
from fastapi.responses import RedirectResponse
2026-02-05 16:43:48 -08:00
return RedirectResponse(url=image_entry["url"])
# If there's b64_json data, decode and serve it
if image_entry.get("b64_json"):
image_bytes = base64.b64decode(image_entry["b64_json"])
return Response(
content=image_bytes,
media_type="image/png",
headers={
"Cache-Control": "public, max-age=86400",
"Content-Disposition": f'inline; filename="generated-{image_gen_id}-{index}.png"',
},
)
raise HTTPException(status_code=404, detail="No displayable image data")
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to serve generated image")
raise HTTPException(
status_code=500, detail=f"Failed to serve image: {e!s}"
) from e