""" Image Generation routes: - Image generation execution (calls litellm.aimage_generation()) - CRUD for ImageGeneration records (results) - Image serving endpoint (serves b64_json images from DB, protected by signed tokens) """ import base64 import logging from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import Response from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.auth.context import AuthContext from app.config import config from app.db import ( ImageGeneration, Model, Permission, SearchSpace, SearchSpaceMembership, User, get_async_session, ) from app.schemas import ( ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, ) from app.services.auto_model_pin_service import ( auto_model_candidates, choose_auto_model_candidate, ) from app.services.billable_calls import ( DEFAULT_IMAGE_RESERVE_MICROS, QuotaInsufficientError, billable_call, ) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, is_image_gen_auto_mode, ) from app.services.model_capabilities import has_capability from app.services.model_resolver import to_litellm from app.users import get_auth_context from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token router = APIRouter() logger = logging.getLogger(__name__) def _get_global_model(model_id: int) -> dict | None: return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) def _get_global_connection(connection_id: int) -> dict | None: return next( (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), None, ) async def _resolve_billing_for_image_gen( session: AsyncSession, config_id: int | None, search_space: SearchSpace, ) -> tuple[str, str, int]: """Resolve ``(billing_tier, base_model, reserve_micros)`` for a request. The resolution mirrors ``_execute_image_generation``'s lookup tree but only extracts the fields needed for billing — we do this *before* ``billable_call`` so the reservation is correctly sized for the config that will actually run, and so we don't open an ``ImageGeneration`` row for a request that's about to 402. User-owned (positive ID) BYOK models are always free — they cost the user nothing on our side. Auto mode resolves to one concrete global or BYOK model before billing is calculated. """ resolved_id = config_id if resolved_id is None: resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID if is_image_gen_auto_mode(resolved_id): candidates = await auto_model_candidates( session, search_space_id=search_space.id, user_id=search_space.user_id, capability="image_gen", ) if not candidates: return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) selected = choose_auto_model_candidate(candidates, search_space.id) resolved_id = int(selected["id"]) if resolved_id < 0: global_model = _get_global_model(resolved_id) or {} global_connection = _get_global_connection(global_model.get("connection_id", 0)) billing_tier = str(global_model.get("billing_tier", "free")).lower() if global_connection and global_model.get("model_id"): base_model, _ = to_litellm(global_connection, global_model["model_id"]) else: base_model = "global_image_model" catalog = global_model.get("catalog") or {} reserve_micros = int( catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS ) return (billing_tier, base_model, reserve_micros) # Positive ID = user-owned BYOK image-gen model — always free. return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS) async def _execute_image_generation( session: AsyncSession, image_gen: ImageGeneration, search_space: SearchSpace, ) -> None: """ Call litellm.aimage_generation() with the appropriate config. Resolution order: 1. Explicit image_gen_model_id on the request 2. Search space's image_gen_model_id preference 3. Falls back to Auto mode if available """ config_id = image_gen.image_gen_model_id if config_id is None: config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID image_gen.image_gen_model_id = config_id # Build kwargs gen_kwargs = {} if image_gen.n is not None: gen_kwargs["n"] = image_gen.n if image_gen.quality is not None: gen_kwargs["quality"] = image_gen.quality if image_gen.size is not None: gen_kwargs["size"] = image_gen.size if image_gen.style is not None: gen_kwargs["style"] = image_gen.style if image_gen.response_format is not None: gen_kwargs["response_format"] = image_gen.response_format if is_image_gen_auto_mode(config_id): candidates = await auto_model_candidates( session, search_space_id=search_space.id, user_id=search_space.user_id, capability="image_gen", ) if not candidates: raise ValueError("No image-generation models are available for Auto mode") config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"]) image_gen.image_gen_model_id = config_id if config_id < 0: global_model = _get_global_model(config_id) if not global_model or not has_capability(global_model, "image_gen"): raise ValueError(f"Global image generation model {config_id} not found") global_connection = _get_global_connection(global_model["connection_id"]) if not global_connection: raise ValueError(f"Global connection for image model {config_id} not found") model_string, resolved_kwargs = to_litellm( global_connection, global_model["model_id"], ) gen_kwargs.update(resolved_kwargs) # User model override if image_gen.model: model_string = image_gen.model response = await aimage_generation( prompt=image_gen.prompt, model=model_string, **gen_kwargs ) else: # Positive ID = Model + Connection result = await session.execute( select(Model) .options(selectinload(Model.connection)) .filter(Model.id == config_id, Model.enabled.is_(True)) ) db_model = result.scalars().first() if not db_model or not db_model.connection or not db_model.connection.enabled: raise ValueError(f"Image generation model {config_id} not found") conn = db_model.connection if conn.search_space_id is not None and conn.search_space_id != search_space.id: raise ValueError(f"Image generation model {config_id} not found") if conn.user_id is not None and conn.user_id != search_space.user_id: raise ValueError(f"Image generation model {config_id} not found") if not has_capability(db_model, "image_gen"): raise ValueError(f"Model {config_id} is not image-generation capable") model_string, resolved_kwargs = to_litellm( db_model.connection, db_model.model_id, ) gen_kwargs.update(resolved_kwargs) # User model override if image_gen.model: model_string = image_gen.model response = await aimage_generation( prompt=image_gen.prompt, model=model_string, **gen_kwargs ) # Store response response_dict = ( response.model_dump() if hasattr(response, "model_dump") else dict(response) ) if not image_gen.model and hasattr(response, "_hidden_params"): hidden = response._hidden_params if isinstance(hidden, dict) and hidden.get("model"): image_gen.model = hidden["model"] # Fix relative URLs in response data (for the serving endpoint) from urllib.parse import urlparse images = response_dict.get("data", []) provider_base_url = resolved_kwargs.get("api_base") for image in images: if image.get("url"): raw_url: str = image["url"] if raw_url.startswith("/") and provider_base_url: parsed = urlparse(provider_base_url) origin = f"{parsed.scheme}://{parsed.netloc}" image["url"] = f"{origin}{raw_url}" image_gen.response_data = response_dict # ============================================================================= # Image Generation Execution + Results CRUD # ============================================================================= @router.post("/image-generations", response_model=ImageGenerationRead) async def create_image_generation( data: ImageGenerationCreate, session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): user = auth.user """Create and execute an image generation request. Premium configs are gated by the user's shared premium credit pool. The flow is: 1. Permission check + load the search space (cheap, no provider call). 2. Resolve which config will run so we know its billing tier and the worst-case reservation size *before* opening any DB rows. 3. Wrap the entire ImageGeneration row insert + provider call in ``billable_call``. If quota is denied, ``billable_call`` raises ``QuotaInsufficientError`` *before* we flush a row, which we translate to HTTP 402 (no orphaned rows on the user's account, no inserted error rows for "you ran out of credit"). 4. On success, the actual ``response_cost`` flows through the LiteLLM callback into the accumulator, and ``billable_call`` finalizes the debit at exit. Inner ``try/except`` still catches provider errors and stores them on ``error_message`` (HTTP 200 with ``error_message`` set is preserved for failed-but-not-quota scenarios — clients already know how to surface those). """ try: await check_permission( session, auth, data.search_space_id, Permission.IMAGE_GENERATIONS_CREATE.value, "You don't have permission to create image generations in this search space", ) result = await session.execute( select(SearchSpace).filter(SearchSpace.id == data.search_space_id) ) search_space = result.scalars().first() if not search_space: raise HTTPException(status_code=404, detail="Search space not found") billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( session, data.image_gen_model_id, search_space ) # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError # propagates to the outer ``except QuotaInsufficientError`` handler # below as HTTP 402 — it is intentionally NOT swallowed into # ``error_message`` because that would (1) imply a successful row # exists when none does, and (2) return HTTP 200 to a client # whose request was actively *denied* (issue K). async with billable_call( user_id=search_space.user_id, search_space_id=data.search_space_id, billing_tier=billing_tier, base_model=base_model, quota_reserve_micros_override=reserve_micros, usage_type="image_generation", call_details={"model": base_model, "prompt": data.prompt[:100]}, ): db_image_gen = ImageGeneration( prompt=data.prompt, model=data.model, n=data.n, quality=data.quality, size=data.size, style=data.style, response_format=data.response_format, image_gen_model_id=data.image_gen_model_id, search_space_id=data.search_space_id, created_by_id=user.id, ) session.add(db_image_gen) await session.flush() try: await _execute_image_generation(session, db_image_gen, search_space) except Exception as e: logger.exception("Image generation call failed") db_image_gen.error_message = str(e) await session.commit() await session.refresh(db_image_gen) return db_image_gen except HTTPException: raise except QuotaInsufficientError as exc: # The user's premium credit pool is empty. No DB row is created # because ``billable_call`` denies before yielding (issue K). await session.rollback() raise HTTPException( status_code=402, detail={ "error_code": "premium_quota_exhausted", "usage_type": exc.usage_type, "balance_micros": exc.balance_micros, "remaining_micros": exc.remaining_micros, "message": ( "Out of credits for image generation. " "Purchase additional credits or switch to a free model." ), }, ) from exc except SQLAlchemyError: await session.rollback() raise HTTPException( status_code=500, detail="Database error during image generation" ) from None except Exception as e: await session.rollback() logger.exception("Failed to create image generation") raise HTTPException( status_code=500, detail=f"Image generation failed: {e!s}" ) from e @router.get("/image-generations", response_model=list[ImageGenerationListRead]) async def list_image_generations( search_space_id: int | None = None, skip: int = 0, limit: int = 50, session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): user = auth.user """List image generations.""" if skip < 0 or limit < 1: raise HTTPException(status_code=400, detail="Invalid pagination parameters") if limit > 100: limit = 100 try: if search_space_id is not None: await check_permission( session, auth, search_space_id, Permission.IMAGE_GENERATIONS_READ.value, "You don't have permission to read image generations in this search space", ) result = await session.execute( select(ImageGeneration) .filter(ImageGeneration.search_space_id == search_space_id) .order_by(ImageGeneration.created_at.desc()) .offset(skip) .limit(limit) ) else: result = await session.execute( select(ImageGeneration) .join(SearchSpace) .join(SearchSpaceMembership) .filter(SearchSpaceMembership.user_id == user.id) .order_by(ImageGeneration.created_at.desc()) .offset(skip) .limit(limit) ) return [ ImageGenerationListRead.from_orm_with_count(img) for img in result.scalars().all() ] except HTTPException: raise except SQLAlchemyError: raise HTTPException( status_code=500, detail="Database error fetching image generations" ) from None @router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead) async def get_image_generation( image_gen_id: int, session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): user = auth.user """Get a specific image generation by ID.""" try: result = await session.execute( select(ImageGeneration).filter(ImageGeneration.id == image_gen_id) ) image_gen = result.scalars().first() if not image_gen: raise HTTPException(status_code=404, detail="Image generation not found") await check_permission( session, auth, image_gen.search_space_id, Permission.IMAGE_GENERATIONS_READ.value, "You don't have permission to read image generations in this search space", ) return image_gen except HTTPException: raise except SQLAlchemyError: raise HTTPException( status_code=500, detail="Database error fetching image generation" ) from None @router.delete("/image-generations/{image_gen_id}", response_model=dict) async def delete_image_generation( image_gen_id: int, session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): user = auth.user """Delete an image generation record.""" try: result = await session.execute( select(ImageGeneration).filter(ImageGeneration.id == image_gen_id) ) db_image_gen = result.scalars().first() if not db_image_gen: raise HTTPException(status_code=404, detail="Image generation not found") await check_permission( session, auth, db_image_gen.search_space_id, Permission.IMAGE_GENERATIONS_DELETE.value, "You don't have permission to delete image generations in this search space", ) await session.delete(db_image_gen) await session.commit() return {"message": "Image generation deleted successfully"} except HTTPException: raise except SQLAlchemyError: await session.rollback() raise HTTPException( status_code=500, detail="Database error deleting image generation" ) from None # ============================================================================= # Image Serving (serves generated images from DB, protected by signed tokens) # ============================================================================= @router.get("/image-generations/{image_gen_id}/image") async def serve_generated_image( image_gen_id: int, token: str = Query(..., description="Signed access token"), index: int = 0, session: AsyncSession = Depends(get_async_session), ): """ Serve a generated image by ID, protected by a signed token. The token is generated when the image URL is created by the generate_image tool and encodes the image_gen_id, search_space_id, and an expiry timestamp. This ensures only users with access to the search space can view images, without requiring auth headers (which tags cannot pass). Args: image_gen_id: The image generation record ID token: HMAC-signed access token (included as query parameter) index: Which image to serve if multiple were generated (default: 0) """ try: result = await session.execute( select(ImageGeneration).filter(ImageGeneration.id == image_gen_id) ) image_gen = result.scalars().first() if not image_gen: raise HTTPException(status_code=404, detail="Image generation not found") # Verify the access token against the one stored on the record if not verify_image_token(image_gen.access_token, token): raise HTTPException(status_code=403, detail="Invalid image access token") if not image_gen.response_data: raise HTTPException(status_code=404, detail="No image data available") images = image_gen.response_data.get("data", []) if not images or index >= len(images): raise HTTPException( status_code=404, detail="Image not found at the specified index" ) image_entry = images[index] # If there's a URL, redirect to it if image_entry.get("url"): from fastapi.responses import RedirectResponse return RedirectResponse(url=image_entry["url"]) # If there's b64_json data, decode and serve it if image_entry.get("b64_json"): image_bytes = base64.b64decode(image_entry["b64_json"]) return Response( content=image_bytes, media_type="image/png", headers={ "Cache-Control": "public, max-age=86400", "Content-Disposition": f'inline; filename="generated-{image_gen_id}-{index}.png"', }, ) raise HTTPException(status_code=404, detail="No displayable image data") except HTTPException: raise except Exception as e: logger.exception("Failed to serve generated image") raise HTTPException( status_code=500, detail=f"Failed to serve image: {e!s}" ) from e